mirror of
https://github.com/wassname/ray.git
synced 2026-06-27 20:22:39 +08:00
[ray_client]: Monitor client stream errors (#13386)
This commit is contained in:
@@ -364,5 +364,32 @@ def test_startup_retry(ray_start_regular_shared):
|
||||
ray_client._inside_client_test = False
|
||||
|
||||
|
||||
def test_dataclient_server_drop(ray_start_regular_shared):
|
||||
from ray.util.client import ray as ray_client
|
||||
ray_client._inside_client_test = True
|
||||
|
||||
@ray_client.remote
|
||||
def f(x):
|
||||
time.sleep(4)
|
||||
return x
|
||||
|
||||
def stop_server(server):
|
||||
time.sleep(2)
|
||||
server.stop(0)
|
||||
|
||||
server = ray_client_server.serve("localhost:50051")
|
||||
ray_client.connect("localhost:50051")
|
||||
thread = threading.Thread(target=stop_server, args=(server, ))
|
||||
thread.start()
|
||||
x = f.remote(2)
|
||||
with pytest.raises(ConnectionError):
|
||||
_ = ray_client.get(x)
|
||||
thread.join()
|
||||
ray_client.disconnect()
|
||||
ray_client._inside_client_test = False
|
||||
# Wait for f(x) to finish before ray.shutdown() in the fixture
|
||||
time.sleep(3)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
sys.exit(pytest.main(["-v", __file__]))
|
||||
|
||||
@@ -89,7 +89,9 @@ class RayAPIStub:
|
||||
return getattr(self.api, key)
|
||||
|
||||
def is_connected(self) -> bool:
|
||||
return self.client_worker is not None
|
||||
if self.client_worker is None:
|
||||
return False
|
||||
return self.client_worker.is_connected()
|
||||
|
||||
def init(self, *args, **kwargs):
|
||||
if self._server is not None:
|
||||
|
||||
@@ -37,6 +37,7 @@ class DataClient:
|
||||
self._req_id = 0
|
||||
self._client_id = client_id
|
||||
self._metadata = metadata
|
||||
self._in_shutdown = False
|
||||
self.data_thread.start()
|
||||
|
||||
def _next_id(self) -> int:
|
||||
@@ -67,9 +68,19 @@ class DataClient:
|
||||
self.ready_data[response.req_id] = response
|
||||
self.cv.notify_all()
|
||||
except grpc.RpcError as e:
|
||||
if grpc.StatusCode.CANCELLED == e.code():
|
||||
with self.cv:
|
||||
self._in_shutdown = True
|
||||
self.cv.notify_all()
|
||||
if e.code() == grpc.StatusCode.CANCELLED:
|
||||
# Gracefully shutting down
|
||||
logger.info("Cancelling data channel")
|
||||
elif e.code() == grpc.StatusCode.UNAVAILABLE:
|
||||
# TODO(barakmich): The server may have
|
||||
# dropped. In theory, we can retry, as per
|
||||
# https://grpc.github.io/grpc/core/md_doc_statuscodes.html but
|
||||
# in practice we may need to think about the correct semantics
|
||||
# here.
|
||||
logger.info("Server disconnected from data channel")
|
||||
else:
|
||||
logger.error(
|
||||
f"Got Error from data channel -- shutting down: {e}")
|
||||
@@ -88,7 +99,11 @@ class DataClient:
|
||||
self.request_queue.put(req)
|
||||
data = None
|
||||
with self.cv:
|
||||
self.cv.wait_for(lambda: req_id in self.ready_data)
|
||||
self.cv.wait_for(
|
||||
lambda: req_id in self.ready_data or self._in_shutdown)
|
||||
if self._in_shutdown:
|
||||
raise ConnectionError(
|
||||
f"cannot send request {req}: data channel shutting down")
|
||||
data = self.ready_data[req_id]
|
||||
del self.ready_data[req_id]
|
||||
return data
|
||||
|
||||
@@ -44,8 +44,18 @@ class LogstreamClient:
|
||||
self.stdstream(level=record.level, msg=record.msg)
|
||||
self.log(level=record.level, msg=record.msg)
|
||||
except grpc.RpcError as e:
|
||||
if grpc.StatusCode.CANCELLED != e.code():
|
||||
# Not just shutting down normally
|
||||
if e.code() == grpc.StatusCode.CANCELLED:
|
||||
# Graceful shutdown. We've cancelled our own connection.
|
||||
logger.info("Cancelling logs channel")
|
||||
elif e.code() == grpc.StatusCode.UNAVAILABLE:
|
||||
# TODO(barakmich): The server may have
|
||||
# dropped. In theory, we can retry, as per
|
||||
# https://grpc.github.io/grpc/core/md_doc_statuscodes.html but
|
||||
# in practice we may need to think about the correct semantics
|
||||
# here.
|
||||
logger.info("Server disconnected from logs channel")
|
||||
else:
|
||||
# Some other, unhandled, gRPC error
|
||||
logger.error(
|
||||
f"Got Error from logger channel -- shutting down: {e}")
|
||||
raise e
|
||||
|
||||
@@ -60,6 +60,7 @@ class Worker:
|
||||
"""
|
||||
self.metadata = metadata if metadata else []
|
||||
self.channel = None
|
||||
self._conn_state = grpc.ChannelConnectivity.IDLE
|
||||
self._client_id = make_client_id()
|
||||
if secure:
|
||||
credentials = grpc.ssl_channel_credentials()
|
||||
@@ -67,6 +68,8 @@ class Worker:
|
||||
else:
|
||||
self.channel = grpc.insecure_channel(conn_str)
|
||||
|
||||
self.channel.subscribe(self._on_channel_state_change)
|
||||
|
||||
# Retry the connection until the channel responds to something
|
||||
# looking like a gRPC connection, though it may be a proxy.
|
||||
conn_attempts = 0
|
||||
@@ -128,6 +131,10 @@ class Worker:
|
||||
self.log_client.set_logstream_level(logging.INFO)
|
||||
self.closed = False
|
||||
|
||||
def _on_channel_state_change(self, conn_state: grpc.ChannelConnectivity):
|
||||
logger.debug(f"client gRPC channel state change: {conn_state}")
|
||||
self._conn_state = conn_state
|
||||
|
||||
def connection_info(self):
|
||||
try:
|
||||
data = self.data_client.ConnectionInfo()
|
||||
@@ -357,6 +364,9 @@ class Worker:
|
||||
ray_client_pb2.ClusterInfoType.IS_INITIALIZED)
|
||||
return False
|
||||
|
||||
def is_connected(self) -> bool:
|
||||
return self._conn_state == grpc.ChannelConnectivity.READY
|
||||
|
||||
|
||||
def make_client_id() -> str:
|
||||
id = uuid.uuid4()
|
||||
|
||||
Reference in New Issue
Block a user