[ray_client]: Implement function calls (#11922)

This commit is contained in:
Barak Michener
2020-11-12 16:49:34 -08:00
committed by GitHub
parent a6a8e777f3
commit 272edcca94
10 changed files with 378 additions and 87 deletions
+42 -26
View File
@@ -1,44 +1,60 @@
import ray
from ray.experimental.client.worker import Worker
from ray.experimental.client.api import ClientAPI
from ray.experimental.client.api import APIImpl
from typing import Optional
_client_worker: Optional[Worker] = None
_in_cluster: bool = True
import logging
logger = logging.getLogger(__name__)
_client_api: Optional[APIImpl] = None
def check_client_api():
global _client_api
if _client_api is None:
logger.info(
"No client API initialized: probably a worker, using core ray")
from ray.experimental.client.core_ray_api import set_client_api_as_ray
set_client_api_as_ray()
def get(*args, **kwargs):
global _client_worker
global _in_cluster
if _in_cluster:
return ray.get(*args, **kwargs)
else:
return _client_worker.get(*args, **kwargs)
global _client_api
check_client_api()
return _client_api.get(*args, **kwargs)
def put(*args, **kwargs):
global _client_worker
global _in_cluster
if _in_cluster:
return ray.put(*args, **kwargs)
else:
return _client_worker.put(*args, **kwargs)
global _client_api
check_client_api()
return _client_api.put(*args, **kwargs)
def remote(*args, **kwargs):
pass
global _client_api
check_client_api()
return _client_api.remote(*args, **kwargs)
def call_remote(f, *args, **kwargs):
global _client_api
check_client_api()
return _client_api.call_remote(f, *args, **kwargs)
def connect(conn_str):
global _in_cluster
global _client_worker
_in_cluster = False
global _client_api
from ray.experimental.client.worker import Worker
_client_worker = Worker(conn_str)
_client_api = ClientAPI(_client_worker)
def disconnect():
global _in_cluster
global _client_worker
if _client_worker is not None:
_client_worker.close()
_in_cluster = True
_client_worker = None
global _client_api
_client_api.close()
_client_api = None
def _set_client_api(api: Optional[APIImpl]):
global _client_api
_client_api = api
+55
View File
@@ -0,0 +1,55 @@
# This file defines an interface and client-side API stub
# for referring either to the core Ray API or the same interface
# from the Ray client.
#
# In tandem with __init__.py, we want to expose an API that's
# close to `python/ray/__init__.py` but with more than one implementation.
# The stubs in __init__ should call into a well-defined interface.
# Only the core Ray API implementation should actually `import ray`
# (and thus import all the raylet worker C bindings and such).
# But to make sure that we're matching these calls, we define this API.
from abc import ABC
from abc import abstractmethod
class APIImpl(ABC):
@abstractmethod
def get(self, *args, **kwargs):
pass
@abstractmethod
def put(self, *args, **kwargs):
pass
@abstractmethod
def remote(self, *args, **kwargs):
pass
@abstractmethod
def call_remote(self, f, *args, **kwargs):
pass
@abstractmethod
def close(self, *args, **kwargs):
pass
class ClientAPI(APIImpl):
def __init__(self, worker):
self.worker = worker
def get(self, *args, **kwargs):
return self.worker.get(*args, **kwargs)
def put(self, *args, **kwargs):
return self.worker.put(*args, **kwargs)
def remote(self, *args, **kwargs):
return self.worker.remote(*args, **kwargs)
def call_remote(self, f, *args, **kwargs):
return self.worker.call_remote(f, *args, **kwargs)
def close(self, *args, **kwargs):
return self.worker.close()
@@ -2,7 +2,44 @@ import ray.experimental.client as ray
ray.connect("localhost:50051")
@ray.remote
def plus2(x):
return x + 2
@ray.remote
def fact(x):
print(x, type(fact))
if x <= 0:
return 1
# This hits the "nested tasks" issue
# https://github.com/ray-project/ray/issues/3644
# So we're on the right track!
return ray.get(fact.remote(x - 1)) * x
objectref = ray.put("hello world")
# `ClientObjectRef(...)`
print(objectref)
# `hello world`
print(ray.get(objectref))
ref2 = plus2.remote(234)
# `ClientObjectRef(...)`
print(ref2)
# `236`
print(ray.get(ref2))
ref3 = fact.remote(20)
# `ClientObjectRef(...)`
print(ref3)
# `2432902008176640000`
print(ray.get(ref3))
# Reuse the cached ClientRemoteFunc object
ref4 = fact.remote(5)
# `120`
print(ray.get(ref4))
+55
View File
@@ -0,0 +1,55 @@
import ray.core.generated.ray_client_pb2 as ray_client_pb2
from ray.experimental.client import call_remote
from typing import Any
from ray import cloudpickle
class ClientObjectRef:
def __init__(self, id):
self.id = id
def __repr__(self):
return "ClientObjectRef(%s)" % self.id.hex()
class ClientRemoteFunc:
def __init__(self, f):
self._func = f
self._name = f.__name__
self.id = None
self._raylet_remote_func = None
def __call__(self, *args, **kwargs):
raise TypeError(f"Remote function cannot be called directly. "
"Use {self._name}.remote method instead")
def remote(self, *args, **kwargs):
if self._raylet_remote_func is not None:
return self._raylet_remote_func.remote(*args, **kwargs)
return call_remote(self, *args, **kwargs)
def __repr__(self):
return "ClientRemoteFunc(%s, %s)" % (self._name, self.id)
def set_remote_func(self, func):
self._raylet_remote_func = func
def convert_from_arg(pb) -> Any:
if pb.local == ray_client_pb2.Arg.Locality.REFERENCE:
return ClientObjectRef(pb.reference_id)
elif pb.local == ray_client_pb2.Arg.Locality.INTERNED:
return cloudpickle.loads(pb.data)
raise Exception("convert_from_arg: Uncovered locality enum")
def convert_to_arg(val):
out = ray_client_pb2.Arg()
if isinstance(val, ClientObjectRef):
out.local = ray_client_pb2.Arg.Locality.REFERENCE
out.reference_id = val.id
else:
out.local = ray_client_pb2.Arg.Locality.INTERNED
out.data = cloudpickle.dumps(val)
return out
@@ -0,0 +1,34 @@
# Along with `api.py` this is the stub that interfaces with
# the real (C-binding, raylet) ray core.
#
# Ideally, the first import line is the only time we actually
# import ray in this library (excluding the main function for the server)
#
# While the stub is trivial, it allows us to check that the calls we're
# making into the core-ray module are contained and well-defined.
import ray
from ray.experimental.client.api import APIImpl
class CoreRayAPI(APIImpl):
def get(self, *args, **kwargs):
return ray.get(*args, **kwargs)
def put(self, *args, **kwargs):
return ray.put(*args, **kwargs)
def remote(self, *args, **kwargs):
return ray.remote(*args, **kwargs)
def call_remote(self, f, *args, **kwargs):
return f.remote(*args, **kwargs)
def close(self, *args, **kwargs):
return None
def set_client_api_as_ray():
ray_api = CoreRayAPI()
ray.experimental.client._set_client_api(ray_api)
+47 -15
View File
@@ -6,32 +6,62 @@ import ray
import ray.core.generated.ray_client_pb2 as ray_client_pb2
import ray.core.generated.ray_client_pb2_grpc as ray_client_pb2_grpc
import time
from ray.experimental.client.core_ray_api import set_client_api_as_ray
from ray.experimental.client.common import convert_from_arg
from ray.experimental.client.common import ClientObjectRef
from ray.experimental.client.common import ClientRemoteFunc
logger = logging.getLogger(__name__)
class RayletServicer(ray_client_pb2_grpc.RayletDriverServicer):
def __init__(self):
self.realref = {}
self.object_refs = {}
self.function_refs = {}
def GetObject(self, request, context=None):
objectref = self.realref[request.id]
print("get: %s" % objectref)
item = ray.get(objectref)
if item is None:
if request.id not in self.object_refs:
return ray_client_pb2.GetResponse(valid=False)
data = cloudpickle.loads(item)
return ray_client_pb2.GetResponse(valid=True, data=data)
objectref = self.object_refs[request.id]
logger.info("get: %s" % objectref)
item = ray.get(objectref)
item_ser = cloudpickle.dumps(item)
return ray_client_pb2.GetResponse(valid=True, data=item_ser)
def PutObject(self, request, context=None):
data = cloudpickle.dumps(request.data)
objectref = ray.put(data)
self.realref[objectref.binary()] = objectref
print("put: %s" % objectref)
obj = cloudpickle.loads(request.data)
objectref = ray.put(obj)
self.object_refs[objectref.binary()] = objectref
logger.info("put: %s" % objectref)
return ray_client_pb2.PutResponse(id=objectref.binary())
def Schedule(self, task, context=None):
context.set_code(grpc.StatusCode.UNIMPLEMENTED)
context.set_details("Unimplemented")
return ray_client_pb2.TaskTicket()
logger.info("schedule: %s" % task)
if task.payload_id not in self.function_refs:
funcref = self.object_refs[task.payload_id]
func = ray.get(funcref)
if not isinstance(func, ClientRemoteFunc):
raise Exception("Attempting to schedule something that "
"isn't a ClientRemoteFunc")
ray_remote = ray.remote(func._func)
func.set_remote_func(ray_remote)
self.function_refs[task.payload_id] = func
remote_func = self.function_refs[task.payload_id]
arglist = _convert_args(task.args)
output = remote_func.remote(*arglist)
self.object_refs[output.binary()] = output
return ray_client_pb2.ClientTaskTicket(return_id=output.binary())
def _convert_args(arg_list):
out = []
for arg in arg_list:
t = convert_from_arg(arg)
if isinstance(t, ClientObjectRef):
out.append(ray.ObjectRef(t.id))
else:
out.append(t)
return out
def serve(connection_str):
@@ -45,8 +75,10 @@ def serve(connection_str):
if __name__ == "__main__":
logging.basicConfig()
logging.basicConfig(level="INFO")
# TODO(barakmich): Perhaps wrap ray init
ray.init()
set_client_api_as_ray()
server = serve("0.0.0.0:50051")
try:
while True:
+26
View File
@@ -0,0 +1,26 @@
from ray import cloudpickle
from ray.experimental.client.worker import ClientObjectRef
import ray
import ray.core.generated.ray_client_pb2 as ray_client_pb2
def dump_args_proto(arg):
if arg.local == ray_client_pb2.Arg.Locality.INTERNED:
return cloudpickle.loads(arg.data)
else:
# TODO(barakmich): This is a dirty hack that assumes the
# server maintains a reference to the ID we've been given
ref = ray.ObjectRef(arg.reference_id)
return ray.get(ref)
def load_args_proto(thing):
arg = ray_client_pb2.Arg()
if isinstance(thing, ClientObjectRef):
arg.local = ray_client_pb2.Arg.Locality.REFERENCE
arg.reference_id = thing.id
else:
arg.local = ray_client_pb2.Arg.Locality.INTERNED
arg.data = cloudpickle.dumps(thing)
return arg
+24 -33
View File
@@ -1,15 +1,10 @@
import cloudpickle
from ray import cloudpickle
import grpc
import ray.core.generated.ray_client_pb2 as ray_client_pb2
import ray.core.generated.ray_client_pb2_grpc as ray_client_pb2_grpc
class ObjectID:
def __init__(self, id):
self.id = id
def __repr__(self):
return "ObjectID(%s)" % self.id.hex()
from ray.experimental.client.common import convert_to_arg
from ray.experimental.client.common import ClientObjectRef
from ray.experimental.client.common import ClientRemoteFunc
class Worker:
@@ -25,13 +20,12 @@ class Worker:
single = False
if isinstance(ids, list):
to_get = [x.id for x in ids]
elif isinstance(ids, ObjectID):
elif isinstance(ids, ClientObjectRef):
to_get = [ids.id]
single = True
else:
raise Exception(
"Can't get something that's not a list of IDs or just an ID")
raise Exception("Can't get something that's not a "
"list of IDs or just an ID: %s" % type(ids))
out = [self._get(x) for x in to_get]
if single:
out = out[0]
@@ -40,6 +34,9 @@ class Worker:
def _get(self, id: bytes):
req = ray_client_pb2.GetRequest(id=id)
data = self.server.GetObject(req)
if not data.valid:
raise Exception(
"Client GetObject returned invalid data: id invalid?")
return cloudpickle.loads(data.data)
def put(self, vals):
@@ -60,29 +57,23 @@ class Worker:
data = cloudpickle.dumps(val)
req = ray_client_pb2.PutRequest(data=data)
resp = self.server.PutObject(req)
return ObjectID(resp.id)
return ClientObjectRef(resp.id)
def remote(self, func):
return RemoteFunc(self, func)
return ClientRemoteFunc(func)
def schedule(self, task):
return self.server.Schedule(task)
def call_remote(self, func, *args, **kwargs):
if not isinstance(func, ClientRemoteFunc):
raise TypeError("Client not passing a ClientRemoteFunc stub")
func_ref = self._put(func)
task = ray_client_pb2.ClientTask()
task.name = func._name
task.payload_id = func_ref.id
for arg in args:
pb_arg = convert_to_arg(arg)
task.args.append(pb_arg)
ticket = self.server.Schedule(task)
return ClientObjectRef(ticket.return_id)
def close(self):
self.channel.close()
class RemoteFunc:
def __init__(self, worker, f):
self._func = f
self._name = f.__name__
self.id = None
def __call__(self, *args, **kwargs):
raise Exception("Matching the old API")
def remote(self, *args):
pass
def __repr__(self):
return "RemoteFunc(%s, %s)" % (self._name, self.id)
@@ -16,6 +16,39 @@ def test_put_get(ray_start_regular_shared):
server.stop(0)
def test_remote_functions(ray_start_regular_shared):
server = ray_client_server.serve("localhost:50051")
@ray.remote
def plus2(x):
return x + 2
@ray.remote
def fact(x):
print(x, type(fact))
if x <= 0:
return 1
# This hits the "nested tasks" issue
# https://github.com/ray-project/ray/issues/3644
# So we're on the right track!
return ray.get(fact.remote(x - 1)) * x
ref2 = plus2.remote(234)
# `236`
assert ray.get(ref2) == 236
ref3 = fact.remote(20)
# `2432902008176640000`
assert ray.get(ref3) == 2_432_902_008_176_640_000
# Reuse the cached ClientRemoteFunc object
ref4 = fact.remote(5)
assert ray.get(ref4) == 120
ray.disconnect()
server.stop(0)
if __name__ == "__main__":
import sys
sys.exit(pytest.main(["-v", __file__]))