diff --git a/python/ray/tests/BUILD b/python/ray/tests/BUILD index 7f4c61bb1..8fe8b21c3 100644 --- a/python/ray/tests/BUILD +++ b/python/ray/tests/BUILD @@ -79,6 +79,7 @@ py_test_module_list( "test_asyncio.py", "test_autoscaler.py", "test_autoscaler_yaml.py", + "test_client_init.py", "test_client_metadata.py", "test_client.py", "test_client_references.py", diff --git a/python/ray/tests/test_client.py b/python/ray/tests/test_client.py index 21bb807fd..dc5de2470 100644 --- a/python/ray/tests/test_client.py +++ b/python/ray/tests/test_client.py @@ -2,42 +2,13 @@ import pytest import time import sys import logging +import threading import ray.util.client.server.server as ray_client_server -from ray.util.client import RayAPIStub from ray.util.client.common import ClientObjectRef from ray.util.client.ray_client_helpers import ray_start_client_server -def test_num_clients(shutdown_only): - # Tests num clients reporting; useful if you want to build an app that - # load balances clients between Ray client servers. - server = ray_client_server.serve("localhost:50051") - try: - api1 = RayAPIStub() - info1 = api1.connect("localhost:50051") - assert info1["num_clients"] == 1, info1 - api2 = RayAPIStub() - info2 = api2.connect("localhost:50051") - assert info2["num_clients"] == 2, info2 - - # Disconnect the first two clients. - api1.disconnect() - api2.disconnect() - time.sleep(1) - - api3 = RayAPIStub() - info3 = api3.connect("localhost:50051") - assert info3["num_clients"] == 1, info3 - - # Check info contains ray and python version. - assert isinstance(info3["ray_version"], str), info3 - assert isinstance(info3["ray_commit"], str), info3 - assert isinstance(info3["python_version"], str), info3 - finally: - server.stop(0) - - @pytest.mark.skipif(sys.platform == "win32", reason="Failing on Windows.") def test_real_ray_fallback(ray_start_regular_shared): with ray_start_client_server() as ray: @@ -373,5 +344,25 @@ def test_internal_kv(ray_start_regular_shared): assert ray._internal_kv_get("apple") == b"" +def test_startup_retry(ray_start_regular_shared): + from ray.util.client import ray as ray_client + ray_client._inside_client_test = True + + with pytest.raises(ConnectionError): + ray_client.connect("localhost:50051", connection_retries=1) + + def run_client(): + ray_client.connect("localhost:50051") + ray_client.disconnect() + + thread = threading.Thread(target=run_client, daemon=True) + thread.start() + time.sleep(3) + server = ray_client_server.serve("localhost:50051") + thread.join() + server.stop(0) + ray_client._inside_client_test = False + + if __name__ == "__main__": sys.exit(pytest.main(["-v", __file__])) diff --git a/python/ray/tests/test_client_init.py b/python/ray/tests/test_client_init.py new file mode 100644 index 000000000..1949fe3fd --- /dev/null +++ b/python/ray/tests/test_client_init.py @@ -0,0 +1,37 @@ +"""Client tests that run their own init (as with init_and_serve) live here""" +import time + +import ray.util.client.server.server as ray_client_server + +from ray.util.client import RayAPIStub + + +def test_num_clients(): + # Tests num clients reporting; useful if you want to build an app that + # load balances clients between Ray client servers. + server, _ = ray_client_server.init_and_serve("localhost:50051") + try: + api1 = RayAPIStub() + info1 = api1.connect("localhost:50051") + assert info1["num_clients"] == 1, info1 + api2 = RayAPIStub() + info2 = api2.connect("localhost:50051") + assert info2["num_clients"] == 2, info2 + + # Disconnect the first two clients. + api1.disconnect() + api2.disconnect() + time.sleep(1) + + api3 = RayAPIStub() + info3 = api3.connect("localhost:50051") + assert info3["num_clients"] == 1, info3 + + # Check info contains ray and python version. + assert isinstance(info3["ray_version"], str), info3 + assert isinstance(info3["ray_commit"], str), info3 + assert isinstance(info3["python_version"], str), info3 + api3.disconnect() + finally: + ray_client_server.shutdown_with_server(server) + time.sleep(2) diff --git a/python/ray/util/client/worker.py b/python/ray/util/client/worker.py index 3c6401fda..d62173be7 100644 --- a/python/ray/util/client/worker.py +++ b/python/ray/util/client/worker.py @@ -5,6 +5,7 @@ to the server. import base64 import json import logging +import time import uuid from collections import defaultdict from typing import Any @@ -33,6 +34,13 @@ INITIAL_TIMEOUT_SEC = 5 MAX_TIMEOUT_SEC = 30 +def backoff(timeout: int) -> int: + timeout = timeout + 5 + if timeout > MAX_TIMEOUT_SEC: + timeout = MAX_TIMEOUT_SEC + return timeout + + class Worker: def __init__(self, conn_str: str = "", @@ -59,23 +67,59 @@ class Worker: else: self.channel = grpc.insecure_channel(conn_str) + # Retry the connection until the channel responds to something + # looking like a gRPC connection, though it may be a proxy. conn_attempts = 0 timeout = INITIAL_TIMEOUT_SEC - while conn_attempts < connection_retries + 1: + ray_ready = False + while conn_attempts < max(connection_retries, 1): conn_attempts += 1 try: + # Let gRPC wait for us to see if the channel becomes ready. + # If it throws, we couldn't connect. grpc.channel_ready_future(self.channel).result(timeout=timeout) - break + # The HTTP2 channel is ready. Wrap the channel with the + # RayletDriverStub, allowing for unary requests. + self.server = ray_client_pb2_grpc.RayletDriverStub( + self.channel) + # Now the HTTP2 channel is ready, or proxied, but the + # servicer may not be ready. Call is_initialized() and if + # it throws, the servicer is not ready. On success, the + # `ray_ready` result is checked. + ray_ready = self.is_initialized() + if ray_ready: + # Ray is ready! Break out of the retry loop + break + # Ray is not ready yet, wait a timeout + time.sleep(timeout) except grpc.FutureTimeoutError: - if conn_attempts >= connection_retries: - raise ConnectionError("ray client connection timeout") - logger.info(f"Couldn't connect in {timeout} seconds, retrying") - timeout = timeout + 5 - if timeout > MAX_TIMEOUT_SEC: - timeout = MAX_TIMEOUT_SEC + logger.info( + f"Couldn't connect channel in {timeout} seconds, retrying") + # Note that channel_ready_future constitutes its own timeout, + # which is why we do not sleep here. + except grpc.RpcError as e: + if e.code() == grpc.StatusCode.UNAVAILABLE: + # UNAVAILABLE is gRPC's retryable error, + # so we do that here. + logger.info("Ray client server unavailable, " + f"retrying in {timeout}s...") + logger.debug(f"Received when checking init: {e.details()}") + # Ray is not ready yet, wait a timeout + time.sleep(timeout) + else: + # Any other gRPC error gets a reraise + raise e + # Fallthrough, backoff, and retry at the top of the loop + logger.info("Waiting for Ray to become ready on the server, " + f"retry in {timeout}s...") + timeout = backoff(timeout) - self.server = ray_client_pb2_grpc.RayletDriverStub(self.channel) + # If we made it through the loop without ray_ready it means we've used + # up our retries and should error back to the user. + if not ray_ready: + raise ConnectionError("ray client connection timeout") + # Initialize the streams to finish protocol negotiation. self.data_client = DataClient(self.channel, self._client_id, self.metadata) self.reference_count: Dict[bytes, int] = defaultdict(int)