mirror of
https://github.com/wassname/ray.git
synced 2026-06-27 20:22:39 +08:00
[ray_client]: Add more retry logic (#13478)
This commit is contained in:
@@ -79,6 +79,7 @@ py_test_module_list(
|
||||
"test_asyncio.py",
|
||||
"test_autoscaler.py",
|
||||
"test_autoscaler_yaml.py",
|
||||
"test_client_init.py",
|
||||
"test_client_metadata.py",
|
||||
"test_client.py",
|
||||
"test_client_references.py",
|
||||
|
||||
@@ -2,42 +2,13 @@ import pytest
|
||||
import time
|
||||
import sys
|
||||
import logging
|
||||
import threading
|
||||
|
||||
import ray.util.client.server.server as ray_client_server
|
||||
from ray.util.client import RayAPIStub
|
||||
from ray.util.client.common import ClientObjectRef
|
||||
from ray.util.client.ray_client_helpers import ray_start_client_server
|
||||
|
||||
|
||||
def test_num_clients(shutdown_only):
|
||||
# Tests num clients reporting; useful if you want to build an app that
|
||||
# load balances clients between Ray client servers.
|
||||
server = ray_client_server.serve("localhost:50051")
|
||||
try:
|
||||
api1 = RayAPIStub()
|
||||
info1 = api1.connect("localhost:50051")
|
||||
assert info1["num_clients"] == 1, info1
|
||||
api2 = RayAPIStub()
|
||||
info2 = api2.connect("localhost:50051")
|
||||
assert info2["num_clients"] == 2, info2
|
||||
|
||||
# Disconnect the first two clients.
|
||||
api1.disconnect()
|
||||
api2.disconnect()
|
||||
time.sleep(1)
|
||||
|
||||
api3 = RayAPIStub()
|
||||
info3 = api3.connect("localhost:50051")
|
||||
assert info3["num_clients"] == 1, info3
|
||||
|
||||
# Check info contains ray and python version.
|
||||
assert isinstance(info3["ray_version"], str), info3
|
||||
assert isinstance(info3["ray_commit"], str), info3
|
||||
assert isinstance(info3["python_version"], str), info3
|
||||
finally:
|
||||
server.stop(0)
|
||||
|
||||
|
||||
@pytest.mark.skipif(sys.platform == "win32", reason="Failing on Windows.")
|
||||
def test_real_ray_fallback(ray_start_regular_shared):
|
||||
with ray_start_client_server() as ray:
|
||||
@@ -373,5 +344,25 @@ def test_internal_kv(ray_start_regular_shared):
|
||||
assert ray._internal_kv_get("apple") == b""
|
||||
|
||||
|
||||
def test_startup_retry(ray_start_regular_shared):
|
||||
from ray.util.client import ray as ray_client
|
||||
ray_client._inside_client_test = True
|
||||
|
||||
with pytest.raises(ConnectionError):
|
||||
ray_client.connect("localhost:50051", connection_retries=1)
|
||||
|
||||
def run_client():
|
||||
ray_client.connect("localhost:50051")
|
||||
ray_client.disconnect()
|
||||
|
||||
thread = threading.Thread(target=run_client, daemon=True)
|
||||
thread.start()
|
||||
time.sleep(3)
|
||||
server = ray_client_server.serve("localhost:50051")
|
||||
thread.join()
|
||||
server.stop(0)
|
||||
ray_client._inside_client_test = False
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
sys.exit(pytest.main(["-v", __file__]))
|
||||
|
||||
@@ -0,0 +1,37 @@
|
||||
"""Client tests that run their own init (as with init_and_serve) live here"""
|
||||
import time
|
||||
|
||||
import ray.util.client.server.server as ray_client_server
|
||||
|
||||
from ray.util.client import RayAPIStub
|
||||
|
||||
|
||||
def test_num_clients():
|
||||
# Tests num clients reporting; useful if you want to build an app that
|
||||
# load balances clients between Ray client servers.
|
||||
server, _ = ray_client_server.init_and_serve("localhost:50051")
|
||||
try:
|
||||
api1 = RayAPIStub()
|
||||
info1 = api1.connect("localhost:50051")
|
||||
assert info1["num_clients"] == 1, info1
|
||||
api2 = RayAPIStub()
|
||||
info2 = api2.connect("localhost:50051")
|
||||
assert info2["num_clients"] == 2, info2
|
||||
|
||||
# Disconnect the first two clients.
|
||||
api1.disconnect()
|
||||
api2.disconnect()
|
||||
time.sleep(1)
|
||||
|
||||
api3 = RayAPIStub()
|
||||
info3 = api3.connect("localhost:50051")
|
||||
assert info3["num_clients"] == 1, info3
|
||||
|
||||
# Check info contains ray and python version.
|
||||
assert isinstance(info3["ray_version"], str), info3
|
||||
assert isinstance(info3["ray_commit"], str), info3
|
||||
assert isinstance(info3["python_version"], str), info3
|
||||
api3.disconnect()
|
||||
finally:
|
||||
ray_client_server.shutdown_with_server(server)
|
||||
time.sleep(2)
|
||||
@@ -5,6 +5,7 @@ to the server.
|
||||
import base64
|
||||
import json
|
||||
import logging
|
||||
import time
|
||||
import uuid
|
||||
from collections import defaultdict
|
||||
from typing import Any
|
||||
@@ -33,6 +34,13 @@ INITIAL_TIMEOUT_SEC = 5
|
||||
MAX_TIMEOUT_SEC = 30
|
||||
|
||||
|
||||
def backoff(timeout: int) -> int:
|
||||
timeout = timeout + 5
|
||||
if timeout > MAX_TIMEOUT_SEC:
|
||||
timeout = MAX_TIMEOUT_SEC
|
||||
return timeout
|
||||
|
||||
|
||||
class Worker:
|
||||
def __init__(self,
|
||||
conn_str: str = "",
|
||||
@@ -59,23 +67,59 @@ class Worker:
|
||||
else:
|
||||
self.channel = grpc.insecure_channel(conn_str)
|
||||
|
||||
# Retry the connection until the channel responds to something
|
||||
# looking like a gRPC connection, though it may be a proxy.
|
||||
conn_attempts = 0
|
||||
timeout = INITIAL_TIMEOUT_SEC
|
||||
while conn_attempts < connection_retries + 1:
|
||||
ray_ready = False
|
||||
while conn_attempts < max(connection_retries, 1):
|
||||
conn_attempts += 1
|
||||
try:
|
||||
# Let gRPC wait for us to see if the channel becomes ready.
|
||||
# If it throws, we couldn't connect.
|
||||
grpc.channel_ready_future(self.channel).result(timeout=timeout)
|
||||
break
|
||||
# The HTTP2 channel is ready. Wrap the channel with the
|
||||
# RayletDriverStub, allowing for unary requests.
|
||||
self.server = ray_client_pb2_grpc.RayletDriverStub(
|
||||
self.channel)
|
||||
# Now the HTTP2 channel is ready, or proxied, but the
|
||||
# servicer may not be ready. Call is_initialized() and if
|
||||
# it throws, the servicer is not ready. On success, the
|
||||
# `ray_ready` result is checked.
|
||||
ray_ready = self.is_initialized()
|
||||
if ray_ready:
|
||||
# Ray is ready! Break out of the retry loop
|
||||
break
|
||||
# Ray is not ready yet, wait a timeout
|
||||
time.sleep(timeout)
|
||||
except grpc.FutureTimeoutError:
|
||||
if conn_attempts >= connection_retries:
|
||||
raise ConnectionError("ray client connection timeout")
|
||||
logger.info(f"Couldn't connect in {timeout} seconds, retrying")
|
||||
timeout = timeout + 5
|
||||
if timeout > MAX_TIMEOUT_SEC:
|
||||
timeout = MAX_TIMEOUT_SEC
|
||||
logger.info(
|
||||
f"Couldn't connect channel in {timeout} seconds, retrying")
|
||||
# Note that channel_ready_future constitutes its own timeout,
|
||||
# which is why we do not sleep here.
|
||||
except grpc.RpcError as e:
|
||||
if e.code() == grpc.StatusCode.UNAVAILABLE:
|
||||
# UNAVAILABLE is gRPC's retryable error,
|
||||
# so we do that here.
|
||||
logger.info("Ray client server unavailable, "
|
||||
f"retrying in {timeout}s...")
|
||||
logger.debug(f"Received when checking init: {e.details()}")
|
||||
# Ray is not ready yet, wait a timeout
|
||||
time.sleep(timeout)
|
||||
else:
|
||||
# Any other gRPC error gets a reraise
|
||||
raise e
|
||||
# Fallthrough, backoff, and retry at the top of the loop
|
||||
logger.info("Waiting for Ray to become ready on the server, "
|
||||
f"retry in {timeout}s...")
|
||||
timeout = backoff(timeout)
|
||||
|
||||
self.server = ray_client_pb2_grpc.RayletDriverStub(self.channel)
|
||||
# If we made it through the loop without ray_ready it means we've used
|
||||
# up our retries and should error back to the user.
|
||||
if not ray_ready:
|
||||
raise ConnectionError("ray client connection timeout")
|
||||
|
||||
# Initialize the streams to finish protocol negotiation.
|
||||
self.data_client = DataClient(self.channel, self._client_id,
|
||||
self.metadata)
|
||||
self.reference_count: Dict[bytes, int] = defaultdict(int)
|
||||
|
||||
Reference in New Issue
Block a user