From 3a230fa1a439a7c6b56099d450faf5702ac5b4ae Mon Sep 17 00:00:00 2001 From: Richard Liaw Date: Sun, 7 Feb 2021 13:11:38 -0800 Subject: [PATCH] [ray_client] close ray connection upon client deactivation (#13919) --- ci/travis/ci.sh | 1 + python/ray/tests/BUILD | 4 +- python/ray/tests/test_client_init.py | 228 ++++++++++-------- python/ray/tests/test_job.py | 4 +- python/ray/util/client/server/dataservicer.py | 15 +- python/ray/util/client/server/server.py | 51 +++- python/ray/util/client/worker.py | 31 ++- 7 files changed, 198 insertions(+), 136 deletions(-) diff --git a/ci/travis/ci.sh b/ci/travis/ci.sh index 2d381ba24..61b74b082 100755 --- a/ci/travis/ci.sh +++ b/ci/travis/ci.sh @@ -152,6 +152,7 @@ test_python() { -python/ray/tests:test_basic_3 # timeout -python/ray/tests:test_basic_3_client_mode -python/ray/tests:test_cli + -python/ray/tests:test_client_init # timeout -python/ray/tests:test_failure -python/ray/tests:test_global_gc -python/ray/tests:test_job diff --git a/python/ray/tests/BUILD b/python/ray/tests/BUILD index 4ef81d504..2572c50c2 100644 --- a/python/ray/tests/BUILD +++ b/python/ray/tests/BUILD @@ -26,6 +26,8 @@ py_test_module_list( "test_basic_3.py", "test_cancel.py", "test_cli.py", + "test_client.py", + "test_client_init.py", "test_component_failures_2.py", "test_component_failures_3.py", "test_error_ray_not_initialized.py", @@ -80,9 +82,7 @@ py_test_module_list( "test_asyncio.py", "test_autoscaler.py", "test_autoscaler_yaml.py", - "test_client_init.py", "test_client_metadata.py", - "test_client.py", "test_client_references.py", "test_client_terminate.py", "test_command_runner.py", diff --git a/python/ray/tests/test_client_init.py b/python/ray/tests/test_client_init.py index 6b6ce8a42..8053ab577 100644 --- a/python/ray/tests/test_client_init.py +++ b/python/ray/tests/test_client_init.py @@ -38,130 +38,146 @@ class C: return self.val -def test_basic_preregister(): +@pytest.fixture +def init_and_serve(): + server_handle, _ = ray_client_server.init_and_serve("localhost:50051") + yield server_handle + ray_client_server.shutdown_with_server(server_handle.grpc_server) + time.sleep(2) + + +@pytest.fixture +def init_and_serve_lazy(): + cluster = ray.cluster_utils.Cluster() + cluster.add_node(num_cpus=1, num_gpus=0) + address = cluster.address + + def connect(): + ray.init(address=address) + + server_handle = ray_client_server.serve("localhost:50051", connect) + yield server_handle + ray_client_server.shutdown_with_server(server_handle.grpc_server) + time.sleep(2) + + +def test_basic_preregister(init_and_serve): from ray.util.client import ray - server, _ = ray_client_server.init_and_serve("localhost:50051") - try: - ray.connect("localhost:50051") - val = ray.get(hello_world.remote()) - print(val) - assert val >= 20 - assert val <= 200 - c = C.remote(3) - x = c.double.remote() - y = c.double.remote() - ray.wait([x, y]) - val = ray.get(c.get.remote()) - assert val == 12 - finally: - ray.disconnect() - ray_client_server.shutdown_with_server(server) - time.sleep(2) + ray.connect("localhost:50051") + val = ray.get(hello_world.remote()) + print(val) + assert val >= 20 + assert val <= 200 + c = C.remote(3) + x = c.double.remote() + y = c.double.remote() + ray.wait([x, y]) + val = ray.get(c.get.remote()) + assert val == 12 + ray.disconnect() -def test_num_clients(): +def test_num_clients(init_and_serve_lazy): # Tests num clients reporting; useful if you want to build an app that # load balances clients between Ray client servers. - server_handle, _ = ray_client_server.init_and_serve("localhost:50051") - server = server_handle.grpc_server - try: - api1 = RayAPIStub() - info1 = api1.connect("localhost:50051") - assert info1["num_clients"] == 1, info1 - api2 = RayAPIStub() - info2 = api2.connect("localhost:50051") - assert info2["num_clients"] == 2, info2 - # Disconnect the first two clients. - api1.disconnect() - api2.disconnect() - time.sleep(1) + def get_job_id(api): + return api.get_runtime_context().worker.current_job_id - api3 = RayAPIStub() - info3 = api3.connect("localhost:50051") - assert info3["num_clients"] == 1, info3 + api1 = RayAPIStub() + info1 = api1.connect("localhost:50051") + job_id_1 = get_job_id(api1) + assert info1["num_clients"] == 1, info1 + api2 = RayAPIStub() + info2 = api2.connect("localhost:50051") + job_id_2 = get_job_id(api2) + assert info2["num_clients"] == 2, info2 - # Check info contains ray and python version. - assert isinstance(info3["ray_version"], str), info3 - assert isinstance(info3["ray_commit"], str), info3 - assert isinstance(info3["python_version"], str), info3 - assert isinstance(info3["protocol_version"], str), info3 - api3.disconnect() - finally: - ray_client_server.shutdown_with_server(server) - time.sleep(2) + assert job_id_1 == job_id_2 + + # Disconnect the first two clients. + api1.disconnect() + api2.disconnect() + time.sleep(1) + + api3 = RayAPIStub() + info3 = api3.connect("localhost:50051") + job_id_3 = get_job_id(api3) + assert info3["num_clients"] == 1, info3 + assert job_id_1 != job_id_3 + + # Check info contains ray and python version. + assert isinstance(info3["ray_version"], str), info3 + assert isinstance(info3["ray_commit"], str), info3 + assert isinstance(info3["python_version"], str), info3 + assert isinstance(info3["protocol_version"], str), info3 + api3.disconnect() -def test_python_version(): +def test_python_version(init_and_serve): + server_handle = init_and_serve + ray = RayAPIStub() + info1 = ray.connect("localhost:50051") + assert info1["python_version"] == ".".join( + [str(x) for x in list(sys.version_info)[:3]]) + ray.disconnect() + time.sleep(1) - server_handle, _ = ray_client_server.init_and_serve("localhost:50051") - try: - ray = RayAPIStub() - info1 = ray.connect("localhost:50051") - assert info1["python_version"] == ".".join( - [str(x) for x in list(sys.version_info)[:3]]) - ray.disconnect() - time.sleep(1) + def mock_connection_response(): + return ray_client_pb2.ConnectionInfoResponse( + num_clients=1, + python_version="2.7.12", + ray_version="", + ray_commit="", + protocol_version=CURRENT_PROTOCOL_VERSION, + ) - def mock_connection_response(): - return ray_client_pb2.ConnectionInfoResponse( - num_clients=1, - python_version="2.7.12", - ray_version="", - ray_commit="", - protocol_version=CURRENT_PROTOCOL_VERSION, - ) + # inject mock connection function + server_handle.data_servicer._build_connection_response = \ + mock_connection_response - # inject mock connection function - server_handle.data_servicer._build_connection_response = \ - mock_connection_response + ray = RayAPIStub() + with pytest.raises(RuntimeError): + _ = ray.connect("localhost:50051") - ray = RayAPIStub() - with pytest.raises(RuntimeError): - _ = ray.connect("localhost:50051") - - ray = RayAPIStub() - info3 = ray.connect("localhost:50051", ignore_version=True) - assert info3["num_clients"] == 1, info3 - ray.disconnect() - finally: - ray_client_server.shutdown_with_server(server_handle.grpc_server) - time.sleep(2) + ray = RayAPIStub() + info3 = ray.connect("localhost:50051", ignore_version=True) + assert info3["num_clients"] == 1, info3 + ray.disconnect() -def test_protocol_version(): +def test_protocol_version(init_and_serve): + server_handle = init_and_serve + ray = RayAPIStub() + info1 = ray.connect("localhost:50051") + local_py_version = ".".join([str(x) for x in list(sys.version_info)[:3]]) + assert info1["protocol_version"] == CURRENT_PROTOCOL_VERSION, info1 + ray.disconnect() + time.sleep(1) - server_handle, _ = ray_client_server.init_and_serve("localhost:50051") - try: - ray = RayAPIStub() - info1 = ray.connect("localhost:50051") - local_py_version = ".".join( - [str(x) for x in list(sys.version_info)[:3]]) - assert info1["protocol_version"] == CURRENT_PROTOCOL_VERSION, info1 - ray.disconnect() - time.sleep(1) + def mock_connection_response(): + return ray_client_pb2.ConnectionInfoResponse( + num_clients=1, + python_version=local_py_version, + ray_version="", + ray_commit="", + protocol_version="2050-01-01", # from the future + ) - def mock_connection_response(): - return ray_client_pb2.ConnectionInfoResponse( - num_clients=1, - python_version=local_py_version, - ray_version="", - ray_commit="", - protocol_version="2050-01-01", # from the future - ) + # inject mock connection function + server_handle.data_servicer._build_connection_response = \ + mock_connection_response - # inject mock connection function - server_handle.data_servicer._build_connection_response = \ - mock_connection_response + ray = RayAPIStub() + with pytest.raises(RuntimeError): + _ = ray.connect("localhost:50051") - ray = RayAPIStub() - with pytest.raises(RuntimeError): - _ = ray.connect("localhost:50051") + ray = RayAPIStub() + info3 = ray.connect("localhost:50051", ignore_version=True) + assert info3["num_clients"] == 1, info3 + ray.disconnect() - ray = RayAPIStub() - info3 = ray.connect("localhost:50051", ignore_version=True) - assert info3["num_clients"] == 1, info3 - ray.disconnect() - finally: - ray_client_server.shutdown_with_server(server_handle.grpc_server) - time.sleep(2) + +if __name__ == "__main__": + import pytest + sys.exit(pytest.main(["-v", __file__] + sys.argv[1:])) diff --git a/python/ray/tests/test_job.py b/python/ray/tests/test_job.py index cc7909dd8..15313d7ba 100644 --- a/python/ray/tests/test_job.py +++ b/python/ray/tests/test_job.py @@ -33,7 +33,7 @@ _ = Actor.remote() assert len(actor_table) == 1 job_table = ray.jobs() - assert len(job_table) == 3 # dash, ray client server + assert len(job_table) == 2 # dash # Kill the driver process. p.kill() @@ -79,7 +79,7 @@ ray.get(_.value.remote()) assert len(actor_table) == 1 job_table = ray.jobs() - assert len(job_table) == 3 # dash, ray client server + assert len(job_table) == 2 # dash # Kill the driver process. p.kill() diff --git a/python/ray/util/client/server/dataservicer.py b/python/ray/util/client/server/dataservicer.py index 82ddc85c6..c9e345219 100644 --- a/python/ray/util/client/server/dataservicer.py +++ b/python/ray/util/client/server/dataservicer.py @@ -3,12 +3,13 @@ import logging import grpc import sys -from typing import TYPE_CHECKING +from typing import TYPE_CHECKING, Callable from threading import Lock import ray.core.generated.ray_client_pb2 as ray_client_pb2 import ray.core.generated.ray_client_pb2_grpc as ray_client_pb2_grpc from ray.util.client import CURRENT_PROTOCOL_VERSION +from ray._private.client_mode_hook import disable_client_hook if TYPE_CHECKING: from ray.util.client.server.server import RayletServicer @@ -17,10 +18,12 @@ logger = logging.getLogger(__name__) class DataServicer(ray_client_pb2_grpc.RayletDataStreamerServicer): - def __init__(self, basic_service: "RayletServicer"): + def __init__(self, basic_service: "RayletServicer", + ray_connect_handler: Callable): self.basic_service = basic_service self._clients_lock = Lock() self._num_clients = 0 # guarded by self._clients_lock + self.ray_connect_handler = ray_connect_handler def Datapath(self, request_iterator, context): metadata = {k: v for k, v in context.invocation_metadata()} @@ -31,6 +34,9 @@ class DataServicer(ray_client_pb2_grpc.RayletDataStreamerServicer): logger.info(f"New data connection from client {client_id}") try: with self._clients_lock: + with disable_client_hook(): + if self._num_clients == 0 and not ray.is_initialized(): + self.ray_connect_handler() self._num_clients += 1 for req in request_iterator: resp = None @@ -63,9 +69,14 @@ class DataServicer(ray_client_pb2_grpc.RayletDataStreamerServicer): finally: logger.info(f"Lost data connection from client {client_id}") self.basic_service.release_all(client_id) + with self._clients_lock: self._num_clients -= 1 + with disable_client_hook(): + if self._num_clients == 0: + ray.shutdown() + def _build_connection_response(self): with self._clients_lock: cur_num_clients = self._num_clients diff --git a/python/ray/util/client/server/server.py b/python/ray/util/client/server/server.py index 6a7badaf7..6e65c929b 100644 --- a/python/ray/util/client/server/server.py +++ b/python/ray/util/client/server/server.py @@ -422,10 +422,17 @@ class ClientServerHandle: return getattr(self.grpc_server, attr) -def serve(connection_str): +def serve(connection_str, ray_connect_handler=None): + def default_connect_handler(): + with disable_client_hook(): + if not ray.is_initialized(): + return ray.init() + + ray_connect_handler = ray_connect_handler or default_connect_handler server = grpc.server(futures.ThreadPoolExecutor(max_workers=10)) task_servicer = RayletServicer() - data_servicer = DataServicer(task_servicer) + data_servicer = DataServicer( + task_servicer, ray_connect_handler=ray_connect_handler) logs_servicer = LogstreamServicer() ray_client_pb2_grpc.add_RayletDriverServicer_to_server( task_servicer, server) @@ -448,7 +455,17 @@ def init_and_serve(connection_str, *args, **kwargs): with disable_client_hook(): # Disable client mode inside the worker's environment info = ray.init(*args, **kwargs) - server_handle = serve(connection_str) + + def ray_connect_handler(): + # Ray client will disconnect from ray when + # num_clients == 0. + if ray.is_initialized(): + return info + else: + return ray.init(*args, **kwargs) + + server_handle = serve( + connection_str, ray_connect_handler=ray_connect_handler) return (server_handle, info) @@ -458,6 +475,19 @@ def shutdown_with_server(server, _exiting_interpreter=False): ray.shutdown(_exiting_interpreter) +def create_ray_handler(redis_address, redis_password): + def ray_connect_handler(): + if redis_address: + if redis_password: + ray.init(address=redis_address, _redis_password=redis_password) + else: + ray.init(address=redis_address) + else: + ray.init() + + return ray_connect_handler + + def main(): import argparse parser = argparse.ArgumentParser() @@ -477,18 +507,13 @@ def main(): help="Password for connecting to Redis") args = parser.parse_args() logging.basicConfig(level="INFO") - if args.redis_address: - if args.redis_password: - ray.init( - address=args.redis_address, - _redis_password=args.redis_password) - else: - ray.init(address=args.redis_address) - else: - ray.init() + + ray_connect_handler = create_ray_handler(args.redis_address, + args.redis_password) + hostport = "%s:%d" % (args.host, args.port) logger.info(f"Starting Ray Client server on {hostport}") - server = serve(hostport) + server = serve(hostport, ray_connect_handler) try: while True: time.sleep(1000) diff --git a/python/ray/util/client/worker.py b/python/ray/util/client/worker.py index 3f04c80a4..db9a1cc63 100644 --- a/python/ray/util/client/worker.py +++ b/python/ray/util/client/worker.py @@ -68,6 +68,7 @@ class Worker: """ self.metadata = metadata if metadata else [] self.channel = None + self.server = None self._conn_state = grpc.ChannelConnectivity.IDLE self._client_id = make_client_id() self._converted: Dict[str, ClientStub] = {} @@ -83,7 +84,7 @@ class Worker: # looking like a gRPC connection, though it may be a proxy. conn_attempts = 0 timeout = INITIAL_TIMEOUT_SEC - ray_ready = False + service_ready = False while conn_attempts < max(connection_retries, 1): conn_attempts += 1 try: @@ -94,13 +95,8 @@ class Worker: # RayletDriverStub, allowing for unary requests. self.server = ray_client_pb2_grpc.RayletDriverStub( self.channel) - # Now the HTTP2 channel is ready, or proxied, but the - # servicer may not be ready. Call is_initialized() and if - # it throws, the servicer is not ready. On success, the - # `ray_ready` result is checked. - ray_ready = self.is_initialized() - if ray_ready: - # Ray is ready! Break out of the retry loop + service_ready = bool(self.ping_server()) + if service_ready: break # Ray is not ready yet, wait a timeout time.sleep(timeout) @@ -120,9 +116,10 @@ class Worker: f"retry in {timeout}s...") timeout = backoff(timeout) - # If we made it through the loop without ray_ready it means we've used - # up our retries and should error back to the user. - if not ray_ready: + # If we made it through the loop without service_ready + # it means we've used up our retries and + # should error back to the user. + if not service_ready: raise ConnectionError("ray client connection timeout") # Initialize the streams to finish protocol negotiation. @@ -377,6 +374,18 @@ class Worker: ray_client_pb2.ClusterInfoType.IS_INITIALIZED) return False + def ping_server(self) -> bool: + """Simple health check. + + Piggybacks the IS_INITIALIZED call to check if the server provides + an actual response. + """ + if self.server is not None: + result = self.get_cluster_info( + ray_client_pb2.ClusterInfoType.IS_INITIALIZED) + return result is not None + return False + def is_connected(self) -> bool: return self._conn_state == grpc.ChannelConnectivity.READY