diff --git a/python/ray/tests/test_client.py b/python/ray/tests/test_client.py index dc5de2470..30d6faccb 100644 --- a/python/ray/tests/test_client.py +++ b/python/ray/tests/test_client.py @@ -364,5 +364,32 @@ def test_startup_retry(ray_start_regular_shared): ray_client._inside_client_test = False +def test_dataclient_server_drop(ray_start_regular_shared): + from ray.util.client import ray as ray_client + ray_client._inside_client_test = True + + @ray_client.remote + def f(x): + time.sleep(4) + return x + + def stop_server(server): + time.sleep(2) + server.stop(0) + + server = ray_client_server.serve("localhost:50051") + ray_client.connect("localhost:50051") + thread = threading.Thread(target=stop_server, args=(server, )) + thread.start() + x = f.remote(2) + with pytest.raises(ConnectionError): + _ = ray_client.get(x) + thread.join() + ray_client.disconnect() + ray_client._inside_client_test = False + # Wait for f(x) to finish before ray.shutdown() in the fixture + time.sleep(3) + + if __name__ == "__main__": sys.exit(pytest.main(["-v", __file__])) diff --git a/python/ray/util/client/__init__.py b/python/ray/util/client/__init__.py index 02aab93ff..1c28dc53c 100644 --- a/python/ray/util/client/__init__.py +++ b/python/ray/util/client/__init__.py @@ -89,7 +89,9 @@ class RayAPIStub: return getattr(self.api, key) def is_connected(self) -> bool: - return self.client_worker is not None + if self.client_worker is None: + return False + return self.client_worker.is_connected() def init(self, *args, **kwargs): if self._server is not None: diff --git a/python/ray/util/client/dataclient.py b/python/ray/util/client/dataclient.py index 6e29ea927..a0750b790 100644 --- a/python/ray/util/client/dataclient.py +++ b/python/ray/util/client/dataclient.py @@ -37,6 +37,7 @@ class DataClient: self._req_id = 0 self._client_id = client_id self._metadata = metadata + self._in_shutdown = False self.data_thread.start() def _next_id(self) -> int: @@ -67,9 +68,19 @@ class DataClient: self.ready_data[response.req_id] = response self.cv.notify_all() except grpc.RpcError as e: - if grpc.StatusCode.CANCELLED == e.code(): + with self.cv: + self._in_shutdown = True + self.cv.notify_all() + if e.code() == grpc.StatusCode.CANCELLED: # Gracefully shutting down logger.info("Cancelling data channel") + elif e.code() == grpc.StatusCode.UNAVAILABLE: + # TODO(barakmich): The server may have + # dropped. In theory, we can retry, as per + # https://grpc.github.io/grpc/core/md_doc_statuscodes.html but + # in practice we may need to think about the correct semantics + # here. + logger.info("Server disconnected from data channel") else: logger.error( f"Got Error from data channel -- shutting down: {e}") @@ -88,7 +99,11 @@ class DataClient: self.request_queue.put(req) data = None with self.cv: - self.cv.wait_for(lambda: req_id in self.ready_data) + self.cv.wait_for( + lambda: req_id in self.ready_data or self._in_shutdown) + if self._in_shutdown: + raise ConnectionError( + f"cannot send request {req}: data channel shutting down") data = self.ready_data[req_id] del self.ready_data[req_id] return data diff --git a/python/ray/util/client/logsclient.py b/python/ray/util/client/logsclient.py index 0e4d02846..f7902024d 100644 --- a/python/ray/util/client/logsclient.py +++ b/python/ray/util/client/logsclient.py @@ -44,8 +44,18 @@ class LogstreamClient: self.stdstream(level=record.level, msg=record.msg) self.log(level=record.level, msg=record.msg) except grpc.RpcError as e: - if grpc.StatusCode.CANCELLED != e.code(): - # Not just shutting down normally + if e.code() == grpc.StatusCode.CANCELLED: + # Graceful shutdown. We've cancelled our own connection. + logger.info("Cancelling logs channel") + elif e.code() == grpc.StatusCode.UNAVAILABLE: + # TODO(barakmich): The server may have + # dropped. In theory, we can retry, as per + # https://grpc.github.io/grpc/core/md_doc_statuscodes.html but + # in practice we may need to think about the correct semantics + # here. + logger.info("Server disconnected from logs channel") + else: + # Some other, unhandled, gRPC error logger.error( f"Got Error from logger channel -- shutting down: {e}") raise e diff --git a/python/ray/util/client/worker.py b/python/ray/util/client/worker.py index d62173be7..9f2f189c6 100644 --- a/python/ray/util/client/worker.py +++ b/python/ray/util/client/worker.py @@ -60,6 +60,7 @@ class Worker: """ self.metadata = metadata if metadata else [] self.channel = None + self._conn_state = grpc.ChannelConnectivity.IDLE self._client_id = make_client_id() if secure: credentials = grpc.ssl_channel_credentials() @@ -67,6 +68,8 @@ class Worker: else: self.channel = grpc.insecure_channel(conn_str) + self.channel.subscribe(self._on_channel_state_change) + # Retry the connection until the channel responds to something # looking like a gRPC connection, though it may be a proxy. conn_attempts = 0 @@ -128,6 +131,10 @@ class Worker: self.log_client.set_logstream_level(logging.INFO) self.closed = False + def _on_channel_state_change(self, conn_state: grpc.ChannelConnectivity): + logger.debug(f"client gRPC channel state change: {conn_state}") + self._conn_state = conn_state + def connection_info(self): try: data = self.data_client.ConnectionInfo() @@ -357,6 +364,9 @@ class Worker: ray_client_pb2.ClusterInfoType.IS_INITIALIZED) return False + def is_connected(self) -> bool: + return self._conn_state == grpc.ChannelConnectivity.READY + def make_client_id() -> str: id = uuid.uuid4()