mirror of
https://github.com/wassname/ray.git
synced 2026-06-27 15:16:34 +08:00
[client] Report number of currently active clients on connect (#13326)
* wip * update * update * reset worker * fix conn * fix * disable pycodestyle
This commit is contained in:
@@ -3,10 +3,36 @@ import time
|
||||
import sys
|
||||
import logging
|
||||
|
||||
import ray.util.client.server.server as ray_client_server
|
||||
from ray.util.client import RayAPIStub
|
||||
from ray.util.client.common import ClientObjectRef
|
||||
from ray.util.client.ray_client_helpers import ray_start_client_server
|
||||
|
||||
|
||||
def test_num_clients(shutdown_only):
|
||||
# Tests num clients reporting; useful if you want to build an app that
|
||||
# load balances clients between Ray client servers.
|
||||
server = ray_client_server.serve("localhost:50051")
|
||||
try:
|
||||
api1 = RayAPIStub()
|
||||
info1 = api1.connect("localhost:50051")
|
||||
assert info1["num_clients"] == 1, info1
|
||||
api2 = RayAPIStub()
|
||||
info2 = api2.connect("localhost:50051")
|
||||
assert info2["num_clients"] == 2, info2
|
||||
|
||||
# Disconnect the first two clients.
|
||||
api1.disconnect()
|
||||
api2.disconnect()
|
||||
time.sleep(1)
|
||||
|
||||
api3 = RayAPIStub()
|
||||
info3 = api3.connect("localhost:50051")
|
||||
assert info3["num_clients"] == 1, info3
|
||||
finally:
|
||||
server.stop(0)
|
||||
|
||||
|
||||
def test_real_ray_fallback(ray_start_regular_shared):
|
||||
with ray_start_client_server() as ray:
|
||||
|
||||
|
||||
@@ -1,4 +1,4 @@
|
||||
from typing import List, Tuple
|
||||
from typing import List, Tuple, Dict, Any
|
||||
|
||||
import logging
|
||||
|
||||
@@ -24,13 +24,16 @@ class RayAPIStub:
|
||||
def connect(self,
|
||||
conn_str: str,
|
||||
secure: bool = False,
|
||||
metadata: List[Tuple[str, str]] = None) -> None:
|
||||
metadata: List[Tuple[str, str]] = None) -> Dict[str, Any]:
|
||||
"""Connect the Ray Client to a server.
|
||||
|
||||
Args:
|
||||
conn_str: Connection string, in the form "[host]:port"
|
||||
secure: Whether to use a TLS secured gRPC channel
|
||||
metadata: gRPC metadata to send on connect
|
||||
|
||||
Returns:
|
||||
Dictionary of connection info, e.g., {"num_clients": 1}.
|
||||
"""
|
||||
# Delay imports until connect to avoid circular imports.
|
||||
from ray.util.client.worker import Worker
|
||||
@@ -44,8 +47,15 @@ class RayAPIStub:
|
||||
# If we're calling a client connect specifically and we're not
|
||||
# currently in client mode, ensure we are.
|
||||
ray._private.client_mode_hook._explicitly_enable_client_mode()
|
||||
self.client_worker = Worker(conn_str, secure=secure, metadata=metadata)
|
||||
self.api.worker = self.client_worker
|
||||
|
||||
try:
|
||||
self.client_worker = Worker(
|
||||
conn_str, secure=secure, metadata=metadata)
|
||||
self.api.worker = self.client_worker
|
||||
return self.client_worker.connection_info()
|
||||
except Exception:
|
||||
self.disconnect()
|
||||
raise
|
||||
|
||||
def disconnect(self):
|
||||
"""Disconnect the Ray Client.
|
||||
|
||||
@@ -93,6 +93,13 @@ class DataClient:
|
||||
del self.ready_data[req_id]
|
||||
return data
|
||||
|
||||
def ConnectionInfo(self,
|
||||
context=None) -> ray_client_pb2.ConnectionInfoResponse:
|
||||
datareq = ray_client_pb2.DataRequest(
|
||||
connection_info=ray_client_pb2.ConnectionInfoRequest())
|
||||
resp = self._blocking_send(datareq)
|
||||
return resp.connection_info
|
||||
|
||||
def GetObject(self, request: ray_client_pb2.GetRequest,
|
||||
context=None) -> ray_client_pb2.GetResponse:
|
||||
datareq = ray_client_pb2.DataRequest(get=request, )
|
||||
|
||||
@@ -2,6 +2,7 @@ import logging
|
||||
import grpc
|
||||
|
||||
from typing import TYPE_CHECKING
|
||||
from threading import Lock
|
||||
|
||||
import ray.core.generated.ray_client_pb2 as ray_client_pb2
|
||||
import ray.core.generated.ray_client_pb2_grpc as ray_client_pb2_grpc
|
||||
@@ -15,6 +16,8 @@ logger = logging.getLogger(__name__)
|
||||
class DataServicer(ray_client_pb2_grpc.RayletDataStreamerServicer):
|
||||
def __init__(self, basic_service: "RayletServicer"):
|
||||
self.basic_service = basic_service
|
||||
self._clients_lock = Lock()
|
||||
self._num_clients = 0 # guarded by self._clients_lock
|
||||
|
||||
def Datapath(self, request_iterator, context):
|
||||
metadata = {k: v for k, v in context.invocation_metadata()}
|
||||
@@ -24,6 +27,8 @@ class DataServicer(ray_client_pb2_grpc.RayletDataStreamerServicer):
|
||||
return
|
||||
logger.info(f"New data connection from client {client_id}")
|
||||
try:
|
||||
with self._clients_lock:
|
||||
self._num_clients += 1
|
||||
for req in request_iterator:
|
||||
resp = None
|
||||
req_type = req.WhichOneof("type")
|
||||
@@ -42,6 +47,12 @@ class DataServicer(ray_client_pb2_grpc.RayletDataStreamerServicer):
|
||||
released.append(rel)
|
||||
resp = ray_client_pb2.DataResponse(
|
||||
release=ray_client_pb2.ReleaseResponse(ok=released))
|
||||
elif req_type == "connection_info":
|
||||
with self._clients_lock:
|
||||
cur_num_clients = self._num_clients
|
||||
info = ray_client_pb2.ConnectionInfoResponse(
|
||||
num_clients=cur_num_clients)
|
||||
resp = ray_client_pb2.DataResponse(connection_info=info)
|
||||
else:
|
||||
raise Exception(f"Unreachable code: Request type "
|
||||
f"{req_type} not handled in Datapath")
|
||||
@@ -52,3 +63,5 @@ class DataServicer(ray_client_pb2_grpc.RayletDataStreamerServicer):
|
||||
finally:
|
||||
logger.info(f"Lost data connection from client {client_id}")
|
||||
self.basic_service.release_all(client_id)
|
||||
with self._clients_lock:
|
||||
self._num_clients -= 1
|
||||
|
||||
@@ -59,6 +59,13 @@ class Worker:
|
||||
self.log_client.set_logstream_level(logging.INFO)
|
||||
self.closed = False
|
||||
|
||||
def connection_info(self):
|
||||
try:
|
||||
data = self.data_client.ConnectionInfo()
|
||||
except grpc.RpcError as e:
|
||||
raise e.details()
|
||||
return {"num_clients": data.num_clients}
|
||||
|
||||
def get(self, vals, *, timeout: Optional[float] = None) -> Any:
|
||||
to_get = []
|
||||
single = False
|
||||
|
||||
@@ -3,13 +3,12 @@ from ray.util.client import ray
|
||||
from ray._private.client_mode_hook import _enable_client_hook
|
||||
from ray._private.client_mode_hook import _explicitly_enable_client_mode
|
||||
|
||||
from typing import List
|
||||
from typing import Tuple
|
||||
from typing import List, Tuple, Dict, Any
|
||||
|
||||
|
||||
def connect(conn_str: str,
|
||||
secure: bool = False,
|
||||
metadata: List[Tuple[str, str]] = None) -> None:
|
||||
metadata: List[Tuple[str, str]] = None) -> Dict[str, Any]:
|
||||
if ray.is_connected():
|
||||
raise RuntimeError("Ray Client is already connected. "
|
||||
"Maybe you called ray.util.connect twice by "
|
||||
|
||||
@@ -204,6 +204,14 @@ message ReleaseResponse {
|
||||
repeated bool ok = 2;
|
||||
}
|
||||
|
||||
message ConnectionInfoRequest {
|
||||
}
|
||||
|
||||
message ConnectionInfoResponse {
|
||||
// The number of data clients connected to the server, including the caller.
|
||||
int32 num_clients = 1;
|
||||
}
|
||||
|
||||
message DataRequest {
|
||||
// An incrementing counter of request IDs on the Datapath,
|
||||
// to match requests with responses asynchronously.
|
||||
@@ -212,6 +220,7 @@ message DataRequest {
|
||||
GetRequest get = 2;
|
||||
PutRequest put = 3;
|
||||
ReleaseRequest release = 4;
|
||||
ConnectionInfoRequest connection_info = 5;
|
||||
}
|
||||
}
|
||||
|
||||
@@ -222,6 +231,7 @@ message DataResponse {
|
||||
GetResponse get = 2;
|
||||
PutResponse put = 3;
|
||||
ReleaseResponse release = 4;
|
||||
ConnectionInfoResponse connection_info = 5;
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
Reference in New Issue
Block a user