From fbb979537409b9154f0f2f99443eb325fe30cddc Mon Sep 17 00:00:00 2001 From: Eric Liang Date: Mon, 11 Jan 2021 14:53:12 -0800 Subject: [PATCH] [client] Report number of currently active clients on connect (#13326) * wip * update * update * reset worker * fix conn * fix * disable pycodestyle --- .flake8 | 2 ++ python/ray/tests/test_client.py | 26 +++++++++++++++++++ python/ray/util/client/__init__.py | 18 ++++++++++--- python/ray/util/client/dataclient.py | 7 +++++ python/ray/util/client/server/dataservicer.py | 13 ++++++++++ python/ray/util/client/worker.py | 7 +++++ python/ray/util/client_connect.py | 5 ++-- src/ray/protobuf/ray_client.proto | 10 +++++++ 8 files changed, 81 insertions(+), 7 deletions(-) diff --git a/.flake8 b/.flake8 index 7edb1b7d5..782615692 100644 --- a/.flake8 +++ b/.flake8 @@ -20,4 +20,6 @@ ignore = W503 W504 W605 + I + N avoid-escape = no diff --git a/python/ray/tests/test_client.py b/python/ray/tests/test_client.py index 645640b45..0b5c34e99 100644 --- a/python/ray/tests/test_client.py +++ b/python/ray/tests/test_client.py @@ -3,10 +3,36 @@ import time import sys import logging +import ray.util.client.server.server as ray_client_server +from ray.util.client import RayAPIStub from ray.util.client.common import ClientObjectRef from ray.util.client.ray_client_helpers import ray_start_client_server +def test_num_clients(shutdown_only): + # Tests num clients reporting; useful if you want to build an app that + # load balances clients between Ray client servers. + server = ray_client_server.serve("localhost:50051") + try: + api1 = RayAPIStub() + info1 = api1.connect("localhost:50051") + assert info1["num_clients"] == 1, info1 + api2 = RayAPIStub() + info2 = api2.connect("localhost:50051") + assert info2["num_clients"] == 2, info2 + + # Disconnect the first two clients. + api1.disconnect() + api2.disconnect() + time.sleep(1) + + api3 = RayAPIStub() + info3 = api3.connect("localhost:50051") + assert info3["num_clients"] == 1, info3 + finally: + server.stop(0) + + def test_real_ray_fallback(ray_start_regular_shared): with ray_start_client_server() as ray: diff --git a/python/ray/util/client/__init__.py b/python/ray/util/client/__init__.py index 3ad5bd639..22cf0af5b 100644 --- a/python/ray/util/client/__init__.py +++ b/python/ray/util/client/__init__.py @@ -1,4 +1,4 @@ -from typing import List, Tuple +from typing import List, Tuple, Dict, Any import logging @@ -24,13 +24,16 @@ class RayAPIStub: def connect(self, conn_str: str, secure: bool = False, - metadata: List[Tuple[str, str]] = None) -> None: + metadata: List[Tuple[str, str]] = None) -> Dict[str, Any]: """Connect the Ray Client to a server. Args: conn_str: Connection string, in the form "[host]:port" secure: Whether to use a TLS secured gRPC channel metadata: gRPC metadata to send on connect + + Returns: + Dictionary of connection info, e.g., {"num_clients": 1}. """ # Delay imports until connect to avoid circular imports. from ray.util.client.worker import Worker @@ -44,8 +47,15 @@ class RayAPIStub: # If we're calling a client connect specifically and we're not # currently in client mode, ensure we are. ray._private.client_mode_hook._explicitly_enable_client_mode() - self.client_worker = Worker(conn_str, secure=secure, metadata=metadata) - self.api.worker = self.client_worker + + try: + self.client_worker = Worker( + conn_str, secure=secure, metadata=metadata) + self.api.worker = self.client_worker + return self.client_worker.connection_info() + except Exception: + self.disconnect() + raise def disconnect(self): """Disconnect the Ray Client. diff --git a/python/ray/util/client/dataclient.py b/python/ray/util/client/dataclient.py index bba9f53f3..6e29ea927 100644 --- a/python/ray/util/client/dataclient.py +++ b/python/ray/util/client/dataclient.py @@ -93,6 +93,13 @@ class DataClient: del self.ready_data[req_id] return data + def ConnectionInfo(self, + context=None) -> ray_client_pb2.ConnectionInfoResponse: + datareq = ray_client_pb2.DataRequest( + connection_info=ray_client_pb2.ConnectionInfoRequest()) + resp = self._blocking_send(datareq) + return resp.connection_info + def GetObject(self, request: ray_client_pb2.GetRequest, context=None) -> ray_client_pb2.GetResponse: datareq = ray_client_pb2.DataRequest(get=request, ) diff --git a/python/ray/util/client/server/dataservicer.py b/python/ray/util/client/server/dataservicer.py index 1a014de6b..f80ef957a 100644 --- a/python/ray/util/client/server/dataservicer.py +++ b/python/ray/util/client/server/dataservicer.py @@ -2,6 +2,7 @@ import logging import grpc from typing import TYPE_CHECKING +from threading import Lock import ray.core.generated.ray_client_pb2 as ray_client_pb2 import ray.core.generated.ray_client_pb2_grpc as ray_client_pb2_grpc @@ -15,6 +16,8 @@ logger = logging.getLogger(__name__) class DataServicer(ray_client_pb2_grpc.RayletDataStreamerServicer): def __init__(self, basic_service: "RayletServicer"): self.basic_service = basic_service + self._clients_lock = Lock() + self._num_clients = 0 # guarded by self._clients_lock def Datapath(self, request_iterator, context): metadata = {k: v for k, v in context.invocation_metadata()} @@ -24,6 +27,8 @@ class DataServicer(ray_client_pb2_grpc.RayletDataStreamerServicer): return logger.info(f"New data connection from client {client_id}") try: + with self._clients_lock: + self._num_clients += 1 for req in request_iterator: resp = None req_type = req.WhichOneof("type") @@ -42,6 +47,12 @@ class DataServicer(ray_client_pb2_grpc.RayletDataStreamerServicer): released.append(rel) resp = ray_client_pb2.DataResponse( release=ray_client_pb2.ReleaseResponse(ok=released)) + elif req_type == "connection_info": + with self._clients_lock: + cur_num_clients = self._num_clients + info = ray_client_pb2.ConnectionInfoResponse( + num_clients=cur_num_clients) + resp = ray_client_pb2.DataResponse(connection_info=info) else: raise Exception(f"Unreachable code: Request type " f"{req_type} not handled in Datapath") @@ -52,3 +63,5 @@ class DataServicer(ray_client_pb2_grpc.RayletDataStreamerServicer): finally: logger.info(f"Lost data connection from client {client_id}") self.basic_service.release_all(client_id) + with self._clients_lock: + self._num_clients -= 1 diff --git a/python/ray/util/client/worker.py b/python/ray/util/client/worker.py index 683b59082..303d39c8d 100644 --- a/python/ray/util/client/worker.py +++ b/python/ray/util/client/worker.py @@ -59,6 +59,13 @@ class Worker: self.log_client.set_logstream_level(logging.INFO) self.closed = False + def connection_info(self): + try: + data = self.data_client.ConnectionInfo() + except grpc.RpcError as e: + raise e.details() + return {"num_clients": data.num_clients} + def get(self, vals, *, timeout: Optional[float] = None) -> Any: to_get = [] single = False diff --git a/python/ray/util/client_connect.py b/python/ray/util/client_connect.py index e8525e17d..3f33933c4 100644 --- a/python/ray/util/client_connect.py +++ b/python/ray/util/client_connect.py @@ -3,13 +3,12 @@ from ray.util.client import ray from ray._private.client_mode_hook import _enable_client_hook from ray._private.client_mode_hook import _explicitly_enable_client_mode -from typing import List -from typing import Tuple +from typing import List, Tuple, Dict, Any def connect(conn_str: str, secure: bool = False, - metadata: List[Tuple[str, str]] = None) -> None: + metadata: List[Tuple[str, str]] = None) -> Dict[str, Any]: if ray.is_connected(): raise RuntimeError("Ray Client is already connected. " "Maybe you called ray.util.connect twice by " diff --git a/src/ray/protobuf/ray_client.proto b/src/ray/protobuf/ray_client.proto index 3dd3128b2..49ec8475e 100644 --- a/src/ray/protobuf/ray_client.proto +++ b/src/ray/protobuf/ray_client.proto @@ -204,6 +204,14 @@ message ReleaseResponse { repeated bool ok = 2; } +message ConnectionInfoRequest { +} + +message ConnectionInfoResponse { + // The number of data clients connected to the server, including the caller. + int32 num_clients = 1; +} + message DataRequest { // An incrementing counter of request IDs on the Datapath, // to match requests with responses asynchronously. @@ -212,6 +220,7 @@ message DataRequest { GetRequest get = 2; PutRequest put = 3; ReleaseRequest release = 4; + ConnectionInfoRequest connection_info = 5; } } @@ -222,6 +231,7 @@ message DataResponse { GetResponse get = 2; PutResponse put = 3; ReleaseResponse release = 4; + ConnectionInfoResponse connection_info = 5; } }