mirror of
https://github.com/wassname/ray.git
synced 2026-06-27 19:00:36 +08:00
[ray_client] close ray connection upon client deactivation (#13919)
This commit is contained in:
@@ -26,6 +26,8 @@ py_test_module_list(
|
||||
"test_basic_3.py",
|
||||
"test_cancel.py",
|
||||
"test_cli.py",
|
||||
"test_client.py",
|
||||
"test_client_init.py",
|
||||
"test_component_failures_2.py",
|
||||
"test_component_failures_3.py",
|
||||
"test_error_ray_not_initialized.py",
|
||||
@@ -80,9 +82,7 @@ py_test_module_list(
|
||||
"test_asyncio.py",
|
||||
"test_autoscaler.py",
|
||||
"test_autoscaler_yaml.py",
|
||||
"test_client_init.py",
|
||||
"test_client_metadata.py",
|
||||
"test_client.py",
|
||||
"test_client_references.py",
|
||||
"test_client_terminate.py",
|
||||
"test_command_runner.py",
|
||||
|
||||
@@ -38,130 +38,146 @@ class C:
|
||||
return self.val
|
||||
|
||||
|
||||
def test_basic_preregister():
|
||||
@pytest.fixture
|
||||
def init_and_serve():
|
||||
server_handle, _ = ray_client_server.init_and_serve("localhost:50051")
|
||||
yield server_handle
|
||||
ray_client_server.shutdown_with_server(server_handle.grpc_server)
|
||||
time.sleep(2)
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def init_and_serve_lazy():
|
||||
cluster = ray.cluster_utils.Cluster()
|
||||
cluster.add_node(num_cpus=1, num_gpus=0)
|
||||
address = cluster.address
|
||||
|
||||
def connect():
|
||||
ray.init(address=address)
|
||||
|
||||
server_handle = ray_client_server.serve("localhost:50051", connect)
|
||||
yield server_handle
|
||||
ray_client_server.shutdown_with_server(server_handle.grpc_server)
|
||||
time.sleep(2)
|
||||
|
||||
|
||||
def test_basic_preregister(init_and_serve):
|
||||
from ray.util.client import ray
|
||||
server, _ = ray_client_server.init_and_serve("localhost:50051")
|
||||
try:
|
||||
ray.connect("localhost:50051")
|
||||
val = ray.get(hello_world.remote())
|
||||
print(val)
|
||||
assert val >= 20
|
||||
assert val <= 200
|
||||
c = C.remote(3)
|
||||
x = c.double.remote()
|
||||
y = c.double.remote()
|
||||
ray.wait([x, y])
|
||||
val = ray.get(c.get.remote())
|
||||
assert val == 12
|
||||
finally:
|
||||
ray.disconnect()
|
||||
ray_client_server.shutdown_with_server(server)
|
||||
time.sleep(2)
|
||||
ray.connect("localhost:50051")
|
||||
val = ray.get(hello_world.remote())
|
||||
print(val)
|
||||
assert val >= 20
|
||||
assert val <= 200
|
||||
c = C.remote(3)
|
||||
x = c.double.remote()
|
||||
y = c.double.remote()
|
||||
ray.wait([x, y])
|
||||
val = ray.get(c.get.remote())
|
||||
assert val == 12
|
||||
ray.disconnect()
|
||||
|
||||
|
||||
def test_num_clients():
|
||||
def test_num_clients(init_and_serve_lazy):
|
||||
# Tests num clients reporting; useful if you want to build an app that
|
||||
# load balances clients between Ray client servers.
|
||||
server_handle, _ = ray_client_server.init_and_serve("localhost:50051")
|
||||
server = server_handle.grpc_server
|
||||
try:
|
||||
api1 = RayAPIStub()
|
||||
info1 = api1.connect("localhost:50051")
|
||||
assert info1["num_clients"] == 1, info1
|
||||
api2 = RayAPIStub()
|
||||
info2 = api2.connect("localhost:50051")
|
||||
assert info2["num_clients"] == 2, info2
|
||||
|
||||
# Disconnect the first two clients.
|
||||
api1.disconnect()
|
||||
api2.disconnect()
|
||||
time.sleep(1)
|
||||
def get_job_id(api):
|
||||
return api.get_runtime_context().worker.current_job_id
|
||||
|
||||
api3 = RayAPIStub()
|
||||
info3 = api3.connect("localhost:50051")
|
||||
assert info3["num_clients"] == 1, info3
|
||||
api1 = RayAPIStub()
|
||||
info1 = api1.connect("localhost:50051")
|
||||
job_id_1 = get_job_id(api1)
|
||||
assert info1["num_clients"] == 1, info1
|
||||
api2 = RayAPIStub()
|
||||
info2 = api2.connect("localhost:50051")
|
||||
job_id_2 = get_job_id(api2)
|
||||
assert info2["num_clients"] == 2, info2
|
||||
|
||||
# Check info contains ray and python version.
|
||||
assert isinstance(info3["ray_version"], str), info3
|
||||
assert isinstance(info3["ray_commit"], str), info3
|
||||
assert isinstance(info3["python_version"], str), info3
|
||||
assert isinstance(info3["protocol_version"], str), info3
|
||||
api3.disconnect()
|
||||
finally:
|
||||
ray_client_server.shutdown_with_server(server)
|
||||
time.sleep(2)
|
||||
assert job_id_1 == job_id_2
|
||||
|
||||
# Disconnect the first two clients.
|
||||
api1.disconnect()
|
||||
api2.disconnect()
|
||||
time.sleep(1)
|
||||
|
||||
api3 = RayAPIStub()
|
||||
info3 = api3.connect("localhost:50051")
|
||||
job_id_3 = get_job_id(api3)
|
||||
assert info3["num_clients"] == 1, info3
|
||||
assert job_id_1 != job_id_3
|
||||
|
||||
# Check info contains ray and python version.
|
||||
assert isinstance(info3["ray_version"], str), info3
|
||||
assert isinstance(info3["ray_commit"], str), info3
|
||||
assert isinstance(info3["python_version"], str), info3
|
||||
assert isinstance(info3["protocol_version"], str), info3
|
||||
api3.disconnect()
|
||||
|
||||
|
||||
def test_python_version():
|
||||
def test_python_version(init_and_serve):
|
||||
server_handle = init_and_serve
|
||||
ray = RayAPIStub()
|
||||
info1 = ray.connect("localhost:50051")
|
||||
assert info1["python_version"] == ".".join(
|
||||
[str(x) for x in list(sys.version_info)[:3]])
|
||||
ray.disconnect()
|
||||
time.sleep(1)
|
||||
|
||||
server_handle, _ = ray_client_server.init_and_serve("localhost:50051")
|
||||
try:
|
||||
ray = RayAPIStub()
|
||||
info1 = ray.connect("localhost:50051")
|
||||
assert info1["python_version"] == ".".join(
|
||||
[str(x) for x in list(sys.version_info)[:3]])
|
||||
ray.disconnect()
|
||||
time.sleep(1)
|
||||
def mock_connection_response():
|
||||
return ray_client_pb2.ConnectionInfoResponse(
|
||||
num_clients=1,
|
||||
python_version="2.7.12",
|
||||
ray_version="",
|
||||
ray_commit="",
|
||||
protocol_version=CURRENT_PROTOCOL_VERSION,
|
||||
)
|
||||
|
||||
def mock_connection_response():
|
||||
return ray_client_pb2.ConnectionInfoResponse(
|
||||
num_clients=1,
|
||||
python_version="2.7.12",
|
||||
ray_version="",
|
||||
ray_commit="",
|
||||
protocol_version=CURRENT_PROTOCOL_VERSION,
|
||||
)
|
||||
# inject mock connection function
|
||||
server_handle.data_servicer._build_connection_response = \
|
||||
mock_connection_response
|
||||
|
||||
# inject mock connection function
|
||||
server_handle.data_servicer._build_connection_response = \
|
||||
mock_connection_response
|
||||
ray = RayAPIStub()
|
||||
with pytest.raises(RuntimeError):
|
||||
_ = ray.connect("localhost:50051")
|
||||
|
||||
ray = RayAPIStub()
|
||||
with pytest.raises(RuntimeError):
|
||||
_ = ray.connect("localhost:50051")
|
||||
|
||||
ray = RayAPIStub()
|
||||
info3 = ray.connect("localhost:50051", ignore_version=True)
|
||||
assert info3["num_clients"] == 1, info3
|
||||
ray.disconnect()
|
||||
finally:
|
||||
ray_client_server.shutdown_with_server(server_handle.grpc_server)
|
||||
time.sleep(2)
|
||||
ray = RayAPIStub()
|
||||
info3 = ray.connect("localhost:50051", ignore_version=True)
|
||||
assert info3["num_clients"] == 1, info3
|
||||
ray.disconnect()
|
||||
|
||||
|
||||
def test_protocol_version():
|
||||
def test_protocol_version(init_and_serve):
|
||||
server_handle = init_and_serve
|
||||
ray = RayAPIStub()
|
||||
info1 = ray.connect("localhost:50051")
|
||||
local_py_version = ".".join([str(x) for x in list(sys.version_info)[:3]])
|
||||
assert info1["protocol_version"] == CURRENT_PROTOCOL_VERSION, info1
|
||||
ray.disconnect()
|
||||
time.sleep(1)
|
||||
|
||||
server_handle, _ = ray_client_server.init_and_serve("localhost:50051")
|
||||
try:
|
||||
ray = RayAPIStub()
|
||||
info1 = ray.connect("localhost:50051")
|
||||
local_py_version = ".".join(
|
||||
[str(x) for x in list(sys.version_info)[:3]])
|
||||
assert info1["protocol_version"] == CURRENT_PROTOCOL_VERSION, info1
|
||||
ray.disconnect()
|
||||
time.sleep(1)
|
||||
def mock_connection_response():
|
||||
return ray_client_pb2.ConnectionInfoResponse(
|
||||
num_clients=1,
|
||||
python_version=local_py_version,
|
||||
ray_version="",
|
||||
ray_commit="",
|
||||
protocol_version="2050-01-01", # from the future
|
||||
)
|
||||
|
||||
def mock_connection_response():
|
||||
return ray_client_pb2.ConnectionInfoResponse(
|
||||
num_clients=1,
|
||||
python_version=local_py_version,
|
||||
ray_version="",
|
||||
ray_commit="",
|
||||
protocol_version="2050-01-01", # from the future
|
||||
)
|
||||
# inject mock connection function
|
||||
server_handle.data_servicer._build_connection_response = \
|
||||
mock_connection_response
|
||||
|
||||
# inject mock connection function
|
||||
server_handle.data_servicer._build_connection_response = \
|
||||
mock_connection_response
|
||||
ray = RayAPIStub()
|
||||
with pytest.raises(RuntimeError):
|
||||
_ = ray.connect("localhost:50051")
|
||||
|
||||
ray = RayAPIStub()
|
||||
with pytest.raises(RuntimeError):
|
||||
_ = ray.connect("localhost:50051")
|
||||
ray = RayAPIStub()
|
||||
info3 = ray.connect("localhost:50051", ignore_version=True)
|
||||
assert info3["num_clients"] == 1, info3
|
||||
ray.disconnect()
|
||||
|
||||
ray = RayAPIStub()
|
||||
info3 = ray.connect("localhost:50051", ignore_version=True)
|
||||
assert info3["num_clients"] == 1, info3
|
||||
ray.disconnect()
|
||||
finally:
|
||||
ray_client_server.shutdown_with_server(server_handle.grpc_server)
|
||||
time.sleep(2)
|
||||
|
||||
if __name__ == "__main__":
|
||||
import pytest
|
||||
sys.exit(pytest.main(["-v", __file__] + sys.argv[1:]))
|
||||
|
||||
@@ -33,7 +33,7 @@ _ = Actor.remote()
|
||||
assert len(actor_table) == 1
|
||||
|
||||
job_table = ray.jobs()
|
||||
assert len(job_table) == 3 # dash, ray client server
|
||||
assert len(job_table) == 2 # dash
|
||||
|
||||
# Kill the driver process.
|
||||
p.kill()
|
||||
@@ -79,7 +79,7 @@ ray.get(_.value.remote())
|
||||
assert len(actor_table) == 1
|
||||
|
||||
job_table = ray.jobs()
|
||||
assert len(job_table) == 3 # dash, ray client server
|
||||
assert len(job_table) == 2 # dash
|
||||
|
||||
# Kill the driver process.
|
||||
p.kill()
|
||||
|
||||
@@ -3,12 +3,13 @@ import logging
|
||||
import grpc
|
||||
import sys
|
||||
|
||||
from typing import TYPE_CHECKING
|
||||
from typing import TYPE_CHECKING, Callable
|
||||
from threading import Lock
|
||||
|
||||
import ray.core.generated.ray_client_pb2 as ray_client_pb2
|
||||
import ray.core.generated.ray_client_pb2_grpc as ray_client_pb2_grpc
|
||||
from ray.util.client import CURRENT_PROTOCOL_VERSION
|
||||
from ray._private.client_mode_hook import disable_client_hook
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from ray.util.client.server.server import RayletServicer
|
||||
@@ -17,10 +18,12 @@ logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class DataServicer(ray_client_pb2_grpc.RayletDataStreamerServicer):
|
||||
def __init__(self, basic_service: "RayletServicer"):
|
||||
def __init__(self, basic_service: "RayletServicer",
|
||||
ray_connect_handler: Callable):
|
||||
self.basic_service = basic_service
|
||||
self._clients_lock = Lock()
|
||||
self._num_clients = 0 # guarded by self._clients_lock
|
||||
self.ray_connect_handler = ray_connect_handler
|
||||
|
||||
def Datapath(self, request_iterator, context):
|
||||
metadata = {k: v for k, v in context.invocation_metadata()}
|
||||
@@ -31,6 +34,9 @@ class DataServicer(ray_client_pb2_grpc.RayletDataStreamerServicer):
|
||||
logger.info(f"New data connection from client {client_id}")
|
||||
try:
|
||||
with self._clients_lock:
|
||||
with disable_client_hook():
|
||||
if self._num_clients == 0 and not ray.is_initialized():
|
||||
self.ray_connect_handler()
|
||||
self._num_clients += 1
|
||||
for req in request_iterator:
|
||||
resp = None
|
||||
@@ -63,9 +69,14 @@ class DataServicer(ray_client_pb2_grpc.RayletDataStreamerServicer):
|
||||
finally:
|
||||
logger.info(f"Lost data connection from client {client_id}")
|
||||
self.basic_service.release_all(client_id)
|
||||
|
||||
with self._clients_lock:
|
||||
self._num_clients -= 1
|
||||
|
||||
with disable_client_hook():
|
||||
if self._num_clients == 0:
|
||||
ray.shutdown()
|
||||
|
||||
def _build_connection_response(self):
|
||||
with self._clients_lock:
|
||||
cur_num_clients = self._num_clients
|
||||
|
||||
@@ -422,10 +422,17 @@ class ClientServerHandle:
|
||||
return getattr(self.grpc_server, attr)
|
||||
|
||||
|
||||
def serve(connection_str):
|
||||
def serve(connection_str, ray_connect_handler=None):
|
||||
def default_connect_handler():
|
||||
with disable_client_hook():
|
||||
if not ray.is_initialized():
|
||||
return ray.init()
|
||||
|
||||
ray_connect_handler = ray_connect_handler or default_connect_handler
|
||||
server = grpc.server(futures.ThreadPoolExecutor(max_workers=10))
|
||||
task_servicer = RayletServicer()
|
||||
data_servicer = DataServicer(task_servicer)
|
||||
data_servicer = DataServicer(
|
||||
task_servicer, ray_connect_handler=ray_connect_handler)
|
||||
logs_servicer = LogstreamServicer()
|
||||
ray_client_pb2_grpc.add_RayletDriverServicer_to_server(
|
||||
task_servicer, server)
|
||||
@@ -448,7 +455,17 @@ def init_and_serve(connection_str, *args, **kwargs):
|
||||
with disable_client_hook():
|
||||
# Disable client mode inside the worker's environment
|
||||
info = ray.init(*args, **kwargs)
|
||||
server_handle = serve(connection_str)
|
||||
|
||||
def ray_connect_handler():
|
||||
# Ray client will disconnect from ray when
|
||||
# num_clients == 0.
|
||||
if ray.is_initialized():
|
||||
return info
|
||||
else:
|
||||
return ray.init(*args, **kwargs)
|
||||
|
||||
server_handle = serve(
|
||||
connection_str, ray_connect_handler=ray_connect_handler)
|
||||
return (server_handle, info)
|
||||
|
||||
|
||||
@@ -458,6 +475,19 @@ def shutdown_with_server(server, _exiting_interpreter=False):
|
||||
ray.shutdown(_exiting_interpreter)
|
||||
|
||||
|
||||
def create_ray_handler(redis_address, redis_password):
|
||||
def ray_connect_handler():
|
||||
if redis_address:
|
||||
if redis_password:
|
||||
ray.init(address=redis_address, _redis_password=redis_password)
|
||||
else:
|
||||
ray.init(address=redis_address)
|
||||
else:
|
||||
ray.init()
|
||||
|
||||
return ray_connect_handler
|
||||
|
||||
|
||||
def main():
|
||||
import argparse
|
||||
parser = argparse.ArgumentParser()
|
||||
@@ -477,18 +507,13 @@ def main():
|
||||
help="Password for connecting to Redis")
|
||||
args = parser.parse_args()
|
||||
logging.basicConfig(level="INFO")
|
||||
if args.redis_address:
|
||||
if args.redis_password:
|
||||
ray.init(
|
||||
address=args.redis_address,
|
||||
_redis_password=args.redis_password)
|
||||
else:
|
||||
ray.init(address=args.redis_address)
|
||||
else:
|
||||
ray.init()
|
||||
|
||||
ray_connect_handler = create_ray_handler(args.redis_address,
|
||||
args.redis_password)
|
||||
|
||||
hostport = "%s:%d" % (args.host, args.port)
|
||||
logger.info(f"Starting Ray Client server on {hostport}")
|
||||
server = serve(hostport)
|
||||
server = serve(hostport, ray_connect_handler)
|
||||
try:
|
||||
while True:
|
||||
time.sleep(1000)
|
||||
|
||||
@@ -68,6 +68,7 @@ class Worker:
|
||||
"""
|
||||
self.metadata = metadata if metadata else []
|
||||
self.channel = None
|
||||
self.server = None
|
||||
self._conn_state = grpc.ChannelConnectivity.IDLE
|
||||
self._client_id = make_client_id()
|
||||
self._converted: Dict[str, ClientStub] = {}
|
||||
@@ -83,7 +84,7 @@ class Worker:
|
||||
# looking like a gRPC connection, though it may be a proxy.
|
||||
conn_attempts = 0
|
||||
timeout = INITIAL_TIMEOUT_SEC
|
||||
ray_ready = False
|
||||
service_ready = False
|
||||
while conn_attempts < max(connection_retries, 1):
|
||||
conn_attempts += 1
|
||||
try:
|
||||
@@ -94,13 +95,8 @@ class Worker:
|
||||
# RayletDriverStub, allowing for unary requests.
|
||||
self.server = ray_client_pb2_grpc.RayletDriverStub(
|
||||
self.channel)
|
||||
# Now the HTTP2 channel is ready, or proxied, but the
|
||||
# servicer may not be ready. Call is_initialized() and if
|
||||
# it throws, the servicer is not ready. On success, the
|
||||
# `ray_ready` result is checked.
|
||||
ray_ready = self.is_initialized()
|
||||
if ray_ready:
|
||||
# Ray is ready! Break out of the retry loop
|
||||
service_ready = bool(self.ping_server())
|
||||
if service_ready:
|
||||
break
|
||||
# Ray is not ready yet, wait a timeout
|
||||
time.sleep(timeout)
|
||||
@@ -120,9 +116,10 @@ class Worker:
|
||||
f"retry in {timeout}s...")
|
||||
timeout = backoff(timeout)
|
||||
|
||||
# If we made it through the loop without ray_ready it means we've used
|
||||
# up our retries and should error back to the user.
|
||||
if not ray_ready:
|
||||
# If we made it through the loop without service_ready
|
||||
# it means we've used up our retries and
|
||||
# should error back to the user.
|
||||
if not service_ready:
|
||||
raise ConnectionError("ray client connection timeout")
|
||||
|
||||
# Initialize the streams to finish protocol negotiation.
|
||||
@@ -377,6 +374,18 @@ class Worker:
|
||||
ray_client_pb2.ClusterInfoType.IS_INITIALIZED)
|
||||
return False
|
||||
|
||||
def ping_server(self) -> bool:
|
||||
"""Simple health check.
|
||||
|
||||
Piggybacks the IS_INITIALIZED call to check if the server provides
|
||||
an actual response.
|
||||
"""
|
||||
if self.server is not None:
|
||||
result = self.get_cluster_info(
|
||||
ray_client_pb2.ClusterInfoType.IS_INITIALIZED)
|
||||
return result is not None
|
||||
return False
|
||||
|
||||
def is_connected(self) -> bool:
|
||||
return self._conn_state == grpc.ChannelConnectivity.READY
|
||||
|
||||
|
||||
Reference in New Issue
Block a user