diff --git a/python/ray/tests/test_client_init.py b/python/ray/tests/test_client_init.py index 1949fe3fd..0c54f93ea 100644 --- a/python/ray/tests/test_client_init.py +++ b/python/ray/tests/test_client_init.py @@ -1,7 +1,11 @@ """Client tests that run their own init (as with init_and_serve) live here""" +import pytest + import time +import sys import ray.util.client.server.server as ray_client_server +import ray.core.generated.ray_client_pb2 as ray_client_pb2 from ray.util.client import RayAPIStub @@ -9,7 +13,8 @@ from ray.util.client import RayAPIStub def test_num_clients(): # Tests num clients reporting; useful if you want to build an app that # load balances clients between Ray client servers. - server, _ = ray_client_server.init_and_serve("localhost:50051") + server_handle, _ = ray_client_server.init_and_serve("localhost:50051") + server = server_handle.grpc_server try: api1 = RayAPIStub() info1 = api1.connect("localhost:50051") @@ -35,3 +40,39 @@ def test_num_clients(): finally: ray_client_server.shutdown_with_server(server) time.sleep(2) + + +def test_python_version(): + + server_handle, _ = ray_client_server.init_and_serve("localhost:50051") + try: + ray = RayAPIStub() + info1 = ray.connect("localhost:50051") + assert info1["python_version"] == ".".join( + [str(x) for x in list(sys.version_info)[:3]]) + ray.disconnect() + time.sleep(1) + + def mock_connection_response(): + return ray_client_pb2.ConnectionInfoResponse( + num_clients=1, + python_version="2.7.12", + ray_version="", + ray_commit="", + ) + + # inject mock connection function + server_handle.data_servicer._build_connection_response = \ + mock_connection_response + + ray = RayAPIStub() + with pytest.raises(RuntimeError): + _ = ray.connect("localhost:50051") + + ray = RayAPIStub() + info3 = ray.connect("localhost:50051", ignore_version=True) + assert info3["num_clients"] == 1, info3 + ray.disconnect() + finally: + ray_client_server.shutdown_with_server(server_handle.grpc_server) + time.sleep(2) diff --git a/python/ray/tests/test_client_references.py b/python/ray/tests/test_client_references.py index 834fadfcf..8a4458e14 100644 --- a/python/ray/tests/test_client_references.py +++ b/python/ray/tests/test_client_references.py @@ -1,39 +1,38 @@ from ray.util.client.ray_client_helpers import ray_start_client_server +from ray.util.client.ray_client_helpers import ray_start_client_server_pair from ray.test_utils import wait_for_condition import ray as real_ray from ray.core.generated.gcs_pb2 import ActorTableData -from ray.util.client.server.server import _get_current_servicer -def server_object_ref_count(n): - server = _get_current_servicer() +def server_object_ref_count(server, n): assert server is not None def test_cond(): - if len(server.object_refs) == 0: + if len(server.task_servicer.object_refs) == 0: # No open clients return n == 0 - client_id = list(server.object_refs.keys())[0] - return len(server.object_refs[client_id]) == n + client_id = list(server.task_servicer.object_refs.keys())[0] + return len(server.task_servicer.object_refs[client_id]) == n return test_cond -def server_actor_ref_count(n): - server = _get_current_servicer() +def server_actor_ref_count(server, n): assert server is not None def test_cond(): - if len(server.actor_refs) == 0: + if len(server.task_servicer.actor_refs) == 0: # No running actors return n == 0 - return len(server.actor_refs) == n + return len(server.task_servicer.actor_refs) == n return test_cond def test_delete_refs_on_disconnect(ray_start_regular): - with ray_start_client_server() as ray: + with ray_start_client_server_pair() as pair: + ray, server = pair @ray.remote def f(x): @@ -46,14 +45,14 @@ def test_delete_refs_on_disconnect(ray_start_regular): # in a different category, according to the raylet. assert len(real_ray.objects()) == 2 # But we're maintaining the reference - assert server_object_ref_count(3)() + assert server_object_ref_count(server, 3)() # And can get the data assert ray.get(thing1) == 8 # Close the client ray.close() - wait_for_condition(server_object_ref_count(0), timeout=5) + wait_for_condition(server_object_ref_count(server, 0), timeout=5) def test_cond(): return len(real_ray.objects()) == 0 @@ -62,7 +61,8 @@ def test_delete_refs_on_disconnect(ray_start_regular): def test_delete_ref_on_object_deletion(ray_start_regular): - with ray_start_client_server() as ray: + with ray_start_client_server_pair() as pair: + ray, server = pair vals = { "ref": ray.put("Hello World"), "ref2": ray.put("This value stays"), @@ -70,11 +70,12 @@ def test_delete_ref_on_object_deletion(ray_start_regular): del vals["ref"] - wait_for_condition(server_object_ref_count(1), timeout=5) + wait_for_condition(server_object_ref_count(server, 1), timeout=5) def test_delete_actor_on_disconnect(ray_start_regular): - with ray_start_client_server() as ray: + with ray_start_client_server_pair() as pair: + ray, server = pair @ray.remote class Accumulator: @@ -90,13 +91,13 @@ def test_delete_actor_on_disconnect(ray_start_regular): actor = Accumulator.remote() actor.inc.remote() - assert server_actor_ref_count(1)() + assert server_actor_ref_count(server, 1)() assert ray.get(actor.get.remote()) == 1 ray.close() - wait_for_condition(server_actor_ref_count(0), timeout=5) + wait_for_condition(server_actor_ref_count(server, 0), timeout=5) def test_cond(): alive_actors = [ @@ -109,7 +110,8 @@ def test_delete_actor_on_disconnect(ray_start_regular): def test_delete_actor(ray_start_regular): - with ray_start_client_server() as ray: + with ray_start_client_server_pair() as pair: + ray, server = pair @ray.remote class Accumulator: @@ -124,11 +126,11 @@ def test_delete_actor(ray_start_regular): actor2 = Accumulator.remote() actor2.inc.remote() - assert server_actor_ref_count(2)() + assert server_actor_ref_count(server, 2)() del actor - wait_for_condition(server_actor_ref_count(1), timeout=5) + wait_for_condition(server_actor_ref_count(server, 1), timeout=5) def test_simple_multiple_references(ray_start_regular): diff --git a/python/ray/util/client/__init__.py b/python/ray/util/client/__init__.py index 1c28dc53c..9a2d14877 100644 --- a/python/ray/util/client/__init__.py +++ b/python/ray/util/client/__init__.py @@ -1,5 +1,6 @@ from typing import List, Tuple, Dict, Any +import sys import logging logger = logging.getLogger(__name__) @@ -25,7 +26,9 @@ class RayAPIStub: conn_str: str, secure: bool = False, metadata: List[Tuple[str, str]] = None, - connection_retries: int = 3) -> Dict[str, Any]: + connection_retries: int = 3, + *, + ignore_version: bool = False) -> Dict[str, Any]: """Connect the Ray Client to a server. Args: @@ -56,11 +59,25 @@ class RayAPIStub: metadata=metadata, connection_retries=connection_retries) self.api.worker = self.client_worker - return self.client_worker.connection_info() + conn_info = self.client_worker.connection_info() + self._check_versions(conn_info, ignore_version) + return conn_info except Exception: self.disconnect() raise + def _check_versions(self, conn_info, ignore_version: bool) -> None: + local_major_minor = f"{sys.version_info[0]}.{sys.version_info[1]}" + if not conn_info["python_version"].startswith(local_major_minor): + version_str = f"{local_major_minor}.{sys.version_info[2]}" + msg = "Python minor versions differ between client and server:" + \ + f" client is {version_str}," + \ + f" server is {conn_info['python_version']}" + if ignore_version: + logger.warning(msg) + else: + raise RuntimeError(msg) + def disconnect(self): """Disconnect the Ray Client. """ @@ -97,8 +114,9 @@ class RayAPIStub: if self._server is not None: raise Exception("Trying to start two instances of ray via client") import ray.util.client.server.server as ray_client_server - self._server, address_info = ray_client_server.init_and_serve( + server_handle, address_info = ray_client_server.init_and_serve( "localhost:50051", *args, **kwargs) + self._server = server_handle.grpc_server self.connect("localhost:50051") self._connected_with_init = True return address_info diff --git a/python/ray/util/client/ray_client_helpers.py b/python/ray/util/client/ray_client_helpers.py index be5a2918c..77f09346d 100644 --- a/python/ray/util/client/ray_client_helpers.py +++ b/python/ray/util/client/ray_client_helpers.py @@ -6,11 +6,18 @@ from ray.util.client import ray @contextmanager def ray_start_client_server(): + with ray_start_client_server_pair() as pair: + client, server = pair + yield client + + +@contextmanager +def ray_start_client_server_pair(): ray._inside_client_test = True server = ray_client_server.serve("localhost:50051") ray.connect("localhost:50051") try: - yield ray + yield ray, server finally: ray._inside_client_test = False ray.disconnect() diff --git a/python/ray/util/client/server/dataservicer.py b/python/ray/util/client/server/dataservicer.py index 7a7fb3eae..a01369e43 100644 --- a/python/ray/util/client/server/dataservicer.py +++ b/python/ray/util/client/server/dataservicer.py @@ -50,16 +50,8 @@ class DataServicer(ray_client_pb2_grpc.RayletDataStreamerServicer): resp = ray_client_pb2.DataResponse( release=ray_client_pb2.ReleaseResponse(ok=released)) elif req_type == "connection_info": - with self._clients_lock: - cur_num_clients = self._num_clients - info = ray_client_pb2.ConnectionInfoResponse( - num_clients=cur_num_clients, - python_version="{}.{}.{}".format( - sys.version_info[0], sys.version_info[1], - sys.version_info[2]), - ray_version=ray.__version__, - ray_commit=ray.__commit__) - resp = ray_client_pb2.DataResponse(connection_info=info) + resp = ray_client_pb2.DataResponse( + connection_info=self._build_connection_response()) else: raise Exception(f"Unreachable code: Request type " f"{req_type} not handled in Datapath") @@ -72,3 +64,13 @@ class DataServicer(ray_client_pb2_grpc.RayletDataStreamerServicer): self.basic_service.release_all(client_id) with self._clients_lock: self._num_clients -= 1 + + def _build_connection_response(self): + with self._clients_lock: + cur_num_clients = self._num_clients + return ray_client_pb2.ConnectionInfoResponse( + num_clients=cur_num_clients, + python_version="{}.{}.{}".format( + sys.version_info[0], sys.version_info[1], sys.version_info[2]), + ray_version=ray.__version__, + ray_commit=ray.__commit__) diff --git a/python/ray/util/client/server/server.py b/python/ray/util/client/server/server.py index 19a192337..6a7badaf7 100644 --- a/python/ray/util/client/server/server.py +++ b/python/ray/util/client/server/server.py @@ -3,6 +3,7 @@ from concurrent import futures import grpc import base64 from collections import defaultdict +from dataclasses import dataclass from typing import Any from typing import Dict @@ -407,13 +408,18 @@ def decode_options( return opts -_current_servicer: Optional[RayletServicer] = None +@dataclass +class ClientServerHandle: + """Holds the handles to the registered gRPC servicers and their server.""" + task_servicer: RayletServicer + data_servicer: DataServicer + logs_servicer: LogstreamServicer + grpc_server: grpc.Server - -# Used by tests to peek inside the servicer -def _get_current_servicer(): - global _current_servicer - return _current_servicer + # Add a hook for all the cases that previously + # expected simply a gRPC server + def __getattr__(self, attr): + return getattr(self.grpc_server, attr) def serve(connection_str): @@ -421,8 +427,6 @@ def serve(connection_str): task_servicer = RayletServicer() data_servicer = DataServicer(task_servicer) logs_servicer = LogstreamServicer() - global _current_servicer - _current_servicer = task_servicer ray_client_pb2_grpc.add_RayletDriverServicer_to_server( task_servicer, server) ray_client_pb2_grpc.add_RayletDataStreamerServicer_to_server( @@ -430,16 +434,22 @@ def serve(connection_str): ray_client_pb2_grpc.add_RayletLogStreamerServicer_to_server( logs_servicer, server) server.add_insecure_port(connection_str) + current_handle = ClientServerHandle( + task_servicer=task_servicer, + data_servicer=data_servicer, + logs_servicer=logs_servicer, + grpc_server=server, + ) server.start() - return server + return current_handle def init_and_serve(connection_str, *args, **kwargs): with disable_client_hook(): # Disable client mode inside the worker's environment info = ray.init(*args, **kwargs) - server = serve(connection_str) - return (server, info) + server_handle = serve(connection_str) + return (server_handle, info) def shutdown_with_server(server, _exiting_interpreter=False):