mirror of
https://github.com/wassname/ray.git
synced 2026-06-27 17:49:47 +08:00
[ray_client] Support calling functions from other functions and correct the tests (#12141)
* Add test mode and fix f calling g * formatting * remove unused functions * fix tests -- which will be better in actor PR
This commit is contained in:
@@ -14,6 +14,18 @@ logger = logging.getLogger(__name__)
|
||||
_client_api: Optional[APIImpl] = None
|
||||
|
||||
|
||||
def stash_api() -> Optional[APIImpl]:
|
||||
global _client_api
|
||||
a = _client_api
|
||||
_client_api = None
|
||||
return a
|
||||
|
||||
|
||||
def restore_api(api: Optional[APIImpl]):
|
||||
global _client_api
|
||||
_client_api = api
|
||||
|
||||
|
||||
class RayAPIStub:
|
||||
def connect(self, conn_str):
|
||||
global _client_api
|
||||
|
||||
@@ -27,16 +27,11 @@ class ClientRemoteFunc:
|
||||
"Use {self._name}.remote method instead")
|
||||
|
||||
def remote(self, *args, **kwargs):
|
||||
if self._raylet_remote_func is not None:
|
||||
return self._raylet_remote_func.remote(*args, **kwargs)
|
||||
return ray.call_remote(self, *args, **kwargs)
|
||||
|
||||
def __repr__(self):
|
||||
return "ClientRemoteFunc(%s, %s)" % (self._name, self.id)
|
||||
|
||||
def set_remote_func(self, func):
|
||||
self._raylet_remote_func = func
|
||||
|
||||
|
||||
def convert_from_arg(pb) -> Any:
|
||||
if pb.local == ray_client_pb2.Arg.Locality.REFERENCE:
|
||||
|
||||
@@ -10,6 +10,7 @@
|
||||
import ray
|
||||
|
||||
from ray.experimental.client.api import APIImpl
|
||||
from ray.experimental.client.common import ClientRemoteFunc
|
||||
|
||||
|
||||
class CoreRayAPI(APIImpl):
|
||||
@@ -25,8 +26,10 @@ class CoreRayAPI(APIImpl):
|
||||
def remote(self, *args, **kwargs):
|
||||
return ray.remote(*args, **kwargs)
|
||||
|
||||
def call_remote(self, f, *args, **kwargs):
|
||||
return f.remote(*args, **kwargs)
|
||||
def call_remote(self, f: ClientRemoteFunc, *args, **kwargs):
|
||||
if f._raylet_remote_func is None:
|
||||
f._raylet_remote_func = ray.remote(f._func)
|
||||
return f._raylet_remote_func.remote(*args, **kwargs)
|
||||
|
||||
def close(self, *args, **kwargs):
|
||||
return None
|
||||
|
||||
@@ -6,6 +6,7 @@ import ray
|
||||
import ray.core.generated.ray_client_pb2 as ray_client_pb2
|
||||
import ray.core.generated.ray_client_pb2_grpc as ray_client_pb2_grpc
|
||||
import time
|
||||
import ray.experimental.client as client_init
|
||||
from ray.experimental.client.common import convert_from_arg
|
||||
from ray.experimental.client.common import ClientObjectRef
|
||||
from ray.experimental.client.common import ClientRemoteFunc
|
||||
@@ -14,9 +15,10 @@ logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class RayletServicer(ray_client_pb2_grpc.RayletDriverServicer):
|
||||
def __init__(self):
|
||||
def __init__(self, test_mode=False):
|
||||
self.object_refs = {}
|
||||
self.function_refs = {}
|
||||
self._test_mode = test_mode
|
||||
|
||||
def GetObject(self, request, context=None):
|
||||
if request.id not in self.object_refs:
|
||||
@@ -73,12 +75,16 @@ class RayletServicer(ray_client_pb2_grpc.RayletDriverServicer):
|
||||
if not isinstance(func, ClientRemoteFunc):
|
||||
raise Exception("Attempting to schedule something that "
|
||||
"isn't a ClientRemoteFunc")
|
||||
ray_remote = ray.remote(func._func)
|
||||
func.set_remote_func(ray_remote)
|
||||
self.function_refs[task.payload_id] = func
|
||||
remote_func = self.function_refs[task.payload_id]
|
||||
arglist = _convert_args(task.args)
|
||||
# Prepare call if we're in a test
|
||||
api = None
|
||||
if self._test_mode:
|
||||
api = client_init.stash_api()
|
||||
output = remote_func.remote(*arglist)
|
||||
if self._test_mode:
|
||||
client_init.restore_api(api)
|
||||
self.object_refs[output.binary()] = output
|
||||
return ray_client_pb2.ClientTaskTicket(return_id=output.binary())
|
||||
|
||||
@@ -94,9 +100,9 @@ def _convert_args(arg_list):
|
||||
return out
|
||||
|
||||
|
||||
def serve(connection_str):
|
||||
def serve(connection_str, test_mode=False):
|
||||
server = grpc.server(futures.ThreadPoolExecutor(max_workers=10))
|
||||
task_servicer = RayletServicer()
|
||||
task_servicer = RayletServicer(test_mode=test_mode)
|
||||
ray_client_pb2_grpc.add_RayletDriverServicer_to_server(
|
||||
task_servicer, server)
|
||||
server.add_insecure_port(connection_str)
|
||||
|
||||
@@ -5,7 +5,7 @@ from ray.experimental.client.common import ClientObjectRef
|
||||
|
||||
|
||||
def test_real_ray_fallback(ray_start_regular_shared):
|
||||
server = ray_client_server.serve("localhost:50051")
|
||||
server = ray_client_server.serve("localhost:50051", test_mode=True)
|
||||
ray.connect("localhost:50051")
|
||||
|
||||
@ray.remote
|
||||
@@ -26,11 +26,12 @@ def test_real_ray_fallback(ray_start_regular_shared):
|
||||
with pytest.raises(NotImplementedError):
|
||||
print(ray.nodes())
|
||||
|
||||
ray.disconnect()
|
||||
server.stop(0)
|
||||
|
||||
|
||||
def test_nested_function(ray_start_regular_shared):
|
||||
server = ray_client_server.serve("localhost:50051")
|
||||
server = ray_client_server.serve("localhost:50051", test_mode=True)
|
||||
ray.connect("localhost:50051")
|
||||
|
||||
@ray.remote
|
||||
@@ -42,11 +43,12 @@ def test_nested_function(ray_start_regular_shared):
|
||||
return ray.get(f.remote())
|
||||
|
||||
assert ray.get(g.remote()) == "OK"
|
||||
ray.disconnect()
|
||||
server.stop(0)
|
||||
|
||||
|
||||
def test_put_get(ray_start_regular_shared):
|
||||
server = ray_client_server.serve("localhost:50051")
|
||||
server = ray_client_server.serve("localhost:50051", test_mode=True)
|
||||
ray.connect("localhost:50051")
|
||||
|
||||
objectref = ray.put("hello world")
|
||||
@@ -59,7 +61,7 @@ def test_put_get(ray_start_regular_shared):
|
||||
|
||||
|
||||
def test_wait(ray_start_regular_shared):
|
||||
server = ray_client_server.serve("localhost:50051")
|
||||
server = ray_client_server.serve("localhost:50051", test_mode=True)
|
||||
ray.connect("localhost:50051")
|
||||
|
||||
objectref = ray.put("hello world")
|
||||
@@ -92,7 +94,7 @@ def test_wait(ray_start_regular_shared):
|
||||
|
||||
|
||||
def test_remote_functions(ray_start_regular_shared):
|
||||
server = ray_client_server.serve("localhost:50051")
|
||||
server = ray_client_server.serve("localhost:50051", test_mode=True)
|
||||
ray.connect("localhost:50051")
|
||||
|
||||
@ray.remote
|
||||
@@ -138,6 +140,25 @@ def test_remote_functions(ray_start_regular_shared):
|
||||
server.stop(0)
|
||||
|
||||
|
||||
def test_function_calling_function(ray_start_regular_shared):
|
||||
server = ray_client_server.serve("localhost:50051", test_mode=True)
|
||||
ray.connect("localhost:50051")
|
||||
|
||||
@ray.remote
|
||||
def g():
|
||||
return "OK"
|
||||
|
||||
@ray.remote
|
||||
def f():
|
||||
print(f, f._name, g._name, g)
|
||||
return ray.get(g.remote())
|
||||
|
||||
print(f, type(f))
|
||||
assert ray.get(f.remote()) == "OK"
|
||||
ray.disconnect()
|
||||
server.stop(0)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
import sys
|
||||
sys.exit(pytest.main(["-v", __file__]))
|
||||
|
||||
Reference in New Issue
Block a user