diff --git a/BUILD.bazel b/BUILD.bazel index 0cc65ef8f..9f5d256f5 100644 --- a/BUILD.bazel +++ b/BUILD.bazel @@ -1954,24 +1954,13 @@ copy_to_workspace( ) genrule( - name = "ray_pkg", + name = "install_py_proto", srcs = [ - ":cp_raylet_so", - ":cp_streaming", - ":python_sources", ":cp_all_py_proto", - ":cp_redis", - ":cp_libray_redis_module", - ":cp_raylet", - ":cp_gcs_server", - ":cp_plasma_store_server", "//streaming:copy_streaming_py_proto", ], - outs = ["ray_pkg.out"], + outs = ["install_py_proto.out"], cmd = """ - if [ "$${OSTYPE-}" = "msys" ]; then - ln -P -f -- python/ray/_raylet.so python/ray/_raylet.pyd - fi # NOTE(hchen): Protobuf doesn't allow specifying Python package name. So we use this `sed` # command to change the import path in the generated file. # shellcheck disable=SC2006 @@ -1983,3 +1972,26 @@ genrule( """, local = 1, ) + +genrule( + name = "ray_pkg", + srcs = [ + ":cp_raylet_so", + ":cp_streaming", + ":python_sources", + ":install_py_proto", + ":cp_redis", + ":cp_libray_redis_module", + ":cp_raylet", + ":cp_gcs_server", + ":cp_plasma_store_server", + ], + outs = ["ray_pkg.out"], + cmd = """ + if [ "$${OSTYPE-}" = "msys" ]; then + ln -P -f -- python/ray/_raylet.so python/ray/_raylet.pyd + fi + echo "$${PWD}" > $@ + """, + local = 1, +) diff --git a/python/ray/experimental/client/__init__.py b/python/ray/experimental/client/__init__.py index 12b7322a6..b849b3424 100644 --- a/python/ray/experimental/client/__init__.py +++ b/python/ray/experimental/client/__init__.py @@ -1,44 +1,60 @@ -import ray -from ray.experimental.client.worker import Worker +from ray.experimental.client.api import ClientAPI +from ray.experimental.client.api import APIImpl from typing import Optional -_client_worker: Optional[Worker] = None -_in_cluster: bool = True +import logging + +logger = logging.getLogger(__name__) + +_client_api: Optional[APIImpl] = None + + +def check_client_api(): + global _client_api + if _client_api is None: + logger.info( + "No client API initialized: probably a worker, using core ray") + from ray.experimental.client.core_ray_api import set_client_api_as_ray + set_client_api_as_ray() def get(*args, **kwargs): - global _client_worker - global _in_cluster - if _in_cluster: - return ray.get(*args, **kwargs) - else: - return _client_worker.get(*args, **kwargs) + global _client_api + check_client_api() + return _client_api.get(*args, **kwargs) def put(*args, **kwargs): - global _client_worker - global _in_cluster - if _in_cluster: - return ray.put(*args, **kwargs) - else: - return _client_worker.put(*args, **kwargs) + global _client_api + check_client_api() + return _client_api.put(*args, **kwargs) def remote(*args, **kwargs): - pass + global _client_api + check_client_api() + return _client_api.remote(*args, **kwargs) + + +def call_remote(f, *args, **kwargs): + global _client_api + check_client_api() + return _client_api.call_remote(f, *args, **kwargs) def connect(conn_str): - global _in_cluster - global _client_worker - _in_cluster = False + global _client_api + from ray.experimental.client.worker import Worker _client_worker = Worker(conn_str) + _client_api = ClientAPI(_client_worker) def disconnect(): - global _in_cluster - global _client_worker - if _client_worker is not None: - _client_worker.close() - _in_cluster = True - _client_worker = None + global _client_api + _client_api.close() + _client_api = None + + +def _set_client_api(api: Optional[APIImpl]): + global _client_api + _client_api = api diff --git a/python/ray/experimental/client/api.py b/python/ray/experimental/client/api.py new file mode 100644 index 000000000..14f3705fc --- /dev/null +++ b/python/ray/experimental/client/api.py @@ -0,0 +1,55 @@ +# This file defines an interface and client-side API stub +# for referring either to the core Ray API or the same interface +# from the Ray client. +# +# In tandem with __init__.py, we want to expose an API that's +# close to `python/ray/__init__.py` but with more than one implementation. +# The stubs in __init__ should call into a well-defined interface. +# Only the core Ray API implementation should actually `import ray` +# (and thus import all the raylet worker C bindings and such). +# But to make sure that we're matching these calls, we define this API. + +from abc import ABC +from abc import abstractmethod + + +class APIImpl(ABC): + @abstractmethod + def get(self, *args, **kwargs): + pass + + @abstractmethod + def put(self, *args, **kwargs): + pass + + @abstractmethod + def remote(self, *args, **kwargs): + pass + + @abstractmethod + def call_remote(self, f, *args, **kwargs): + pass + + @abstractmethod + def close(self, *args, **kwargs): + pass + + +class ClientAPI(APIImpl): + def __init__(self, worker): + self.worker = worker + + def get(self, *args, **kwargs): + return self.worker.get(*args, **kwargs) + + def put(self, *args, **kwargs): + return self.worker.put(*args, **kwargs) + + def remote(self, *args, **kwargs): + return self.worker.remote(*args, **kwargs) + + def call_remote(self, f, *args, **kwargs): + return self.worker.call_remote(f, *args, **kwargs) + + def close(self, *args, **kwargs): + return self.worker.close() diff --git a/python/ray/experimental/client/client_app.py b/python/ray/experimental/client/client_app.py index f78317493..b0f6e6e21 100644 --- a/python/ray/experimental/client/client_app.py +++ b/python/ray/experimental/client/client_app.py @@ -2,7 +2,44 @@ import ray.experimental.client as ray ray.connect("localhost:50051") + +@ray.remote +def plus2(x): + return x + 2 + + +@ray.remote +def fact(x): + print(x, type(fact)) + if x <= 0: + return 1 + # This hits the "nested tasks" issue + # https://github.com/ray-project/ray/issues/3644 + # So we're on the right track! + return ray.get(fact.remote(x - 1)) * x + + objectref = ray.put("hello world") + +# `ClientObjectRef(...)` print(objectref) +# `hello world` print(ray.get(objectref)) + +ref2 = plus2.remote(234) +# `ClientObjectRef(...)` +print(ref2) +# `236` +print(ray.get(ref2)) + +ref3 = fact.remote(20) +# `ClientObjectRef(...)` +print(ref3) +# `2432902008176640000` +print(ray.get(ref3)) + +# Reuse the cached ClientRemoteFunc object +ref4 = fact.remote(5) +# `120` +print(ray.get(ref4)) diff --git a/python/ray/experimental/client/common.py b/python/ray/experimental/client/common.py new file mode 100644 index 000000000..044898fb2 --- /dev/null +++ b/python/ray/experimental/client/common.py @@ -0,0 +1,55 @@ +import ray.core.generated.ray_client_pb2 as ray_client_pb2 +from ray.experimental.client import call_remote +from typing import Any +from ray import cloudpickle + + +class ClientObjectRef: + def __init__(self, id): + self.id = id + + def __repr__(self): + return "ClientObjectRef(%s)" % self.id.hex() + + +class ClientRemoteFunc: + def __init__(self, f): + self._func = f + self._name = f.__name__ + self.id = None + self._raylet_remote_func = None + + def __call__(self, *args, **kwargs): + raise TypeError(f"Remote function cannot be called directly. " + "Use {self._name}.remote method instead") + + def remote(self, *args, **kwargs): + if self._raylet_remote_func is not None: + return self._raylet_remote_func.remote(*args, **kwargs) + return call_remote(self, *args, **kwargs) + + def __repr__(self): + return "ClientRemoteFunc(%s, %s)" % (self._name, self.id) + + def set_remote_func(self, func): + self._raylet_remote_func = func + + +def convert_from_arg(pb) -> Any: + if pb.local == ray_client_pb2.Arg.Locality.REFERENCE: + return ClientObjectRef(pb.reference_id) + elif pb.local == ray_client_pb2.Arg.Locality.INTERNED: + return cloudpickle.loads(pb.data) + + raise Exception("convert_from_arg: Uncovered locality enum") + + +def convert_to_arg(val): + out = ray_client_pb2.Arg() + if isinstance(val, ClientObjectRef): + out.local = ray_client_pb2.Arg.Locality.REFERENCE + out.reference_id = val.id + else: + out.local = ray_client_pb2.Arg.Locality.INTERNED + out.data = cloudpickle.dumps(val) + return out diff --git a/python/ray/experimental/client/core_ray_api.py b/python/ray/experimental/client/core_ray_api.py new file mode 100644 index 000000000..52aa009db --- /dev/null +++ b/python/ray/experimental/client/core_ray_api.py @@ -0,0 +1,34 @@ +# Along with `api.py` this is the stub that interfaces with +# the real (C-binding, raylet) ray core. +# +# Ideally, the first import line is the only time we actually +# import ray in this library (excluding the main function for the server) +# +# While the stub is trivial, it allows us to check that the calls we're +# making into the core-ray module are contained and well-defined. + +import ray + +from ray.experimental.client.api import APIImpl + + +class CoreRayAPI(APIImpl): + def get(self, *args, **kwargs): + return ray.get(*args, **kwargs) + + def put(self, *args, **kwargs): + return ray.put(*args, **kwargs) + + def remote(self, *args, **kwargs): + return ray.remote(*args, **kwargs) + + def call_remote(self, f, *args, **kwargs): + return f.remote(*args, **kwargs) + + def close(self, *args, **kwargs): + return None + + +def set_client_api_as_ray(): + ray_api = CoreRayAPI() + ray.experimental.client._set_client_api(ray_api) diff --git a/python/ray/experimental/client/server.py b/python/ray/experimental/client/server.py index 804fc834d..489794e4a 100644 --- a/python/ray/experimental/client/server.py +++ b/python/ray/experimental/client/server.py @@ -6,32 +6,62 @@ import ray import ray.core.generated.ray_client_pb2 as ray_client_pb2 import ray.core.generated.ray_client_pb2_grpc as ray_client_pb2_grpc import time +from ray.experimental.client.core_ray_api import set_client_api_as_ray +from ray.experimental.client.common import convert_from_arg +from ray.experimental.client.common import ClientObjectRef +from ray.experimental.client.common import ClientRemoteFunc + +logger = logging.getLogger(__name__) class RayletServicer(ray_client_pb2_grpc.RayletDriverServicer): def __init__(self): - self.realref = {} + self.object_refs = {} + self.function_refs = {} def GetObject(self, request, context=None): - objectref = self.realref[request.id] - print("get: %s" % objectref) - item = ray.get(objectref) - if item is None: + if request.id not in self.object_refs: return ray_client_pb2.GetResponse(valid=False) - data = cloudpickle.loads(item) - return ray_client_pb2.GetResponse(valid=True, data=data) + objectref = self.object_refs[request.id] + logger.info("get: %s" % objectref) + item = ray.get(objectref) + item_ser = cloudpickle.dumps(item) + return ray_client_pb2.GetResponse(valid=True, data=item_ser) def PutObject(self, request, context=None): - data = cloudpickle.dumps(request.data) - objectref = ray.put(data) - self.realref[objectref.binary()] = objectref - print("put: %s" % objectref) + obj = cloudpickle.loads(request.data) + objectref = ray.put(obj) + self.object_refs[objectref.binary()] = objectref + logger.info("put: %s" % objectref) return ray_client_pb2.PutResponse(id=objectref.binary()) def Schedule(self, task, context=None): - context.set_code(grpc.StatusCode.UNIMPLEMENTED) - context.set_details("Unimplemented") - return ray_client_pb2.TaskTicket() + logger.info("schedule: %s" % task) + if task.payload_id not in self.function_refs: + funcref = self.object_refs[task.payload_id] + func = ray.get(funcref) + if not isinstance(func, ClientRemoteFunc): + raise Exception("Attempting to schedule something that " + "isn't a ClientRemoteFunc") + ray_remote = ray.remote(func._func) + func.set_remote_func(ray_remote) + self.function_refs[task.payload_id] = func + remote_func = self.function_refs[task.payload_id] + arglist = _convert_args(task.args) + output = remote_func.remote(*arglist) + self.object_refs[output.binary()] = output + return ray_client_pb2.ClientTaskTicket(return_id=output.binary()) + + +def _convert_args(arg_list): + out = [] + for arg in arg_list: + t = convert_from_arg(arg) + if isinstance(t, ClientObjectRef): + out.append(ray.ObjectRef(t.id)) + else: + out.append(t) + return out def serve(connection_str): @@ -45,8 +75,10 @@ def serve(connection_str): if __name__ == "__main__": - logging.basicConfig() + logging.basicConfig(level="INFO") + # TODO(barakmich): Perhaps wrap ray init ray.init() + set_client_api_as_ray() server = serve("0.0.0.0:50051") try: while True: diff --git a/python/ray/experimental/client/util.py b/python/ray/experimental/client/util.py new file mode 100644 index 000000000..f571b4a6d --- /dev/null +++ b/python/ray/experimental/client/util.py @@ -0,0 +1,26 @@ +from ray import cloudpickle +from ray.experimental.client.worker import ClientObjectRef + +import ray +import ray.core.generated.ray_client_pb2 as ray_client_pb2 + + +def dump_args_proto(arg): + if arg.local == ray_client_pb2.Arg.Locality.INTERNED: + return cloudpickle.loads(arg.data) + else: + # TODO(barakmich): This is a dirty hack that assumes the + # server maintains a reference to the ID we've been given + ref = ray.ObjectRef(arg.reference_id) + return ray.get(ref) + + +def load_args_proto(thing): + arg = ray_client_pb2.Arg() + if isinstance(thing, ClientObjectRef): + arg.local = ray_client_pb2.Arg.Locality.REFERENCE + arg.reference_id = thing.id + else: + arg.local = ray_client_pb2.Arg.Locality.INTERNED + arg.data = cloudpickle.dumps(thing) + return arg diff --git a/python/ray/experimental/client/worker.py b/python/ray/experimental/client/worker.py index c69fd4185..2615b86e6 100644 --- a/python/ray/experimental/client/worker.py +++ b/python/ray/experimental/client/worker.py @@ -1,15 +1,10 @@ -import cloudpickle +from ray import cloudpickle import grpc import ray.core.generated.ray_client_pb2 as ray_client_pb2 import ray.core.generated.ray_client_pb2_grpc as ray_client_pb2_grpc - - -class ObjectID: - def __init__(self, id): - self.id = id - - def __repr__(self): - return "ObjectID(%s)" % self.id.hex() +from ray.experimental.client.common import convert_to_arg +from ray.experimental.client.common import ClientObjectRef +from ray.experimental.client.common import ClientRemoteFunc class Worker: @@ -25,13 +20,12 @@ class Worker: single = False if isinstance(ids, list): to_get = [x.id for x in ids] - elif isinstance(ids, ObjectID): + elif isinstance(ids, ClientObjectRef): to_get = [ids.id] single = True else: - raise Exception( - "Can't get something that's not a list of IDs or just an ID") - + raise Exception("Can't get something that's not a " + "list of IDs or just an ID: %s" % type(ids)) out = [self._get(x) for x in to_get] if single: out = out[0] @@ -40,6 +34,9 @@ class Worker: def _get(self, id: bytes): req = ray_client_pb2.GetRequest(id=id) data = self.server.GetObject(req) + if not data.valid: + raise Exception( + "Client GetObject returned invalid data: id invalid?") return cloudpickle.loads(data.data) def put(self, vals): @@ -60,29 +57,23 @@ class Worker: data = cloudpickle.dumps(val) req = ray_client_pb2.PutRequest(data=data) resp = self.server.PutObject(req) - return ObjectID(resp.id) + return ClientObjectRef(resp.id) def remote(self, func): - return RemoteFunc(self, func) + return ClientRemoteFunc(func) - def schedule(self, task): - return self.server.Schedule(task) + def call_remote(self, func, *args, **kwargs): + if not isinstance(func, ClientRemoteFunc): + raise TypeError("Client not passing a ClientRemoteFunc stub") + func_ref = self._put(func) + task = ray_client_pb2.ClientTask() + task.name = func._name + task.payload_id = func_ref.id + for arg in args: + pb_arg = convert_to_arg(arg) + task.args.append(pb_arg) + ticket = self.server.Schedule(task) + return ClientObjectRef(ticket.return_id) def close(self): self.channel.close() - - -class RemoteFunc: - def __init__(self, worker, f): - self._func = f - self._name = f.__name__ - self.id = None - - def __call__(self, *args, **kwargs): - raise Exception("Matching the old API") - - def remote(self, *args): - pass - - def __repr__(self): - return "RemoteFunc(%s, %s)" % (self._name, self.id) diff --git a/python/ray/tests/test_experimental_client.py b/python/ray/tests/test_experimental_client.py index 7877b57e1..c4e5fbd2d 100644 --- a/python/ray/tests/test_experimental_client.py +++ b/python/ray/tests/test_experimental_client.py @@ -16,6 +16,39 @@ def test_put_get(ray_start_regular_shared): server.stop(0) +def test_remote_functions(ray_start_regular_shared): + server = ray_client_server.serve("localhost:50051") + + @ray.remote + def plus2(x): + return x + 2 + + @ray.remote + def fact(x): + print(x, type(fact)) + if x <= 0: + return 1 + # This hits the "nested tasks" issue + # https://github.com/ray-project/ray/issues/3644 + # So we're on the right track! + return ray.get(fact.remote(x - 1)) * x + + ref2 = plus2.remote(234) + # `236` + assert ray.get(ref2) == 236 + + ref3 = fact.remote(20) + # `2432902008176640000` + assert ray.get(ref3) == 2_432_902_008_176_640_000 + + # Reuse the cached ClientRemoteFunc object + ref4 = fact.remote(5) + assert ray.get(ref4) == 120 + + ray.disconnect() + server.stop(0) + + if __name__ == "__main__": import sys sys.exit(pytest.main(["-v", __file__]))