[client] Report number of currently active clients on connect (#13326)

* wip

* update

* update

* reset worker

* fix conn

* fix

* disable pycodestyle
This commit is contained in:
Eric Liang
2021-01-11 14:53:12 -08:00
committed by GitHub
parent e2b2abb88b
commit fbb9795374
8 changed files with 81 additions and 7 deletions
+2
View File
@@ -20,4 +20,6 @@ ignore =
W503
W504
W605
I
N
avoid-escape = no
+26
View File
@@ -3,10 +3,36 @@ import time
import sys
import logging
import ray.util.client.server.server as ray_client_server
from ray.util.client import RayAPIStub
from ray.util.client.common import ClientObjectRef
from ray.util.client.ray_client_helpers import ray_start_client_server
def test_num_clients(shutdown_only):
# Tests num clients reporting; useful if you want to build an app that
# load balances clients between Ray client servers.
server = ray_client_server.serve("localhost:50051")
try:
api1 = RayAPIStub()
info1 = api1.connect("localhost:50051")
assert info1["num_clients"] == 1, info1
api2 = RayAPIStub()
info2 = api2.connect("localhost:50051")
assert info2["num_clients"] == 2, info2
# Disconnect the first two clients.
api1.disconnect()
api2.disconnect()
time.sleep(1)
api3 = RayAPIStub()
info3 = api3.connect("localhost:50051")
assert info3["num_clients"] == 1, info3
finally:
server.stop(0)
def test_real_ray_fallback(ray_start_regular_shared):
with ray_start_client_server() as ray:
+13 -3
View File
@@ -1,4 +1,4 @@
from typing import List, Tuple
from typing import List, Tuple, Dict, Any
import logging
@@ -24,13 +24,16 @@ class RayAPIStub:
def connect(self,
conn_str: str,
secure: bool = False,
metadata: List[Tuple[str, str]] = None) -> None:
metadata: List[Tuple[str, str]] = None) -> Dict[str, Any]:
"""Connect the Ray Client to a server.
Args:
conn_str: Connection string, in the form "[host]:port"
secure: Whether to use a TLS secured gRPC channel
metadata: gRPC metadata to send on connect
Returns:
Dictionary of connection info, e.g., {"num_clients": 1}.
"""
# Delay imports until connect to avoid circular imports.
from ray.util.client.worker import Worker
@@ -44,8 +47,15 @@ class RayAPIStub:
# If we're calling a client connect specifically and we're not
# currently in client mode, ensure we are.
ray._private.client_mode_hook._explicitly_enable_client_mode()
self.client_worker = Worker(conn_str, secure=secure, metadata=metadata)
try:
self.client_worker = Worker(
conn_str, secure=secure, metadata=metadata)
self.api.worker = self.client_worker
return self.client_worker.connection_info()
except Exception:
self.disconnect()
raise
def disconnect(self):
"""Disconnect the Ray Client.
+7
View File
@@ -93,6 +93,13 @@ class DataClient:
del self.ready_data[req_id]
return data
def ConnectionInfo(self,
context=None) -> ray_client_pb2.ConnectionInfoResponse:
datareq = ray_client_pb2.DataRequest(
connection_info=ray_client_pb2.ConnectionInfoRequest())
resp = self._blocking_send(datareq)
return resp.connection_info
def GetObject(self, request: ray_client_pb2.GetRequest,
context=None) -> ray_client_pb2.GetResponse:
datareq = ray_client_pb2.DataRequest(get=request, )
@@ -2,6 +2,7 @@ import logging
import grpc
from typing import TYPE_CHECKING
from threading import Lock
import ray.core.generated.ray_client_pb2 as ray_client_pb2
import ray.core.generated.ray_client_pb2_grpc as ray_client_pb2_grpc
@@ -15,6 +16,8 @@ logger = logging.getLogger(__name__)
class DataServicer(ray_client_pb2_grpc.RayletDataStreamerServicer):
def __init__(self, basic_service: "RayletServicer"):
self.basic_service = basic_service
self._clients_lock = Lock()
self._num_clients = 0 # guarded by self._clients_lock
def Datapath(self, request_iterator, context):
metadata = {k: v for k, v in context.invocation_metadata()}
@@ -24,6 +27,8 @@ class DataServicer(ray_client_pb2_grpc.RayletDataStreamerServicer):
return
logger.info(f"New data connection from client {client_id}")
try:
with self._clients_lock:
self._num_clients += 1
for req in request_iterator:
resp = None
req_type = req.WhichOneof("type")
@@ -42,6 +47,12 @@ class DataServicer(ray_client_pb2_grpc.RayletDataStreamerServicer):
released.append(rel)
resp = ray_client_pb2.DataResponse(
release=ray_client_pb2.ReleaseResponse(ok=released))
elif req_type == "connection_info":
with self._clients_lock:
cur_num_clients = self._num_clients
info = ray_client_pb2.ConnectionInfoResponse(
num_clients=cur_num_clients)
resp = ray_client_pb2.DataResponse(connection_info=info)
else:
raise Exception(f"Unreachable code: Request type "
f"{req_type} not handled in Datapath")
@@ -52,3 +63,5 @@ class DataServicer(ray_client_pb2_grpc.RayletDataStreamerServicer):
finally:
logger.info(f"Lost data connection from client {client_id}")
self.basic_service.release_all(client_id)
with self._clients_lock:
self._num_clients -= 1
+7
View File
@@ -59,6 +59,13 @@ class Worker:
self.log_client.set_logstream_level(logging.INFO)
self.closed = False
def connection_info(self):
try:
data = self.data_client.ConnectionInfo()
except grpc.RpcError as e:
raise e.details()
return {"num_clients": data.num_clients}
def get(self, vals, *, timeout: Optional[float] = None) -> Any:
to_get = []
single = False
+2 -3
View File
@@ -3,13 +3,12 @@ from ray.util.client import ray
from ray._private.client_mode_hook import _enable_client_hook
from ray._private.client_mode_hook import _explicitly_enable_client_mode
from typing import List
from typing import Tuple
from typing import List, Tuple, Dict, Any
def connect(conn_str: str,
secure: bool = False,
metadata: List[Tuple[str, str]] = None) -> None:
metadata: List[Tuple[str, str]] = None) -> Dict[str, Any]:
if ray.is_connected():
raise RuntimeError("Ray Client is already connected. "
"Maybe you called ray.util.connect twice by "
+10
View File
@@ -204,6 +204,14 @@ message ReleaseResponse {
repeated bool ok = 2;
}
message ConnectionInfoRequest {
}
message ConnectionInfoResponse {
// The number of data clients connected to the server, including the caller.
int32 num_clients = 1;
}
message DataRequest {
// An incrementing counter of request IDs on the Datapath,
// to match requests with responses asynchronously.
@@ -212,6 +220,7 @@ message DataRequest {
GetRequest get = 2;
PutRequest put = 3;
ReleaseRequest release = 4;
ConnectionInfoRequest connection_info = 5;
}
}
@@ -222,6 +231,7 @@ message DataResponse {
GetResponse get = 2;
PutResponse put = 3;
ReleaseResponse release = 4;
ConnectionInfoResponse connection_info = 5;
}
}