[ray_client]: Add python version check and test (and some minor fixes along the way) (#13722)

This commit is contained in:
Barak Michener
2021-02-01 13:04:38 -08:00
committed by GitHub
parent 754bee9282
commit 55566bc797
6 changed files with 127 additions and 47 deletions
+42 -1
View File
@@ -1,7 +1,11 @@
"""Client tests that run their own init (as with init_and_serve) live here"""
import pytest
import time
import sys
import ray.util.client.server.server as ray_client_server
import ray.core.generated.ray_client_pb2 as ray_client_pb2
from ray.util.client import RayAPIStub
@@ -9,7 +13,8 @@ from ray.util.client import RayAPIStub
def test_num_clients():
# Tests num clients reporting; useful if you want to build an app that
# load balances clients between Ray client servers.
server, _ = ray_client_server.init_and_serve("localhost:50051")
server_handle, _ = ray_client_server.init_and_serve("localhost:50051")
server = server_handle.grpc_server
try:
api1 = RayAPIStub()
info1 = api1.connect("localhost:50051")
@@ -35,3 +40,39 @@ def test_num_clients():
finally:
ray_client_server.shutdown_with_server(server)
time.sleep(2)
def test_python_version():
server_handle, _ = ray_client_server.init_and_serve("localhost:50051")
try:
ray = RayAPIStub()
info1 = ray.connect("localhost:50051")
assert info1["python_version"] == ".".join(
[str(x) for x in list(sys.version_info)[:3]])
ray.disconnect()
time.sleep(1)
def mock_connection_response():
return ray_client_pb2.ConnectionInfoResponse(
num_clients=1,
python_version="2.7.12",
ray_version="",
ray_commit="",
)
# inject mock connection function
server_handle.data_servicer._build_connection_response = \
mock_connection_response
ray = RayAPIStub()
with pytest.raises(RuntimeError):
_ = ray.connect("localhost:50051")
ray = RayAPIStub()
info3 = ray.connect("localhost:50051", ignore_version=True)
assert info3["num_clients"] == 1, info3
ray.disconnect()
finally:
ray_client_server.shutdown_with_server(server_handle.grpc_server)
time.sleep(2)
+23 -21
View File
@@ -1,39 +1,38 @@
from ray.util.client.ray_client_helpers import ray_start_client_server
from ray.util.client.ray_client_helpers import ray_start_client_server_pair
from ray.test_utils import wait_for_condition
import ray as real_ray
from ray.core.generated.gcs_pb2 import ActorTableData
from ray.util.client.server.server import _get_current_servicer
def server_object_ref_count(n):
server = _get_current_servicer()
def server_object_ref_count(server, n):
assert server is not None
def test_cond():
if len(server.object_refs) == 0:
if len(server.task_servicer.object_refs) == 0:
# No open clients
return n == 0
client_id = list(server.object_refs.keys())[0]
return len(server.object_refs[client_id]) == n
client_id = list(server.task_servicer.object_refs.keys())[0]
return len(server.task_servicer.object_refs[client_id]) == n
return test_cond
def server_actor_ref_count(n):
server = _get_current_servicer()
def server_actor_ref_count(server, n):
assert server is not None
def test_cond():
if len(server.actor_refs) == 0:
if len(server.task_servicer.actor_refs) == 0:
# No running actors
return n == 0
return len(server.actor_refs) == n
return len(server.task_servicer.actor_refs) == n
return test_cond
def test_delete_refs_on_disconnect(ray_start_regular):
with ray_start_client_server() as ray:
with ray_start_client_server_pair() as pair:
ray, server = pair
@ray.remote
def f(x):
@@ -46,14 +45,14 @@ def test_delete_refs_on_disconnect(ray_start_regular):
# in a different category, according to the raylet.
assert len(real_ray.objects()) == 2
# But we're maintaining the reference
assert server_object_ref_count(3)()
assert server_object_ref_count(server, 3)()
# And can get the data
assert ray.get(thing1) == 8
# Close the client
ray.close()
wait_for_condition(server_object_ref_count(0), timeout=5)
wait_for_condition(server_object_ref_count(server, 0), timeout=5)
def test_cond():
return len(real_ray.objects()) == 0
@@ -62,7 +61,8 @@ def test_delete_refs_on_disconnect(ray_start_regular):
def test_delete_ref_on_object_deletion(ray_start_regular):
with ray_start_client_server() as ray:
with ray_start_client_server_pair() as pair:
ray, server = pair
vals = {
"ref": ray.put("Hello World"),
"ref2": ray.put("This value stays"),
@@ -70,11 +70,12 @@ def test_delete_ref_on_object_deletion(ray_start_regular):
del vals["ref"]
wait_for_condition(server_object_ref_count(1), timeout=5)
wait_for_condition(server_object_ref_count(server, 1), timeout=5)
def test_delete_actor_on_disconnect(ray_start_regular):
with ray_start_client_server() as ray:
with ray_start_client_server_pair() as pair:
ray, server = pair
@ray.remote
class Accumulator:
@@ -90,13 +91,13 @@ def test_delete_actor_on_disconnect(ray_start_regular):
actor = Accumulator.remote()
actor.inc.remote()
assert server_actor_ref_count(1)()
assert server_actor_ref_count(server, 1)()
assert ray.get(actor.get.remote()) == 1
ray.close()
wait_for_condition(server_actor_ref_count(0), timeout=5)
wait_for_condition(server_actor_ref_count(server, 0), timeout=5)
def test_cond():
alive_actors = [
@@ -109,7 +110,8 @@ def test_delete_actor_on_disconnect(ray_start_regular):
def test_delete_actor(ray_start_regular):
with ray_start_client_server() as ray:
with ray_start_client_server_pair() as pair:
ray, server = pair
@ray.remote
class Accumulator:
@@ -124,11 +126,11 @@ def test_delete_actor(ray_start_regular):
actor2 = Accumulator.remote()
actor2.inc.remote()
assert server_actor_ref_count(2)()
assert server_actor_ref_count(server, 2)()
del actor
wait_for_condition(server_actor_ref_count(1), timeout=5)
wait_for_condition(server_actor_ref_count(server, 1), timeout=5)
def test_simple_multiple_references(ray_start_regular):
+21 -3
View File
@@ -1,5 +1,6 @@
from typing import List, Tuple, Dict, Any
import sys
import logging
logger = logging.getLogger(__name__)
@@ -25,7 +26,9 @@ class RayAPIStub:
conn_str: str,
secure: bool = False,
metadata: List[Tuple[str, str]] = None,
connection_retries: int = 3) -> Dict[str, Any]:
connection_retries: int = 3,
*,
ignore_version: bool = False) -> Dict[str, Any]:
"""Connect the Ray Client to a server.
Args:
@@ -56,11 +59,25 @@ class RayAPIStub:
metadata=metadata,
connection_retries=connection_retries)
self.api.worker = self.client_worker
return self.client_worker.connection_info()
conn_info = self.client_worker.connection_info()
self._check_versions(conn_info, ignore_version)
return conn_info
except Exception:
self.disconnect()
raise
def _check_versions(self, conn_info, ignore_version: bool) -> None:
local_major_minor = f"{sys.version_info[0]}.{sys.version_info[1]}"
if not conn_info["python_version"].startswith(local_major_minor):
version_str = f"{local_major_minor}.{sys.version_info[2]}"
msg = "Python minor versions differ between client and server:" + \
f" client is {version_str}," + \
f" server is {conn_info['python_version']}"
if ignore_version:
logger.warning(msg)
else:
raise RuntimeError(msg)
def disconnect(self):
"""Disconnect the Ray Client.
"""
@@ -97,8 +114,9 @@ class RayAPIStub:
if self._server is not None:
raise Exception("Trying to start two instances of ray via client")
import ray.util.client.server.server as ray_client_server
self._server, address_info = ray_client_server.init_and_serve(
server_handle, address_info = ray_client_server.init_and_serve(
"localhost:50051", *args, **kwargs)
self._server = server_handle.grpc_server
self.connect("localhost:50051")
self._connected_with_init = True
return address_info
+8 -1
View File
@@ -6,11 +6,18 @@ from ray.util.client import ray
@contextmanager
def ray_start_client_server():
with ray_start_client_server_pair() as pair:
client, server = pair
yield client
@contextmanager
def ray_start_client_server_pair():
ray._inside_client_test = True
server = ray_client_server.serve("localhost:50051")
ray.connect("localhost:50051")
try:
yield ray
yield ray, server
finally:
ray._inside_client_test = False
ray.disconnect()
+12 -10
View File
@@ -50,16 +50,8 @@ class DataServicer(ray_client_pb2_grpc.RayletDataStreamerServicer):
resp = ray_client_pb2.DataResponse(
release=ray_client_pb2.ReleaseResponse(ok=released))
elif req_type == "connection_info":
with self._clients_lock:
cur_num_clients = self._num_clients
info = ray_client_pb2.ConnectionInfoResponse(
num_clients=cur_num_clients,
python_version="{}.{}.{}".format(
sys.version_info[0], sys.version_info[1],
sys.version_info[2]),
ray_version=ray.__version__,
ray_commit=ray.__commit__)
resp = ray_client_pb2.DataResponse(connection_info=info)
resp = ray_client_pb2.DataResponse(
connection_info=self._build_connection_response())
else:
raise Exception(f"Unreachable code: Request type "
f"{req_type} not handled in Datapath")
@@ -72,3 +64,13 @@ class DataServicer(ray_client_pb2_grpc.RayletDataStreamerServicer):
self.basic_service.release_all(client_id)
with self._clients_lock:
self._num_clients -= 1
def _build_connection_response(self):
with self._clients_lock:
cur_num_clients = self._num_clients
return ray_client_pb2.ConnectionInfoResponse(
num_clients=cur_num_clients,
python_version="{}.{}.{}".format(
sys.version_info[0], sys.version_info[1], sys.version_info[2]),
ray_version=ray.__version__,
ray_commit=ray.__commit__)
+21 -11
View File
@@ -3,6 +3,7 @@ from concurrent import futures
import grpc
import base64
from collections import defaultdict
from dataclasses import dataclass
from typing import Any
from typing import Dict
@@ -407,13 +408,18 @@ def decode_options(
return opts
_current_servicer: Optional[RayletServicer] = None
@dataclass
class ClientServerHandle:
"""Holds the handles to the registered gRPC servicers and their server."""
task_servicer: RayletServicer
data_servicer: DataServicer
logs_servicer: LogstreamServicer
grpc_server: grpc.Server
# Used by tests to peek inside the servicer
def _get_current_servicer():
global _current_servicer
return _current_servicer
# Add a hook for all the cases that previously
# expected simply a gRPC server
def __getattr__(self, attr):
return getattr(self.grpc_server, attr)
def serve(connection_str):
@@ -421,8 +427,6 @@ def serve(connection_str):
task_servicer = RayletServicer()
data_servicer = DataServicer(task_servicer)
logs_servicer = LogstreamServicer()
global _current_servicer
_current_servicer = task_servicer
ray_client_pb2_grpc.add_RayletDriverServicer_to_server(
task_servicer, server)
ray_client_pb2_grpc.add_RayletDataStreamerServicer_to_server(
@@ -430,16 +434,22 @@ def serve(connection_str):
ray_client_pb2_grpc.add_RayletLogStreamerServicer_to_server(
logs_servicer, server)
server.add_insecure_port(connection_str)
current_handle = ClientServerHandle(
task_servicer=task_servicer,
data_servicer=data_servicer,
logs_servicer=logs_servicer,
grpc_server=server,
)
server.start()
return server
return current_handle
def init_and_serve(connection_str, *args, **kwargs):
with disable_client_hook():
# Disable client mode inside the worker's environment
info = ray.init(*args, **kwargs)
server = serve(connection_str)
return (server, info)
server_handle = serve(connection_str)
return (server_handle, info)
def shutdown_with_server(server, _exiting_interpreter=False):