Support ray.* in remote functions for Ray client (#12177)

This commit is contained in:
Eric Liang
2020-11-20 13:28:46 -08:00
committed by GitHub
parent 48042be8bb
commit 839517743d
6 changed files with 66 additions and 2 deletions
+1 -1
View File
@@ -30,7 +30,7 @@ class RayAPIStub:
def __getattr__(self, key: str):
global _client_api
self.__check_client_api()
return _client_api.__getattribute__(key)
return getattr(_client_api, key)
def __check_client_api(self):
global _client_api
+8
View File
@@ -60,3 +60,11 @@ class ClientAPI(APIImpl):
def close(self, *args, **kwargs):
return self.worker.close()
def __getattr__(self, key: str):
if not key.startswith("_"):
raise NotImplementedError(
"Not available in Ray client: `ray.{}`. This method is only "
"available within Ray remote functions and is not yet "
"implemented in the client API.".format(key))
return self.__getattribute__(key)
@@ -19,6 +19,14 @@ def fact(x):
return ray.get(fact.remote(x - 1)) * x
@ray.remote
def get_nodes():
return ray.nodes() # Can access the full Ray API in remote methods.
print("Cluster nodes", ray.get(get_nodes.remote()))
print(ray.nodes())
objectref = ray.put("hello world")
# `ClientObjectRef(...)`
@@ -30,3 +30,9 @@ class CoreRayAPI(APIImpl):
def close(self, *args, **kwargs):
return None
# Allow for generic fallback to ray.* in remote methods. This allows calls
# like ray.nodes() to be run in remote functions even though the client
# doesn't currently support them.
def __getattr__(self, key: str):
return getattr(ray, key)
+1 -1
View File
@@ -4,9 +4,9 @@ to the server.
"""
from typing import List
import ray.cloudpickle as cloudpickle
import grpc
from ray import cloudpickle
import ray.core.generated.ray_client_pb2 as ray_client_pb2
import ray.core.generated.ray_client_pb2_grpc as ray_client_pb2_grpc
from ray.experimental.client.common import convert_to_arg
@@ -4,6 +4,47 @@ from ray.experimental.client import ray
from ray.experimental.client.common import ClientObjectRef
def test_real_ray_fallback(ray_start_regular_shared):
server = ray_client_server.serve("localhost:50051")
ray.connect("localhost:50051")
@ray.remote
def get_nodes_real():
import ray as real_ray
return real_ray.nodes()
nodes = ray.get(get_nodes_real.remote())
assert len(nodes) == 1, nodes
@ray.remote
def get_nodes():
return ray.nodes() # Can access the full Ray API in remote methods.
nodes = ray.get(get_nodes.remote())
assert len(nodes) == 1, nodes
with pytest.raises(NotImplementedError):
print(ray.nodes())
server.stop(0)
def test_nested_function(ray_start_regular_shared):
server = ray_client_server.serve("localhost:50051")
ray.connect("localhost:50051")
@ray.remote
def g():
@ray.remote
def f():
return "OK"
return ray.get(f.remote())
assert ray.get(g.remote()) == "OK"
server.stop(0)
def test_put_get(ray_start_regular_shared):
server = ray_client_server.serve("localhost:50051")
ray.connect("localhost:50051")
@@ -52,6 +93,7 @@ def test_wait(ray_start_regular_shared):
def test_remote_functions(ray_start_regular_shared):
server = ray_client_server.serve("localhost:50051")
ray.connect("localhost:50051")
@ray.remote
def plus2(x):