mirror of
https://github.com/wassname/ray.git
synced 2026-06-27 22:53:20 +08:00
Support ray.* in remote functions for Ray client (#12177)
This commit is contained in:
@@ -30,7 +30,7 @@ class RayAPIStub:
|
||||
def __getattr__(self, key: str):
|
||||
global _client_api
|
||||
self.__check_client_api()
|
||||
return _client_api.__getattribute__(key)
|
||||
return getattr(_client_api, key)
|
||||
|
||||
def __check_client_api(self):
|
||||
global _client_api
|
||||
|
||||
@@ -60,3 +60,11 @@ class ClientAPI(APIImpl):
|
||||
|
||||
def close(self, *args, **kwargs):
|
||||
return self.worker.close()
|
||||
|
||||
def __getattr__(self, key: str):
|
||||
if not key.startswith("_"):
|
||||
raise NotImplementedError(
|
||||
"Not available in Ray client: `ray.{}`. This method is only "
|
||||
"available within Ray remote functions and is not yet "
|
||||
"implemented in the client API.".format(key))
|
||||
return self.__getattribute__(key)
|
||||
|
||||
@@ -19,6 +19,14 @@ def fact(x):
|
||||
return ray.get(fact.remote(x - 1)) * x
|
||||
|
||||
|
||||
@ray.remote
|
||||
def get_nodes():
|
||||
return ray.nodes() # Can access the full Ray API in remote methods.
|
||||
|
||||
|
||||
print("Cluster nodes", ray.get(get_nodes.remote()))
|
||||
print(ray.nodes())
|
||||
|
||||
objectref = ray.put("hello world")
|
||||
|
||||
# `ClientObjectRef(...)`
|
||||
|
||||
@@ -30,3 +30,9 @@ class CoreRayAPI(APIImpl):
|
||||
|
||||
def close(self, *args, **kwargs):
|
||||
return None
|
||||
|
||||
# Allow for generic fallback to ray.* in remote methods. This allows calls
|
||||
# like ray.nodes() to be run in remote functions even though the client
|
||||
# doesn't currently support them.
|
||||
def __getattr__(self, key: str):
|
||||
return getattr(ray, key)
|
||||
|
||||
@@ -4,9 +4,9 @@ to the server.
|
||||
"""
|
||||
from typing import List
|
||||
|
||||
import ray.cloudpickle as cloudpickle
|
||||
import grpc
|
||||
|
||||
from ray import cloudpickle
|
||||
import ray.core.generated.ray_client_pb2 as ray_client_pb2
|
||||
import ray.core.generated.ray_client_pb2_grpc as ray_client_pb2_grpc
|
||||
from ray.experimental.client.common import convert_to_arg
|
||||
|
||||
@@ -4,6 +4,47 @@ from ray.experimental.client import ray
|
||||
from ray.experimental.client.common import ClientObjectRef
|
||||
|
||||
|
||||
def test_real_ray_fallback(ray_start_regular_shared):
|
||||
server = ray_client_server.serve("localhost:50051")
|
||||
ray.connect("localhost:50051")
|
||||
|
||||
@ray.remote
|
||||
def get_nodes_real():
|
||||
import ray as real_ray
|
||||
return real_ray.nodes()
|
||||
|
||||
nodes = ray.get(get_nodes_real.remote())
|
||||
assert len(nodes) == 1, nodes
|
||||
|
||||
@ray.remote
|
||||
def get_nodes():
|
||||
return ray.nodes() # Can access the full Ray API in remote methods.
|
||||
|
||||
nodes = ray.get(get_nodes.remote())
|
||||
assert len(nodes) == 1, nodes
|
||||
|
||||
with pytest.raises(NotImplementedError):
|
||||
print(ray.nodes())
|
||||
|
||||
server.stop(0)
|
||||
|
||||
|
||||
def test_nested_function(ray_start_regular_shared):
|
||||
server = ray_client_server.serve("localhost:50051")
|
||||
ray.connect("localhost:50051")
|
||||
|
||||
@ray.remote
|
||||
def g():
|
||||
@ray.remote
|
||||
def f():
|
||||
return "OK"
|
||||
|
||||
return ray.get(f.remote())
|
||||
|
||||
assert ray.get(g.remote()) == "OK"
|
||||
server.stop(0)
|
||||
|
||||
|
||||
def test_put_get(ray_start_regular_shared):
|
||||
server = ray_client_server.serve("localhost:50051")
|
||||
ray.connect("localhost:50051")
|
||||
@@ -52,6 +93,7 @@ def test_wait(ray_start_regular_shared):
|
||||
|
||||
def test_remote_functions(ray_start_regular_shared):
|
||||
server = ray_client_server.serve("localhost:50051")
|
||||
ray.connect("localhost:50051")
|
||||
|
||||
@ray.remote
|
||||
def plus2(x):
|
||||
|
||||
Reference in New Issue
Block a user