From b7f246c4516fa726c6f7ae8b5d1dac206cdb77d5 Mon Sep 17 00:00:00 2001 From: Barak Michener Date: Thu, 10 Dec 2020 19:09:34 -0800 Subject: [PATCH] [ray_client] Include multiple facets of the Ray API (#12736) --- python/ray/exceptions.py | 8 +- python/ray/experimental/client/api.py | 100 ++++++++++++- python/ray/experimental/client/common.py | 36 ++++- .../client/server/core_ray_api.py | 20 ++- .../ray/experimental/client/server/server.py | 134 +++++++++++++++--- python/ray/experimental/client/worker.py | 97 ++++++++++--- python/ray/tests/BUILD | 2 + python/ray/tests/test_experimental_client.py | 13 +- .../test_experimental_client_metadata.py | 25 ++++ .../test_experimental_client_terminate.py | 99 +++++++++++++ src/ray/protobuf/ray_client.proto | 69 ++++++++- 11 files changed, 530 insertions(+), 73 deletions(-) create mode 100644 python/ray/tests/test_experimental_client_metadata.py create mode 100644 python/ray/tests/test_experimental_client_terminate.py diff --git a/python/ray/exceptions.py b/python/ray/exceptions.py index 0456a3a6d..b5a0b477c 100644 --- a/python/ray/exceptions.py +++ b/python/ray/exceptions.py @@ -3,9 +3,8 @@ from traceback import format_exception import colorama -import ray import ray.cloudpickle as pickle -from ray.core.generated.common_pb2 import RayException, Language +from ray.core.generated.common_pb2 import RayException, Language, PYTHON import setproctitle @@ -17,7 +16,7 @@ class RayError(Exception): exc_info = (type(self), self, self.__traceback__) formatted_exception_string = "\n".join(format_exception(*exc_info)) return RayException( - language=ray.Language.PYTHON.value(), + language=PYTHON, serialized_exception=pickle.dumps(self), formatted_exception_string=formatted_exception_string ).SerializeToString() @@ -26,7 +25,7 @@ class RayError(Exception): def from_bytes(b): ray_exception = RayException() ray_exception.ParseFromString(b) - if ray_exception.language == ray.Language.PYTHON.value(): + if ray_exception.language == PYTHON: return pickle.loads(ray_exception.serialized_exception) else: return CrossLanguageError(ray_exception) @@ -81,6 +80,7 @@ class RayTaskError(RayError): pid=None, ip=None): """Initialize a RayTaskError.""" + import ray if proctitle: self.proctitle = proctitle else: diff --git a/python/ray/experimental/client/api.py b/python/ray/experimental/client/api.py index 66ec61c17..13443149c 100644 --- a/python/ray/experimental/client/api.py +++ b/python/ray/experimental/client/api.py @@ -11,8 +11,10 @@ from abc import ABC from abc import abstractmethod -from typing import TYPE_CHECKING, Any, Union +from typing import TYPE_CHECKING, Any, Union, Optional +import ray.core.generated.ray_client_pb2 as ray_client_pb2 if TYPE_CHECKING: + from ray.experimental.client.common import ClientActorHandle from ray.experimental.client.common import ClientStub from ray.experimental.client.common import ClientObjectRef from ray._raylet import ObjectRef @@ -29,13 +31,13 @@ class APIImpl(ABC): """ @abstractmethod - def get(self, *args, **kwargs) -> Any: + def get(self, vals, *, timeout: Optional[float] = None) -> Any: """ get is the hook stub passed on to replace `ray.get` Args: - args: opaque arguments - kwargs: opaque keyword arguments + vals: [Client]ObjectRef or list of these refs to retrieve. + timeout: Optional timeout in milliseconds """ pass @@ -103,6 +105,39 @@ class APIImpl(ABC): """ pass + @abstractmethod + def kill(self, actor, *, no_restart=True): + """ + kill forcibly stops an actor running in the cluster + + Args: + no_restart: Whether this actor should be restarted if it's a + restartable actor. + """ + pass + + @abstractmethod + def cancel(self, obj, *, force=False, recursive=True): + """ + Cancels a task on the cluster. + + If the specified task is pending execution, it will not be executed. If + the task is currently executing, the behavior depends on the ``force`` + flag, as per `ray.cancel()` + + Only non-actor tasks can be canceled. Canceled tasks will not be + retried (max_retries will not be respected). + + Args: + object_ref (ObjectRef): ObjectRef returned by the task + that should be canceled. + force (boolean): Whether to force-kill a running task by killing + the worker that is running the task. + recursive (boolean): Whether to try to cancel tasks submitted by the + task specified. + """ + pass + class ClientAPI(APIImpl): """ @@ -113,8 +148,8 @@ class ClientAPI(APIImpl): def __init__(self, worker): self.worker = worker - def get(self, *args, **kwargs): - return self.worker.get(*args, **kwargs) + def get(self, vals, *, timeout=None): + return self.worker.get(vals, timeout=timeout) def put(self, *args, **kwargs): return self.worker.put(*args, **kwargs) @@ -131,6 +166,59 @@ class ClientAPI(APIImpl): def close(self) -> None: return self.worker.close() + def kill(self, actor: "ClientActorHandle", *, no_restart=True): + return self.worker.terminate_actor(actor, no_restart) + + def cancel(self, obj: "ClientObjectRef", *, force=False, recursive=True): + return self.worker.terminate_task(obj, force, recursive) + + # Various metadata methods for the client that are defined in the protocol. + def is_initialized(self) -> bool: + """ True if our client is connected, and if the server is initialized. + + Returns: + A boolean determining if the client is connected and + server initialized. + """ + return self.worker.is_initialized() + + def nodes(self): + """Get a list of the nodes in the cluster (for debugging only). + + Returns: + Information about the Ray clients in the cluster. + """ + return self.worker.get_cluster_info( + ray_client_pb2.ClusterInfoType.NODES) + + def cluster_resources(self): + """Get the current total cluster resources. + + Note that this information can grow stale as nodes are added to or + removed from the cluster. + + Returns: + A dictionary mapping resource name to the total quantity of that + resource in the cluster. + """ + return self.worker.get_cluster_info( + ray_client_pb2.ClusterInfoType.CLUSTER_RESOURCES) + + def available_resources(self): + """Get the current available cluster resources. + + This is different from `cluster_resources` in that this will return idle + (available) resources rather than total resources. + + Note that this information can grow stale as tasks start and finish. + + Returns: + A dictionary mapping resource name to the total quantity of that + resource in the cluster. + """ + return self.worker.get_cluster_info( + ray_client_pb2.ClusterInfoType.AVAILABLE_RESOURCES) + def __getattr__(self, key: str): if not key.startswith("_"): raise NotImplementedError( diff --git a/python/ray/experimental/client/common.py b/python/ray/experimental/client/common.py index cea5825e3..24b012790 100644 --- a/python/ray/experimental/client/common.py +++ b/python/ray/experimental/client/common.py @@ -4,10 +4,13 @@ from typing import Any from typing import Dict from ray import cloudpickle +import base64 + class ClientBaseRef: - def __init__(self, id): + def __init__(self, id, handle=None): self.id = id + self.handle = handle def __repr__(self): return "%s(%s)" % ( @@ -21,9 +24,14 @@ class ClientBaseRef: def binary(self): return self.id + @classmethod + def from_remote_ref(cls, ref: ray_client_pb2.RemoteRef): + return cls(id=ref.id, handle=ref.handle) + class ClientObjectRef(ClientBaseRef): - pass + def _unpack_ref(self): + return cloudpickle.loads(self.handle) class ClientActorRef(ClientBaseRef): @@ -88,7 +96,7 @@ class ClientRemoteFunc(ClientStub): task = ray_client_pb2.ClientTask() task.type = ray_client_pb2.ClientTask.FUNCTION task.name = self._name - task.payload_id = self._ref.id + task.payload_id = self._ref.handle return task @@ -130,7 +138,7 @@ class ClientActorClass(ClientStub): def remote(self, *args, **kwargs): # Actually instantiate the actor ref = ray.call_remote(self, *args, **kwargs) - return ClientActorHandle(ClientActorRef(ref.id), self) + return ClientActorHandle(ClientActorRef(ref.id, ref.handle), self) def __repr__(self): return "ClientRemoteActor(%s, %s)" % (self._name, self._ref) @@ -146,7 +154,7 @@ class ClientActorClass(ClientStub): task = ray_client_pb2.ClientTask() task.type = ray_client_pb2.ClientTask.ACTOR task.name = self._name - task.payload_id = self._ref.id + task.payload_id = self._ref.handle return task @@ -174,7 +182,7 @@ class ClientActorHandle(ClientStub): def _get_ray_remote_impl(self): if self._real_actor_handle is None: - self._real_actor_handle = cloudpickle.loads(self.actor_ref.id) + self._real_actor_handle = cloudpickle.loads(self.actor_ref.handle) return self._real_actor_handle def __getstate__(self) -> Dict: @@ -190,6 +198,10 @@ class ClientActorHandle(ClientStub): self.actor_class = state["actor_class"] self._real_actor_handle = state["_real_actor_handle"] + @property + def _actor_id(self): + return self.actor_ref.id + def __getattr__(self, key): return ClientRemoteMethod(self, key) @@ -244,7 +256,7 @@ class ClientRemoteMethod(ClientStub): task = ray_client_pb2.ClientTask() task.type = ray_client_pb2.ClientTask.METHOD task.name = self.method_name - task.payload_id = self.actor_handle.actor_ref.id + task.payload_id = self.actor_handle.actor_ref.handle return task @@ -266,3 +278,13 @@ def convert_to_arg(val): out.local = ray_client_pb2.Arg.Locality.INTERNED out.data = cloudpickle.dumps(val) return out + + +def encode_exception(exception) -> str: + data = cloudpickle.dumps(exception) + return base64.standard_b64encode(data).decode() + + +def decode_exception(data) -> Exception: + data = base64.standard_b64decode(data) + return cloudpickle.loads(data) diff --git a/python/ray/experimental/client/server/core_ray_api.py b/python/ray/experimental/client/server/core_ray_api.py index 83cbc36c0..6513021a8 100644 --- a/python/ray/experimental/client/server/core_ray_api.py +++ b/python/ray/experimental/client/server/core_ray_api.py @@ -8,6 +8,7 @@ # making into the core-ray module are contained and well-defined. from typing import Any +from typing import Optional from typing import Union import ray @@ -24,8 +25,14 @@ class CoreRayAPI(APIImpl): to core ray when passed client stubs. """ - def get(self, *args, **kwargs): - return ray.get(*args, **kwargs) + def get(self, vals, *, timeout: Optional[float] = None) -> Any: + if isinstance(vals, list): + if isinstance(vals[0], ClientObjectRef): + return ray.get( + [val._unpack_ref() for val in vals], timeout=timeout) + elif isinstance(vals, ClientObjectRef): + return ray.get(vals._unpack_ref(), timeout=timeout) + return ray.get(vals, timeout=timeout) def put(self, vals: Any, *args, **kwargs) -> Union[ClientObjectRef, ray._raylet.ObjectRef]: @@ -43,6 +50,15 @@ class CoreRayAPI(APIImpl): def close(self) -> None: return None + def kill(self, actor, *, no_restart=True): + return ray.kill(actor, no_restart=no_restart) + + def cancel(self, obj, *, force=False, recursive=True): + return ray.cancel(obj, force=force, recursive=recursive) + + def is_initialized(self) -> bool: + return ray.is_initialized() + # Allow for generic fallback to ray.* in remote methods. This allows calls # like ray.nodes() to be run in remote functions even though the client # doesn't currently support them. diff --git a/python/ray/experimental/client/server/server.py b/python/ray/experimental/client/server/server.py index a2958f6d1..4db4cd22a 100644 --- a/python/ray/experimental/client/server/server.py +++ b/python/ray/experimental/client/server/server.py @@ -3,12 +3,15 @@ from concurrent import futures import grpc from ray import cloudpickle import ray +import ray.state import ray.core.generated.ray_client_pb2 as ray_client_pb2 import ray.core.generated.ray_client_pb2_grpc as ray_client_pb2_grpc import time import inspect +import json from ray.experimental.client import stash_api_for_tests, _set_server_api from ray.experimental.client.common import convert_from_arg +from ray.experimental.client.common import encode_exception from ray.experimental.client.common import ClientObjectRef from ray.experimental.client.server.core_ray_api import RayServerAPI @@ -23,19 +26,81 @@ class RayletServicer(ray_client_pb2_grpc.RayletDriverServicer): self.registered_actor_classes = {} self._test_mode = test_mode + def ClusterInfo(self, request, + context=None) -> ray_client_pb2.ClusterInfoResponse: + resp = ray_client_pb2.ClusterInfoResponse() + resp.type = request.type + if request.type == ray_client_pb2.ClusterInfoType.CLUSTER_RESOURCES: + resources = ray.cluster_resources() + # Normalize resources into floats + # (the function may return values that are ints) + float_resources = {k: float(v) for k, v in resources.items()} + resp.resource_table.CopyFrom( + ray_client_pb2.ClusterInfoResponse.ResourceTable( + table=float_resources)) + elif request.type == ray_client_pb2.ClusterInfoType.AVAILABLE_RESOURCES: + resources = ray.available_resources() + # Normalize resources into floats + # (the function may return values that are ints) + float_resources = {k: float(v) for k, v in resources.items()} + resp.resource_table.CopyFrom( + ray_client_pb2.ClusterInfoResponse.ResourceTable( + table=float_resources)) + else: + resp.json = self._return_debug_cluster_info(request, context) + return resp + + def _return_debug_cluster_info(self, request, context=None) -> str: + data = None + if request.type == ray_client_pb2.ClusterInfoType.NODES: + data = ray.nodes() + elif request.type == ray_client_pb2.ClusterInfoType.IS_INITIALIZED: + data = ray.is_initialized() + else: + raise TypeError("Unsupported cluster info type") + return json.dumps(data) + + def Terminate(self, request, context=None): + if request.WhichOneof("terminate_type") == "task_object": + try: + object_ref = cloudpickle.loads(request.task_object.handle) + ray.cancel( + object_ref, + force=request.task_object.force, + recursive=request.task_object.recursive) + except Exception as e: + return_exception_in_context(e, context) + elif request.WhichOneof("terminate_type") == "actor": + try: + actor_ref = cloudpickle.loads(request.actor.handle) + ray.kill(actor_ref, no_restart=request.actor.no_restart) + except Exception as e: + return_exception_in_context(e, context) + else: + raise RuntimeError( + "Client requested termination without providing a valid terminate_type" + ) + return ray_client_pb2.TerminateResponse(ok=True) + def GetObject(self, request, context=None): - if request.id not in self.object_refs: + request_ref = cloudpickle.loads(request.handle) + if request_ref.binary() not in self.object_refs: return ray_client_pb2.GetResponse(valid=False) - objectref = self.object_refs[request.id] + objectref = self.object_refs[request_ref.binary()] logger.info("get: %s" % objectref) - item = ray.get(objectref) + try: + item = ray.get(objectref, timeout=request.timeout) + except Exception as e: + return_exception_in_context(e, context) item_ser = cloudpickle.dumps(item) return ray_client_pb2.GetResponse(valid=True, data=item_ser) def PutObject(self, request, context=None) -> ray_client_pb2.PutResponse: obj = cloudpickle.loads(request.data) objectref = self._put_and_retain_obj(obj) - return ray_client_pb2.PutResponse(id=objectref.binary()) + pickled_ref = cloudpickle.dumps(objectref) + return ray_client_pb2.PutResponse( + ref=make_remote_ref(objectref.binary(), pickled_ref)) def _put_and_retain_obj(self, obj) -> ray.ObjectRef: objectref = ray.put(obj) @@ -44,14 +109,14 @@ class RayletServicer(ray_client_pb2_grpc.RayletDriverServicer): return objectref def WaitObject(self, request, context=None) -> ray_client_pb2.WaitResponse: - object_refs = [cloudpickle.loads(o) for o in request.object_refs] + object_refs = [cloudpickle.loads(o) for o in request.object_handles] num_returns = request.num_returns timeout = request.timeout object_refs_ids = [] for object_ref in object_refs: - if object_ref.id not in self.object_refs: + if object_ref.binary() not in self.object_refs: return ray_client_pb2.WaitResponse(valid=False) - object_refs_ids.append(self.object_refs[object_ref.id]) + object_refs_ids.append(self.object_refs[object_ref.binary()]) try: ready_object_refs, remaining_object_refs = ray.wait( object_refs_ids, @@ -63,11 +128,16 @@ class RayletServicer(ray_client_pb2_grpc.RayletDriverServicer): logger.info("wait: %s %s" % (str(ready_object_refs), str(remaining_object_refs))) ready_object_ids = [ - ready_object_ref.binary() for ready_object_ref in ready_object_refs + make_remote_ref( + id=ready_object_ref.binary(), + handle=cloudpickle.dumps(ready_object_ref), + ) for ready_object_ref in ready_object_refs ] remaining_object_ids = [ - remaining_object_ref.binary() - for remaining_object_ref in remaining_object_refs + make_remote_ref( + id=remaining_object_ref.binary(), + handle=cloudpickle.dumps(remaining_object_ref), + ) for remaining_object_ref in remaining_object_refs ] return ray_client_pb2.WaitResponse( valid=True, @@ -103,41 +173,46 @@ class RayletServicer(ray_client_pb2_grpc.RayletDriverServicer): with stash_api_for_tests(self._test_mode): output = getattr(actor_handle, task.name).remote(*arglist) self.object_refs[output.binary()] = output - return ray_client_pb2.ClientTaskTicket(return_id=output.binary()) + pickled_ref = cloudpickle.dumps(output) + return ray_client_pb2.ClientTaskTicket( + return_ref=make_remote_ref(output.binary(), pickled_ref)) def _schedule_actor(self, task: ray_client_pb2.ClientTask, context=None, prepared_args=None) -> ray_client_pb2.ClientTaskTicket: with stash_api_for_tests(self._test_mode): - if task.payload_id not in self.registered_actor_classes: - actor_class_ref = self.object_refs[task.payload_id] + payload_ref = cloudpickle.loads(task.payload_id) + if payload_ref.binary() not in self.registered_actor_classes: + actor_class_ref = self.object_refs[payload_ref.binary()] actor_class = ray.get(actor_class_ref) if not inspect.isclass(actor_class): raise Exception("Attempting to schedule actor that " "isn't a class.") reg_class = ray.remote(actor_class) - self.registered_actor_classes[task.payload_id] = reg_class - remote_class = self.registered_actor_classes[task.payload_id] + self.registered_actor_classes[payload_ref.binary()] = reg_class + remote_class = self.registered_actor_classes[payload_ref.binary()] arglist = _convert_args(task.args, prepared_args) actor = remote_class.remote(*arglist) actorhandle = cloudpickle.dumps(actor) self.actor_refs[actorhandle] = actor - return ray_client_pb2.ClientTaskTicket(return_id=actorhandle) + return ray_client_pb2.ClientTaskTicket( + return_ref=make_remote_ref(actor._actor_id.binary(), actorhandle)) def _schedule_function( self, task: ray_client_pb2.ClientTask, context=None, prepared_args=None) -> ray_client_pb2.ClientTaskTicket: - if task.payload_id not in self.function_refs: - funcref = self.object_refs[task.payload_id] + payload_ref = cloudpickle.loads(task.payload_id) + if payload_ref.binary() not in self.function_refs: + funcref = self.object_refs[payload_ref.binary()] func = ray.get(funcref) if not inspect.isfunction(func): raise Exception("Attempting to schedule function that " "isn't a function.") - self.function_refs[task.payload_id] = ray.remote(func) - remote_func = self.function_refs[task.payload_id] + self.function_refs[payload_ref.binary()] = ray.remote(func) + remote_func = self.function_refs[payload_ref.binary()] arglist = _convert_args(task.args, prepared_args) # Prepare call if we're in a test with stash_api_for_tests(self._test_mode): @@ -145,7 +220,9 @@ class RayletServicer(ray_client_pb2_grpc.RayletDriverServicer): if output.binary() in self.object_refs: raise Exception("already found it") self.object_refs[output.binary()] = output - return ray_client_pb2.ClientTaskTicket(return_id=output.binary()) + pickled_output = cloudpickle.dumps(output) + return ray_client_pb2.ClientTaskTicket( + return_ref=make_remote_ref(output.binary(), pickled_output)) def _convert_args(arg_list, prepared_args=None): @@ -155,12 +232,25 @@ def _convert_args(arg_list, prepared_args=None): for arg in arg_list: t = convert_from_arg(arg) if isinstance(t, ClientObjectRef): - out.append(ray.ObjectRef(t.id)) + out.append(t._unpack_ref()) else: out.append(t) return out +def make_remote_ref(id: bytes, handle: bytes) -> ray_client_pb2.RemoteRef: + return ray_client_pb2.RemoteRef( + id=id, + handle=handle, + ) + + +def return_exception_in_context(err, context): + if context is not None: + context.set_details(encode_exception(err)) + context.set_code(grpc.StatusCode.INTERNAL) + + def serve(connection_str, test_mode=False): server = grpc.server(futures.ThreadPoolExecutor(max_workers=10)) task_servicer = RayletServicer(test_mode=test_mode) diff --git a/python/ray/experimental/client/worker.py b/python/ray/experimental/client/worker.py index 87e5f6897..0a108e4f2 100644 --- a/python/ray/experimental/client/worker.py +++ b/python/ray/experimental/client/worker.py @@ -3,19 +3,25 @@ It implements the Ray API functions that are forwarded through grpc calls to the server. """ import inspect +import json import logging +from typing import Any from typing import List from typing import Tuple +from typing import Optional import ray.cloudpickle as cloudpickle from ray.util.inspect import is_cython import grpc +from ray.exceptions import TaskCancelledError import ray.core.generated.ray_client_pb2 as ray_client_pb2 import ray.core.generated.ray_client_pb2_grpc as ray_client_pb2_grpc from ray.experimental.client.common import convert_to_arg +from ray.experimental.client.common import decode_exception from ray.experimental.client.common import ClientObjectRef from ray.experimental.client.common import ClientActorClass +from ray.experimental.client.common import ClientActorHandle from ray.experimental.client.common import ClientRemoteFunc logger = logging.getLogger(__name__) @@ -35,6 +41,7 @@ class Worker: metadata: additional metadata passed in the grpc request headers. """ self.metadata = metadata + self.channel = None if stub is None: if secure: credentials = grpc.ssl_channel_credentials() @@ -45,28 +52,32 @@ class Worker: else: self.server = stub - def get(self, ids): + def get(self, vals, *, timeout: Optional[float] = None) -> Any: to_get = [] single = False - if isinstance(ids, list): - to_get = [x.id for x in ids] - elif isinstance(ids, ClientObjectRef): - to_get = [ids.id] + if isinstance(vals, list): + to_get = [x.handle for x in vals] + elif isinstance(vals, ClientObjectRef): + to_get = [vals.handle] single = True else: raise Exception("Can't get something that's not a " - "list of IDs or just an ID: %s" % type(ids)) - out = [self._get(x) for x in to_get] + "list of IDs or just an ID: %s" % type(vals)) + if timeout is None: + timeout = 0 + out = [self._get(x, timeout) for x in to_get] if single: out = out[0] return out - def _get(self, id: bytes): - req = ray_client_pb2.GetRequest(id=id) - data = self.server.GetObject(req, metadata=self.metadata) + def _get(self, handle: bytes, timeout: float): + req = ray_client_pb2.GetRequest(handle=handle, timeout=timeout) + try: + data = self.server.GetObject(req, metadata=self.metadata) + except grpc.RpcError as e: + raise decode_exception(e.details()) if not data.valid: - raise Exception( - "Client GetObject returned invalid data: id invalid?") + raise TaskCancelledError(handle) return cloudpickle.loads(data.data) def put(self, vals): @@ -87,7 +98,7 @@ class Worker: data = cloudpickle.dumps(val) req = ray_client_pb2.PutRequest(data=data) resp = self.server.PutObject(req, metadata=self.metadata) - return ClientObjectRef(resp.id) + return ClientObjectRef.from_remote_ref(resp.ref) def wait(self, object_refs: List[ClientObjectRef], @@ -99,8 +110,8 @@ class Worker: for ref in object_refs: assert isinstance(ref, ClientObjectRef) data = { - "object_refs": [ - cloudpickle.dumps(object_ref) for object_ref in object_refs + "object_handles": [ + object_ref.handle for object_ref in object_refs ], "num_returns": num_returns, "timeout": timeout if timeout else -1 @@ -111,10 +122,12 @@ class Worker: # TODO(ameer): improve error/exceptions messages. raise Exception("Client Wait request failed. Reference invalid?") client_ready_object_ids = [ - ClientObjectRef(id) for id in resp.ready_object_ids + ClientObjectRef.from_remote_ref(ref) + for ref in resp.ready_object_ids ] client_remaining_object_ids = [ - ClientObjectRef(id) for id in resp.remaining_object_ids + ClientObjectRef.from_remote_ref(ref) + for ref in resp.remaining_object_ids ] return (client_ready_object_ids, client_remaining_object_ids) @@ -138,7 +151,53 @@ class Worker: task.args.append(pb_arg) logging.debug("Scheduling %s" % task) ticket = self.server.Schedule(task, metadata=self.metadata) - return ClientObjectRef(ticket.return_id) + return ClientObjectRef.from_remote_ref(ticket.return_ref) def close(self): - self.channel.close() + self.server = None + if self.channel: + self.channel.close() + + def terminate_actor(self, actor: ClientActorHandle, + no_restart: bool) -> None: + if not isinstance(actor, ClientActorHandle): + raise ValueError("ray.kill() only supported for actors. " + "Got: {}.".format(type(actor))) + term_actor = ray_client_pb2.TerminateRequest.ActorTerminate() + term_actor.handle = actor.actor_ref.handle + term_actor.no_restart = no_restart + try: + term = ray_client_pb2.TerminateRequest(actor=term_actor) + self.server.Terminate(term) + except grpc.RpcError as e: + raise decode_exception(e.details()) + + def terminate_task(self, obj: ClientObjectRef, force: bool, + recursive: bool) -> None: + if not isinstance(obj, ClientObjectRef): + raise TypeError( + "ray.cancel() only supported for non-actor object refs. " + f"Got: {type(obj)}.") + term_object = ray_client_pb2.TerminateRequest.TaskObjectTerminate() + term_object.handle = obj.handle + term_object.force = force + term_object.recursive = recursive + try: + term = ray_client_pb2.TerminateRequest(task_object=term_object) + self.server.Terminate(term) + except grpc.RpcError as e: + raise decode_exception(e.details()) + + def get_cluster_info(self, type: ray_client_pb2.ClusterInfoType.TypeEnum): + req = ray_client_pb2.ClusterInfoRequest() + req.type = type + resp = self.server.ClusterInfo(req) + if resp.WhichOneof("response_type") == "resource_table": + return resp.resource_table.table + return json.loads(resp.json) + + def is_initialized(self) -> bool: + if self.server is not None: + return self.get_cluster_info( + ray_client_pb2.ClusterInfoType.IS_INITIALIZED) + return False diff --git a/python/ray/tests/BUILD b/python/ray/tests/BUILD index ad9ab4ec0..573131fec 100644 --- a/python/ray/tests/BUILD +++ b/python/ray/tests/BUILD @@ -95,6 +95,8 @@ py_test_module_list( "test_dask_callback.py", "test_debug_tools.py", "test_experimental_client.py", + "test_experimental_client_metadata.py", + "test_experimental_client_terminate.py", "test_job.py", "test_memstat.py", "test_metrics_agent.py", diff --git a/python/ray/tests/test_experimental_client.py b/python/ray/tests/test_experimental_client.py index 430574dd2..cbe52675f 100644 --- a/python/ray/tests/test_experimental_client.py +++ b/python/ray/tests/test_experimental_client.py @@ -10,10 +10,12 @@ from ray.experimental.client.common import ClientObjectRef def ray_start_client_server(): server = ray_client_server.serve("localhost:50051", test_mode=True) ray.connect("localhost:50051") - yield ray - ray.disconnect() - server.stop(0) - reset_api() + try: + yield ray + finally: + ray.disconnect() + server.stop(0) + reset_api() def test_real_ray_fallback(ray_start_regular_shared): @@ -35,9 +37,6 @@ def test_real_ray_fallback(ray_start_regular_shared): nodes = ray.get(get_nodes.remote()) assert len(nodes) == 1, nodes - with pytest.raises(NotImplementedError): - print(ray.nodes()) - def test_nested_function(ray_start_regular_shared): with ray_start_client_server() as ray: diff --git a/python/ray/tests/test_experimental_client_metadata.py b/python/ray/tests/test_experimental_client_metadata.py new file mode 100644 index 000000000..d0bb86c9e --- /dev/null +++ b/python/ray/tests/test_experimental_client_metadata.py @@ -0,0 +1,25 @@ +from ray.tests.test_experimental_client import ray_start_client_server + + +def test_get_ray_metadata(ray_start_regular_shared): + """ + Test the ClusterInfo client data pathway and API surface + """ + with ray_start_client_server() as ray: + ip_address = ray_start_regular_shared["node_ip_address"] + + initialized = ray.is_initialized() + assert initialized + + nodes = ray.nodes() + assert len(nodes) == 1, nodes + assert nodes[0]["NodeManagerAddress"] == ip_address + + current_node_id = "node:" + ip_address + + cluster_resources = ray.cluster_resources() + available_resources = ray.available_resources() + + assert cluster_resources["CPU"] == 1.0 + assert current_node_id in cluster_resources + assert current_node_id in available_resources diff --git a/python/ray/tests/test_experimental_client_terminate.py b/python/ray/tests/test_experimental_client_terminate.py new file mode 100644 index 000000000..e44c617e6 --- /dev/null +++ b/python/ray/tests/test_experimental_client_terminate.py @@ -0,0 +1,99 @@ +import pytest +import asyncio +from ray.tests.test_experimental_client import ray_start_client_server +from ray.test_utils import wait_for_condition +from ray.exceptions import TaskCancelledError +from ray.exceptions import RayTaskError +from ray.exceptions import WorkerCrashedError +from ray.exceptions import ObjectLostError +from ray.exceptions import GetTimeoutError + + +def valid_exceptions(use_force): + if use_force: + return (RayTaskError, TaskCancelledError, WorkerCrashedError, + ObjectLostError) + else: + return (RayTaskError, TaskCancelledError) + + +def _all_actors_dead(ray): + import ray as real_ray + + def _all_actors_dead_internal(): + return all(actor["State"] == real_ray.gcs_utils.ActorTableData.DEAD + for actor in list(real_ray.actors().values())) + + return _all_actors_dead_internal + + +def test_kill_actor_immediately_after_creation(ray_start_regular): + with ray_start_client_server() as ray: + + @ray.remote + class A: + pass + + a = A.remote() + b = A.remote() + + ray.kill(a) + ray.kill(b) + wait_for_condition(_all_actors_dead(ray), timeout=10) + + +@pytest.mark.parametrize("use_force", [True, False]) +def test_cancel_chain(ray_start_regular, use_force): + with ray_start_client_server() as ray: + + @ray.remote + class SignalActor: + def __init__(self): + self.ready_event = asyncio.Event() + + def send(self, clear=False): + self.ready_event.set() + if clear: + self.ready_event.clear() + + async def wait(self, should_wait=True): + if should_wait: + await self.ready_event.wait() + + signaler = SignalActor.remote() + + @ray.remote + def wait_for(t): + return ray.get(t[0]) + + obj1 = wait_for.remote([signaler.wait.remote()]) + obj2 = wait_for.remote([obj1]) + obj3 = wait_for.remote([obj2]) + obj4 = wait_for.remote([obj3]) + + assert len(ray.wait([obj1], timeout=.1)[0]) == 0 + ray.cancel(obj1, force=use_force) + for ob in [obj1, obj2, obj3, obj4]: + with pytest.raises(valid_exceptions(use_force)): + ray.get(ob) + + signaler2 = SignalActor.remote() + obj1 = wait_for.remote([signaler2.wait.remote()]) + obj2 = wait_for.remote([obj1]) + obj3 = wait_for.remote([obj2]) + obj4 = wait_for.remote([obj3]) + + assert len(ray.wait([obj3], timeout=.1)[0]) == 0 + ray.cancel(obj3, force=use_force) + for ob in [obj3, obj4]: + with pytest.raises(valid_exceptions(use_force)): + ray.get(ob) + + with pytest.raises(GetTimeoutError): + ray.get(obj1, timeout=.1) + + with pytest.raises(GetTimeoutError): + ray.get(obj2, timeout=.1) + + signaler2.send.remote() + ray.get(obj1) diff --git a/src/ray/protobuf/ray_client.proto b/src/ray/protobuf/ray_client.proto index fd8fe5345..d4c392321 100644 --- a/src/ray/protobuf/ray_client.proto +++ b/src/ray/protobuf/ray_client.proto @@ -42,8 +42,13 @@ message ClientTask { repeated Arg args = 4; } +message RemoteRef { + bytes id = 1; + bytes handle = 2; +} + message ClientTaskTicket { - bytes return_id = 1; + RemoteRef return_ref = 1; } message PutRequest { @@ -51,27 +56,75 @@ message PutRequest { } message PutResponse { - bytes id = 1; + RemoteRef ref = 1; } message GetRequest { - bytes id = 1; + bytes handle = 1; + float timeout = 2; } message GetResponse { bool valid = 1; bytes data = 2; } + message WaitRequest { - repeated bytes object_refs = 1; + repeated bytes object_handles = 1; int64 num_returns = 2; double timeout = 3; } message WaitResponse { bool valid = 1; - repeated bytes ready_object_ids = 2; - repeated bytes remaining_object_ids = 3; + repeated RemoteRef ready_object_ids = 2; + repeated RemoteRef remaining_object_ids = 3; +} + +message ClusterInfoType { + // Namespace the enum, as it collides in the overall package. + enum TypeEnum { + IS_INITIALIZED = 0; + NODES = 1; + CLUSTER_RESOURCES = 2; + AVAILABLE_RESOURCES = 3; + } +} + +message ClusterInfoRequest { + ClusterInfoType.TypeEnum type = 1; +} + +message ClusterInfoResponse { + message ResourceTable { + map table = 1; + } + ClusterInfoType.TypeEnum type = 1; + oneof response_type { + string json = 2; + ResourceTable resource_table = 3; + } +} + +message TerminateRequest { + message ActorTerminate { + bytes handle = 1; + bool no_restart = 2; + } + message TaskObjectTerminate { + bytes handle = 1; + bool force = 2; + bool recursive = 3; + } + + oneof terminate_type { + ActorTerminate actor = 1; + TaskObjectTerminate task_object = 2; + } +} + +message TerminateResponse { + bool ok = 1; } service RayletDriver { @@ -83,4 +136,8 @@ service RayletDriver { } rpc Schedule(ClientTask) returns (ClientTaskTicket) { } + rpc Terminate(TerminateRequest) returns (TerminateResponse) { + } + rpc ClusterInfo(ClusterInfoRequest) returns (ClusterInfoResponse) { + } }