[ray_client]: Add more retry logic (#13478)

This commit is contained in:
Barak Michener
2021-01-23 23:11:39 -08:00
committed by GitHub
parent b7dd7ddb52
commit e675e5b75a
4 changed files with 112 additions and 39 deletions
+1
View File
@@ -79,6 +79,7 @@ py_test_module_list(
"test_asyncio.py",
"test_autoscaler.py",
"test_autoscaler_yaml.py",
"test_client_init.py",
"test_client_metadata.py",
"test_client.py",
"test_client_references.py",
+21 -30
View File
@@ -2,42 +2,13 @@ import pytest
import time
import sys
import logging
import threading
import ray.util.client.server.server as ray_client_server
from ray.util.client import RayAPIStub
from ray.util.client.common import ClientObjectRef
from ray.util.client.ray_client_helpers import ray_start_client_server
def test_num_clients(shutdown_only):
# Tests num clients reporting; useful if you want to build an app that
# load balances clients between Ray client servers.
server = ray_client_server.serve("localhost:50051")
try:
api1 = RayAPIStub()
info1 = api1.connect("localhost:50051")
assert info1["num_clients"] == 1, info1
api2 = RayAPIStub()
info2 = api2.connect("localhost:50051")
assert info2["num_clients"] == 2, info2
# Disconnect the first two clients.
api1.disconnect()
api2.disconnect()
time.sleep(1)
api3 = RayAPIStub()
info3 = api3.connect("localhost:50051")
assert info3["num_clients"] == 1, info3
# Check info contains ray and python version.
assert isinstance(info3["ray_version"], str), info3
assert isinstance(info3["ray_commit"], str), info3
assert isinstance(info3["python_version"], str), info3
finally:
server.stop(0)
@pytest.mark.skipif(sys.platform == "win32", reason="Failing on Windows.")
def test_real_ray_fallback(ray_start_regular_shared):
with ray_start_client_server() as ray:
@@ -373,5 +344,25 @@ def test_internal_kv(ray_start_regular_shared):
assert ray._internal_kv_get("apple") == b""
def test_startup_retry(ray_start_regular_shared):
from ray.util.client import ray as ray_client
ray_client._inside_client_test = True
with pytest.raises(ConnectionError):
ray_client.connect("localhost:50051", connection_retries=1)
def run_client():
ray_client.connect("localhost:50051")
ray_client.disconnect()
thread = threading.Thread(target=run_client, daemon=True)
thread.start()
time.sleep(3)
server = ray_client_server.serve("localhost:50051")
thread.join()
server.stop(0)
ray_client._inside_client_test = False
if __name__ == "__main__":
sys.exit(pytest.main(["-v", __file__]))
+37
View File
@@ -0,0 +1,37 @@
"""Client tests that run their own init (as with init_and_serve) live here"""
import time
import ray.util.client.server.server as ray_client_server
from ray.util.client import RayAPIStub
def test_num_clients():
# Tests num clients reporting; useful if you want to build an app that
# load balances clients between Ray client servers.
server, _ = ray_client_server.init_and_serve("localhost:50051")
try:
api1 = RayAPIStub()
info1 = api1.connect("localhost:50051")
assert info1["num_clients"] == 1, info1
api2 = RayAPIStub()
info2 = api2.connect("localhost:50051")
assert info2["num_clients"] == 2, info2
# Disconnect the first two clients.
api1.disconnect()
api2.disconnect()
time.sleep(1)
api3 = RayAPIStub()
info3 = api3.connect("localhost:50051")
assert info3["num_clients"] == 1, info3
# Check info contains ray and python version.
assert isinstance(info3["ray_version"], str), info3
assert isinstance(info3["ray_commit"], str), info3
assert isinstance(info3["python_version"], str), info3
api3.disconnect()
finally:
ray_client_server.shutdown_with_server(server)
time.sleep(2)
+53 -9
View File
@@ -5,6 +5,7 @@ to the server.
import base64
import json
import logging
import time
import uuid
from collections import defaultdict
from typing import Any
@@ -33,6 +34,13 @@ INITIAL_TIMEOUT_SEC = 5
MAX_TIMEOUT_SEC = 30
def backoff(timeout: int) -> int:
timeout = timeout + 5
if timeout > MAX_TIMEOUT_SEC:
timeout = MAX_TIMEOUT_SEC
return timeout
class Worker:
def __init__(self,
conn_str: str = "",
@@ -59,23 +67,59 @@ class Worker:
else:
self.channel = grpc.insecure_channel(conn_str)
# Retry the connection until the channel responds to something
# looking like a gRPC connection, though it may be a proxy.
conn_attempts = 0
timeout = INITIAL_TIMEOUT_SEC
while conn_attempts < connection_retries + 1:
ray_ready = False
while conn_attempts < max(connection_retries, 1):
conn_attempts += 1
try:
# Let gRPC wait for us to see if the channel becomes ready.
# If it throws, we couldn't connect.
grpc.channel_ready_future(self.channel).result(timeout=timeout)
break
# The HTTP2 channel is ready. Wrap the channel with the
# RayletDriverStub, allowing for unary requests.
self.server = ray_client_pb2_grpc.RayletDriverStub(
self.channel)
# Now the HTTP2 channel is ready, or proxied, but the
# servicer may not be ready. Call is_initialized() and if
# it throws, the servicer is not ready. On success, the
# `ray_ready` result is checked.
ray_ready = self.is_initialized()
if ray_ready:
# Ray is ready! Break out of the retry loop
break
# Ray is not ready yet, wait a timeout
time.sleep(timeout)
except grpc.FutureTimeoutError:
if conn_attempts >= connection_retries:
raise ConnectionError("ray client connection timeout")
logger.info(f"Couldn't connect in {timeout} seconds, retrying")
timeout = timeout + 5
if timeout > MAX_TIMEOUT_SEC:
timeout = MAX_TIMEOUT_SEC
logger.info(
f"Couldn't connect channel in {timeout} seconds, retrying")
# Note that channel_ready_future constitutes its own timeout,
# which is why we do not sleep here.
except grpc.RpcError as e:
if e.code() == grpc.StatusCode.UNAVAILABLE:
# UNAVAILABLE is gRPC's retryable error,
# so we do that here.
logger.info("Ray client server unavailable, "
f"retrying in {timeout}s...")
logger.debug(f"Received when checking init: {e.details()}")
# Ray is not ready yet, wait a timeout
time.sleep(timeout)
else:
# Any other gRPC error gets a reraise
raise e
# Fallthrough, backoff, and retry at the top of the loop
logger.info("Waiting for Ray to become ready on the server, "
f"retry in {timeout}s...")
timeout = backoff(timeout)
self.server = ray_client_pb2_grpc.RayletDriverStub(self.channel)
# If we made it through the loop without ray_ready it means we've used
# up our retries and should error back to the user.
if not ray_ready:
raise ConnectionError("ray client connection timeout")
# Initialize the streams to finish protocol negotiation.
self.data_client = DataClient(self.channel, self._client_id,
self.metadata)
self.reference_count: Dict[bytes, int] = defaultdict(int)