diff --git a/python/ray/experimental/client/__init__.py b/python/ray/experimental/client/__init__.py index 6842c2d78..8b135f60e 100644 --- a/python/ray/experimental/client/__init__.py +++ b/python/ray/experimental/client/__init__.py @@ -14,6 +14,18 @@ logger = logging.getLogger(__name__) _client_api: Optional[APIImpl] = None +def stash_api() -> Optional[APIImpl]: + global _client_api + a = _client_api + _client_api = None + return a + + +def restore_api(api: Optional[APIImpl]): + global _client_api + _client_api = api + + class RayAPIStub: def connect(self, conn_str): global _client_api diff --git a/python/ray/experimental/client/common.py b/python/ray/experimental/client/common.py index eb721fbbd..d44df6413 100644 --- a/python/ray/experimental/client/common.py +++ b/python/ray/experimental/client/common.py @@ -27,16 +27,11 @@ class ClientRemoteFunc: "Use {self._name}.remote method instead") def remote(self, *args, **kwargs): - if self._raylet_remote_func is not None: - return self._raylet_remote_func.remote(*args, **kwargs) return ray.call_remote(self, *args, **kwargs) def __repr__(self): return "ClientRemoteFunc(%s, %s)" % (self._name, self.id) - def set_remote_func(self, func): - self._raylet_remote_func = func - def convert_from_arg(pb) -> Any: if pb.local == ray_client_pb2.Arg.Locality.REFERENCE: diff --git a/python/ray/experimental/client/server/__init__.py b/python/ray/experimental/client/server/__init__.py new file mode 100644 index 000000000..e69de29bb diff --git a/python/ray/experimental/client/server/core_ray_api.py b/python/ray/experimental/client/server/core_ray_api.py index 564a2ade3..32d3b8ccf 100644 --- a/python/ray/experimental/client/server/core_ray_api.py +++ b/python/ray/experimental/client/server/core_ray_api.py @@ -10,6 +10,7 @@ import ray from ray.experimental.client.api import APIImpl +from ray.experimental.client.common import ClientRemoteFunc class CoreRayAPI(APIImpl): @@ -25,8 +26,10 @@ class CoreRayAPI(APIImpl): def remote(self, *args, **kwargs): return ray.remote(*args, **kwargs) - def call_remote(self, f, *args, **kwargs): - return f.remote(*args, **kwargs) + def call_remote(self, f: ClientRemoteFunc, *args, **kwargs): + if f._raylet_remote_func is None: + f._raylet_remote_func = ray.remote(f._func) + return f._raylet_remote_func.remote(*args, **kwargs) def close(self, *args, **kwargs): return None diff --git a/python/ray/experimental/client/server/server.py b/python/ray/experimental/client/server/server.py index 89284ecf8..b0b68221e 100644 --- a/python/ray/experimental/client/server/server.py +++ b/python/ray/experimental/client/server/server.py @@ -6,6 +6,7 @@ import ray import ray.core.generated.ray_client_pb2 as ray_client_pb2 import ray.core.generated.ray_client_pb2_grpc as ray_client_pb2_grpc import time +import ray.experimental.client as client_init from ray.experimental.client.common import convert_from_arg from ray.experimental.client.common import ClientObjectRef from ray.experimental.client.common import ClientRemoteFunc @@ -14,9 +15,10 @@ logger = logging.getLogger(__name__) class RayletServicer(ray_client_pb2_grpc.RayletDriverServicer): - def __init__(self): + def __init__(self, test_mode=False): self.object_refs = {} self.function_refs = {} + self._test_mode = test_mode def GetObject(self, request, context=None): if request.id not in self.object_refs: @@ -73,12 +75,16 @@ class RayletServicer(ray_client_pb2_grpc.RayletDriverServicer): if not isinstance(func, ClientRemoteFunc): raise Exception("Attempting to schedule something that " "isn't a ClientRemoteFunc") - ray_remote = ray.remote(func._func) - func.set_remote_func(ray_remote) self.function_refs[task.payload_id] = func remote_func = self.function_refs[task.payload_id] arglist = _convert_args(task.args) + # Prepare call if we're in a test + api = None + if self._test_mode: + api = client_init.stash_api() output = remote_func.remote(*arglist) + if self._test_mode: + client_init.restore_api(api) self.object_refs[output.binary()] = output return ray_client_pb2.ClientTaskTicket(return_id=output.binary()) @@ -94,9 +100,9 @@ def _convert_args(arg_list): return out -def serve(connection_str): +def serve(connection_str, test_mode=False): server = grpc.server(futures.ThreadPoolExecutor(max_workers=10)) - task_servicer = RayletServicer() + task_servicer = RayletServicer(test_mode=test_mode) ray_client_pb2_grpc.add_RayletDriverServicer_to_server( task_servicer, server) server.add_insecure_port(connection_str) diff --git a/python/ray/tests/test_experimental_client.py b/python/ray/tests/test_experimental_client.py index 7067ffdbf..03e375eee 100644 --- a/python/ray/tests/test_experimental_client.py +++ b/python/ray/tests/test_experimental_client.py @@ -5,7 +5,7 @@ from ray.experimental.client.common import ClientObjectRef def test_real_ray_fallback(ray_start_regular_shared): - server = ray_client_server.serve("localhost:50051") + server = ray_client_server.serve("localhost:50051", test_mode=True) ray.connect("localhost:50051") @ray.remote @@ -26,11 +26,12 @@ def test_real_ray_fallback(ray_start_regular_shared): with pytest.raises(NotImplementedError): print(ray.nodes()) + ray.disconnect() server.stop(0) def test_nested_function(ray_start_regular_shared): - server = ray_client_server.serve("localhost:50051") + server = ray_client_server.serve("localhost:50051", test_mode=True) ray.connect("localhost:50051") @ray.remote @@ -42,11 +43,12 @@ def test_nested_function(ray_start_regular_shared): return ray.get(f.remote()) assert ray.get(g.remote()) == "OK" + ray.disconnect() server.stop(0) def test_put_get(ray_start_regular_shared): - server = ray_client_server.serve("localhost:50051") + server = ray_client_server.serve("localhost:50051", test_mode=True) ray.connect("localhost:50051") objectref = ray.put("hello world") @@ -59,7 +61,7 @@ def test_put_get(ray_start_regular_shared): def test_wait(ray_start_regular_shared): - server = ray_client_server.serve("localhost:50051") + server = ray_client_server.serve("localhost:50051", test_mode=True) ray.connect("localhost:50051") objectref = ray.put("hello world") @@ -92,7 +94,7 @@ def test_wait(ray_start_regular_shared): def test_remote_functions(ray_start_regular_shared): - server = ray_client_server.serve("localhost:50051") + server = ray_client_server.serve("localhost:50051", test_mode=True) ray.connect("localhost:50051") @ray.remote @@ -138,6 +140,25 @@ def test_remote_functions(ray_start_regular_shared): server.stop(0) +def test_function_calling_function(ray_start_regular_shared): + server = ray_client_server.serve("localhost:50051", test_mode=True) + ray.connect("localhost:50051") + + @ray.remote + def g(): + return "OK" + + @ray.remote + def f(): + print(f, f._name, g._name, g) + return ray.get(g.remote()) + + print(f, type(f)) + assert ray.get(f.remote()) == "OK" + ray.disconnect() + server.stop(0) + + if __name__ == "__main__": import sys sys.exit(pytest.main(["-v", __file__]))