[ray_client] Include multiple facets of the Ray API (#12736)

This commit is contained in:
Barak Michener
2020-12-10 19:09:34 -08:00
committed by GitHub
parent 8d1ad25545
commit b7f246c451
11 changed files with 530 additions and 73 deletions
+4 -4
View File
@@ -3,9 +3,8 @@ from traceback import format_exception
import colorama
import ray
import ray.cloudpickle as pickle
from ray.core.generated.common_pb2 import RayException, Language
from ray.core.generated.common_pb2 import RayException, Language, PYTHON
import setproctitle
@@ -17,7 +16,7 @@ class RayError(Exception):
exc_info = (type(self), self, self.__traceback__)
formatted_exception_string = "\n".join(format_exception(*exc_info))
return RayException(
language=ray.Language.PYTHON.value(),
language=PYTHON,
serialized_exception=pickle.dumps(self),
formatted_exception_string=formatted_exception_string
).SerializeToString()
@@ -26,7 +25,7 @@ class RayError(Exception):
def from_bytes(b):
ray_exception = RayException()
ray_exception.ParseFromString(b)
if ray_exception.language == ray.Language.PYTHON.value():
if ray_exception.language == PYTHON:
return pickle.loads(ray_exception.serialized_exception)
else:
return CrossLanguageError(ray_exception)
@@ -81,6 +80,7 @@ class RayTaskError(RayError):
pid=None,
ip=None):
"""Initialize a RayTaskError."""
import ray
if proctitle:
self.proctitle = proctitle
else:
+94 -6
View File
@@ -11,8 +11,10 @@
from abc import ABC
from abc import abstractmethod
from typing import TYPE_CHECKING, Any, Union
from typing import TYPE_CHECKING, Any, Union, Optional
import ray.core.generated.ray_client_pb2 as ray_client_pb2
if TYPE_CHECKING:
from ray.experimental.client.common import ClientActorHandle
from ray.experimental.client.common import ClientStub
from ray.experimental.client.common import ClientObjectRef
from ray._raylet import ObjectRef
@@ -29,13 +31,13 @@ class APIImpl(ABC):
"""
@abstractmethod
def get(self, *args, **kwargs) -> Any:
def get(self, vals, *, timeout: Optional[float] = None) -> Any:
"""
get is the hook stub passed on to replace `ray.get`
Args:
args: opaque arguments
kwargs: opaque keyword arguments
vals: [Client]ObjectRef or list of these refs to retrieve.
timeout: Optional timeout in milliseconds
"""
pass
@@ -103,6 +105,39 @@ class APIImpl(ABC):
"""
pass
@abstractmethod
def kill(self, actor, *, no_restart=True):
"""
kill forcibly stops an actor running in the cluster
Args:
no_restart: Whether this actor should be restarted if it's a
restartable actor.
"""
pass
@abstractmethod
def cancel(self, obj, *, force=False, recursive=True):
"""
Cancels a task on the cluster.
If the specified task is pending execution, it will not be executed. If
the task is currently executing, the behavior depends on the ``force``
flag, as per `ray.cancel()`
Only non-actor tasks can be canceled. Canceled tasks will not be
retried (max_retries will not be respected).
Args:
object_ref (ObjectRef): ObjectRef returned by the task
that should be canceled.
force (boolean): Whether to force-kill a running task by killing
the worker that is running the task.
recursive (boolean): Whether to try to cancel tasks submitted by the
task specified.
"""
pass
class ClientAPI(APIImpl):
"""
@@ -113,8 +148,8 @@ class ClientAPI(APIImpl):
def __init__(self, worker):
self.worker = worker
def get(self, *args, **kwargs):
return self.worker.get(*args, **kwargs)
def get(self, vals, *, timeout=None):
return self.worker.get(vals, timeout=timeout)
def put(self, *args, **kwargs):
return self.worker.put(*args, **kwargs)
@@ -131,6 +166,59 @@ class ClientAPI(APIImpl):
def close(self) -> None:
return self.worker.close()
def kill(self, actor: "ClientActorHandle", *, no_restart=True):
return self.worker.terminate_actor(actor, no_restart)
def cancel(self, obj: "ClientObjectRef", *, force=False, recursive=True):
return self.worker.terminate_task(obj, force, recursive)
# Various metadata methods for the client that are defined in the protocol.
def is_initialized(self) -> bool:
""" True if our client is connected, and if the server is initialized.
Returns:
A boolean determining if the client is connected and
server initialized.
"""
return self.worker.is_initialized()
def nodes(self):
"""Get a list of the nodes in the cluster (for debugging only).
Returns:
Information about the Ray clients in the cluster.
"""
return self.worker.get_cluster_info(
ray_client_pb2.ClusterInfoType.NODES)
def cluster_resources(self):
"""Get the current total cluster resources.
Note that this information can grow stale as nodes are added to or
removed from the cluster.
Returns:
A dictionary mapping resource name to the total quantity of that
resource in the cluster.
"""
return self.worker.get_cluster_info(
ray_client_pb2.ClusterInfoType.CLUSTER_RESOURCES)
def available_resources(self):
"""Get the current available cluster resources.
This is different from `cluster_resources` in that this will return idle
(available) resources rather than total resources.
Note that this information can grow stale as tasks start and finish.
Returns:
A dictionary mapping resource name to the total quantity of that
resource in the cluster.
"""
return self.worker.get_cluster_info(
ray_client_pb2.ClusterInfoType.AVAILABLE_RESOURCES)
def __getattr__(self, key: str):
if not key.startswith("_"):
raise NotImplementedError(
+29 -7
View File
@@ -4,10 +4,13 @@ from typing import Any
from typing import Dict
from ray import cloudpickle
import base64
class ClientBaseRef:
def __init__(self, id):
def __init__(self, id, handle=None):
self.id = id
self.handle = handle
def __repr__(self):
return "%s(%s)" % (
@@ -21,9 +24,14 @@ class ClientBaseRef:
def binary(self):
return self.id
@classmethod
def from_remote_ref(cls, ref: ray_client_pb2.RemoteRef):
return cls(id=ref.id, handle=ref.handle)
class ClientObjectRef(ClientBaseRef):
pass
def _unpack_ref(self):
return cloudpickle.loads(self.handle)
class ClientActorRef(ClientBaseRef):
@@ -88,7 +96,7 @@ class ClientRemoteFunc(ClientStub):
task = ray_client_pb2.ClientTask()
task.type = ray_client_pb2.ClientTask.FUNCTION
task.name = self._name
task.payload_id = self._ref.id
task.payload_id = self._ref.handle
return task
@@ -130,7 +138,7 @@ class ClientActorClass(ClientStub):
def remote(self, *args, **kwargs):
# Actually instantiate the actor
ref = ray.call_remote(self, *args, **kwargs)
return ClientActorHandle(ClientActorRef(ref.id), self)
return ClientActorHandle(ClientActorRef(ref.id, ref.handle), self)
def __repr__(self):
return "ClientRemoteActor(%s, %s)" % (self._name, self._ref)
@@ -146,7 +154,7 @@ class ClientActorClass(ClientStub):
task = ray_client_pb2.ClientTask()
task.type = ray_client_pb2.ClientTask.ACTOR
task.name = self._name
task.payload_id = self._ref.id
task.payload_id = self._ref.handle
return task
@@ -174,7 +182,7 @@ class ClientActorHandle(ClientStub):
def _get_ray_remote_impl(self):
if self._real_actor_handle is None:
self._real_actor_handle = cloudpickle.loads(self.actor_ref.id)
self._real_actor_handle = cloudpickle.loads(self.actor_ref.handle)
return self._real_actor_handle
def __getstate__(self) -> Dict:
@@ -190,6 +198,10 @@ class ClientActorHandle(ClientStub):
self.actor_class = state["actor_class"]
self._real_actor_handle = state["_real_actor_handle"]
@property
def _actor_id(self):
return self.actor_ref.id
def __getattr__(self, key):
return ClientRemoteMethod(self, key)
@@ -244,7 +256,7 @@ class ClientRemoteMethod(ClientStub):
task = ray_client_pb2.ClientTask()
task.type = ray_client_pb2.ClientTask.METHOD
task.name = self.method_name
task.payload_id = self.actor_handle.actor_ref.id
task.payload_id = self.actor_handle.actor_ref.handle
return task
@@ -266,3 +278,13 @@ def convert_to_arg(val):
out.local = ray_client_pb2.Arg.Locality.INTERNED
out.data = cloudpickle.dumps(val)
return out
def encode_exception(exception) -> str:
data = cloudpickle.dumps(exception)
return base64.standard_b64encode(data).decode()
def decode_exception(data) -> Exception:
data = base64.standard_b64decode(data)
return cloudpickle.loads(data)
@@ -8,6 +8,7 @@
# making into the core-ray module are contained and well-defined.
from typing import Any
from typing import Optional
from typing import Union
import ray
@@ -24,8 +25,14 @@ class CoreRayAPI(APIImpl):
to core ray when passed client stubs.
"""
def get(self, *args, **kwargs):
return ray.get(*args, **kwargs)
def get(self, vals, *, timeout: Optional[float] = None) -> Any:
if isinstance(vals, list):
if isinstance(vals[0], ClientObjectRef):
return ray.get(
[val._unpack_ref() for val in vals], timeout=timeout)
elif isinstance(vals, ClientObjectRef):
return ray.get(vals._unpack_ref(), timeout=timeout)
return ray.get(vals, timeout=timeout)
def put(self, vals: Any, *args,
**kwargs) -> Union[ClientObjectRef, ray._raylet.ObjectRef]:
@@ -43,6 +50,15 @@ class CoreRayAPI(APIImpl):
def close(self) -> None:
return None
def kill(self, actor, *, no_restart=True):
return ray.kill(actor, no_restart=no_restart)
def cancel(self, obj, *, force=False, recursive=True):
return ray.cancel(obj, force=force, recursive=recursive)
def is_initialized(self) -> bool:
return ray.is_initialized()
# Allow for generic fallback to ray.* in remote methods. This allows calls
# like ray.nodes() to be run in remote functions even though the client
# doesn't currently support them.
+112 -22
View File
@@ -3,12 +3,15 @@ from concurrent import futures
import grpc
from ray import cloudpickle
import ray
import ray.state
import ray.core.generated.ray_client_pb2 as ray_client_pb2
import ray.core.generated.ray_client_pb2_grpc as ray_client_pb2_grpc
import time
import inspect
import json
from ray.experimental.client import stash_api_for_tests, _set_server_api
from ray.experimental.client.common import convert_from_arg
from ray.experimental.client.common import encode_exception
from ray.experimental.client.common import ClientObjectRef
from ray.experimental.client.server.core_ray_api import RayServerAPI
@@ -23,19 +26,81 @@ class RayletServicer(ray_client_pb2_grpc.RayletDriverServicer):
self.registered_actor_classes = {}
self._test_mode = test_mode
def ClusterInfo(self, request,
context=None) -> ray_client_pb2.ClusterInfoResponse:
resp = ray_client_pb2.ClusterInfoResponse()
resp.type = request.type
if request.type == ray_client_pb2.ClusterInfoType.CLUSTER_RESOURCES:
resources = ray.cluster_resources()
# Normalize resources into floats
# (the function may return values that are ints)
float_resources = {k: float(v) for k, v in resources.items()}
resp.resource_table.CopyFrom(
ray_client_pb2.ClusterInfoResponse.ResourceTable(
table=float_resources))
elif request.type == ray_client_pb2.ClusterInfoType.AVAILABLE_RESOURCES:
resources = ray.available_resources()
# Normalize resources into floats
# (the function may return values that are ints)
float_resources = {k: float(v) for k, v in resources.items()}
resp.resource_table.CopyFrom(
ray_client_pb2.ClusterInfoResponse.ResourceTable(
table=float_resources))
else:
resp.json = self._return_debug_cluster_info(request, context)
return resp
def _return_debug_cluster_info(self, request, context=None) -> str:
data = None
if request.type == ray_client_pb2.ClusterInfoType.NODES:
data = ray.nodes()
elif request.type == ray_client_pb2.ClusterInfoType.IS_INITIALIZED:
data = ray.is_initialized()
else:
raise TypeError("Unsupported cluster info type")
return json.dumps(data)
def Terminate(self, request, context=None):
if request.WhichOneof("terminate_type") == "task_object":
try:
object_ref = cloudpickle.loads(request.task_object.handle)
ray.cancel(
object_ref,
force=request.task_object.force,
recursive=request.task_object.recursive)
except Exception as e:
return_exception_in_context(e, context)
elif request.WhichOneof("terminate_type") == "actor":
try:
actor_ref = cloudpickle.loads(request.actor.handle)
ray.kill(actor_ref, no_restart=request.actor.no_restart)
except Exception as e:
return_exception_in_context(e, context)
else:
raise RuntimeError(
"Client requested termination without providing a valid terminate_type"
)
return ray_client_pb2.TerminateResponse(ok=True)
def GetObject(self, request, context=None):
if request.id not in self.object_refs:
request_ref = cloudpickle.loads(request.handle)
if request_ref.binary() not in self.object_refs:
return ray_client_pb2.GetResponse(valid=False)
objectref = self.object_refs[request.id]
objectref = self.object_refs[request_ref.binary()]
logger.info("get: %s" % objectref)
item = ray.get(objectref)
try:
item = ray.get(objectref, timeout=request.timeout)
except Exception as e:
return_exception_in_context(e, context)
item_ser = cloudpickle.dumps(item)
return ray_client_pb2.GetResponse(valid=True, data=item_ser)
def PutObject(self, request, context=None) -> ray_client_pb2.PutResponse:
obj = cloudpickle.loads(request.data)
objectref = self._put_and_retain_obj(obj)
return ray_client_pb2.PutResponse(id=objectref.binary())
pickled_ref = cloudpickle.dumps(objectref)
return ray_client_pb2.PutResponse(
ref=make_remote_ref(objectref.binary(), pickled_ref))
def _put_and_retain_obj(self, obj) -> ray.ObjectRef:
objectref = ray.put(obj)
@@ -44,14 +109,14 @@ class RayletServicer(ray_client_pb2_grpc.RayletDriverServicer):
return objectref
def WaitObject(self, request, context=None) -> ray_client_pb2.WaitResponse:
object_refs = [cloudpickle.loads(o) for o in request.object_refs]
object_refs = [cloudpickle.loads(o) for o in request.object_handles]
num_returns = request.num_returns
timeout = request.timeout
object_refs_ids = []
for object_ref in object_refs:
if object_ref.id not in self.object_refs:
if object_ref.binary() not in self.object_refs:
return ray_client_pb2.WaitResponse(valid=False)
object_refs_ids.append(self.object_refs[object_ref.id])
object_refs_ids.append(self.object_refs[object_ref.binary()])
try:
ready_object_refs, remaining_object_refs = ray.wait(
object_refs_ids,
@@ -63,11 +128,16 @@ class RayletServicer(ray_client_pb2_grpc.RayletDriverServicer):
logger.info("wait: %s %s" % (str(ready_object_refs),
str(remaining_object_refs)))
ready_object_ids = [
ready_object_ref.binary() for ready_object_ref in ready_object_refs
make_remote_ref(
id=ready_object_ref.binary(),
handle=cloudpickle.dumps(ready_object_ref),
) for ready_object_ref in ready_object_refs
]
remaining_object_ids = [
remaining_object_ref.binary()
for remaining_object_ref in remaining_object_refs
make_remote_ref(
id=remaining_object_ref.binary(),
handle=cloudpickle.dumps(remaining_object_ref),
) for remaining_object_ref in remaining_object_refs
]
return ray_client_pb2.WaitResponse(
valid=True,
@@ -103,41 +173,46 @@ class RayletServicer(ray_client_pb2_grpc.RayletDriverServicer):
with stash_api_for_tests(self._test_mode):
output = getattr(actor_handle, task.name).remote(*arglist)
self.object_refs[output.binary()] = output
return ray_client_pb2.ClientTaskTicket(return_id=output.binary())
pickled_ref = cloudpickle.dumps(output)
return ray_client_pb2.ClientTaskTicket(
return_ref=make_remote_ref(output.binary(), pickled_ref))
def _schedule_actor(self,
task: ray_client_pb2.ClientTask,
context=None,
prepared_args=None) -> ray_client_pb2.ClientTaskTicket:
with stash_api_for_tests(self._test_mode):
if task.payload_id not in self.registered_actor_classes:
actor_class_ref = self.object_refs[task.payload_id]
payload_ref = cloudpickle.loads(task.payload_id)
if payload_ref.binary() not in self.registered_actor_classes:
actor_class_ref = self.object_refs[payload_ref.binary()]
actor_class = ray.get(actor_class_ref)
if not inspect.isclass(actor_class):
raise Exception("Attempting to schedule actor that "
"isn't a class.")
reg_class = ray.remote(actor_class)
self.registered_actor_classes[task.payload_id] = reg_class
remote_class = self.registered_actor_classes[task.payload_id]
self.registered_actor_classes[payload_ref.binary()] = reg_class
remote_class = self.registered_actor_classes[payload_ref.binary()]
arglist = _convert_args(task.args, prepared_args)
actor = remote_class.remote(*arglist)
actorhandle = cloudpickle.dumps(actor)
self.actor_refs[actorhandle] = actor
return ray_client_pb2.ClientTaskTicket(return_id=actorhandle)
return ray_client_pb2.ClientTaskTicket(
return_ref=make_remote_ref(actor._actor_id.binary(), actorhandle))
def _schedule_function(
self,
task: ray_client_pb2.ClientTask,
context=None,
prepared_args=None) -> ray_client_pb2.ClientTaskTicket:
if task.payload_id not in self.function_refs:
funcref = self.object_refs[task.payload_id]
payload_ref = cloudpickle.loads(task.payload_id)
if payload_ref.binary() not in self.function_refs:
funcref = self.object_refs[payload_ref.binary()]
func = ray.get(funcref)
if not inspect.isfunction(func):
raise Exception("Attempting to schedule function that "
"isn't a function.")
self.function_refs[task.payload_id] = ray.remote(func)
remote_func = self.function_refs[task.payload_id]
self.function_refs[payload_ref.binary()] = ray.remote(func)
remote_func = self.function_refs[payload_ref.binary()]
arglist = _convert_args(task.args, prepared_args)
# Prepare call if we're in a test
with stash_api_for_tests(self._test_mode):
@@ -145,7 +220,9 @@ class RayletServicer(ray_client_pb2_grpc.RayletDriverServicer):
if output.binary() in self.object_refs:
raise Exception("already found it")
self.object_refs[output.binary()] = output
return ray_client_pb2.ClientTaskTicket(return_id=output.binary())
pickled_output = cloudpickle.dumps(output)
return ray_client_pb2.ClientTaskTicket(
return_ref=make_remote_ref(output.binary(), pickled_output))
def _convert_args(arg_list, prepared_args=None):
@@ -155,12 +232,25 @@ def _convert_args(arg_list, prepared_args=None):
for arg in arg_list:
t = convert_from_arg(arg)
if isinstance(t, ClientObjectRef):
out.append(ray.ObjectRef(t.id))
out.append(t._unpack_ref())
else:
out.append(t)
return out
def make_remote_ref(id: bytes, handle: bytes) -> ray_client_pb2.RemoteRef:
return ray_client_pb2.RemoteRef(
id=id,
handle=handle,
)
def return_exception_in_context(err, context):
if context is not None:
context.set_details(encode_exception(err))
context.set_code(grpc.StatusCode.INTERNAL)
def serve(connection_str, test_mode=False):
server = grpc.server(futures.ThreadPoolExecutor(max_workers=10))
task_servicer = RayletServicer(test_mode=test_mode)
+78 -19
View File
@@ -3,19 +3,25 @@ It implements the Ray API functions that are forwarded through grpc calls
to the server.
"""
import inspect
import json
import logging
from typing import Any
from typing import List
from typing import Tuple
from typing import Optional
import ray.cloudpickle as cloudpickle
from ray.util.inspect import is_cython
import grpc
from ray.exceptions import TaskCancelledError
import ray.core.generated.ray_client_pb2 as ray_client_pb2
import ray.core.generated.ray_client_pb2_grpc as ray_client_pb2_grpc
from ray.experimental.client.common import convert_to_arg
from ray.experimental.client.common import decode_exception
from ray.experimental.client.common import ClientObjectRef
from ray.experimental.client.common import ClientActorClass
from ray.experimental.client.common import ClientActorHandle
from ray.experimental.client.common import ClientRemoteFunc
logger = logging.getLogger(__name__)
@@ -35,6 +41,7 @@ class Worker:
metadata: additional metadata passed in the grpc request headers.
"""
self.metadata = metadata
self.channel = None
if stub is None:
if secure:
credentials = grpc.ssl_channel_credentials()
@@ -45,28 +52,32 @@ class Worker:
else:
self.server = stub
def get(self, ids):
def get(self, vals, *, timeout: Optional[float] = None) -> Any:
to_get = []
single = False
if isinstance(ids, list):
to_get = [x.id for x in ids]
elif isinstance(ids, ClientObjectRef):
to_get = [ids.id]
if isinstance(vals, list):
to_get = [x.handle for x in vals]
elif isinstance(vals, ClientObjectRef):
to_get = [vals.handle]
single = True
else:
raise Exception("Can't get something that's not a "
"list of IDs or just an ID: %s" % type(ids))
out = [self._get(x) for x in to_get]
"list of IDs or just an ID: %s" % type(vals))
if timeout is None:
timeout = 0
out = [self._get(x, timeout) for x in to_get]
if single:
out = out[0]
return out
def _get(self, id: bytes):
req = ray_client_pb2.GetRequest(id=id)
data = self.server.GetObject(req, metadata=self.metadata)
def _get(self, handle: bytes, timeout: float):
req = ray_client_pb2.GetRequest(handle=handle, timeout=timeout)
try:
data = self.server.GetObject(req, metadata=self.metadata)
except grpc.RpcError as e:
raise decode_exception(e.details())
if not data.valid:
raise Exception(
"Client GetObject returned invalid data: id invalid?")
raise TaskCancelledError(handle)
return cloudpickle.loads(data.data)
def put(self, vals):
@@ -87,7 +98,7 @@ class Worker:
data = cloudpickle.dumps(val)
req = ray_client_pb2.PutRequest(data=data)
resp = self.server.PutObject(req, metadata=self.metadata)
return ClientObjectRef(resp.id)
return ClientObjectRef.from_remote_ref(resp.ref)
def wait(self,
object_refs: List[ClientObjectRef],
@@ -99,8 +110,8 @@ class Worker:
for ref in object_refs:
assert isinstance(ref, ClientObjectRef)
data = {
"object_refs": [
cloudpickle.dumps(object_ref) for object_ref in object_refs
"object_handles": [
object_ref.handle for object_ref in object_refs
],
"num_returns": num_returns,
"timeout": timeout if timeout else -1
@@ -111,10 +122,12 @@ class Worker:
# TODO(ameer): improve error/exceptions messages.
raise Exception("Client Wait request failed. Reference invalid?")
client_ready_object_ids = [
ClientObjectRef(id) for id in resp.ready_object_ids
ClientObjectRef.from_remote_ref(ref)
for ref in resp.ready_object_ids
]
client_remaining_object_ids = [
ClientObjectRef(id) for id in resp.remaining_object_ids
ClientObjectRef.from_remote_ref(ref)
for ref in resp.remaining_object_ids
]
return (client_ready_object_ids, client_remaining_object_ids)
@@ -138,7 +151,53 @@ class Worker:
task.args.append(pb_arg)
logging.debug("Scheduling %s" % task)
ticket = self.server.Schedule(task, metadata=self.metadata)
return ClientObjectRef(ticket.return_id)
return ClientObjectRef.from_remote_ref(ticket.return_ref)
def close(self):
self.channel.close()
self.server = None
if self.channel:
self.channel.close()
def terminate_actor(self, actor: ClientActorHandle,
no_restart: bool) -> None:
if not isinstance(actor, ClientActorHandle):
raise ValueError("ray.kill() only supported for actors. "
"Got: {}.".format(type(actor)))
term_actor = ray_client_pb2.TerminateRequest.ActorTerminate()
term_actor.handle = actor.actor_ref.handle
term_actor.no_restart = no_restart
try:
term = ray_client_pb2.TerminateRequest(actor=term_actor)
self.server.Terminate(term)
except grpc.RpcError as e:
raise decode_exception(e.details())
def terminate_task(self, obj: ClientObjectRef, force: bool,
recursive: bool) -> None:
if not isinstance(obj, ClientObjectRef):
raise TypeError(
"ray.cancel() only supported for non-actor object refs. "
f"Got: {type(obj)}.")
term_object = ray_client_pb2.TerminateRequest.TaskObjectTerminate()
term_object.handle = obj.handle
term_object.force = force
term_object.recursive = recursive
try:
term = ray_client_pb2.TerminateRequest(task_object=term_object)
self.server.Terminate(term)
except grpc.RpcError as e:
raise decode_exception(e.details())
def get_cluster_info(self, type: ray_client_pb2.ClusterInfoType.TypeEnum):
req = ray_client_pb2.ClusterInfoRequest()
req.type = type
resp = self.server.ClusterInfo(req)
if resp.WhichOneof("response_type") == "resource_table":
return resp.resource_table.table
return json.loads(resp.json)
def is_initialized(self) -> bool:
if self.server is not None:
return self.get_cluster_info(
ray_client_pb2.ClusterInfoType.IS_INITIALIZED)
return False
+2
View File
@@ -95,6 +95,8 @@ py_test_module_list(
"test_dask_callback.py",
"test_debug_tools.py",
"test_experimental_client.py",
"test_experimental_client_metadata.py",
"test_experimental_client_terminate.py",
"test_job.py",
"test_memstat.py",
"test_metrics_agent.py",
+6 -7
View File
@@ -10,10 +10,12 @@ from ray.experimental.client.common import ClientObjectRef
def ray_start_client_server():
server = ray_client_server.serve("localhost:50051", test_mode=True)
ray.connect("localhost:50051")
yield ray
ray.disconnect()
server.stop(0)
reset_api()
try:
yield ray
finally:
ray.disconnect()
server.stop(0)
reset_api()
def test_real_ray_fallback(ray_start_regular_shared):
@@ -35,9 +37,6 @@ def test_real_ray_fallback(ray_start_regular_shared):
nodes = ray.get(get_nodes.remote())
assert len(nodes) == 1, nodes
with pytest.raises(NotImplementedError):
print(ray.nodes())
def test_nested_function(ray_start_regular_shared):
with ray_start_client_server() as ray:
@@ -0,0 +1,25 @@
from ray.tests.test_experimental_client import ray_start_client_server
def test_get_ray_metadata(ray_start_regular_shared):
"""
Test the ClusterInfo client data pathway and API surface
"""
with ray_start_client_server() as ray:
ip_address = ray_start_regular_shared["node_ip_address"]
initialized = ray.is_initialized()
assert initialized
nodes = ray.nodes()
assert len(nodes) == 1, nodes
assert nodes[0]["NodeManagerAddress"] == ip_address
current_node_id = "node:" + ip_address
cluster_resources = ray.cluster_resources()
available_resources = ray.available_resources()
assert cluster_resources["CPU"] == 1.0
assert current_node_id in cluster_resources
assert current_node_id in available_resources
@@ -0,0 +1,99 @@
import pytest
import asyncio
from ray.tests.test_experimental_client import ray_start_client_server
from ray.test_utils import wait_for_condition
from ray.exceptions import TaskCancelledError
from ray.exceptions import RayTaskError
from ray.exceptions import WorkerCrashedError
from ray.exceptions import ObjectLostError
from ray.exceptions import GetTimeoutError
def valid_exceptions(use_force):
if use_force:
return (RayTaskError, TaskCancelledError, WorkerCrashedError,
ObjectLostError)
else:
return (RayTaskError, TaskCancelledError)
def _all_actors_dead(ray):
import ray as real_ray
def _all_actors_dead_internal():
return all(actor["State"] == real_ray.gcs_utils.ActorTableData.DEAD
for actor in list(real_ray.actors().values()))
return _all_actors_dead_internal
def test_kill_actor_immediately_after_creation(ray_start_regular):
with ray_start_client_server() as ray:
@ray.remote
class A:
pass
a = A.remote()
b = A.remote()
ray.kill(a)
ray.kill(b)
wait_for_condition(_all_actors_dead(ray), timeout=10)
@pytest.mark.parametrize("use_force", [True, False])
def test_cancel_chain(ray_start_regular, use_force):
with ray_start_client_server() as ray:
@ray.remote
class SignalActor:
def __init__(self):
self.ready_event = asyncio.Event()
def send(self, clear=False):
self.ready_event.set()
if clear:
self.ready_event.clear()
async def wait(self, should_wait=True):
if should_wait:
await self.ready_event.wait()
signaler = SignalActor.remote()
@ray.remote
def wait_for(t):
return ray.get(t[0])
obj1 = wait_for.remote([signaler.wait.remote()])
obj2 = wait_for.remote([obj1])
obj3 = wait_for.remote([obj2])
obj4 = wait_for.remote([obj3])
assert len(ray.wait([obj1], timeout=.1)[0]) == 0
ray.cancel(obj1, force=use_force)
for ob in [obj1, obj2, obj3, obj4]:
with pytest.raises(valid_exceptions(use_force)):
ray.get(ob)
signaler2 = SignalActor.remote()
obj1 = wait_for.remote([signaler2.wait.remote()])
obj2 = wait_for.remote([obj1])
obj3 = wait_for.remote([obj2])
obj4 = wait_for.remote([obj3])
assert len(ray.wait([obj3], timeout=.1)[0]) == 0
ray.cancel(obj3, force=use_force)
for ob in [obj3, obj4]:
with pytest.raises(valid_exceptions(use_force)):
ray.get(ob)
with pytest.raises(GetTimeoutError):
ray.get(obj1, timeout=.1)
with pytest.raises(GetTimeoutError):
ray.get(obj2, timeout=.1)
signaler2.send.remote()
ray.get(obj1)