[ray_client] close ray connection upon client deactivation (#13919)

This commit is contained in:
Richard Liaw
2021-02-07 13:11:38 -08:00
committed by GitHub
parent 4b4941435d
commit 3a230fa1a4
7 changed files with 198 additions and 136 deletions
+1
View File
@@ -152,6 +152,7 @@ test_python() {
-python/ray/tests:test_basic_3 # timeout
-python/ray/tests:test_basic_3_client_mode
-python/ray/tests:test_cli
-python/ray/tests:test_client_init # timeout
-python/ray/tests:test_failure
-python/ray/tests:test_global_gc
-python/ray/tests:test_job
+2 -2
View File
@@ -26,6 +26,8 @@ py_test_module_list(
"test_basic_3.py",
"test_cancel.py",
"test_cli.py",
"test_client.py",
"test_client_init.py",
"test_component_failures_2.py",
"test_component_failures_3.py",
"test_error_ray_not_initialized.py",
@@ -80,9 +82,7 @@ py_test_module_list(
"test_asyncio.py",
"test_autoscaler.py",
"test_autoscaler_yaml.py",
"test_client_init.py",
"test_client_metadata.py",
"test_client.py",
"test_client_references.py",
"test_client_terminate.py",
"test_command_runner.py",
+122 -106
View File
@@ -38,130 +38,146 @@ class C:
return self.val
def test_basic_preregister():
@pytest.fixture
def init_and_serve():
server_handle, _ = ray_client_server.init_and_serve("localhost:50051")
yield server_handle
ray_client_server.shutdown_with_server(server_handle.grpc_server)
time.sleep(2)
@pytest.fixture
def init_and_serve_lazy():
cluster = ray.cluster_utils.Cluster()
cluster.add_node(num_cpus=1, num_gpus=0)
address = cluster.address
def connect():
ray.init(address=address)
server_handle = ray_client_server.serve("localhost:50051", connect)
yield server_handle
ray_client_server.shutdown_with_server(server_handle.grpc_server)
time.sleep(2)
def test_basic_preregister(init_and_serve):
from ray.util.client import ray
server, _ = ray_client_server.init_and_serve("localhost:50051")
try:
ray.connect("localhost:50051")
val = ray.get(hello_world.remote())
print(val)
assert val >= 20
assert val <= 200
c = C.remote(3)
x = c.double.remote()
y = c.double.remote()
ray.wait([x, y])
val = ray.get(c.get.remote())
assert val == 12
finally:
ray.disconnect()
ray_client_server.shutdown_with_server(server)
time.sleep(2)
ray.connect("localhost:50051")
val = ray.get(hello_world.remote())
print(val)
assert val >= 20
assert val <= 200
c = C.remote(3)
x = c.double.remote()
y = c.double.remote()
ray.wait([x, y])
val = ray.get(c.get.remote())
assert val == 12
ray.disconnect()
def test_num_clients():
def test_num_clients(init_and_serve_lazy):
# Tests num clients reporting; useful if you want to build an app that
# load balances clients between Ray client servers.
server_handle, _ = ray_client_server.init_and_serve("localhost:50051")
server = server_handle.grpc_server
try:
api1 = RayAPIStub()
info1 = api1.connect("localhost:50051")
assert info1["num_clients"] == 1, info1
api2 = RayAPIStub()
info2 = api2.connect("localhost:50051")
assert info2["num_clients"] == 2, info2
# Disconnect the first two clients.
api1.disconnect()
api2.disconnect()
time.sleep(1)
def get_job_id(api):
return api.get_runtime_context().worker.current_job_id
api3 = RayAPIStub()
info3 = api3.connect("localhost:50051")
assert info3["num_clients"] == 1, info3
api1 = RayAPIStub()
info1 = api1.connect("localhost:50051")
job_id_1 = get_job_id(api1)
assert info1["num_clients"] == 1, info1
api2 = RayAPIStub()
info2 = api2.connect("localhost:50051")
job_id_2 = get_job_id(api2)
assert info2["num_clients"] == 2, info2
# Check info contains ray and python version.
assert isinstance(info3["ray_version"], str), info3
assert isinstance(info3["ray_commit"], str), info3
assert isinstance(info3["python_version"], str), info3
assert isinstance(info3["protocol_version"], str), info3
api3.disconnect()
finally:
ray_client_server.shutdown_with_server(server)
time.sleep(2)
assert job_id_1 == job_id_2
# Disconnect the first two clients.
api1.disconnect()
api2.disconnect()
time.sleep(1)
api3 = RayAPIStub()
info3 = api3.connect("localhost:50051")
job_id_3 = get_job_id(api3)
assert info3["num_clients"] == 1, info3
assert job_id_1 != job_id_3
# Check info contains ray and python version.
assert isinstance(info3["ray_version"], str), info3
assert isinstance(info3["ray_commit"], str), info3
assert isinstance(info3["python_version"], str), info3
assert isinstance(info3["protocol_version"], str), info3
api3.disconnect()
def test_python_version():
def test_python_version(init_and_serve):
server_handle = init_and_serve
ray = RayAPIStub()
info1 = ray.connect("localhost:50051")
assert info1["python_version"] == ".".join(
[str(x) for x in list(sys.version_info)[:3]])
ray.disconnect()
time.sleep(1)
server_handle, _ = ray_client_server.init_and_serve("localhost:50051")
try:
ray = RayAPIStub()
info1 = ray.connect("localhost:50051")
assert info1["python_version"] == ".".join(
[str(x) for x in list(sys.version_info)[:3]])
ray.disconnect()
time.sleep(1)
def mock_connection_response():
return ray_client_pb2.ConnectionInfoResponse(
num_clients=1,
python_version="2.7.12",
ray_version="",
ray_commit="",
protocol_version=CURRENT_PROTOCOL_VERSION,
)
def mock_connection_response():
return ray_client_pb2.ConnectionInfoResponse(
num_clients=1,
python_version="2.7.12",
ray_version="",
ray_commit="",
protocol_version=CURRENT_PROTOCOL_VERSION,
)
# inject mock connection function
server_handle.data_servicer._build_connection_response = \
mock_connection_response
# inject mock connection function
server_handle.data_servicer._build_connection_response = \
mock_connection_response
ray = RayAPIStub()
with pytest.raises(RuntimeError):
_ = ray.connect("localhost:50051")
ray = RayAPIStub()
with pytest.raises(RuntimeError):
_ = ray.connect("localhost:50051")
ray = RayAPIStub()
info3 = ray.connect("localhost:50051", ignore_version=True)
assert info3["num_clients"] == 1, info3
ray.disconnect()
finally:
ray_client_server.shutdown_with_server(server_handle.grpc_server)
time.sleep(2)
ray = RayAPIStub()
info3 = ray.connect("localhost:50051", ignore_version=True)
assert info3["num_clients"] == 1, info3
ray.disconnect()
def test_protocol_version():
def test_protocol_version(init_and_serve):
server_handle = init_and_serve
ray = RayAPIStub()
info1 = ray.connect("localhost:50051")
local_py_version = ".".join([str(x) for x in list(sys.version_info)[:3]])
assert info1["protocol_version"] == CURRENT_PROTOCOL_VERSION, info1
ray.disconnect()
time.sleep(1)
server_handle, _ = ray_client_server.init_and_serve("localhost:50051")
try:
ray = RayAPIStub()
info1 = ray.connect("localhost:50051")
local_py_version = ".".join(
[str(x) for x in list(sys.version_info)[:3]])
assert info1["protocol_version"] == CURRENT_PROTOCOL_VERSION, info1
ray.disconnect()
time.sleep(1)
def mock_connection_response():
return ray_client_pb2.ConnectionInfoResponse(
num_clients=1,
python_version=local_py_version,
ray_version="",
ray_commit="",
protocol_version="2050-01-01", # from the future
)
def mock_connection_response():
return ray_client_pb2.ConnectionInfoResponse(
num_clients=1,
python_version=local_py_version,
ray_version="",
ray_commit="",
protocol_version="2050-01-01", # from the future
)
# inject mock connection function
server_handle.data_servicer._build_connection_response = \
mock_connection_response
# inject mock connection function
server_handle.data_servicer._build_connection_response = \
mock_connection_response
ray = RayAPIStub()
with pytest.raises(RuntimeError):
_ = ray.connect("localhost:50051")
ray = RayAPIStub()
with pytest.raises(RuntimeError):
_ = ray.connect("localhost:50051")
ray = RayAPIStub()
info3 = ray.connect("localhost:50051", ignore_version=True)
assert info3["num_clients"] == 1, info3
ray.disconnect()
ray = RayAPIStub()
info3 = ray.connect("localhost:50051", ignore_version=True)
assert info3["num_clients"] == 1, info3
ray.disconnect()
finally:
ray_client_server.shutdown_with_server(server_handle.grpc_server)
time.sleep(2)
if __name__ == "__main__":
import pytest
sys.exit(pytest.main(["-v", __file__] + sys.argv[1:]))
+2 -2
View File
@@ -33,7 +33,7 @@ _ = Actor.remote()
assert len(actor_table) == 1
job_table = ray.jobs()
assert len(job_table) == 3 # dash, ray client server
assert len(job_table) == 2 # dash
# Kill the driver process.
p.kill()
@@ -79,7 +79,7 @@ ray.get(_.value.remote())
assert len(actor_table) == 1
job_table = ray.jobs()
assert len(job_table) == 3 # dash, ray client server
assert len(job_table) == 2 # dash
# Kill the driver process.
p.kill()
+13 -2
View File
@@ -3,12 +3,13 @@ import logging
import grpc
import sys
from typing import TYPE_CHECKING
from typing import TYPE_CHECKING, Callable
from threading import Lock
import ray.core.generated.ray_client_pb2 as ray_client_pb2
import ray.core.generated.ray_client_pb2_grpc as ray_client_pb2_grpc
from ray.util.client import CURRENT_PROTOCOL_VERSION
from ray._private.client_mode_hook import disable_client_hook
if TYPE_CHECKING:
from ray.util.client.server.server import RayletServicer
@@ -17,10 +18,12 @@ logger = logging.getLogger(__name__)
class DataServicer(ray_client_pb2_grpc.RayletDataStreamerServicer):
def __init__(self, basic_service: "RayletServicer"):
def __init__(self, basic_service: "RayletServicer",
ray_connect_handler: Callable):
self.basic_service = basic_service
self._clients_lock = Lock()
self._num_clients = 0 # guarded by self._clients_lock
self.ray_connect_handler = ray_connect_handler
def Datapath(self, request_iterator, context):
metadata = {k: v for k, v in context.invocation_metadata()}
@@ -31,6 +34,9 @@ class DataServicer(ray_client_pb2_grpc.RayletDataStreamerServicer):
logger.info(f"New data connection from client {client_id}")
try:
with self._clients_lock:
with disable_client_hook():
if self._num_clients == 0 and not ray.is_initialized():
self.ray_connect_handler()
self._num_clients += 1
for req in request_iterator:
resp = None
@@ -63,9 +69,14 @@ class DataServicer(ray_client_pb2_grpc.RayletDataStreamerServicer):
finally:
logger.info(f"Lost data connection from client {client_id}")
self.basic_service.release_all(client_id)
with self._clients_lock:
self._num_clients -= 1
with disable_client_hook():
if self._num_clients == 0:
ray.shutdown()
def _build_connection_response(self):
with self._clients_lock:
cur_num_clients = self._num_clients
+38 -13
View File
@@ -422,10 +422,17 @@ class ClientServerHandle:
return getattr(self.grpc_server, attr)
def serve(connection_str):
def serve(connection_str, ray_connect_handler=None):
def default_connect_handler():
with disable_client_hook():
if not ray.is_initialized():
return ray.init()
ray_connect_handler = ray_connect_handler or default_connect_handler
server = grpc.server(futures.ThreadPoolExecutor(max_workers=10))
task_servicer = RayletServicer()
data_servicer = DataServicer(task_servicer)
data_servicer = DataServicer(
task_servicer, ray_connect_handler=ray_connect_handler)
logs_servicer = LogstreamServicer()
ray_client_pb2_grpc.add_RayletDriverServicer_to_server(
task_servicer, server)
@@ -448,7 +455,17 @@ def init_and_serve(connection_str, *args, **kwargs):
with disable_client_hook():
# Disable client mode inside the worker's environment
info = ray.init(*args, **kwargs)
server_handle = serve(connection_str)
def ray_connect_handler():
# Ray client will disconnect from ray when
# num_clients == 0.
if ray.is_initialized():
return info
else:
return ray.init(*args, **kwargs)
server_handle = serve(
connection_str, ray_connect_handler=ray_connect_handler)
return (server_handle, info)
@@ -458,6 +475,19 @@ def shutdown_with_server(server, _exiting_interpreter=False):
ray.shutdown(_exiting_interpreter)
def create_ray_handler(redis_address, redis_password):
def ray_connect_handler():
if redis_address:
if redis_password:
ray.init(address=redis_address, _redis_password=redis_password)
else:
ray.init(address=redis_address)
else:
ray.init()
return ray_connect_handler
def main():
import argparse
parser = argparse.ArgumentParser()
@@ -477,18 +507,13 @@ def main():
help="Password for connecting to Redis")
args = parser.parse_args()
logging.basicConfig(level="INFO")
if args.redis_address:
if args.redis_password:
ray.init(
address=args.redis_address,
_redis_password=args.redis_password)
else:
ray.init(address=args.redis_address)
else:
ray.init()
ray_connect_handler = create_ray_handler(args.redis_address,
args.redis_password)
hostport = "%s:%d" % (args.host, args.port)
logger.info(f"Starting Ray Client server on {hostport}")
server = serve(hostport)
server = serve(hostport, ray_connect_handler)
try:
while True:
time.sleep(1000)
+20 -11
View File
@@ -68,6 +68,7 @@ class Worker:
"""
self.metadata = metadata if metadata else []
self.channel = None
self.server = None
self._conn_state = grpc.ChannelConnectivity.IDLE
self._client_id = make_client_id()
self._converted: Dict[str, ClientStub] = {}
@@ -83,7 +84,7 @@ class Worker:
# looking like a gRPC connection, though it may be a proxy.
conn_attempts = 0
timeout = INITIAL_TIMEOUT_SEC
ray_ready = False
service_ready = False
while conn_attempts < max(connection_retries, 1):
conn_attempts += 1
try:
@@ -94,13 +95,8 @@ class Worker:
# RayletDriverStub, allowing for unary requests.
self.server = ray_client_pb2_grpc.RayletDriverStub(
self.channel)
# Now the HTTP2 channel is ready, or proxied, but the
# servicer may not be ready. Call is_initialized() and if
# it throws, the servicer is not ready. On success, the
# `ray_ready` result is checked.
ray_ready = self.is_initialized()
if ray_ready:
# Ray is ready! Break out of the retry loop
service_ready = bool(self.ping_server())
if service_ready:
break
# Ray is not ready yet, wait a timeout
time.sleep(timeout)
@@ -120,9 +116,10 @@ class Worker:
f"retry in {timeout}s...")
timeout = backoff(timeout)
# If we made it through the loop without ray_ready it means we've used
# up our retries and should error back to the user.
if not ray_ready:
# If we made it through the loop without service_ready
# it means we've used up our retries and
# should error back to the user.
if not service_ready:
raise ConnectionError("ray client connection timeout")
# Initialize the streams to finish protocol negotiation.
@@ -377,6 +374,18 @@ class Worker:
ray_client_pb2.ClusterInfoType.IS_INITIALIZED)
return False
def ping_server(self) -> bool:
"""Simple health check.
Piggybacks the IS_INITIALIZED call to check if the server provides
an actual response.
"""
if self.server is not None:
result = self.get_cluster_info(
ray_client_pb2.ClusterInfoType.IS_INITIALIZED)
return result is not None
return False
def is_connected(self) -> bool:
return self._conn_state == grpc.ChannelConnectivity.READY