diff --git a/python/ray/experimental/client/__init__.py b/python/ray/experimental/client/__init__.py index b849b3424..738c5356b 100644 --- a/python/ray/experimental/client/__init__.py +++ b/python/ray/experimental/client/__init__.py @@ -30,6 +30,12 @@ def put(*args, **kwargs): return _client_api.put(*args, **kwargs) +def wait(*args, **kwargs): + global _client_api + check_client_api() + return _client_api.wait(*args, **kwargs) + + def remote(*args, **kwargs): global _client_api check_client_api() diff --git a/python/ray/experimental/client/api.py b/python/ray/experimental/client/api.py index 14f3705fc..97cb6944f 100644 --- a/python/ray/experimental/client/api.py +++ b/python/ray/experimental/client/api.py @@ -22,6 +22,10 @@ class APIImpl(ABC): def put(self, *args, **kwargs): pass + @abstractmethod + def wait(self, *args, **kwargs): + pass + @abstractmethod def remote(self, *args, **kwargs): pass @@ -45,6 +49,9 @@ class ClientAPI(APIImpl): def put(self, *args, **kwargs): return self.worker.put(*args, **kwargs) + def wait(self, *args, **kwargs): + return self.worker.wait(*args, **kwargs) + def remote(self, *args, **kwargs): return self.worker.remote(*args, **kwargs) diff --git a/python/ray/experimental/client/client_app.py b/python/ray/experimental/client/client_app.py index b0f6e6e21..f6eedeee3 100644 --- a/python/ray/experimental/client/client_app.py +++ b/python/ray/experimental/client/client_app.py @@ -43,3 +43,18 @@ print(ray.get(ref3)) ref4 = fact.remote(5) # `120` print(ray.get(ref4)) + +ref5 = fact.remote(10) + +print([ref2, ref3, ref4, ref5]) +# should return ref2, ref3, ref4 +res = ray.wait([ref5, ref2, ref3, ref4], num_returns=3) +print(res) +assert [ref2, ref3, ref4] == res[0] +assert [ref5] == res[1] + +# should return ref2, ref3, ref4, ref5 +res = ray.wait([ref2, ref3, ref4, ref5], num_returns=4) +print(res) +assert [ref2, ref3, ref4, ref5] == res[0] +assert [] == res[1] diff --git a/python/ray/experimental/client/common.py b/python/ray/experimental/client/common.py index 044898fb2..17456a4be 100644 --- a/python/ray/experimental/client/common.py +++ b/python/ray/experimental/client/common.py @@ -11,6 +11,9 @@ class ClientObjectRef: def __repr__(self): return "ClientObjectRef(%s)" % self.id.hex() + def __eq__(self, other): + return self.id == other.id + class ClientRemoteFunc: def __init__(self, f): diff --git a/python/ray/experimental/client/core_ray_api.py b/python/ray/experimental/client/core_ray_api.py index 52aa009db..2738e7a4c 100644 --- a/python/ray/experimental/client/core_ray_api.py +++ b/python/ray/experimental/client/core_ray_api.py @@ -19,6 +19,9 @@ class CoreRayAPI(APIImpl): def put(self, *args, **kwargs): return ray.put(*args, **kwargs) + def wait(self, *args, **kwargs): + return ray.wait(*args, **kwargs) + def remote(self, *args, **kwargs): return ray.remote(*args, **kwargs) diff --git a/python/ray/experimental/client/server.py b/python/ray/experimental/client/server.py index 489794e4a..f70e86e3c 100644 --- a/python/ray/experimental/client/server.py +++ b/python/ray/experimental/client/server.py @@ -35,6 +35,37 @@ class RayletServicer(ray_client_pb2_grpc.RayletDriverServicer): logger.info("put: %s" % objectref) return ray_client_pb2.PutResponse(id=objectref.binary()) + def WaitObject(self, request, context=None) -> ray_client_pb2.WaitResponse: + object_refs = [cloudpickle.loads(o) for o in request.object_refs] + num_returns = request.num_returns + timeout = request.timeout + object_refs_ids = [] + for object_ref in object_refs: + if object_ref.id not in self.object_refs: + return ray_client_pb2.WaitResponse(valid=False) + object_refs_ids.append(self.object_refs[object_ref.id]) + try: + ready_object_refs, remaining_object_refs = ray.wait( + object_refs_ids, + num_returns=num_returns, + timeout=timeout if timeout != -1 else None) + except Exception: + # TODO(ameer): improve exception messages. + return ray_client_pb2.WaitResponse(valid=False) + logger.info("wait: %s %s" % (str(ready_object_refs), + str(remaining_object_refs))) + ready_object_ids = [ + ready_object_ref.binary() for ready_object_ref in ready_object_refs + ] + remaining_object_ids = [ + remaining_object_ref.binary() + for remaining_object_ref in remaining_object_refs + ] + return ray_client_pb2.WaitResponse( + valid=True, + ready_object_ids=ready_object_ids, + remaining_object_ids=remaining_object_ids) + def Schedule(self, task, context=None): logger.info("schedule: %s" % task) if task.payload_id not in self.function_refs: diff --git a/python/ray/experimental/client/worker.py b/python/ray/experimental/client/worker.py index 2615b86e6..e1542440e 100644 --- a/python/ray/experimental/client/worker.py +++ b/python/ray/experimental/client/worker.py @@ -1,5 +1,12 @@ -from ray import cloudpickle +"""This file includes the Worker class which sits on the client side. +It implements the Ray API functions that are forwarded through grpc calls +to the server. +""" +from typing import List + import grpc + +from ray import cloudpickle import ray.core.generated.ray_client_pb2 as ray_client_pb2 import ray.core.generated.ray_client_pb2_grpc as ray_client_pb2_grpc from ray.experimental.client.common import convert_to_arg @@ -59,6 +66,36 @@ class Worker: resp = self.server.PutObject(req) return ClientObjectRef(resp.id) + def wait(self, + object_refs: List[ClientObjectRef], + *, + num_returns: int = 1, + timeout: float = None + ) -> (List[ClientObjectRef], List[ClientObjectRef]): + assert isinstance(object_refs, list) + for ref in object_refs: + assert isinstance(ref, ClientObjectRef) + data = { + "object_refs": [ + cloudpickle.dumps(object_ref) for object_ref in object_refs + ], + "num_returns": num_returns, + "timeout": timeout if timeout else -1 + } + req = ray_client_pb2.WaitRequest(**data) + resp = self.server.WaitObject(req) + if not resp.valid: + # TODO(ameer): improve error/exceptions messages. + raise Exception("Client Wait request failed. Reference invalid?") + client_ready_object_ids = [ + ClientObjectRef(id) for id in resp.ready_object_ids + ] + client_remaining_object_ids = [ + ClientObjectRef(id) for id in resp.remaining_object_ids + ] + + return (client_ready_object_ids, client_remaining_object_ids) + def remote(self, func): return ClientRemoteFunc(func) diff --git a/python/ray/tests/test_experimental_client.py b/python/ray/tests/test_experimental_client.py index c4e5fbd2d..1a2ac0b60 100644 --- a/python/ray/tests/test_experimental_client.py +++ b/python/ray/tests/test_experimental_client.py @@ -1,6 +1,7 @@ import pytest import ray.experimental.client.server as ray_client_server import ray.experimental.client as ray +from ray.experimental.client.common import ClientObjectRef def test_put_get(ray_start_regular_shared): @@ -16,6 +17,39 @@ def test_put_get(ray_start_regular_shared): server.stop(0) +def test_wait(ray_start_regular_shared): + server = ray_client_server.serve("localhost:50051") + ray.connect("localhost:50051") + + objectref = ray.put("hello world") + ready, remaining = ray.wait([objectref]) + assert remaining == [] + retval = ray.get(ready[0]) + assert retval == "hello world" + + objectref2 = ray.put(5) + ready, remaining = ray.wait([objectref, objectref2]) + assert (ready, remaining) == ([objectref], [objectref2]) or \ + (ready, remaining) == ([objectref2], [objectref]) + ready_retval = ray.get(ready[0]) + remaining_retval = ray.get(remaining[0]) + assert (ready_retval, remaining_retval) == ("hello world", 5) \ + or (ready_retval, remaining_retval) == (5, "hello world") + + with pytest.raises(Exception): + # Reference not in the object store. + ray.wait([ClientObjectRef("blabla")]) + with pytest.raises(AssertionError): + ray.wait("blabla") + with pytest.raises(AssertionError): + ray.wait(ClientObjectRef("blabla")) + with pytest.raises(AssertionError): + ray.wait(["blabla"]) + + ray.disconnect() + server.stop(0) + + def test_remote_functions(ray_start_regular_shared): server = ray_client_server.serve("localhost:50051") @@ -45,6 +79,19 @@ def test_remote_functions(ray_start_regular_shared): ref4 = fact.remote(5) assert ray.get(ref4) == 120 + # Test ray.wait() + ref5 = fact.remote(10) + # should return ref2, ref3, ref4 + res = ray.wait([ref5, ref2, ref3, ref4], num_returns=3) + assert [ref2, ref3, ref4] == res[0] + assert [ref5] == res[1] + assert ray.get(res[0]) == [236, 2_432_902_008_176_640_000, 120] + # should return ref2, ref3, ref4, ref5 + res = ray.wait([ref2, ref3, ref4, ref5], num_returns=4) + assert [ref2, ref3, ref4, ref5] == res[0] + assert [] == res[1] + assert ray.get(res[0]) == [236, 2_432_902_008_176_640_000, 120, 3628800] + ray.disconnect() server.stop(0) diff --git a/src/ray/protobuf/ray_client.proto b/src/ray/protobuf/ray_client.proto index dc04a5927..c4fd65555 100644 --- a/src/ray/protobuf/ray_client.proto +++ b/src/ray/protobuf/ray_client.proto @@ -56,12 +56,25 @@ message GetResponse { bool valid = 1; bytes data = 2; } +message WaitRequest { + repeated bytes object_refs = 1; + int64 num_returns = 2; + double timeout = 3; +} + +message WaitResponse { + bool valid = 1; + repeated bytes ready_object_ids = 2; + repeated bytes remaining_object_ids = 3; +} service RayletDriver { rpc GetObject(GetRequest) returns (GetResponse) { } rpc PutObject(PutRequest) returns (PutResponse) { } + rpc WaitObject(WaitRequest) returns (WaitResponse) { + } rpc Schedule(ClientTask) returns (ClientTaskTicket) { } }