[ray_client] Support calling functions from other functions and correct the tests (#12141)

* Add test mode and fix f calling g

* formatting

* remove unused functions

* fix tests -- which will be better in actor PR
This commit is contained in:
Barak Michener
2020-11-24 22:19:20 -08:00
committed by GitHub
parent 4dd0aa7822
commit 4066056a0d
6 changed files with 54 additions and 17 deletions
@@ -14,6 +14,18 @@ logger = logging.getLogger(__name__)
_client_api: Optional[APIImpl] = None
def stash_api() -> Optional[APIImpl]:
global _client_api
a = _client_api
_client_api = None
return a
def restore_api(api: Optional[APIImpl]):
global _client_api
_client_api = api
class RayAPIStub:
def connect(self, conn_str):
global _client_api
-5
View File
@@ -27,16 +27,11 @@ class ClientRemoteFunc:
"Use {self._name}.remote method instead")
def remote(self, *args, **kwargs):
if self._raylet_remote_func is not None:
return self._raylet_remote_func.remote(*args, **kwargs)
return ray.call_remote(self, *args, **kwargs)
def __repr__(self):
return "ClientRemoteFunc(%s, %s)" % (self._name, self.id)
def set_remote_func(self, func):
self._raylet_remote_func = func
def convert_from_arg(pb) -> Any:
if pb.local == ray_client_pb2.Arg.Locality.REFERENCE:
@@ -10,6 +10,7 @@
import ray
from ray.experimental.client.api import APIImpl
from ray.experimental.client.common import ClientRemoteFunc
class CoreRayAPI(APIImpl):
@@ -25,8 +26,10 @@ class CoreRayAPI(APIImpl):
def remote(self, *args, **kwargs):
return ray.remote(*args, **kwargs)
def call_remote(self, f, *args, **kwargs):
return f.remote(*args, **kwargs)
def call_remote(self, f: ClientRemoteFunc, *args, **kwargs):
if f._raylet_remote_func is None:
f._raylet_remote_func = ray.remote(f._func)
return f._raylet_remote_func.remote(*args, **kwargs)
def close(self, *args, **kwargs):
return None
@@ -6,6 +6,7 @@ import ray
import ray.core.generated.ray_client_pb2 as ray_client_pb2
import ray.core.generated.ray_client_pb2_grpc as ray_client_pb2_grpc
import time
import ray.experimental.client as client_init
from ray.experimental.client.common import convert_from_arg
from ray.experimental.client.common import ClientObjectRef
from ray.experimental.client.common import ClientRemoteFunc
@@ -14,9 +15,10 @@ logger = logging.getLogger(__name__)
class RayletServicer(ray_client_pb2_grpc.RayletDriverServicer):
def __init__(self):
def __init__(self, test_mode=False):
self.object_refs = {}
self.function_refs = {}
self._test_mode = test_mode
def GetObject(self, request, context=None):
if request.id not in self.object_refs:
@@ -73,12 +75,16 @@ class RayletServicer(ray_client_pb2_grpc.RayletDriverServicer):
if not isinstance(func, ClientRemoteFunc):
raise Exception("Attempting to schedule something that "
"isn't a ClientRemoteFunc")
ray_remote = ray.remote(func._func)
func.set_remote_func(ray_remote)
self.function_refs[task.payload_id] = func
remote_func = self.function_refs[task.payload_id]
arglist = _convert_args(task.args)
# Prepare call if we're in a test
api = None
if self._test_mode:
api = client_init.stash_api()
output = remote_func.remote(*arglist)
if self._test_mode:
client_init.restore_api(api)
self.object_refs[output.binary()] = output
return ray_client_pb2.ClientTaskTicket(return_id=output.binary())
@@ -94,9 +100,9 @@ def _convert_args(arg_list):
return out
def serve(connection_str):
def serve(connection_str, test_mode=False):
server = grpc.server(futures.ThreadPoolExecutor(max_workers=10))
task_servicer = RayletServicer()
task_servicer = RayletServicer(test_mode=test_mode)
ray_client_pb2_grpc.add_RayletDriverServicer_to_server(
task_servicer, server)
server.add_insecure_port(connection_str)
+26 -5
View File
@@ -5,7 +5,7 @@ from ray.experimental.client.common import ClientObjectRef
def test_real_ray_fallback(ray_start_regular_shared):
server = ray_client_server.serve("localhost:50051")
server = ray_client_server.serve("localhost:50051", test_mode=True)
ray.connect("localhost:50051")
@ray.remote
@@ -26,11 +26,12 @@ def test_real_ray_fallback(ray_start_regular_shared):
with pytest.raises(NotImplementedError):
print(ray.nodes())
ray.disconnect()
server.stop(0)
def test_nested_function(ray_start_regular_shared):
server = ray_client_server.serve("localhost:50051")
server = ray_client_server.serve("localhost:50051", test_mode=True)
ray.connect("localhost:50051")
@ray.remote
@@ -42,11 +43,12 @@ def test_nested_function(ray_start_regular_shared):
return ray.get(f.remote())
assert ray.get(g.remote()) == "OK"
ray.disconnect()
server.stop(0)
def test_put_get(ray_start_regular_shared):
server = ray_client_server.serve("localhost:50051")
server = ray_client_server.serve("localhost:50051", test_mode=True)
ray.connect("localhost:50051")
objectref = ray.put("hello world")
@@ -59,7 +61,7 @@ def test_put_get(ray_start_regular_shared):
def test_wait(ray_start_regular_shared):
server = ray_client_server.serve("localhost:50051")
server = ray_client_server.serve("localhost:50051", test_mode=True)
ray.connect("localhost:50051")
objectref = ray.put("hello world")
@@ -92,7 +94,7 @@ def test_wait(ray_start_regular_shared):
def test_remote_functions(ray_start_regular_shared):
server = ray_client_server.serve("localhost:50051")
server = ray_client_server.serve("localhost:50051", test_mode=True)
ray.connect("localhost:50051")
@ray.remote
@@ -138,6 +140,25 @@ def test_remote_functions(ray_start_regular_shared):
server.stop(0)
def test_function_calling_function(ray_start_regular_shared):
server = ray_client_server.serve("localhost:50051", test_mode=True)
ray.connect("localhost:50051")
@ray.remote
def g():
return "OK"
@ray.remote
def f():
print(f, f._name, g._name, g)
return ray.get(g.remote())
print(f, type(f))
assert ray.get(f.remote()) == "OK"
ray.disconnect()
server.stop(0)
if __name__ == "__main__":
import sys
sys.exit(pytest.main(["-v", __file__]))