From dc4b5c7aa3ffadf71e2ba15b04c83c23b203e585 Mon Sep 17 00:00:00 2001 From: Barak Michener Date: Tue, 8 Dec 2020 21:54:55 -0800 Subject: [PATCH] [ray_client] Passing actors to actors (#12585) * start building tests around passing handles to handles Change-Id: Ie8c3de5c8ce789c3ec8d29f0702df80ba598279f * clean up the switch statements by moving to a method, implement state tranfer, extend test Change-Id: Ie7b6493db3a6c203d3a0b262b8fbacb90e5cdbc5 * passing Change-Id: Id88dc0a41da1c9d5ba68f754c5b57141aae47beb * flush out tests Change-Id: If77c0f586e9e99449d494be4e85f854e4a7a4952 * formatting Change-Id: I497c07cee70b52453b221ed4393f04f6f560061e * fix python3.6 and other attributes Change-Id: I5a2c5231e8a021184d9dfc3e346df7f71fc93257 * address documentation Change-Id: I049d841ed1f85b7350c17c05da4a4d81d5cb03df * formatting Change-Id: I6a2b32a2466ffc9f03fc91ac17901b9c1a49505c * use the pickled handle as the id bytes for actors Change-Id: I9ddcb41d614de65d42d6f0382fe0faa7ad2c2ade * pydoc Change-Id: I9b32a0f383d5ff5ac052e61929b7ae3e42a89fc5 * format Change-Id: Iac0010bb990a4025a98139ab88700030b2e9e7f5 * todos Change-Id: I7b550800cf7499403e8a17b77484bc46f20f0afc * tests Change-Id: If8ebf6a335baeb113c1332acc930c41a6b4f5384 * fix lint Change-Id: I019f41e0ec341d39bbbbd39aa43d9fb5f8b57cf0 * nits Change-Id: I2e6813d8db34f4ce008326faa095d414c10eee95 * add some tricky, python3.6-troublesome type checking Change-Id: Ib887fc943a6e7084002bc13dfbe113b69b4d9317 --- python/ray/experimental/client/__init__.py | 97 ++++++--- python/ray/experimental/client/api.py | 84 +++++++- python/ray/experimental/client/common.py | 190 ++++++++++++++++-- .../client/server/core_ray_api.py | 60 +++++- .../ray/experimental/client/server/server.py | 68 ++++--- python/ray/experimental/client/worker.py | 55 +---- python/ray/tests/test_experimental_client.py | 67 +++++- 7 files changed, 491 insertions(+), 130 deletions(-) diff --git a/python/ray/experimental/client/__init__.py b/python/ray/experimental/client/__init__.py index 36d19ba56..a6bba39ed 100644 --- a/python/ray/experimental/client/__init__.py +++ b/python/ray/experimental/client/__init__.py @@ -7,34 +7,88 @@ import logging logger = logging.getLogger(__name__) -# _client_api has to be external to the API stub, below. -# Otherwise, ray.remote() that contains ray.remote() -# contains a reference to the RayAPIStub, therefore a -# reference to the _client_api, and then tries to pickle -# the thing. +# About these global variables: Ray 1.0 uses exported module functions to +# provide its API, and we need to match that. However, we want different +# behaviors depending on where, exactly, in the client stack this is running. +# +# The reason for these differences depends on what's being pickled and passed +# to functions, or functions inside functions. So there are three cases to care +# about +# +# (Python Client)-->(Python ClientServer)-->(Internal Raylet Process) +# +# * _client_api should be set if we're inside the client +# * _server_api should be set if we're inside the clientserver +# * Both will be set if we're running both (as in a test) +# * Neither should be set if we're inside the raylet (but we still need to shim +# from the client API surface to the Ray API) +# +# The job of RayAPIStub (below) delegates to the appropriate one of these +# depending on what's set or not. Then, all users importing the ray object +# from this package get the stub which routes them to the appropriate APIImpl. _client_api: Optional[APIImpl] = None +_server_api: Optional[APIImpl] = None + +# The reason for _is_server is a hack around the above comment while running +# tests. If we have both a client and a server trying to control these static +# variables then we need a way to decide which to use. In this case, both +# _client_api and _server_api are set. +# This boolean flips between the two +_is_server: bool = False @contextmanager def stash_api_for_tests(in_test: bool): - api = None + global _is_server + is_server = _is_server if in_test: - api = stash_api() - yield api + _is_server = True + yield _server_api if in_test: - restore_api(api) + _is_server = is_server -def stash_api() -> Optional[APIImpl]: +def _set_client_api(val: Optional[APIImpl]): global _client_api - a = _client_api + global _is_server + if _client_api is not None: + raise Exception("Trying to set more than one client API") + _client_api = val + _is_server = False + + +def _set_server_api(val: Optional[APIImpl]): + global _server_api + global _is_server + if _server_api is not None: + raise Exception("Trying to set more than one server API") + _server_api = val + _is_server = True + + +def reset_api(): + global _client_api + global _server_api + global _is_server _client_api = None - return a + _server_api = None + _is_server = False -def restore_api(api: Optional[APIImpl]): +def _get_client_api() -> APIImpl: global _client_api - _client_api = api + global _server_api + global _is_server + api = None + if _is_server: + api = _server_api + else: + api = _client_api + if api is None: + # We're inside a raylet worker + from ray.experimental.client.server.core_ray_api import CoreRayAPI + return CoreRayAPI() + return api class RayAPIStub: @@ -43,11 +97,10 @@ class RayAPIStub: secure: bool = False, metadata: List[Tuple[str, str]] = None, stub=None): - global _client_api from ray.experimental.client.worker import Worker _client_worker = Worker( conn_str, secure=secure, metadata=metadata, stub=stub) - _client_api = ClientAPI(_client_worker) + _set_client_api(ClientAPI(_client_worker)) def disconnect(self): global _client_api @@ -56,15 +109,9 @@ class RayAPIStub: _client_api = None def __getattr__(self, key: str): - global _client_api - self.__check_client_api() - return getattr(_client_api, key) - - def __check_client_api(self): - global _client_api - if _client_api is None: - from ray.experimental.client.server.core_ray_api import CoreRayAPI - _client_api = CoreRayAPI() + global _get_client_api + api = _get_client_api() + return getattr(api, key) ray = RayAPIStub() diff --git a/python/ray/experimental/client/api.py b/python/ray/experimental/client/api.py index 17d0d6a97..66ec61c17 100644 --- a/python/ray/experimental/client/api.py +++ b/python/ray/experimental/client/api.py @@ -11,35 +11,105 @@ from abc import ABC from abc import abstractmethod +from typing import TYPE_CHECKING, Any, Union +if TYPE_CHECKING: + from ray.experimental.client.common import ClientStub + from ray.experimental.client.common import ClientObjectRef + from ray._raylet import ObjectRef + + # Use the imports for type checking. This is a python 3.6 limitation. + # See https://www.python.org/dev/peps/pep-0563/ + PutType = Union[ClientObjectRef, ObjectRef] class APIImpl(ABC): + """ + APIImpl is the interface to implement for whichever version of the core + Ray API that needs abstracting when run in client mode. + """ + @abstractmethod - def get(self, *args, **kwargs): + def get(self, *args, **kwargs) -> Any: + """ + get is the hook stub passed on to replace `ray.get` + + Args: + args: opaque arguments + kwargs: opaque keyword arguments + """ pass @abstractmethod - def put(self, *args, **kwargs): + def put(self, vals: Any, *args, + **kwargs) -> Union["ClientObjectRef", "ObjectRef"]: + """ + put is the hook stub passed on to replace `ray.put` + + Args: + vals: The value or list of values to `put`. + args: opaque arguments + kwargs: opaque keyword arguments + """ pass @abstractmethod def wait(self, *args, **kwargs): + """ + wait is the hook stub passed on to replace `ray.wait` + + Args: + args: opaque arguments + kwargs: opaque keyword arguments + """ pass @abstractmethod def remote(self, *args, **kwargs): + """ + remote is the hook stub passed on to replace `ray.remote`. + + This sets up remote functions or actors, as the decorator, + but does not execute them. + + Args: + args: opaque arguments + kwargs: opaque keyword arguments + """ pass @abstractmethod - def call_remote(self, f, kind, *args, **kwargs): + def call_remote(self, instance: "ClientStub", *args, **kwargs): + """ + call_remote is called by stub objects to execute them remotely. + + This is used by stub objects in situations where they're called + with .remote, eg, `f.remote()` or `actor_cls.remote()`. + This allows the client stub objects to delegate execution to be + implemented in the most effective way whether it's in the client, + clientserver, or raylet worker. + + Args: + instance: The Client-side stub reference to a remote object + args: opaque arguments + kwargs: opaque keyword arguments + """ pass @abstractmethod - def close(self, *args, **kwargs): + def close(self) -> None: + """ + close cleans up an API connection by closing any channels or + shutting down any servers gracefully. + """ pass class ClientAPI(APIImpl): + """ + The Client-side methods corresponding to the ray API. Delegates + to the Client Worker that contains the connection to the ClientServer. + """ + def __init__(self, worker): self.worker = worker @@ -55,10 +125,10 @@ class ClientAPI(APIImpl): def remote(self, *args, **kwargs): return self.worker.remote(*args, **kwargs) - def call_remote(self, f, kind, *args, **kwargs): - return self.worker.call_remote(f, kind, *args, **kwargs) + def call_remote(self, instance: "ClientStub", *args, **kwargs): + return self.worker.call_remote(instance, *args, **kwargs) - def close(self, *args, **kwargs): + def close(self) -> None: return self.worker.close() def __getattr__(self, key: str): diff --git a/python/ray/experimental/client/common.py b/python/ray/experimental/client/common.py index d2ec7e041..cea5825e3 100644 --- a/python/ray/experimental/client/common.py +++ b/python/ray/experimental/client/common.py @@ -1,6 +1,7 @@ import ray.core.generated.ray_client_pb2 as ray_client_pb2 from ray.experimental.client import ray from typing import Any +from typing import Dict from ray import cloudpickle @@ -17,6 +18,9 @@ class ClientBaseRef: def __eq__(self, other): return self.id == other.id + def binary(self): + return self.id + class ClientObjectRef(ClientBaseRef): pass @@ -26,74 +30,222 @@ class ClientActorRef(ClientBaseRef): pass -class ClientRemoteFunc: +class ClientStub: + pass + + +class ClientRemoteFunc(ClientStub): + """ + A stub created on the Ray Client to represent a remote + function that can be exectued on the cluster. + + This class is allowed to be passed around between remote functions. + + Args: + _func: The actual function to execute remotely + _name: The original name of the function + _ref: The ClientObjectRef of the pickled code of the function, _func + _raylet_remote: The Raylet-side ray.remote_function.RemoteFunction + for this object + """ + def __init__(self, f): self._func = f self._name = f.__name__ self.id = None - self._raylet_remote_func = None + + # self._ref can be lazily instantiated. Rather than eagerly creating + # function data objects in the server we can put them just before we + # execute the function, especially in cases where many @ray.remote + # functions exist in a library and only a handful are ever executed by + # a user of the library. + # + # TODO(barakmich): This ref might actually be better as a serialized + # ObjectRef. This requires being able to serialize the ref without + # pinning it (as the lifetime of the ref is tied with the server, not + # the client) + self._ref = None + self._raylet_remote = None def __call__(self, *args, **kwargs): raise TypeError(f"Remote function cannot be called directly. " "Use {self._name}.remote method instead") def remote(self, *args, **kwargs): - return ray.call_remote(self, ray_client_pb2.ClientTask.FUNCTION, *args, - **kwargs) + return ray.call_remote(self, *args, **kwargs) + + def _get_ray_remote_impl(self): + if self._raylet_remote is None: + self._raylet_remote = ray.remote(self._func) + return self._raylet_remote def __repr__(self): - return "ClientRemoteFunc(%s, %s)" % (self._name, self.id) + return "ClientRemoteFunc(%s, %s)" % (self._name, self._ref) + + def _prepare_client_task(self) -> ray_client_pb2.ClientTask: + if self._ref is None: + self._ref = ray.put(self._func) + task = ray_client_pb2.ClientTask() + task.type = ray_client_pb2.ClientTask.FUNCTION + task.name = self._name + task.payload_id = self._ref.id + return task -class ClientActorClass: +class ClientActorClass(ClientStub): + """ A stub created on the Ray Client to represent an actor class. + + It is wrapped by ray.remote and can be executed on the cluster. + + Args: + actor_cls: The actual class to execute remotely + _name: The original name of the class + _ref: The ClientObjectRef of the pickled `actor_cls` + _raylet_remote: The Raylet-side ray.ActorClass for this object + """ + def __init__(self, actor_cls): self.actor_cls = actor_cls self._name = actor_cls.__name__ + self._ref = None + self._raylet_remote = None def __call__(self, *args, **kwargs): raise TypeError(f"Remote actor cannot be instantiated directly. " "Use {self._name}.remote() instead") + def __getstate__(self) -> Dict: + state = { + "actor_cls": self.actor_cls, + "_name": self._name, + "_ref": self._ref, + } + return state + + def __setstate__(self, state: Dict) -> None: + self.actor_cls = state["actor_cls"] + self._name = state["_name"] + self._ref = state["_ref"] + def remote(self, *args, **kwargs): # Actually instantiate the actor - ref = ray.call_remote(self, ray_client_pb2.ClientTask.ACTOR, *args, - **kwargs) - return ClientActorHandle(ref, self) + ref = ray.call_remote(self, *args, **kwargs) + return ClientActorHandle(ClientActorRef(ref.id), self) def __repr__(self): - return "ClientRemoteActor(%s, %s)" % (self._name, self.id) + return "ClientRemoteActor(%s, %s)" % (self._name, self._ref) def __getattr__(self, key): + if key not in self.__dict__: + raise AttributeError("Not a class attribute") raise NotImplementedError("static methods") + def _prepare_client_task(self) -> ray_client_pb2.ClientTask: + if self._ref is None: + self._ref = ray.put(self.actor_cls) + task = ray_client_pb2.ClientTask() + task.type = ray_client_pb2.ClientTask.ACTOR + task.name = self._name + task.payload_id = self._ref.id + return task -class ClientActorHandle: - def __init__(self, actor_id: ClientActorRef, + +class ClientActorHandle(ClientStub): + """Client-side stub for instantiated actor. + + A stub created on the Ray Client to represent a remote actor that + has been started on the cluster. This class is allowed to be passed + around between remote functions. + + Args: + actor_ref: A reference to the running actor given to the client. This + is a serialized version of the actual handle as an opaque token. + actor_class: A reference to the ClientActorClass that this actor was + instantiated from. + _real_actor_handle: Cached copy of the Raylet-side + ray.actor.ActorHandle contained in the actor_id ref. + """ + + def __init__(self, actor_ref: ClientActorRef, actor_class: ClientActorClass): - self.actor_id = actor_id + self.actor_ref = actor_ref self.actor_class = actor_class + self._real_actor_handle = None + + def _get_ray_remote_impl(self): + if self._real_actor_handle is None: + self._real_actor_handle = cloudpickle.loads(self.actor_ref.id) + return self._real_actor_handle + + def __getstate__(self) -> Dict: + state = { + "actor_ref": self.actor_ref, + "actor_class": self.actor_class, + "_real_actor_handle": self._real_actor_handle, + } + return state + + def __setstate__(self, state: Dict) -> None: + self.actor_ref = state["actor_ref"] + self.actor_class = state["actor_class"] + self._real_actor_handle = state["_real_actor_handle"] def __getattr__(self, key): return ClientRemoteMethod(self, key) + def __repr__(self): + return "ClientActorHandle(%s)" % (self.actor_ref.id.hex()) + + +class ClientRemoteMethod(ClientStub): + """A stub for a method on a remote actor. + + Can be annotated with exection options. + + Args: + actor_handle: A reference to the ClientActorHandle that generated + this method and will have this method called upon it. + method_name: The name of this method + """ -class ClientRemoteMethod: def __init__(self, actor_handle: ClientActorHandle, method_name: str): self.actor_handle = actor_handle self.method_name = method_name - self._name = "%s.%s" % (self.actor_handle.actor_class._name, - self.method_name) def __call__(self, *args, **kwargs): raise TypeError(f"Remote method cannot be called directly. " "Use {self._name}.remote() instead") + def _get_ray_remote_impl(self): + return getattr(self.actor_handle._get_ray_remote_impl(), + self.method_name) + + def __getstate__(self) -> Dict: + state = { + "actor_handle": self.actor_handle, + "method_name": self.method_name, + } + return state + + def __setstate__(self, state: Dict) -> None: + self.actor_handle = state["actor_handle"] + self.method_name = state["method_name"] + def remote(self, *args, **kwargs): - return ray.call_remote(self, ray_client_pb2.ClientTask.METHOD, *args, - **kwargs) + return ray.call_remote(self, *args, **kwargs) def __repr__(self): - return "ClientRemoteMethod(%s, %s)" % (self._name, self.actor_id) + name = "%s.%s" % (self.actor_handle.actor_class._name, + self.method_name) + return "ClientRemoteMethod(%s, %s)" % (name, + self.actor_handle.actor_id) + + def _prepare_client_task(self) -> ray_client_pb2.ClientTask: + task = ray_client_pb2.ClientTask() + task.type = ray_client_pb2.ClientTask.METHOD + task.name = self.method_name + task.payload_id = self.actor_handle.actor_ref.id + return task def convert_from_arg(pb) -> Any: diff --git a/python/ray/experimental/client/server/core_ray_api.py b/python/ray/experimental/client/server/core_ray_api.py index 3ebb36c32..83cbc36c0 100644 --- a/python/ray/experimental/client/server/core_ray_api.py +++ b/python/ray/experimental/client/server/core_ray_api.py @@ -7,18 +7,29 @@ # While the stub is trivial, it allows us to check that the calls we're # making into the core-ray module are contained and well-defined. +from typing import Any +from typing import Union + import ray from ray.experimental.client.api import APIImpl -from ray.experimental.client.common import ClientRemoteFunc +from ray.experimental.client.common import ClientObjectRef +from ray.experimental.client.common import ClientStub class CoreRayAPI(APIImpl): + """ + Implements the equivalent client-side Ray API by simply passing along to + the Core Ray API. Primarily used inside of Ray Workers as a trampoline back + to core ray when passed client stubs. + """ + def get(self, *args, **kwargs): return ray.get(*args, **kwargs) - def put(self, *args, **kwargs): - return ray.put(*args, **kwargs) + def put(self, vals: Any, *args, + **kwargs) -> Union[ClientObjectRef, ray._raylet.ObjectRef]: + return ray.put(vals, *args, **kwargs) def wait(self, *args, **kwargs): return ray.wait(*args, **kwargs) @@ -26,12 +37,10 @@ class CoreRayAPI(APIImpl): def remote(self, *args, **kwargs): return ray.remote(*args, **kwargs) - def call_remote(self, f: ClientRemoteFunc, kind: int, *args, **kwargs): - if f._raylet_remote_func is None: - f._raylet_remote_func = ray.remote(f._func) - return f._raylet_remote_func.remote(*args, **kwargs) + def call_remote(self, instance: ClientStub, *args, **kwargs): + return instance._get_ray_remote_impl().remote(*args, **kwargs) - def close(self, *args, **kwargs): + def close(self) -> None: return None # Allow for generic fallback to ray.* in remote methods. This allows calls @@ -39,3 +48,38 @@ class CoreRayAPI(APIImpl): # doesn't currently support them. def __getattr__(self, key: str): return getattr(ray, key) + + +class RayServerAPI(CoreRayAPI): + """ + Ray Client server-side API shim. By default, simply calls the default Core + Ray API calls, but also accepts scheduling calls from functions running + inside of other remote functions that need to create more work. + """ + + def __init__(self, server_instance): + self.server = server_instance + + # Wrap single item into list if needed before calling server put. + def put(self, vals: Any, *args, **kwargs) -> ClientObjectRef: + to_put = [] + single = False + if isinstance(vals, list): + to_put = vals + else: + single = True + to_put.append(vals) + + out = [self._put(x) for x in to_put] + if single: + out = out[0] + return out + + def _put(self, val: Any): + resp = self.server._put_and_retain_obj(val) + return ClientObjectRef(resp.id) + + def call_remote(self, instance: ClientStub, *args, **kwargs): + task = instance._prepare_client_task() + ticket = self.server.Schedule(task, prepared_args=args) + return ClientObjectRef(ticket.return_id) diff --git a/python/ray/experimental/client/server/server.py b/python/ray/experimental/client/server/server.py index e42ea8db4..a2958f6d1 100644 --- a/python/ray/experimental/client/server/server.py +++ b/python/ray/experimental/client/server/server.py @@ -7,10 +7,10 @@ import ray.core.generated.ray_client_pb2 as ray_client_pb2 import ray.core.generated.ray_client_pb2_grpc as ray_client_pb2_grpc import time import inspect -from ray.experimental.client import stash_api_for_tests +from ray.experimental.client import stash_api_for_tests, _set_server_api from ray.experimental.client.common import convert_from_arg from ray.experimental.client.common import ClientObjectRef -from ray.experimental.client.common import ClientRemoteFunc +from ray.experimental.client.server.core_ray_api import RayServerAPI logger = logging.getLogger(__name__) @@ -32,12 +32,16 @@ class RayletServicer(ray_client_pb2_grpc.RayletDriverServicer): item_ser = cloudpickle.dumps(item) return ray_client_pb2.GetResponse(valid=True, data=item_ser) - def PutObject(self, request, context=None): + def PutObject(self, request, context=None) -> ray_client_pb2.PutResponse: obj = cloudpickle.loads(request.data) + objectref = self._put_and_retain_obj(obj) + return ray_client_pb2.PutResponse(id=objectref.binary()) + + def _put_and_retain_obj(self, obj) -> ray.ObjectRef: objectref = ray.put(obj) self.object_refs[objectref.binary()] = objectref logger.info("put: %s" % objectref) - return ray_client_pb2.PutResponse(id=objectref.binary()) + return objectref def WaitObject(self, request, context=None) -> ray_client_pb2.WaitResponse: object_refs = [cloudpickle.loads(o) for o in request.object_refs] @@ -70,70 +74,83 @@ class RayletServicer(ray_client_pb2_grpc.RayletDriverServicer): ready_object_ids=ready_object_ids, remaining_object_ids=remaining_object_ids) - def Schedule(self, task, context=None) -> ray_client_pb2.ClientTaskTicket: + def Schedule(self, task, context=None, + prepared_args=None) -> ray_client_pb2.ClientTaskTicket: logger.info("schedule: %s %s" % (task.name, ray_client_pb2.ClientTask.RemoteExecType.Name(task.type))) if task.type == ray_client_pb2.ClientTask.FUNCTION: - return self._schedule_function(task, context) + return self._schedule_function(task, context, prepared_args) elif task.type == ray_client_pb2.ClientTask.ACTOR: - return self._schedule_actor(task, context) + return self._schedule_actor(task, context, prepared_args) elif task.type == ray_client_pb2.ClientTask.METHOD: - return self._schedule_method(task, context) + return self._schedule_method(task, context, prepared_args) else: raise NotImplementedError( "Unimplemented Schedule task type: %s" % ray_client_pb2.ClientTask.RemoteExecType.Name(task.type)) - def _schedule_method(self, task: ray_client_pb2.ClientTask, - context=None) -> ray_client_pb2.ClientTaskTicket: + def _schedule_method( + self, + task: ray_client_pb2.ClientTask, + context=None, + prepared_args=None) -> ray_client_pb2.ClientTaskTicket: actor_handle = self.actor_refs.get(task.payload_id) if actor_handle is None: raise Exception( "Can't run an actor the server doesn't have a handle for") - arglist = _convert_args(task.args) + arglist = _convert_args(task.args, prepared_args) with stash_api_for_tests(self._test_mode): output = getattr(actor_handle, task.name).remote(*arglist) self.object_refs[output.binary()] = output return ray_client_pb2.ClientTaskTicket(return_id=output.binary()) - def _schedule_actor(self, task: ray_client_pb2.ClientTask, - context=None) -> ray_client_pb2.ClientTaskTicket: + def _schedule_actor(self, + task: ray_client_pb2.ClientTask, + context=None, + prepared_args=None) -> ray_client_pb2.ClientTaskTicket: with stash_api_for_tests(self._test_mode): if task.payload_id not in self.registered_actor_classes: actor_class_ref = self.object_refs[task.payload_id] actor_class = ray.get(actor_class_ref) if not inspect.isclass(actor_class): raise Exception("Attempting to schedule actor that " - "isn't a ClientActorClass.") + "isn't a class.") reg_class = ray.remote(actor_class) self.registered_actor_classes[task.payload_id] = reg_class remote_class = self.registered_actor_classes[task.payload_id] - arglist = _convert_args(task.args) + arglist = _convert_args(task.args, prepared_args) actor = remote_class.remote(*arglist) - actor_ref = actor._actor_id - self.actor_refs[actor_ref.binary()] = actor - return ray_client_pb2.ClientTaskTicket(return_id=actor_ref.binary()) + actorhandle = cloudpickle.dumps(actor) + self.actor_refs[actorhandle] = actor + return ray_client_pb2.ClientTaskTicket(return_id=actorhandle) - def _schedule_function(self, task: ray_client_pb2.ClientTask, - context=None) -> ray_client_pb2.ClientTaskTicket: + def _schedule_function( + self, + task: ray_client_pb2.ClientTask, + context=None, + prepared_args=None) -> ray_client_pb2.ClientTaskTicket: if task.payload_id not in self.function_refs: funcref = self.object_refs[task.payload_id] func = ray.get(funcref) - if not isinstance(func, ClientRemoteFunc): + if not inspect.isfunction(func): raise Exception("Attempting to schedule function that " - "isn't a ClientRemoteFunc.") - self.function_refs[task.payload_id] = func + "isn't a function.") + self.function_refs[task.payload_id] = ray.remote(func) remote_func = self.function_refs[task.payload_id] - arglist = _convert_args(task.args) + arglist = _convert_args(task.args, prepared_args) # Prepare call if we're in a test with stash_api_for_tests(self._test_mode): output = remote_func.remote(*arglist) + if output.binary() in self.object_refs: + raise Exception("already found it") self.object_refs[output.binary()] = output return ray_client_pb2.ClientTaskTicket(return_id=output.binary()) -def _convert_args(arg_list): +def _convert_args(arg_list, prepared_args=None): + if prepared_args is not None: + return prepared_args out = [] for arg in arg_list: t = convert_from_arg(arg) @@ -147,6 +164,7 @@ def _convert_args(arg_list): def serve(connection_str, test_mode=False): server = grpc.server(futures.ThreadPoolExecutor(max_workers=10)) task_servicer = RayletServicer(test_mode=test_mode) + _set_server_api(RayServerAPI(task_servicer)) ray_client_pb2_grpc.add_RayletDriverServicer_to_server( task_servicer, server) server.add_insecure_port(connection_str) diff --git a/python/ray/experimental/client/worker.py b/python/ray/experimental/client/worker.py index 8c01bea34..87e5f6897 100644 --- a/python/ray/experimental/client/worker.py +++ b/python/ray/experimental/client/worker.py @@ -3,6 +3,7 @@ It implements the Ray API functions that are forwarded through grpc calls to the server. """ import inspect +import logging from typing import List from typing import Tuple @@ -14,11 +15,11 @@ import ray.core.generated.ray_client_pb2 as ray_client_pb2 import ray.core.generated.ray_client_pb2_grpc as ray_client_pb2_grpc from ray.experimental.client.common import convert_to_arg from ray.experimental.client.common import ClientObjectRef -from ray.experimental.client.common import ClientActorRef from ray.experimental.client.common import ClientActorClass -from ray.experimental.client.common import ClientRemoteMethod from ray.experimental.client.common import ClientRemoteFunc +logger = logging.getLogger(__name__) + class Worker: def __init__(self, @@ -130,50 +131,14 @@ class Worker: raise TypeError("The @ray.remote decorator must be applied to " "either a function or to a class.") - def call_remote(self, instance, kind, *args, **kwargs): - ticket = None - if kind == ray_client_pb2.ClientTask.FUNCTION: - ticket = self._put_and_schedule(instance, kind, *args, **kwargs) - elif kind == ray_client_pb2.ClientTask.ACTOR: - ticket = self._put_and_schedule(instance, kind, *args, **kwargs) - return ClientActorRef(ticket.return_id) - elif kind == ray_client_pb2.ClientTask.METHOD: - ticket = self._call_method(instance, *args, **kwargs) - - if ticket is None: - raise Exception( - "Couldn't call_remote on %s for type %s" % (instance, kind)) + def call_remote(self, instance, *args, **kwargs): + task = instance._prepare_client_task() + for arg in args: + pb_arg = convert_to_arg(arg) + task.args.append(pb_arg) + logging.debug("Scheduling %s" % task) + ticket = self.server.Schedule(task, metadata=self.metadata) return ClientObjectRef(ticket.return_id) - def _call_method(self, instance: ClientRemoteMethod, *args, **kwargs): - if not isinstance(instance, ClientRemoteMethod): - raise TypeError("Client not passing a ClientRemoteMethod stub") - task = ray_client_pb2.ClientTask() - task.type = ray_client_pb2.ClientTask.METHOD - task.name = instance.method_name - task.payload_id = instance.actor_handle.actor_id.id - for arg in args: - pb_arg = convert_to_arg(arg) - task.args.append(pb_arg) - ticket = self.server.Schedule(task, metadata=self.metadata) - return ticket - - def _put_and_schedule(self, item, task_type, *args, **kwargs): - if isinstance(item, ClientRemoteFunc): - ref = self._put(item) - elif isinstance(item, ClientActorClass): - ref = self._put(item.actor_cls) - else: - raise TypeError("Client not passing a ClientRemoteFunc stub") - task = ray_client_pb2.ClientTask() - task.type = task_type - task.name = item._name - task.payload_id = ref.id - for arg in args: - pb_arg = convert_to_arg(arg) - task.args.append(pb_arg) - ticket = self.server.Schedule(task, metadata=self.metadata) - return ticket - def close(self): self.channel.close() diff --git a/python/ray/tests/test_experimental_client.py b/python/ray/tests/test_experimental_client.py index 8fc07590e..430574dd2 100644 --- a/python/ray/tests/test_experimental_client.py +++ b/python/ray/tests/test_experimental_client.py @@ -2,7 +2,7 @@ import pytest from contextlib import contextmanager import ray.experimental.client.server.server as ray_client_server -from ray.experimental.client import ray +from ray.experimental.client import ray, reset_api from ray.experimental.client.common import ClientObjectRef @@ -13,6 +13,7 @@ def ray_start_client_server(): yield ray ray.disconnect() server.stop(0) + reset_api() def test_real_ray_fallback(ray_start_regular_shared): @@ -170,6 +171,70 @@ def test_basic_actor(ray_start_regular_shared): assert count == 2 +def test_pass_handles(ray_start_regular_shared): + """ + Test that passing client handles to actors and functions to remote actors + in functions (on the server or raylet side) works transparently to the + caller. + """ + with ray_start_client_server() as ray: + + @ray.remote + class ExecActor: + def exec(self, f, x): + return ray.get(f.remote(x)) + + def exec_exec(self, actor, f, x): + return ray.get(actor.exec.remote(f, x)) + + @ray.remote + def fact(x): + out = 1 + while x > 0: + out = out * x + x -= 1 + return out + + @ray.remote + def func_exec(f, x): + return ray.get(f.remote(x)) + + @ray.remote + def func_actor_exec(actor, f, x): + return ray.get(actor.exec.remote(f, x)) + + @ray.remote + def sneaky_func_exec(obj, x): + return ray.get(obj["f"].remote(x)) + + @ray.remote + def sneaky_actor_exec(obj, x): + return ray.get(obj["actor"].exec.remote(obj["f"], x)) + + def local_fact(x): + if x <= 0: + return 1 + return x * local_fact(x - 1) + + assert ray.get(fact.remote(7)) == local_fact(7) + assert ray.get(func_exec.remote(fact, 8)) == local_fact(8) + test_obj = {} + test_obj["f"] = fact + assert ray.get(sneaky_func_exec.remote(test_obj, 5)) == local_fact(5) + actor_handle = ExecActor.remote() + assert ray.get(actor_handle.exec.remote(fact, 7)) == local_fact(7) + assert ray.get(func_actor_exec.remote(actor_handle, fact, + 10)) == local_fact(10) + second_actor = ExecActor.remote() + assert ray.get(actor_handle.exec_exec.remote(second_actor, fact, + 9)) == local_fact(9) + test_actor_obj = {} + test_actor_obj["actor"] = second_actor + test_actor_obj["f"] = fact + assert ray.get(sneaky_actor_exec.remote(test_actor_obj, + 4)) == local_fact(4) + + if __name__ == "__main__": import sys sys.exit(pytest.main(["-v", __file__]))