mirror of
https://github.com/wassname/ray.git
synced 2026-06-27 20:53:14 +08:00
[ray_client]: Add python version check and test (and some minor fixes along the way) (#13722)
This commit is contained in:
@@ -1,7 +1,11 @@
|
||||
"""Client tests that run their own init (as with init_and_serve) live here"""
|
||||
import pytest
|
||||
|
||||
import time
|
||||
import sys
|
||||
|
||||
import ray.util.client.server.server as ray_client_server
|
||||
import ray.core.generated.ray_client_pb2 as ray_client_pb2
|
||||
|
||||
from ray.util.client import RayAPIStub
|
||||
|
||||
@@ -9,7 +13,8 @@ from ray.util.client import RayAPIStub
|
||||
def test_num_clients():
|
||||
# Tests num clients reporting; useful if you want to build an app that
|
||||
# load balances clients between Ray client servers.
|
||||
server, _ = ray_client_server.init_and_serve("localhost:50051")
|
||||
server_handle, _ = ray_client_server.init_and_serve("localhost:50051")
|
||||
server = server_handle.grpc_server
|
||||
try:
|
||||
api1 = RayAPIStub()
|
||||
info1 = api1.connect("localhost:50051")
|
||||
@@ -35,3 +40,39 @@ def test_num_clients():
|
||||
finally:
|
||||
ray_client_server.shutdown_with_server(server)
|
||||
time.sleep(2)
|
||||
|
||||
|
||||
def test_python_version():
|
||||
|
||||
server_handle, _ = ray_client_server.init_and_serve("localhost:50051")
|
||||
try:
|
||||
ray = RayAPIStub()
|
||||
info1 = ray.connect("localhost:50051")
|
||||
assert info1["python_version"] == ".".join(
|
||||
[str(x) for x in list(sys.version_info)[:3]])
|
||||
ray.disconnect()
|
||||
time.sleep(1)
|
||||
|
||||
def mock_connection_response():
|
||||
return ray_client_pb2.ConnectionInfoResponse(
|
||||
num_clients=1,
|
||||
python_version="2.7.12",
|
||||
ray_version="",
|
||||
ray_commit="",
|
||||
)
|
||||
|
||||
# inject mock connection function
|
||||
server_handle.data_servicer._build_connection_response = \
|
||||
mock_connection_response
|
||||
|
||||
ray = RayAPIStub()
|
||||
with pytest.raises(RuntimeError):
|
||||
_ = ray.connect("localhost:50051")
|
||||
|
||||
ray = RayAPIStub()
|
||||
info3 = ray.connect("localhost:50051", ignore_version=True)
|
||||
assert info3["num_clients"] == 1, info3
|
||||
ray.disconnect()
|
||||
finally:
|
||||
ray_client_server.shutdown_with_server(server_handle.grpc_server)
|
||||
time.sleep(2)
|
||||
|
||||
@@ -1,39 +1,38 @@
|
||||
from ray.util.client.ray_client_helpers import ray_start_client_server
|
||||
from ray.util.client.ray_client_helpers import ray_start_client_server_pair
|
||||
from ray.test_utils import wait_for_condition
|
||||
import ray as real_ray
|
||||
from ray.core.generated.gcs_pb2 import ActorTableData
|
||||
from ray.util.client.server.server import _get_current_servicer
|
||||
|
||||
|
||||
def server_object_ref_count(n):
|
||||
server = _get_current_servicer()
|
||||
def server_object_ref_count(server, n):
|
||||
assert server is not None
|
||||
|
||||
def test_cond():
|
||||
if len(server.object_refs) == 0:
|
||||
if len(server.task_servicer.object_refs) == 0:
|
||||
# No open clients
|
||||
return n == 0
|
||||
client_id = list(server.object_refs.keys())[0]
|
||||
return len(server.object_refs[client_id]) == n
|
||||
client_id = list(server.task_servicer.object_refs.keys())[0]
|
||||
return len(server.task_servicer.object_refs[client_id]) == n
|
||||
|
||||
return test_cond
|
||||
|
||||
|
||||
def server_actor_ref_count(n):
|
||||
server = _get_current_servicer()
|
||||
def server_actor_ref_count(server, n):
|
||||
assert server is not None
|
||||
|
||||
def test_cond():
|
||||
if len(server.actor_refs) == 0:
|
||||
if len(server.task_servicer.actor_refs) == 0:
|
||||
# No running actors
|
||||
return n == 0
|
||||
return len(server.actor_refs) == n
|
||||
return len(server.task_servicer.actor_refs) == n
|
||||
|
||||
return test_cond
|
||||
|
||||
|
||||
def test_delete_refs_on_disconnect(ray_start_regular):
|
||||
with ray_start_client_server() as ray:
|
||||
with ray_start_client_server_pair() as pair:
|
||||
ray, server = pair
|
||||
|
||||
@ray.remote
|
||||
def f(x):
|
||||
@@ -46,14 +45,14 @@ def test_delete_refs_on_disconnect(ray_start_regular):
|
||||
# in a different category, according to the raylet.
|
||||
assert len(real_ray.objects()) == 2
|
||||
# But we're maintaining the reference
|
||||
assert server_object_ref_count(3)()
|
||||
assert server_object_ref_count(server, 3)()
|
||||
# And can get the data
|
||||
assert ray.get(thing1) == 8
|
||||
|
||||
# Close the client
|
||||
ray.close()
|
||||
|
||||
wait_for_condition(server_object_ref_count(0), timeout=5)
|
||||
wait_for_condition(server_object_ref_count(server, 0), timeout=5)
|
||||
|
||||
def test_cond():
|
||||
return len(real_ray.objects()) == 0
|
||||
@@ -62,7 +61,8 @@ def test_delete_refs_on_disconnect(ray_start_regular):
|
||||
|
||||
|
||||
def test_delete_ref_on_object_deletion(ray_start_regular):
|
||||
with ray_start_client_server() as ray:
|
||||
with ray_start_client_server_pair() as pair:
|
||||
ray, server = pair
|
||||
vals = {
|
||||
"ref": ray.put("Hello World"),
|
||||
"ref2": ray.put("This value stays"),
|
||||
@@ -70,11 +70,12 @@ def test_delete_ref_on_object_deletion(ray_start_regular):
|
||||
|
||||
del vals["ref"]
|
||||
|
||||
wait_for_condition(server_object_ref_count(1), timeout=5)
|
||||
wait_for_condition(server_object_ref_count(server, 1), timeout=5)
|
||||
|
||||
|
||||
def test_delete_actor_on_disconnect(ray_start_regular):
|
||||
with ray_start_client_server() as ray:
|
||||
with ray_start_client_server_pair() as pair:
|
||||
ray, server = pair
|
||||
|
||||
@ray.remote
|
||||
class Accumulator:
|
||||
@@ -90,13 +91,13 @@ def test_delete_actor_on_disconnect(ray_start_regular):
|
||||
actor = Accumulator.remote()
|
||||
actor.inc.remote()
|
||||
|
||||
assert server_actor_ref_count(1)()
|
||||
assert server_actor_ref_count(server, 1)()
|
||||
|
||||
assert ray.get(actor.get.remote()) == 1
|
||||
|
||||
ray.close()
|
||||
|
||||
wait_for_condition(server_actor_ref_count(0), timeout=5)
|
||||
wait_for_condition(server_actor_ref_count(server, 0), timeout=5)
|
||||
|
||||
def test_cond():
|
||||
alive_actors = [
|
||||
@@ -109,7 +110,8 @@ def test_delete_actor_on_disconnect(ray_start_regular):
|
||||
|
||||
|
||||
def test_delete_actor(ray_start_regular):
|
||||
with ray_start_client_server() as ray:
|
||||
with ray_start_client_server_pair() as pair:
|
||||
ray, server = pair
|
||||
|
||||
@ray.remote
|
||||
class Accumulator:
|
||||
@@ -124,11 +126,11 @@ def test_delete_actor(ray_start_regular):
|
||||
actor2 = Accumulator.remote()
|
||||
actor2.inc.remote()
|
||||
|
||||
assert server_actor_ref_count(2)()
|
||||
assert server_actor_ref_count(server, 2)()
|
||||
|
||||
del actor
|
||||
|
||||
wait_for_condition(server_actor_ref_count(1), timeout=5)
|
||||
wait_for_condition(server_actor_ref_count(server, 1), timeout=5)
|
||||
|
||||
|
||||
def test_simple_multiple_references(ray_start_regular):
|
||||
|
||||
@@ -1,5 +1,6 @@
|
||||
from typing import List, Tuple, Dict, Any
|
||||
|
||||
import sys
|
||||
import logging
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
@@ -25,7 +26,9 @@ class RayAPIStub:
|
||||
conn_str: str,
|
||||
secure: bool = False,
|
||||
metadata: List[Tuple[str, str]] = None,
|
||||
connection_retries: int = 3) -> Dict[str, Any]:
|
||||
connection_retries: int = 3,
|
||||
*,
|
||||
ignore_version: bool = False) -> Dict[str, Any]:
|
||||
"""Connect the Ray Client to a server.
|
||||
|
||||
Args:
|
||||
@@ -56,11 +59,25 @@ class RayAPIStub:
|
||||
metadata=metadata,
|
||||
connection_retries=connection_retries)
|
||||
self.api.worker = self.client_worker
|
||||
return self.client_worker.connection_info()
|
||||
conn_info = self.client_worker.connection_info()
|
||||
self._check_versions(conn_info, ignore_version)
|
||||
return conn_info
|
||||
except Exception:
|
||||
self.disconnect()
|
||||
raise
|
||||
|
||||
def _check_versions(self, conn_info, ignore_version: bool) -> None:
|
||||
local_major_minor = f"{sys.version_info[0]}.{sys.version_info[1]}"
|
||||
if not conn_info["python_version"].startswith(local_major_minor):
|
||||
version_str = f"{local_major_minor}.{sys.version_info[2]}"
|
||||
msg = "Python minor versions differ between client and server:" + \
|
||||
f" client is {version_str}," + \
|
||||
f" server is {conn_info['python_version']}"
|
||||
if ignore_version:
|
||||
logger.warning(msg)
|
||||
else:
|
||||
raise RuntimeError(msg)
|
||||
|
||||
def disconnect(self):
|
||||
"""Disconnect the Ray Client.
|
||||
"""
|
||||
@@ -97,8 +114,9 @@ class RayAPIStub:
|
||||
if self._server is not None:
|
||||
raise Exception("Trying to start two instances of ray via client")
|
||||
import ray.util.client.server.server as ray_client_server
|
||||
self._server, address_info = ray_client_server.init_and_serve(
|
||||
server_handle, address_info = ray_client_server.init_and_serve(
|
||||
"localhost:50051", *args, **kwargs)
|
||||
self._server = server_handle.grpc_server
|
||||
self.connect("localhost:50051")
|
||||
self._connected_with_init = True
|
||||
return address_info
|
||||
|
||||
@@ -6,11 +6,18 @@ from ray.util.client import ray
|
||||
|
||||
@contextmanager
|
||||
def ray_start_client_server():
|
||||
with ray_start_client_server_pair() as pair:
|
||||
client, server = pair
|
||||
yield client
|
||||
|
||||
|
||||
@contextmanager
|
||||
def ray_start_client_server_pair():
|
||||
ray._inside_client_test = True
|
||||
server = ray_client_server.serve("localhost:50051")
|
||||
ray.connect("localhost:50051")
|
||||
try:
|
||||
yield ray
|
||||
yield ray, server
|
||||
finally:
|
||||
ray._inside_client_test = False
|
||||
ray.disconnect()
|
||||
|
||||
@@ -50,16 +50,8 @@ class DataServicer(ray_client_pb2_grpc.RayletDataStreamerServicer):
|
||||
resp = ray_client_pb2.DataResponse(
|
||||
release=ray_client_pb2.ReleaseResponse(ok=released))
|
||||
elif req_type == "connection_info":
|
||||
with self._clients_lock:
|
||||
cur_num_clients = self._num_clients
|
||||
info = ray_client_pb2.ConnectionInfoResponse(
|
||||
num_clients=cur_num_clients,
|
||||
python_version="{}.{}.{}".format(
|
||||
sys.version_info[0], sys.version_info[1],
|
||||
sys.version_info[2]),
|
||||
ray_version=ray.__version__,
|
||||
ray_commit=ray.__commit__)
|
||||
resp = ray_client_pb2.DataResponse(connection_info=info)
|
||||
resp = ray_client_pb2.DataResponse(
|
||||
connection_info=self._build_connection_response())
|
||||
else:
|
||||
raise Exception(f"Unreachable code: Request type "
|
||||
f"{req_type} not handled in Datapath")
|
||||
@@ -72,3 +64,13 @@ class DataServicer(ray_client_pb2_grpc.RayletDataStreamerServicer):
|
||||
self.basic_service.release_all(client_id)
|
||||
with self._clients_lock:
|
||||
self._num_clients -= 1
|
||||
|
||||
def _build_connection_response(self):
|
||||
with self._clients_lock:
|
||||
cur_num_clients = self._num_clients
|
||||
return ray_client_pb2.ConnectionInfoResponse(
|
||||
num_clients=cur_num_clients,
|
||||
python_version="{}.{}.{}".format(
|
||||
sys.version_info[0], sys.version_info[1], sys.version_info[2]),
|
||||
ray_version=ray.__version__,
|
||||
ray_commit=ray.__commit__)
|
||||
|
||||
@@ -3,6 +3,7 @@ from concurrent import futures
|
||||
import grpc
|
||||
import base64
|
||||
from collections import defaultdict
|
||||
from dataclasses import dataclass
|
||||
|
||||
from typing import Any
|
||||
from typing import Dict
|
||||
@@ -407,13 +408,18 @@ def decode_options(
|
||||
return opts
|
||||
|
||||
|
||||
_current_servicer: Optional[RayletServicer] = None
|
||||
@dataclass
|
||||
class ClientServerHandle:
|
||||
"""Holds the handles to the registered gRPC servicers and their server."""
|
||||
task_servicer: RayletServicer
|
||||
data_servicer: DataServicer
|
||||
logs_servicer: LogstreamServicer
|
||||
grpc_server: grpc.Server
|
||||
|
||||
|
||||
# Used by tests to peek inside the servicer
|
||||
def _get_current_servicer():
|
||||
global _current_servicer
|
||||
return _current_servicer
|
||||
# Add a hook for all the cases that previously
|
||||
# expected simply a gRPC server
|
||||
def __getattr__(self, attr):
|
||||
return getattr(self.grpc_server, attr)
|
||||
|
||||
|
||||
def serve(connection_str):
|
||||
@@ -421,8 +427,6 @@ def serve(connection_str):
|
||||
task_servicer = RayletServicer()
|
||||
data_servicer = DataServicer(task_servicer)
|
||||
logs_servicer = LogstreamServicer()
|
||||
global _current_servicer
|
||||
_current_servicer = task_servicer
|
||||
ray_client_pb2_grpc.add_RayletDriverServicer_to_server(
|
||||
task_servicer, server)
|
||||
ray_client_pb2_grpc.add_RayletDataStreamerServicer_to_server(
|
||||
@@ -430,16 +434,22 @@ def serve(connection_str):
|
||||
ray_client_pb2_grpc.add_RayletLogStreamerServicer_to_server(
|
||||
logs_servicer, server)
|
||||
server.add_insecure_port(connection_str)
|
||||
current_handle = ClientServerHandle(
|
||||
task_servicer=task_servicer,
|
||||
data_servicer=data_servicer,
|
||||
logs_servicer=logs_servicer,
|
||||
grpc_server=server,
|
||||
)
|
||||
server.start()
|
||||
return server
|
||||
return current_handle
|
||||
|
||||
|
||||
def init_and_serve(connection_str, *args, **kwargs):
|
||||
with disable_client_hook():
|
||||
# Disable client mode inside the worker's environment
|
||||
info = ray.init(*args, **kwargs)
|
||||
server = serve(connection_str)
|
||||
return (server, info)
|
||||
server_handle = serve(connection_str)
|
||||
return (server_handle, info)
|
||||
|
||||
|
||||
def shutdown_with_server(server, _exiting_interpreter=False):
|
||||
|
||||
Reference in New Issue
Block a user