From 839517743d49cfc9b49eac76b4ea4d623aca2f8d Mon Sep 17 00:00:00 2001 From: Eric Liang Date: Fri, 20 Nov 2020 13:28:46 -0800 Subject: [PATCH] Support ray.* in remote functions for Ray client (#12177) --- python/ray/experimental/client/__init__.py | 2 +- python/ray/experimental/client/api.py | 8 ++++ python/ray/experimental/client/client_app.py | 8 ++++ .../client/server/core_ray_api.py | 6 +++ python/ray/experimental/client/worker.py | 2 +- python/ray/tests/test_experimental_client.py | 42 +++++++++++++++++++ 6 files changed, 66 insertions(+), 2 deletions(-) diff --git a/python/ray/experimental/client/__init__.py b/python/ray/experimental/client/__init__.py index 76a2ea91f..6842c2d78 100644 --- a/python/ray/experimental/client/__init__.py +++ b/python/ray/experimental/client/__init__.py @@ -30,7 +30,7 @@ class RayAPIStub: def __getattr__(self, key: str): global _client_api self.__check_client_api() - return _client_api.__getattribute__(key) + return getattr(_client_api, key) def __check_client_api(self): global _client_api diff --git a/python/ray/experimental/client/api.py b/python/ray/experimental/client/api.py index 97cb6944f..a91111bde 100644 --- a/python/ray/experimental/client/api.py +++ b/python/ray/experimental/client/api.py @@ -60,3 +60,11 @@ class ClientAPI(APIImpl): def close(self, *args, **kwargs): return self.worker.close() + + def __getattr__(self, key: str): + if not key.startswith("_"): + raise NotImplementedError( + "Not available in Ray client: `ray.{}`. This method is only " + "available within Ray remote functions and is not yet " + "implemented in the client API.".format(key)) + return self.__getattribute__(key) diff --git a/python/ray/experimental/client/client_app.py b/python/ray/experimental/client/client_app.py index a87223d13..fe30e17ff 100644 --- a/python/ray/experimental/client/client_app.py +++ b/python/ray/experimental/client/client_app.py @@ -19,6 +19,14 @@ def fact(x): return ray.get(fact.remote(x - 1)) * x +@ray.remote +def get_nodes(): + return ray.nodes() # Can access the full Ray API in remote methods. + + +print("Cluster nodes", ray.get(get_nodes.remote())) +print(ray.nodes()) + objectref = ray.put("hello world") # `ClientObjectRef(...)` diff --git a/python/ray/experimental/client/server/core_ray_api.py b/python/ray/experimental/client/server/core_ray_api.py index 4a58af49d..564a2ade3 100644 --- a/python/ray/experimental/client/server/core_ray_api.py +++ b/python/ray/experimental/client/server/core_ray_api.py @@ -30,3 +30,9 @@ class CoreRayAPI(APIImpl): def close(self, *args, **kwargs): return None + + # Allow for generic fallback to ray.* in remote methods. This allows calls + # like ray.nodes() to be run in remote functions even though the client + # doesn't currently support them. + def __getattr__(self, key: str): + return getattr(ray, key) diff --git a/python/ray/experimental/client/worker.py b/python/ray/experimental/client/worker.py index e1542440e..f17959d3f 100644 --- a/python/ray/experimental/client/worker.py +++ b/python/ray/experimental/client/worker.py @@ -4,9 +4,9 @@ to the server. """ from typing import List +import ray.cloudpickle as cloudpickle import grpc -from ray import cloudpickle import ray.core.generated.ray_client_pb2 as ray_client_pb2 import ray.core.generated.ray_client_pb2_grpc as ray_client_pb2_grpc from ray.experimental.client.common import convert_to_arg diff --git a/python/ray/tests/test_experimental_client.py b/python/ray/tests/test_experimental_client.py index 647b97578..7067ffdbf 100644 --- a/python/ray/tests/test_experimental_client.py +++ b/python/ray/tests/test_experimental_client.py @@ -4,6 +4,47 @@ from ray.experimental.client import ray from ray.experimental.client.common import ClientObjectRef +def test_real_ray_fallback(ray_start_regular_shared): + server = ray_client_server.serve("localhost:50051") + ray.connect("localhost:50051") + + @ray.remote + def get_nodes_real(): + import ray as real_ray + return real_ray.nodes() + + nodes = ray.get(get_nodes_real.remote()) + assert len(nodes) == 1, nodes + + @ray.remote + def get_nodes(): + return ray.nodes() # Can access the full Ray API in remote methods. + + nodes = ray.get(get_nodes.remote()) + assert len(nodes) == 1, nodes + + with pytest.raises(NotImplementedError): + print(ray.nodes()) + + server.stop(0) + + +def test_nested_function(ray_start_regular_shared): + server = ray_client_server.serve("localhost:50051") + ray.connect("localhost:50051") + + @ray.remote + def g(): + @ray.remote + def f(): + return "OK" + + return ray.get(f.remote()) + + assert ray.get(g.remote()) == "OK" + server.stop(0) + + def test_put_get(ray_start_regular_shared): server = ray_client_server.serve("localhost:50051") ray.connect("localhost:50051") @@ -52,6 +93,7 @@ def test_wait(ray_start_regular_shared): def test_remote_functions(ray_start_regular_shared): server = ray_client_server.serve("localhost:50051") + ray.connect("localhost:50051") @ray.remote def plus2(x):