diff --git a/python/ray/experimental/client/__init__.py b/python/ray/experimental/client/__init__.py index a6bba39ed..2af86d023 100644 --- a/python/ray/experimental/client/__init__.py +++ b/python/ray/experimental/client/__init__.py @@ -91,15 +91,22 @@ def _get_client_api() -> APIImpl: return api +def _get_server_instance(): + """Used inside tests to inspect the running server. + """ + global _server_api + if _server_api is not None: + return _server_api.server + + class RayAPIStub: def connect(self, conn_str: str, secure: bool = False, metadata: List[Tuple[str, str]] = None, - stub=None): + stub=None) -> None: from ray.experimental.client.worker import Worker - _client_worker = Worker( - conn_str, secure=secure, metadata=metadata, stub=stub) + _client_worker = Worker(conn_str, secure=secure, metadata=metadata) _set_client_api(ClientAPI(_client_worker)) def disconnect(self): @@ -113,6 +120,10 @@ class RayAPIStub: api = _get_client_api() return getattr(api, key) + def is_connected(self) -> bool: + global _client_api + return _client_api is not None + ray = RayAPIStub() diff --git a/python/ray/experimental/client/api.py b/python/ray/experimental/client/api.py index 304cc4467..5167e5988 100644 --- a/python/ray/experimental/client/api.py +++ b/python/ray/experimental/client/api.py @@ -138,6 +138,31 @@ class APIImpl(ABC): """ pass + @abstractmethod + def call_release(self, id: bytes) -> None: + """ + Attempts to release an object reference. + + When client references are destructed, they release their reference, + which can opportunistically send a notification through the datachannel + to release the reference being held for that object on the server. + + Args: + id: The id of the reference to release on the server side. + """ + + @abstractmethod + def call_retain(self, id: bytes) -> None: + """ + Attempts to retain a client object reference. + + Increments the reference count on the client side, to prevent + the client worker from attempting to release the server reference. + + Args: + id: The id of the reference to retain on the client side. + """ + class ClientAPI(APIImpl): """ @@ -163,6 +188,12 @@ class ClientAPI(APIImpl): def call_remote(self, instance: "ClientStub", *args, **kwargs): return self.worker.call_remote(instance, *args, **kwargs) + def call_release(self, id: bytes) -> None: + return self.worker.call_release(id) + + def call_retain(self, id: bytes) -> None: + return self.worker.call_retain(id) + def close(self) -> None: return self.worker.close() diff --git a/python/ray/experimental/client/client_pickler.py b/python/ray/experimental/client/client_pickler.py new file mode 100644 index 000000000..73df31c0e --- /dev/null +++ b/python/ray/experimental/client/client_pickler.py @@ -0,0 +1,123 @@ +""" +Implements the client side of the client/server pickling protocol. + +All ray client client/server data transfer happens through this pickling +protocol. The model is as follows: + + * All Client objects (eg ClientObjectRef) always live on the client and + are never represented in the server + * All Ray objects (eg, ray.ObjectRef) always live on the server and are + never returned to the client + * In order to translate between these two references, PickleStub tuples + are generated as persistent ids in the data blobs during the pickling + and unpickling of these objects. + +The PickleStubs have just enough information to find or generate their +associated partner object on either side. + +This also has the advantage of avoiding predefined pickle behavior for ray +objects, which may include ray internal reference counting. + +ClientPickler dumps things from the client into the appropriate stubs +ServerUnpickler loads stubs from the server into their client counterparts. +""" + +import cloudpickle +import io +import sys + +from typing import NamedTuple +from typing import Any + +from ray.experimental.client.common import ClientObjectRef +from ray.experimental.client.common import ClientActorHandle +from ray.experimental.client.common import ClientActorRef +from ray.experimental.client.common import ClientRemoteFunc +from ray.experimental.client.common import SelfReferenceSentinel +import ray.core.generated.ray_client_pb2 as ray_client_pb2 + +if sys.version_info < (3, 8): + try: + import pickle5 as pickle # noqa: F401 + except ImportError: + import pickle # noqa: F401 +else: + import pickle # noqa: F401 + +PickleStub = NamedTuple("PickleStub", [("type", str), ("client_id", str), + ("ref_id", bytes)]) + + +class ClientPickler(cloudpickle.CloudPickler): + def __init__(self, client_id, *args, **kwargs): + super().__init__(*args, **kwargs) + self.client_id = client_id + + def persistent_id(self, obj): + if isinstance(obj, ClientObjectRef): + return PickleStub( + type="Object", + client_id=self.client_id, + ref_id=obj.id, + ) + elif isinstance(obj, ClientActorHandle): + return PickleStub( + type="Actor", + client_id=self.client_id, + ref_id=obj._actor_id, + ) + elif isinstance(obj, ClientRemoteFunc): + # TODO(barakmich): This is going to have trouble with mutually + # recursive functions that haven't, as yet, been executed. It's + # relatively doable (keep track of intermediate refs in progress + # with ensure_ref and return appropriately) But punting for now. + if obj._ref is None: + obj._ensure_ref() + if type(obj._ref) == SelfReferenceSentinel: + return PickleStub( + type="RemoteFuncSelfReference", + client_id=self.client_id, + ref_id=b"") + return PickleStub( + type="RemoteFunc", + client_id=self.client_id, + ref_id=obj._ref.id) + return None + + +class ServerUnpickler(pickle.Unpickler): + def persistent_load(self, pid): + assert isinstance(pid, PickleStub) + if pid.type == "Object": + return ClientObjectRef(id=pid.ref_id) + elif pid.type == "Actor": + return ClientActorHandle(ClientActorRef(id=pid.ref_id)) + else: + raise NotImplementedError("Being passed back an unknown stub") + + +def dumps_from_client(obj: Any, client_id: str, protocol=None) -> bytes: + with io.BytesIO() as file: + cp = ClientPickler(client_id, file, protocol=protocol) + cp.dump(obj) + return file.getvalue() + + +def loads_from_server(data: bytes, + *, + fix_imports=True, + encoding="ASCII", + errors="strict") -> Any: + if isinstance(data, str): + raise TypeError("Can't load pickle from unicode string") + file = io.BytesIO(data) + return ServerUnpickler( + file, fix_imports=fix_imports, encoding=encoding, + errors=errors).load() + + +def convert_to_arg(val: Any, client_id: str) -> ray_client_pb2.Arg: + out = ray_client_pb2.Arg() + out.local = ray_client_pb2.Arg.Locality.INTERNED + out.data = dumps_from_client(val, client_id) + return out diff --git a/python/ray/experimental/client/common.py b/python/ray/experimental/client/common.py index 24b012790..74f11c2c2 100644 --- a/python/ray/experimental/client/common.py +++ b/python/ray/experimental/client/common.py @@ -1,16 +1,12 @@ import ray.core.generated.ray_client_pb2 as ray_client_pb2 from ray.experimental.client import ray -from typing import Any from typing import Dict -from ray import cloudpickle - -import base64 class ClientBaseRef: - def __init__(self, id, handle=None): - self.id = id - self.handle = handle + def __init__(self, id: bytes): + self.id: bytes = id + ray.call_retain(id) def __repr__(self): return "%s(%s)" % ( @@ -24,14 +20,13 @@ class ClientBaseRef: def binary(self): return self.id - @classmethod - def from_remote_ref(cls, ref: ray_client_pb2.RemoteRef): - return cls(id=ref.id, handle=ref.handle) + def __del__(self): + if ray.is_connected(): + ray.call_release(self.id) class ClientObjectRef(ClientBaseRef): - def _unpack_ref(self): - return cloudpickle.loads(self.handle) + pass class ClientActorRef(ClientBaseRef): @@ -53,50 +48,42 @@ class ClientRemoteFunc(ClientStub): _func: The actual function to execute remotely _name: The original name of the function _ref: The ClientObjectRef of the pickled code of the function, _func - _raylet_remote: The Raylet-side ray.remote_function.RemoteFunction - for this object """ def __init__(self, f): self._func = f self._name = f.__name__ - self.id = None - - # self._ref can be lazily instantiated. Rather than eagerly creating - # function data objects in the server we can put them just before we - # execute the function, especially in cases where many @ray.remote - # functions exist in a library and only a handful are ever executed by - # a user of the library. - # - # TODO(barakmich): This ref might actually be better as a serialized - # ObjectRef. This requires being able to serialize the ref without - # pinning it (as the lifetime of the ref is tied with the server, not - # the client) self._ref = None - self._raylet_remote = None def __call__(self, *args, **kwargs): raise TypeError(f"Remote function cannot be called directly. " "Use {self._name}.remote method instead") def remote(self, *args, **kwargs): - return ray.call_remote(self, *args, **kwargs) - - def _get_ray_remote_impl(self): - if self._raylet_remote is None: - self._raylet_remote = ray.remote(self._func) - return self._raylet_remote + return ClientObjectRef(ray.call_remote(self, *args, **kwargs)) def __repr__(self): return "ClientRemoteFunc(%s, %s)" % (self._name, self._ref) - def _prepare_client_task(self) -> ray_client_pb2.ClientTask: + def _ensure_ref(self): if self._ref is None: + # While calling ray.put() on our function, if + # our function is recursive, it will attempt to + # encode the ClientRemoteFunc -- itself -- and + # infinitely recurse on _ensure_ref. + # + # So we set the state of the reference to be an + # in-progress self reference value, which + # the encoding can detect and handle correctly. + self._ref = SelfReferenceSentinel() self._ref = ray.put(self._func) + + def _prepare_client_task(self) -> ray_client_pb2.ClientTask: + self._ensure_ref() task = ray_client_pb2.ClientTask() task.type = ray_client_pb2.ClientTask.FUNCTION task.name = self._name - task.payload_id = self._ref.handle + task.payload_id = self._ref.id return task @@ -109,14 +96,12 @@ class ClientActorClass(ClientStub): actor_cls: The actual class to execute remotely _name: The original name of the class _ref: The ClientObjectRef of the pickled `actor_cls` - _raylet_remote: The Raylet-side ray.ActorClass for this object """ def __init__(self, actor_cls): self.actor_cls = actor_cls self._name = actor_cls.__name__ self._ref = None - self._raylet_remote = None def __call__(self, *args, **kwargs): raise TypeError(f"Remote actor cannot be instantiated directly. " @@ -135,10 +120,10 @@ class ClientActorClass(ClientStub): self._name = state["_name"] self._ref = state["_ref"] - def remote(self, *args, **kwargs): + def remote(self, *args, **kwargs) -> "ClientActorHandle": # Actually instantiate the actor - ref = ray.call_remote(self, *args, **kwargs) - return ClientActorHandle(ClientActorRef(ref.id, ref.handle), self) + ref_id = ray.call_remote(self, *args, **kwargs) + return ClientActorHandle(ClientActorRef(ref_id), self) def __repr__(self): return "ClientRemoteActor(%s, %s)" % (self._name, self._ref) @@ -154,7 +139,7 @@ class ClientActorClass(ClientStub): task = ray_client_pb2.ClientTask() task.type = ray_client_pb2.ClientTask.ACTOR task.name = self._name - task.payload_id = self._ref.handle + task.payload_id = self._ref.id return task @@ -177,26 +162,9 @@ class ClientActorHandle(ClientStub): def __init__(self, actor_ref: ClientActorRef, actor_class: ClientActorClass): self.actor_ref = actor_ref - self.actor_class = actor_class - self._real_actor_handle = None - def _get_ray_remote_impl(self): - if self._real_actor_handle is None: - self._real_actor_handle = cloudpickle.loads(self.actor_ref.handle) - return self._real_actor_handle - - def __getstate__(self) -> Dict: - state = { - "actor_ref": self.actor_ref, - "actor_class": self.actor_class, - "_real_actor_handle": self._real_actor_handle, - } - return state - - def __setstate__(self, state: Dict) -> None: - self.actor_ref = state["actor_ref"] - self.actor_class = state["actor_class"] - self._real_actor_handle = state["_real_actor_handle"] + def __del__(self) -> None: + ray.call_release(self.actor_ref.id) @property def _actor_id(self): @@ -226,65 +194,27 @@ class ClientRemoteMethod(ClientStub): def __call__(self, *args, **kwargs): raise TypeError(f"Remote method cannot be called directly. " - "Use {self._name}.remote() instead") - - def _get_ray_remote_impl(self): - return getattr(self.actor_handle._get_ray_remote_impl(), - self.method_name) - - def __getstate__(self) -> Dict: - state = { - "actor_handle": self.actor_handle, - "method_name": self.method_name, - } - return state - - def __setstate__(self, state: Dict) -> None: - self.actor_handle = state["actor_handle"] - self.method_name = state["method_name"] + f"Use {self._name}.remote() instead") def remote(self, *args, **kwargs): - return ray.call_remote(self, *args, **kwargs) + return ClientObjectRef(ray.call_remote(self, *args, **kwargs)) def __repr__(self): - name = "%s.%s" % (self.actor_handle.actor_class._name, - self.method_name) - return "ClientRemoteMethod(%s, %s)" % (name, - self.actor_handle.actor_id) + return "ClientRemoteMethod(%s, %s)" % (self.method_name, + self.actor_handle) def _prepare_client_task(self) -> ray_client_pb2.ClientTask: task = ray_client_pb2.ClientTask() task.type = ray_client_pb2.ClientTask.METHOD task.name = self.method_name - task.payload_id = self.actor_handle.actor_ref.handle + task.payload_id = self.actor_handle.actor_ref.id return task -def convert_from_arg(pb) -> Any: - if pb.local == ray_client_pb2.Arg.Locality.REFERENCE: - return ClientObjectRef(pb.reference_id) - elif pb.local == ray_client_pb2.Arg.Locality.INTERNED: - return cloudpickle.loads(pb.data) - - raise Exception("convert_from_arg: Uncovered locality enum") +class DataEncodingSentinel: + def __repr__(self) -> str: + return self.__class__.__name__ -def convert_to_arg(val): - out = ray_client_pb2.Arg() - if isinstance(val, ClientObjectRef): - out.local = ray_client_pb2.Arg.Locality.REFERENCE - out.reference_id = val.id - else: - out.local = ray_client_pb2.Arg.Locality.INTERNED - out.data = cloudpickle.dumps(val) - return out - - -def encode_exception(exception) -> str: - data = cloudpickle.dumps(exception) - return base64.standard_b64encode(data).decode() - - -def decode_exception(data) -> Exception: - data = base64.standard_b64decode(data) - return cloudpickle.loads(data) +class SelfReferenceSentinel(DataEncodingSentinel): + pass diff --git a/python/ray/experimental/client/dataclient.py b/python/ray/experimental/client/dataclient.py new file mode 100644 index 000000000..7e16c015b --- /dev/null +++ b/python/ray/experimental/client/dataclient.py @@ -0,0 +1,103 @@ +""" +This file implements a threaded stream controller to abstract a data stream +back to the ray clientserver. +""" +import logging +import queue +import threading +import grpc + +from typing import Any +from typing import Dict + +import ray.core.generated.ray_client_pb2 as ray_client_pb2 +import ray.core.generated.ray_client_pb2_grpc as ray_client_pb2_grpc + +logger = logging.getLogger(__name__) + +# The maximum field value for request_id -- which is also the maximum +# number of simultaneous in-flight requests. +INT32_MAX = (2**31) - 1 + + +class DataClient: + def __init__(self, channel: "grpc._channel.Channel", client_id: str): + """Initializes a thread-safe datapath over a Ray Client gRPC channel. + + Args: + channel: connected gRPC channel + """ + self.channel = channel + self.request_queue = queue.Queue() + self.data_thread = self._start_datathread() + self.ready_data: Dict[int, Any] = {} + self.cv = threading.Condition() + self._req_id = 0 + self._client_id = client_id + self.data_thread.start() + + def _next_id(self) -> int: + self._req_id += 1 + if self._req_id > INT32_MAX: + self._req_id = 1 + # Responses that aren't tracked (like opportunistic releases) + # have req_id=0, so make sure we never mint such an id. + assert self._req_id != 0 + return self._req_id + + def _start_datathread(self) -> threading.Thread: + return threading.Thread(target=self._data_main, args=(), daemon=True) + + def _data_main(self) -> None: + stub = ray_client_pb2_grpc.RayletDataStreamerStub(self.channel) + resp_stream = stub.Datapath( + iter(self.request_queue.get, None), + metadata=(("client_id", self._client_id), )) + for response in resp_stream: + if response.req_id == 0: + # This is not being waited for. + logger.debug(f"Got unawaited response {response}") + continue + with self.cv: + self.ready_data[response.req_id] = response + self.cv.notify_all() + + def close(self, close_channel: bool = False) -> None: + if self.request_queue is not None: + self.request_queue.put(None) + self.request_queue = None + if self.data_thread is not None: + self.data_thread.join() + self.data_thread = None + if close_channel: + self.channel.close() + + def _blocking_send(self, req: ray_client_pb2.DataRequest + ) -> ray_client_pb2.DataResponse: + req_id = self._next_id() + req.req_id = req_id + self.request_queue.put(req) + data = None + with self.cv: + self.cv.wait_for(lambda: req_id in self.ready_data) + data = self.ready_data[req_id] + del self.ready_data[req_id] + return data + + def GetObject(self, request: ray_client_pb2.GetRequest, + context=None) -> ray_client_pb2.GetResponse: + datareq = ray_client_pb2.DataRequest(get=request, ) + resp = self._blocking_send(datareq) + return resp.get + + def PutObject(self, request: ray_client_pb2.PutRequest, + context=None) -> ray_client_pb2.PutResponse: + datareq = ray_client_pb2.DataRequest(put=request, ) + resp = self._blocking_send(datareq) + return resp.put + + def ReleaseObject(self, + request: ray_client_pb2.ReleaseRequest, + context=None) -> None: + datareq = ray_client_pb2.DataRequest(release=request, ) + self.request_queue.put(datareq) diff --git a/python/ray/experimental/client/server/core_ray_api.py b/python/ray/experimental/client/server/core_ray_api.py index 6513021a8..2d930f352 100644 --- a/python/ray/experimental/client/server/core_ray_api.py +++ b/python/ray/experimental/client/server/core_ray_api.py @@ -11,12 +11,15 @@ from typing import Any from typing import Optional from typing import Union +import logging import ray from ray.experimental.client.api import APIImpl from ray.experimental.client.common import ClientObjectRef from ray.experimental.client.common import ClientStub +logger = logging.getLogger(__name__) + class CoreRayAPI(APIImpl): """ @@ -26,12 +29,6 @@ class CoreRayAPI(APIImpl): """ def get(self, vals, *, timeout: Optional[float] = None) -> Any: - if isinstance(vals, list): - if isinstance(vals[0], ClientObjectRef): - return ray.get( - [val._unpack_ref() for val in vals], timeout=timeout) - elif isinstance(vals, ClientObjectRef): - return ray.get(vals._unpack_ref(), timeout=timeout) return ray.get(vals, timeout=timeout) def put(self, vals: Any, *args, @@ -45,7 +42,8 @@ class CoreRayAPI(APIImpl): return ray.remote(*args, **kwargs) def call_remote(self, instance: ClientStub, *args, **kwargs): - return instance._get_ray_remote_impl().remote(*args, **kwargs) + raise NotImplementedError( + "Should not attempt execution of a client stub inside the raylet") def close(self) -> None: return None @@ -59,6 +57,12 @@ class CoreRayAPI(APIImpl): def is_initialized(self) -> bool: return ray.is_initialized() + def call_release(self, id: bytes) -> None: + return None + + def call_retain(self, id: bytes) -> None: + return None + # Allow for generic fallback to ray.* in remote methods. This allows calls # like ray.nodes() to be run in remote functions even though the client # doesn't currently support them. @@ -76,26 +80,7 @@ class RayServerAPI(CoreRayAPI): def __init__(self, server_instance): self.server = server_instance - # Wrap single item into list if needed before calling server put. - def put(self, vals: Any, *args, **kwargs) -> ClientObjectRef: - to_put = [] - single = False - if isinstance(vals, list): - to_put = vals - else: - single = True - to_put.append(vals) - - out = [self._put(x) for x in to_put] - if single: - out = out[0] - return out - - def _put(self, val: Any): - resp = self.server._put_and_retain_obj(val) - return ClientObjectRef(resp.id) - - def call_remote(self, instance: ClientStub, *args, **kwargs): + def call_remote(self, instance: ClientStub, *args, **kwargs) -> bytes: task = instance._prepare_client_task() ticket = self.server.Schedule(task, prepared_args=args) - return ClientObjectRef(ticket.return_id) + return ticket.return_id diff --git a/python/ray/experimental/client/server/dataservicer.py b/python/ray/experimental/client/server/dataservicer.py new file mode 100644 index 000000000..874e741d9 --- /dev/null +++ b/python/ray/experimental/client/server/dataservicer.py @@ -0,0 +1,54 @@ +import logging +import grpc + +from typing import TYPE_CHECKING + +import ray.core.generated.ray_client_pb2 as ray_client_pb2 +import ray.core.generated.ray_client_pb2_grpc as ray_client_pb2_grpc + +if TYPE_CHECKING: + from ray.experimental.client.server.server import RayletServicer + +logger = logging.getLogger(__name__) + + +class DataServicer(ray_client_pb2_grpc.RayletDataStreamerServicer): + def __init__(self, basic_service: "RayletServicer"): + self.basic_service = basic_service + + def Datapath(self, request_iterator, context): + metadata = {k: v for k, v in context.invocation_metadata()} + client_id = metadata["client_id"] + if client_id == "": + logger.error("Client connecting with no client_id") + return + logger.info(f"New data connection from client {client_id}") + try: + for req in request_iterator: + resp = None + req_type = req.WhichOneof("type") + if req_type == "get": + get_resp = self.basic_service._get_object( + req.get, client_id) + resp = ray_client_pb2.DataResponse(get=get_resp) + elif req_type == "put": + put_resp = self.basic_service._put_object( + req.put, client_id) + resp = ray_client_pb2.DataResponse(put=put_resp) + elif req_type == "release": + released = [] + for rel_id in req.release.ids: + rel = self.basic_service.release(client_id, rel_id) + released.append(rel) + resp = ray_client_pb2.DataResponse( + release=ray_client_pb2.ReleaseResponse(ok=released)) + else: + raise Exception(f"Unreachable code: Request type " + f"{req_type} not handled in Datapath") + resp.req_id = req.req_id + yield resp + except grpc.RpcError as e: + logger.debug(f"Closing channel: {e}") + finally: + logger.info(f"Lost data connection from client {client_id}") + self.basic_service.release_all(client_id) diff --git a/python/ray/experimental/client/server/server.py b/python/ray/experimental/client/server/server.py index 616e6e60d..0fd34eda4 100644 --- a/python/ray/experimental/client/server/server.py +++ b/python/ray/experimental/client/server/server.py @@ -1,6 +1,12 @@ import logging from concurrent import futures import grpc +import base64 +from collections import defaultdict + +from typing import Dict +from typing import Set + from ray import cloudpickle import ray import ray.state @@ -10,21 +16,26 @@ import time import inspect import json from ray.experimental.client import stash_api_for_tests, _set_server_api -from ray.experimental.client.common import convert_from_arg -from ray.experimental.client.common import encode_exception -from ray.experimental.client.common import ClientObjectRef +from ray.experimental.client.server.server_pickler import convert_from_arg +from ray.experimental.client.server.server_pickler import dumps_from_server +from ray.experimental.client.server.server_pickler import loads_from_client from ray.experimental.client.server.core_ray_api import RayServerAPI +from ray.experimental.client.server.dataservicer import DataServicer +from ray.experimental.client.server.server_stubs import current_func logger = logging.getLogger(__name__) class RayletServicer(ray_client_pb2_grpc.RayletDriverServicer): def __init__(self, test_mode=False): - self.object_refs = {} + self.object_refs: Dict[str, Dict[bytes, ray.ObjectRef]] = defaultdict( + dict) self.function_refs = {} - self.actor_refs = {} + self.actor_refs: Dict[bytes, ray.ActorHandle] = {} + self.actor_owners: Dict[str, Set[bytes]] = defaultdict(set) self.registered_actor_classes = {} self._test_mode = test_mode + self._current_function_stub = None def ClusterInfo(self, request, context=None) -> ray_client_pb2.ClusterInfoResponse: @@ -61,20 +72,59 @@ class RayletServicer(ray_client_pb2_grpc.RayletDriverServicer): raise TypeError("Unsupported cluster info type") return json.dumps(data) - def Terminate(self, request, context=None): - if request.WhichOneof("terminate_type") == "task_object": + def release(self, client_id: str, id: bytes) -> bool: + if client_id in self.object_refs: + if id in self.object_refs[client_id]: + logger.debug(f"Releasing object {id.hex()} for {client_id}") + del self.object_refs[client_id][id] + return True + + if client_id in self.actor_owners: + if id in self.actor_owners[client_id]: + logger.debug(f"Releasing actor {id.hex()} for {client_id}") + del self.actor_refs[id] + self.actor_owners[client_id].remove(id) + return True + + return False + + def release_all(self, client_id): + self._release_objects(client_id) + self._release_actors(client_id) + + def _release_objects(self, client_id): + if client_id not in self.object_refs: + logger.debug(f"Releasing client with no references: {client_id}") + return + count = len(self.object_refs[client_id]) + del self.object_refs[client_id] + logger.debug(f"Released all {count} objects for client {client_id}") + + def _release_actors(self, client_id): + if client_id not in self.actor_owners: + logger.debug(f"Releasing client with no actors: {client_id}") + count = 0 + for id_bytes in self.actor_owners[client_id]: + count += 1 + del self.actor_refs[id_bytes] + del self.actor_owners[client_id] + logger.debug(f"Released all {count} actors for client: {client_id}") + + def Terminate(self, req, context=None): + if req.WhichOneof("terminate_type") == "task_object": try: - object_ref = cloudpickle.loads(request.task_object.handle) + object_ref = \ + self.object_refs[req.client_id][req.task_object.id] ray.cancel( object_ref, - force=request.task_object.force, - recursive=request.task_object.recursive) + force=req.task_object.force, + recursive=req.task_object.recursive) except Exception as e: return_exception_in_context(e, context) - elif request.WhichOneof("terminate_type") == "actor": + elif req.WhichOneof("terminate_type") == "actor": try: - actor_ref = cloudpickle.loads(request.actor.handle) - ray.kill(actor_ref, no_restart=request.actor.no_restart) + actor_ref = self.actor_refs[req.actor.id] + ray.kill(actor_ref, no_restart=req.actor.no_restart) except Exception as e: return_exception_in_context(e, context) else: @@ -84,61 +134,71 @@ class RayletServicer(ray_client_pb2_grpc.RayletDriverServicer): return ray_client_pb2.TerminateResponse(ok=True) def GetObject(self, request, context=None): - request_ref = cloudpickle.loads(request.handle) - if request_ref.binary() not in self.object_refs: + return self._get_object(request, "", context) + + def _get_object(self, request, client_id: str, context=None): + if request.id not in self.object_refs[client_id]: return ray_client_pb2.GetResponse(valid=False) - objectref = self.object_refs[request_ref.binary()] - logger.info("get: %s" % objectref) + objectref = self.object_refs[client_id][request.id] + logger.debug("get: %s" % objectref) try: item = ray.get(objectref, timeout=request.timeout) except Exception as e: - return_exception_in_context(e, context) - item_ser = cloudpickle.dumps(item) + return ray_client_pb2.GetResponse( + valid=False, error=cloudpickle.dumps(e)) + item_ser = dumps_from_server(item, client_id, self) return ray_client_pb2.GetResponse(valid=True, data=item_ser) - def PutObject(self, request, context=None) -> ray_client_pb2.PutResponse: - obj = cloudpickle.loads(request.data) - objectref = self._put_and_retain_obj(obj) - pickled_ref = cloudpickle.dumps(objectref) - return ray_client_pb2.PutResponse( - ref=make_remote_ref(objectref.binary(), pickled_ref)) + def PutObject(self, request: ray_client_pb2.PutRequest, + context=None) -> ray_client_pb2.PutResponse: + """gRPC entrypoint for unary PutObject + """ + return self._put_object(request, "", context) - def _put_and_retain_obj(self, obj) -> ray.ObjectRef: + def _put_object(self, + request: ray_client_pb2.PutRequest, + client_id: str, + context=None): + """Put an object in the cluster with ray.put() via gRPC. + + Args: + request: PutRequest with pickled data. + client_id: The client who owns this data, for tracking when to + delete this reference. + context: gRPC context. + """ + obj = loads_from_client(request.data, self) objectref = ray.put(obj) - self.object_refs[objectref.binary()] = objectref - logger.info("put: %s" % objectref) - return objectref + self.object_refs[client_id][objectref.binary()] = objectref + logger.debug("put: %s" % objectref) + return ray_client_pb2.PutResponse(id=objectref.binary()) def WaitObject(self, request, context=None) -> ray_client_pb2.WaitResponse: - object_refs = [cloudpickle.loads(o) for o in request.object_handles] + object_refs = [] + for id in request.object_ids: + if id not in self.object_refs[request.client_id]: + raise Exception( + "Asking for a ref not associated with this client: %s" % + str(id)) + object_refs.append(self.object_refs[request.client_id][id]) num_returns = request.num_returns timeout = request.timeout - object_refs_ids = [] - for object_ref in object_refs: - if object_ref.binary() not in self.object_refs: - return ray_client_pb2.WaitResponse(valid=False) - object_refs_ids.append(self.object_refs[object_ref.binary()]) try: ready_object_refs, remaining_object_refs = ray.wait( - object_refs_ids, + object_refs, num_returns=num_returns, timeout=timeout if timeout != -1 else None) except Exception: # TODO(ameer): improve exception messages. return ray_client_pb2.WaitResponse(valid=False) - logger.info("wait: %s %s" % (str(ready_object_refs), - str(remaining_object_refs))) + logger.debug("wait: %s %s" % (str(ready_object_refs), + str(remaining_object_refs))) ready_object_ids = [ - make_remote_ref( - id=ready_object_ref.binary(), - handle=cloudpickle.dumps(ready_object_ref), - ) for ready_object_ref in ready_object_refs + ready_object_ref.binary() for ready_object_ref in ready_object_refs ] remaining_object_ids = [ - make_remote_ref( - id=remaining_object_ref.binary(), - handle=cloudpickle.dumps(remaining_object_ref), - ) for remaining_object_ref in remaining_object_refs + remaining_object_ref.binary() + for remaining_object_ref in remaining_object_refs ] return ray_client_pb2.WaitResponse( valid=True, @@ -150,16 +210,17 @@ class RayletServicer(ray_client_pb2_grpc.RayletDriverServicer): logger.info("schedule: %s %s" % (task.name, ray_client_pb2.ClientTask.RemoteExecType.Name(task.type))) - if task.type == ray_client_pb2.ClientTask.FUNCTION: - return self._schedule_function(task, context, prepared_args) - elif task.type == ray_client_pb2.ClientTask.ACTOR: - return self._schedule_actor(task, context, prepared_args) - elif task.type == ray_client_pb2.ClientTask.METHOD: - return self._schedule_method(task, context, prepared_args) - else: - raise NotImplementedError( - "Unimplemented Schedule task type: %s" % - ray_client_pb2.ClientTask.RemoteExecType.Name(task.type)) + with stash_api_for_tests(self._test_mode): + if task.type == ray_client_pb2.ClientTask.FUNCTION: + return self._schedule_function(task, context, prepared_args) + elif task.type == ray_client_pb2.ClientTask.ACTOR: + return self._schedule_actor(task, context, prepared_args) + elif task.type == ray_client_pb2.ClientTask.METHOD: + return self._schedule_method(task, context, prepared_args) + else: + raise NotImplementedError( + "Unimplemented Schedule task type: %s" % + ray_client_pb2.ClientTask.RemoteExecType.Name(task.type)) def _schedule_method( self, @@ -170,80 +231,67 @@ class RayletServicer(ray_client_pb2_grpc.RayletDriverServicer): if actor_handle is None: raise Exception( "Can't run an actor the server doesn't have a handle for") - arglist = _convert_args(task.args, prepared_args) - with stash_api_for_tests(self._test_mode): - output = getattr(actor_handle, task.name).remote(*arglist) - self.object_refs[output.binary()] = output - pickled_ref = cloudpickle.dumps(output) - return ray_client_pb2.ClientTaskTicket( - return_ref=make_remote_ref(output.binary(), pickled_ref)) + arglist = self._convert_args(task.args, prepared_args) + output = getattr(actor_handle, task.name).remote(*arglist) + self.object_refs[task.client_id][output.binary()] = output + return ray_client_pb2.ClientTaskTicket(return_id=output.binary()) def _schedule_actor(self, task: ray_client_pb2.ClientTask, context=None, prepared_args=None) -> ray_client_pb2.ClientTaskTicket: - with stash_api_for_tests(self._test_mode): - payload_ref = cloudpickle.loads(task.payload_id) - if payload_ref.binary() not in self.registered_actor_classes: - actor_class_ref = self.object_refs[payload_ref.binary()] - actor_class = ray.get(actor_class_ref) - if not inspect.isclass(actor_class): - raise Exception("Attempting to schedule actor that " - "isn't a class.") - reg_class = ray.remote(actor_class) - self.registered_actor_classes[payload_ref.binary()] = reg_class - remote_class = self.registered_actor_classes[payload_ref.binary()] - arglist = _convert_args(task.args, prepared_args) - actor = remote_class.remote(*arglist) - actorhandle = cloudpickle.dumps(actor) - self.actor_refs[actorhandle] = actor + if task.payload_id not in self.registered_actor_classes: + actor_class_ref = \ + self.object_refs[task.client_id][task.payload_id] + actor_class = ray.get(actor_class_ref) + if not inspect.isclass(actor_class): + raise Exception("Attempting to schedule actor that " + "isn't a class.") + reg_class = ray.remote(actor_class) + self.registered_actor_classes[task.payload_id] = reg_class + remote_class = self.registered_actor_classes[task.payload_id] + arglist = self._convert_args(task.args, prepared_args) + actor = remote_class.remote(*arglist) + self.actor_refs[actor._actor_id.binary()] = actor + self.actor_owners[task.client_id].add(actor._actor_id.binary()) return ray_client_pb2.ClientTaskTicket( - return_ref=make_remote_ref(actor._actor_id.binary(), actorhandle)) + return_id=actor._actor_id.binary()) def _schedule_function( self, task: ray_client_pb2.ClientTask, context=None, prepared_args=None) -> ray_client_pb2.ClientTaskTicket: - payload_ref = cloudpickle.loads(task.payload_id) - if payload_ref.binary() not in self.function_refs: - funcref = self.object_refs[payload_ref.binary()] + remote_func = self.lookup_or_register_func(task.payload_id, + task.client_id) + arglist = self._convert_args(task.args, prepared_args) + # Prepare call if we're in a test + with current_func(remote_func): + output = remote_func.remote(*arglist) + if output.binary() in self.object_refs[task.client_id]: + raise Exception("already found it") + self.object_refs[task.client_id][output.binary()] = output + return ray_client_pb2.ClientTaskTicket(return_id=output.binary()) + + def _convert_args(self, arg_list, prepared_args=None): + if prepared_args is not None: + return prepared_args + out = [] + for arg in arg_list: + t = convert_from_arg(arg, self) + out.append(t) + return out + + def lookup_or_register_func(self, id: bytes, client_id: str + ) -> ray.remote_function.RemoteFunction: + if id not in self.function_refs: + funcref = self.object_refs[client_id][id] func = ray.get(funcref) if not inspect.isfunction(func): - raise Exception("Attempting to schedule function that " + raise Exception("Attempting to register function that " "isn't a function.") - self.function_refs[payload_ref.binary()] = ray.remote(func) - remote_func = self.function_refs[payload_ref.binary()] - arglist = _convert_args(task.args, prepared_args) - # Prepare call if we're in a test - with stash_api_for_tests(self._test_mode): - output = remote_func.remote(*arglist) - if output.binary() in self.object_refs: - raise Exception("already found it") - self.object_refs[output.binary()] = output - pickled_output = cloudpickle.dumps(output) - return ray_client_pb2.ClientTaskTicket( - return_ref=make_remote_ref(output.binary(), pickled_output)) - - -def _convert_args(arg_list, prepared_args=None): - if prepared_args is not None: - return prepared_args - out = [] - for arg in arg_list: - t = convert_from_arg(arg) - if isinstance(t, ClientObjectRef): - out.append(t._unpack_ref()) - else: - out.append(t) - return out - - -def make_remote_ref(id: bytes, handle: bytes) -> ray_client_pb2.RemoteRef: - return ray_client_pb2.RemoteRef( - id=id, - handle=handle, - ) + self.function_refs[id] = ray.remote(func) + return self.function_refs[id] def return_exception_in_context(err, context): @@ -252,12 +300,20 @@ def return_exception_in_context(err, context): context.set_code(grpc.StatusCode.INTERNAL) +def encode_exception(exception) -> str: + data = cloudpickle.dumps(exception) + return base64.standard_b64encode(data).decode() + + def serve(connection_str, test_mode=False): server = grpc.server(futures.ThreadPoolExecutor(max_workers=10)) task_servicer = RayletServicer(test_mode=test_mode) + data_servicer = DataServicer(task_servicer) _set_server_api(RayServerAPI(task_servicer)) ray_client_pb2_grpc.add_RayletDriverServicer_to_server( task_servicer, server) + ray_client_pb2_grpc.add_RayletDataStreamerServicer_to_server( + data_servicer, server) server.add_insecure_port(connection_str) server.start() return server diff --git a/python/ray/experimental/client/server/server_pickler.py b/python/ray/experimental/client/server/server_pickler.py new file mode 100644 index 000000000..ea6bd74d0 --- /dev/null +++ b/python/ray/experimental/client/server/server_pickler.py @@ -0,0 +1,119 @@ +""" +Implements the client side of the client/server pickling protocol. + +These picklers are aware of the server internals and can find the +references held for the client within the server. + +More discussion about the client/server pickling protocol can be found in: + + ray/experimental/client/client_pickler.py + +ServerPickler dumps ray objects from the server into the appropriate stubs. +ClientUnpickler loads stubs from the client and finds their associated handle +in the server instance. +""" +import cloudpickle +import io +import sys +import ray + +from typing import Any +from typing import TYPE_CHECKING + +from ray.experimental.client.client_pickler import PickleStub +from ray.experimental.client.server.server_stubs import ServerFunctionSentinel + +if TYPE_CHECKING: + from ray.experimental.client.server.server import RayletServicer + import ray.core.generated.ray_client_pb2 as ray_client_pb2 + +if sys.version_info < (3, 8): + try: + import pickle5 as pickle # noqa: F401 + except ImportError: + import pickle # noqa: F401 +else: + import pickle # noqa: F401 + + +class ServerPickler(cloudpickle.CloudPickler): + def __init__(self, client_id: str, server: "RayletServicer", *args, + **kwargs): + super().__init__(*args, **kwargs) + self.client_id = client_id + self.server = server + + def persistent_id(self, obj): + if isinstance(obj, ray.ObjectRef): + obj_id = obj.binary() + if obj_id not in self.server.object_refs[self.client_id]: + # We're passing back a reference, probably inside a reference. + # Let's hold onto it. + self.server.object_refs[self.client_id][obj_id] = obj + return PickleStub( + type="Object", + client_id=self.client_id, + ref_id=obj_id, + ) + elif isinstance(obj, ray.actor.ActorHandle): + actor_id = obj._actor_id.binary() + if actor_id not in self.server.actor_refs: + # We're passing back a handle, probably inside a reference. + self.actor_refs[actor_id] = obj + if actor_id not in self.actor_owners[self.client_id]: + self.actor_owners[self.client_id].add(actor_id) + return PickleStub( + type="Actor", + client_id=self.client_id, + ref_id=obj._actor_id.binary(), + ) + return None + + +class ClientUnpickler(pickle.Unpickler): + def __init__(self, server, *args, **kwargs): + super().__init__(*args, **kwargs) + self.server = server + + def persistent_load(self, pid): + assert isinstance(pid, PickleStub) + if pid.type == "Object": + return self.server.object_refs[pid.client_id][pid.ref_id] + elif pid.type == "Actor": + return self.server.actor_refs[pid.ref_id] + elif pid.type == "RemoteFuncSelfReference": + return ServerFunctionSentinel() + elif pid.type == "RemoteFunc": + return self.server.lookup_or_register_func(pid.ref_id, + pid.client_id) + else: + raise NotImplementedError("Uncovered client data type") + + +def dumps_from_server(obj: Any, + client_id: str, + server_instance: "RayletServicer", + protocol=None) -> bytes: + with io.BytesIO() as file: + sp = ServerPickler(client_id, server_instance, file, protocol=protocol) + sp.dump(obj) + return file.getvalue() + + +def loads_from_client(data: bytes, + server_instance: "RayletServicer", + *, + fix_imports=True, + encoding="ASCII", + errors="strict") -> Any: + if isinstance(data, str): + raise TypeError("Can't load pickle from unicode string") + file = io.BytesIO(data) + return ClientUnpickler( + server_instance, file, fix_imports=fix_imports, + encoding=encoding).load() + + +def convert_from_arg(pb: "ray_client_pb2.Arg", + server: "RayletServicer") -> Any: + return loads_from_client(pb.data, server) diff --git a/python/ray/experimental/client/server/server_stubs.py b/python/ray/experimental/client/server/server_stubs.py new file mode 100644 index 000000000..f55f64f25 --- /dev/null +++ b/python/ray/experimental/client/server/server_stubs.py @@ -0,0 +1,29 @@ +from contextlib import contextmanager + +_current_remote_func = None + + +@contextmanager +def current_func(f): + global _current_remote_func + remote_func = _current_remote_func + _current_remote_func = f + try: + yield + finally: + _current_remote_func = remote_func + + +class ServerFunctionSentinel: + def __init__(self): + pass + + def __reduce__(self): + global _current_remote_func + if _current_remote_func is None: + return (ServerFunctionSentinel, tuple()) + return (identity, (_current_remote_func, )) + + +def identity(x): + return x diff --git a/python/ray/experimental/client/worker.py b/python/ray/experimental/client/worker.py index 0a108e4f2..54ac71711 100644 --- a/python/ray/experimental/client/worker.py +++ b/python/ray/experimental/client/worker.py @@ -2,27 +2,32 @@ It implements the Ray API functions that are forwarded through grpc calls to the server. """ +import base64 import inspect import json import logging +import uuid +from collections import defaultdict from typing import Any +from typing import Dict from typing import List from typing import Tuple from typing import Optional -import ray.cloudpickle as cloudpickle from ray.util.inspect import is_cython import grpc -from ray.exceptions import TaskCancelledError +import ray.cloudpickle as cloudpickle import ray.core.generated.ray_client_pb2 as ray_client_pb2 import ray.core.generated.ray_client_pb2_grpc as ray_client_pb2_grpc -from ray.experimental.client.common import convert_to_arg -from ray.experimental.client.common import decode_exception +from ray.experimental.client.client_pickler import convert_to_arg +from ray.experimental.client.client_pickler import loads_from_server +from ray.experimental.client.client_pickler import dumps_from_client from ray.experimental.client.common import ClientObjectRef from ray.experimental.client.common import ClientActorClass from ray.experimental.client.common import ClientActorHandle from ray.experimental.client.common import ClientRemoteFunc +from ray.experimental.client.dataclient import DataClient logger = logging.getLogger(__name__) @@ -31,34 +36,32 @@ class Worker: def __init__(self, conn_str: str = "", secure: bool = False, - metadata: List[Tuple[str, str]] = None, - stub=None): + metadata: List[Tuple[str, str]] = None): """Initializes the worker side grpc client. Args: - stub: custom grpc stub. secure: whether to use SSL secure channel or not. metadata: additional metadata passed in the grpc request headers. """ self.metadata = metadata self.channel = None - if stub is None: - if secure: - credentials = grpc.ssl_channel_credentials() - self.channel = grpc.secure_channel(conn_str, credentials) - else: - self.channel = grpc.insecure_channel(conn_str) - self.server = ray_client_pb2_grpc.RayletDriverStub(self.channel) + self._client_id = make_client_id() + if secure: + credentials = grpc.ssl_channel_credentials() + self.channel = grpc.secure_channel(conn_str, credentials) else: - self.server = stub + self.channel = grpc.insecure_channel(conn_str) + self.server = ray_client_pb2_grpc.RayletDriverStub(self.channel) + self.data_client = DataClient(self.channel, self._client_id) + self.reference_count: Dict[bytes, int] = defaultdict(int) def get(self, vals, *, timeout: Optional[float] = None) -> Any: to_get = [] single = False if isinstance(vals, list): - to_get = [x.handle for x in vals] + to_get = vals elif isinstance(vals, ClientObjectRef): - to_get = [vals.handle] + to_get = [vals] single = True else: raise Exception("Can't get something that's not a " @@ -70,15 +73,15 @@ class Worker: out = out[0] return out - def _get(self, handle: bytes, timeout: float): - req = ray_client_pb2.GetRequest(handle=handle, timeout=timeout) + def _get(self, ref: ClientObjectRef, timeout: float): + req = ray_client_pb2.GetRequest(id=ref.id, timeout=timeout) try: - data = self.server.GetObject(req, metadata=self.metadata) + data = self.data_client.GetObject(req) except grpc.RpcError as e: - raise decode_exception(e.details()) + raise e.details() if not data.valid: - raise TaskCancelledError(handle) - return cloudpickle.loads(data.data) + raise cloudpickle.loads(data.error) + return loads_from_server(data.data) def put(self, vals): to_put = [] @@ -95,10 +98,10 @@ class Worker: return out def _put(self, val): - data = cloudpickle.dumps(val) + data = dumps_from_client(val, self._client_id) req = ray_client_pb2.PutRequest(data=data) - resp = self.server.PutObject(req, metadata=self.metadata) - return ClientObjectRef.from_remote_ref(resp.ref) + resp = self.data_client.PutObject(req) + return ClientObjectRef(resp.id) def wait(self, object_refs: List[ClientObjectRef], @@ -110,11 +113,10 @@ class Worker: for ref in object_refs: assert isinstance(ref, ClientObjectRef) data = { - "object_handles": [ - object_ref.handle for object_ref in object_refs - ], + "object_ids": [object_ref.id for object_ref in object_refs], "num_returns": num_returns, - "timeout": timeout if timeout else -1 + "timeout": timeout if timeout else -1, + "client_id": self._client_id, } req = ray_client_pb2.WaitRequest(**data) resp = self.server.WaitObject(req, metadata=self.metadata) @@ -122,12 +124,10 @@ class Worker: # TODO(ameer): improve error/exceptions messages. raise Exception("Client Wait request failed. Reference invalid?") client_ready_object_ids = [ - ClientObjectRef.from_remote_ref(ref) - for ref in resp.ready_object_ids + ClientObjectRef(ref) for ref in resp.ready_object_ids ] client_remaining_object_ids = [ - ClientObjectRef.from_remote_ref(ref) - for ref in resp.remaining_object_ids + ClientObjectRef(ref) for ref in resp.remaining_object_ids ] return (client_ready_object_ids, client_remaining_object_ids) @@ -144,19 +144,38 @@ class Worker: raise TypeError("The @ray.remote decorator must be applied to " "either a function or to a class.") - def call_remote(self, instance, *args, **kwargs): + def call_remote(self, instance, *args, **kwargs) -> bytes: task = instance._prepare_client_task() for arg in args: - pb_arg = convert_to_arg(arg) + pb_arg = convert_to_arg(arg, self._client_id) task.args.append(pb_arg) - logging.debug("Scheduling %s" % task) + task.client_id = self._client_id + logger.debug("Scheduling %s" % task) ticket = self.server.Schedule(task, metadata=self.metadata) - return ClientObjectRef.from_remote_ref(ticket.return_ref) + return ticket.return_id + + def call_release(self, id: bytes) -> None: + self.reference_count[id] -= 1 + if self.reference_count[id] == 0: + self._release_server(id) + del self.reference_count[id] + + def _release_server(self, id: bytes) -> None: + if self.data_client is not None: + logger.debug(f"Releasing {id}") + self.data_client.ReleaseObject( + ray_client_pb2.ReleaseRequest(ids=[id])) + + def call_retain(self, id: bytes) -> None: + logger.debug(f"Retaining {id}") + self.reference_count[id] += 1 def close(self): + self.data_client.close() self.server = None if self.channel: self.channel.close() + self.channel = None def terminate_actor(self, actor: ClientActorHandle, no_restart: bool) -> None: @@ -164,10 +183,11 @@ class Worker: raise ValueError("ray.kill() only supported for actors. " "Got: {}.".format(type(actor))) term_actor = ray_client_pb2.TerminateRequest.ActorTerminate() - term_actor.handle = actor.actor_ref.handle + term_actor.id = actor.actor_ref.id term_actor.no_restart = no_restart try: term = ray_client_pb2.TerminateRequest(actor=term_actor) + term.client_id = self._client_id self.server.Terminate(term) except grpc.RpcError as e: raise decode_exception(e.details()) @@ -179,11 +199,12 @@ class Worker: "ray.cancel() only supported for non-actor object refs. " f"Got: {type(obj)}.") term_object = ray_client_pb2.TerminateRequest.TaskObjectTerminate() - term_object.handle = obj.handle + term_object.id = obj.id term_object.force = force term_object.recursive = recursive try: term = ray_client_pb2.TerminateRequest(task_object=term_object) + term.client_id = self._client_id self.server.Terminate(term) except grpc.RpcError as e: raise decode_exception(e.details()) @@ -201,3 +222,13 @@ class Worker: return self.get_cluster_info( ray_client_pb2.ClusterInfoType.IS_INITIALIZED) return False + + +def make_client_id() -> str: + id = uuid.uuid4() + return id.hex + + +def decode_exception(data) -> Exception: + data = base64.standard_b64decode(data) + return loads_from_server(data) diff --git a/python/ray/tests/BUILD b/python/ray/tests/BUILD index 55fac64e5..588710e3a 100644 --- a/python/ray/tests/BUILD +++ b/python/ray/tests/BUILD @@ -96,6 +96,7 @@ py_test_module_list( "test_debug_tools.py", "test_experimental_client.py", "test_experimental_client_metadata.py", + "test_experimental_client_references.py", "test_experimental_client_terminate.py", "test_job.py", "test_memstat.py", diff --git a/python/ray/tests/test_experimental_client.py b/python/ray/tests/test_experimental_client.py index cbe52675f..1231b6730 100644 --- a/python/ray/tests/test_experimental_client.py +++ b/python/ray/tests/test_experimental_client.py @@ -142,7 +142,7 @@ def test_function_calling_function(ray_start_regular_shared): @ray.remote def f(): - print(f, f._name, g._name, g) + print(f, g) return ray.get(g.remote()) print(f, type(f)) diff --git a/python/ray/tests/test_experimental_client_references.py b/python/ray/tests/test_experimental_client_references.py new file mode 100644 index 000000000..9675b9c97 --- /dev/null +++ b/python/ray/tests/test_experimental_client_references.py @@ -0,0 +1,152 @@ +from ray.tests.test_experimental_client import ray_start_client_server +from ray.test_utils import wait_for_condition +import ray as real_ray +from ray.core.generated.gcs_pb2 import ActorTableData +from ray.experimental.client import _get_server_instance + + +def server_object_ref_count(n): + server = _get_server_instance() + assert server is not None + + def test_cond(): + if len(server.object_refs) == 0: + # No open clients + return n == 0 + client_id = list(server.object_refs.keys())[0] + return len(server.object_refs[client_id]) == n + + return test_cond + + +def server_actor_ref_count(n): + server = _get_server_instance() + assert server is not None + + def test_cond(): + if len(server.actor_refs) == 0: + # No running actors + return n == 0 + return len(server.actor_refs) == n + + return test_cond + + +def test_delete_refs_on_disconnect(ray_start_regular): + with ray_start_client_server() as ray: + + @ray.remote + def f(x): + return x + 2 + + thing1 = f.remote(6) # noqa + thing2 = ray.put("Hello World") # noqa + + # One put, one function -- the function result thing1 is + # in a different category, according to the raylet. + assert len(real_ray.objects()) == 2 + # But we're maintaining the reference + assert server_object_ref_count(3)() + # And can get the data + assert ray.get(thing1) == 8 + + # Close the client + ray.close() + + wait_for_condition(server_object_ref_count(0), timeout=5) + + def test_cond(): + return len(real_ray.objects()) == 0 + + wait_for_condition(test_cond, timeout=5) + + +def test_delete_ref_on_object_deletion(ray_start_regular): + with ray_start_client_server() as ray: + vals = { + "ref": ray.put("Hello World"), + "ref2": ray.put("This value stays"), + } + + del vals["ref"] + + wait_for_condition(server_object_ref_count(1), timeout=5) + + +def test_delete_actor_on_disconnect(ray_start_regular): + with ray_start_client_server() as ray: + + @ray.remote + class Accumulator: + def __init__(self): + self.acc = 0 + + def inc(self): + self.acc += 1 + + def get(self): + return self.acc + + actor = Accumulator.remote() + actor.inc.remote() + + assert server_actor_ref_count(1)() + + assert ray.get(actor.get.remote()) == 1 + + ray.close() + + wait_for_condition(server_actor_ref_count(0), timeout=5) + + def test_cond(): + alive_actors = [ + v for v in real_ray.actors().values() + if v["State"] != ActorTableData.DEAD + ] + return len(alive_actors) == 0 + + wait_for_condition(test_cond, timeout=10) + + +def test_delete_actor(ray_start_regular): + with ray_start_client_server() as ray: + + @ray.remote + class Accumulator: + def __init__(self): + self.acc = 0 + + def inc(self): + self.acc += 1 + + actor = Accumulator.remote() + actor.inc.remote() + actor2 = Accumulator.remote() + actor2.inc.remote() + + assert server_actor_ref_count(2)() + + del actor + + wait_for_condition(server_actor_ref_count(1), timeout=5) + + +def test_simple_multiple_references(ray_start_regular): + with ray_start_client_server() as ray: + + @ray.remote + class A: + def __init__(self): + self.x = ray.put("hi") + + def get(self): + return [self.x] + + a = A.remote() + ref1 = ray.get(a.get.remote())[0] + ref2 = ray.get(a.get.remote())[0] + del a + assert ray.get(ref1) == "hi" + del ref1 + assert ray.get(ref2) == "hi" + del ref2 diff --git a/src/ray/protobuf/ray_client.proto b/src/ray/protobuf/ray_client.proto index d4c392321..cdc3ee8aa 100644 --- a/src/ray/protobuf/ray_client.proto +++ b/src/ray/protobuf/ray_client.proto @@ -18,17 +18,24 @@ package ray.rpc; enum Type { DEFAULT = 0; } +// An argument to a ClientTask. message Arg { enum Locality { INTERNED = 0; REFERENCE = 1; } + + // The type of argument this is -- whether a data blob or a reference. Locality local = 1; + // The reference id, if a reference. bytes reference_id = 2; + // A data blob, if passed in-band. bytes data = 3; + // How to decode this data blob. Type type = 4; } +// Represents one unit of work to be executed by the server. message ClientTask { enum RemoteExecType { FUNCTION = 0; @@ -36,49 +43,69 @@ message ClientTask { METHOD = 2; STATIC_METHOD = 3; } + // Which type of work this request represents. RemoteExecType type = 1; + // A name parameter, if the payload can be called in more than one way (like a method on + // a payload object). string name = 2; + // A reference to the payload. bytes payload_id = 3; + // The parameters to pass to this call. repeated Arg args = 4; -} - -message RemoteRef { - bytes id = 1; - bytes handle = 2; + // The ID of the client namespace associated with the Datapath stream making this + // request. + string client_id = 5; } message ClientTaskTicket { - RemoteRef return_ref = 1; + // A reference to the returned value from the execution. + bytes return_id = 1; } +// Delivers data to the server message PutRequest { + // The data blob for the server to store. bytes data = 1; } message PutResponse { - RemoteRef ref = 1; + // The reference ID for the data that the server has stored. + bytes id = 1; } +// Requests data from the server. message GetRequest { - bytes handle = 1; + // The reference ID for the requested object data + bytes id = 1; + // Length of time to wait for data to be available, in seconds. Zero is no timeout. float timeout = 2; } message GetResponse { + // Whether or not the data was successfully retrieved bool valid = 1; + // The data blob, on success bytes data = 2; + // An error blob (for example, an exception) on failure. + bytes error = 3; } +// Waits for data to be ready on the server, with a timeout. message WaitRequest { - repeated bytes object_handles = 1; + // The IDs of the data to wait for ready status. + repeated bytes object_ids = 1; + // How many of the above ids to wait for before returning. int64 num_returns = 2; + // How long to wait for these IDs to become ready. double timeout = 3; + // The Client namespace associated with the Datapath stream that holds these IDs. + string client_id = 4; } message WaitResponse { bool valid = 1; - repeated RemoteRef ready_object_ids = 2; - repeated RemoteRef remaining_object_ids = 3; + repeated bytes ready_object_ids = 2; + repeated bytes remaining_object_ids = 3; } message ClusterInfoType { @@ -108,18 +135,19 @@ message ClusterInfoResponse { message TerminateRequest { message ActorTerminate { - bytes handle = 1; + bytes id = 1; bool no_restart = 2; } message TaskObjectTerminate { - bytes handle = 1; + bytes id = 1; bool force = 2; bool recursive = 3; } + string client_id = 1; oneof terminate_type { - ActorTerminate actor = 1; - TaskObjectTerminate task_object = 2; + ActorTerminate actor = 2; + TaskObjectTerminate task_object = 3; } } @@ -141,3 +169,40 @@ service RayletDriver { rpc ClusterInfo(ClusterInfoRequest) returns (ClusterInfoResponse) { } } + +message ReleaseRequest { + // The IDs to release from the server; the client connected on this stream no + // longer holds a reference to them. + repeated bytes ids = 1; +} + +message ReleaseResponse { + // For each requested ID, whether or not it was released. + repeated bool ok = 2; +} + +message DataRequest { + // An incrementing counter of request IDs on the Datapath, + // to match requests with responses asynchronously. + int32 req_id = 1; + oneof type { + GetRequest get = 2; + PutRequest put = 3; + ReleaseRequest release = 4; + } +} + +message DataResponse { + // The request id that this response matches with. + int32 req_id = 1; + oneof type { + GetResponse get = 2; + PutResponse put = 3; + ReleaseResponse release = 4; + } +} + +service RayletDataStreamer { + rpc Datapath(stream DataRequest) returns (stream DataResponse) { + } +}