mirror of
https://github.com/wassname/ray.git
synced 2026-06-28 11:37:28 +08:00
[ray_client]: Implement function calls (#11922)
This commit is contained in:
@@ -1,44 +1,60 @@
|
||||
import ray
|
||||
from ray.experimental.client.worker import Worker
|
||||
from ray.experimental.client.api import ClientAPI
|
||||
from ray.experimental.client.api import APIImpl
|
||||
from typing import Optional
|
||||
|
||||
_client_worker: Optional[Worker] = None
|
||||
_in_cluster: bool = True
|
||||
import logging
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
_client_api: Optional[APIImpl] = None
|
||||
|
||||
|
||||
def check_client_api():
|
||||
global _client_api
|
||||
if _client_api is None:
|
||||
logger.info(
|
||||
"No client API initialized: probably a worker, using core ray")
|
||||
from ray.experimental.client.core_ray_api import set_client_api_as_ray
|
||||
set_client_api_as_ray()
|
||||
|
||||
|
||||
def get(*args, **kwargs):
|
||||
global _client_worker
|
||||
global _in_cluster
|
||||
if _in_cluster:
|
||||
return ray.get(*args, **kwargs)
|
||||
else:
|
||||
return _client_worker.get(*args, **kwargs)
|
||||
global _client_api
|
||||
check_client_api()
|
||||
return _client_api.get(*args, **kwargs)
|
||||
|
||||
|
||||
def put(*args, **kwargs):
|
||||
global _client_worker
|
||||
global _in_cluster
|
||||
if _in_cluster:
|
||||
return ray.put(*args, **kwargs)
|
||||
else:
|
||||
return _client_worker.put(*args, **kwargs)
|
||||
global _client_api
|
||||
check_client_api()
|
||||
return _client_api.put(*args, **kwargs)
|
||||
|
||||
|
||||
def remote(*args, **kwargs):
|
||||
pass
|
||||
global _client_api
|
||||
check_client_api()
|
||||
return _client_api.remote(*args, **kwargs)
|
||||
|
||||
|
||||
def call_remote(f, *args, **kwargs):
|
||||
global _client_api
|
||||
check_client_api()
|
||||
return _client_api.call_remote(f, *args, **kwargs)
|
||||
|
||||
|
||||
def connect(conn_str):
|
||||
global _in_cluster
|
||||
global _client_worker
|
||||
_in_cluster = False
|
||||
global _client_api
|
||||
from ray.experimental.client.worker import Worker
|
||||
_client_worker = Worker(conn_str)
|
||||
_client_api = ClientAPI(_client_worker)
|
||||
|
||||
|
||||
def disconnect():
|
||||
global _in_cluster
|
||||
global _client_worker
|
||||
if _client_worker is not None:
|
||||
_client_worker.close()
|
||||
_in_cluster = True
|
||||
_client_worker = None
|
||||
global _client_api
|
||||
_client_api.close()
|
||||
_client_api = None
|
||||
|
||||
|
||||
def _set_client_api(api: Optional[APIImpl]):
|
||||
global _client_api
|
||||
_client_api = api
|
||||
|
||||
@@ -0,0 +1,55 @@
|
||||
# This file defines an interface and client-side API stub
|
||||
# for referring either to the core Ray API or the same interface
|
||||
# from the Ray client.
|
||||
#
|
||||
# In tandem with __init__.py, we want to expose an API that's
|
||||
# close to `python/ray/__init__.py` but with more than one implementation.
|
||||
# The stubs in __init__ should call into a well-defined interface.
|
||||
# Only the core Ray API implementation should actually `import ray`
|
||||
# (and thus import all the raylet worker C bindings and such).
|
||||
# But to make sure that we're matching these calls, we define this API.
|
||||
|
||||
from abc import ABC
|
||||
from abc import abstractmethod
|
||||
|
||||
|
||||
class APIImpl(ABC):
|
||||
@abstractmethod
|
||||
def get(self, *args, **kwargs):
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
def put(self, *args, **kwargs):
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
def remote(self, *args, **kwargs):
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
def call_remote(self, f, *args, **kwargs):
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
def close(self, *args, **kwargs):
|
||||
pass
|
||||
|
||||
|
||||
class ClientAPI(APIImpl):
|
||||
def __init__(self, worker):
|
||||
self.worker = worker
|
||||
|
||||
def get(self, *args, **kwargs):
|
||||
return self.worker.get(*args, **kwargs)
|
||||
|
||||
def put(self, *args, **kwargs):
|
||||
return self.worker.put(*args, **kwargs)
|
||||
|
||||
def remote(self, *args, **kwargs):
|
||||
return self.worker.remote(*args, **kwargs)
|
||||
|
||||
def call_remote(self, f, *args, **kwargs):
|
||||
return self.worker.call_remote(f, *args, **kwargs)
|
||||
|
||||
def close(self, *args, **kwargs):
|
||||
return self.worker.close()
|
||||
@@ -2,7 +2,44 @@ import ray.experimental.client as ray
|
||||
|
||||
ray.connect("localhost:50051")
|
||||
|
||||
|
||||
@ray.remote
|
||||
def plus2(x):
|
||||
return x + 2
|
||||
|
||||
|
||||
@ray.remote
|
||||
def fact(x):
|
||||
print(x, type(fact))
|
||||
if x <= 0:
|
||||
return 1
|
||||
# This hits the "nested tasks" issue
|
||||
# https://github.com/ray-project/ray/issues/3644
|
||||
# So we're on the right track!
|
||||
return ray.get(fact.remote(x - 1)) * x
|
||||
|
||||
|
||||
objectref = ray.put("hello world")
|
||||
|
||||
# `ClientObjectRef(...)`
|
||||
print(objectref)
|
||||
|
||||
# `hello world`
|
||||
print(ray.get(objectref))
|
||||
|
||||
ref2 = plus2.remote(234)
|
||||
# `ClientObjectRef(...)`
|
||||
print(ref2)
|
||||
# `236`
|
||||
print(ray.get(ref2))
|
||||
|
||||
ref3 = fact.remote(20)
|
||||
# `ClientObjectRef(...)`
|
||||
print(ref3)
|
||||
# `2432902008176640000`
|
||||
print(ray.get(ref3))
|
||||
|
||||
# Reuse the cached ClientRemoteFunc object
|
||||
ref4 = fact.remote(5)
|
||||
# `120`
|
||||
print(ray.get(ref4))
|
||||
|
||||
@@ -0,0 +1,55 @@
|
||||
import ray.core.generated.ray_client_pb2 as ray_client_pb2
|
||||
from ray.experimental.client import call_remote
|
||||
from typing import Any
|
||||
from ray import cloudpickle
|
||||
|
||||
|
||||
class ClientObjectRef:
|
||||
def __init__(self, id):
|
||||
self.id = id
|
||||
|
||||
def __repr__(self):
|
||||
return "ClientObjectRef(%s)" % self.id.hex()
|
||||
|
||||
|
||||
class ClientRemoteFunc:
|
||||
def __init__(self, f):
|
||||
self._func = f
|
||||
self._name = f.__name__
|
||||
self.id = None
|
||||
self._raylet_remote_func = None
|
||||
|
||||
def __call__(self, *args, **kwargs):
|
||||
raise TypeError(f"Remote function cannot be called directly. "
|
||||
"Use {self._name}.remote method instead")
|
||||
|
||||
def remote(self, *args, **kwargs):
|
||||
if self._raylet_remote_func is not None:
|
||||
return self._raylet_remote_func.remote(*args, **kwargs)
|
||||
return call_remote(self, *args, **kwargs)
|
||||
|
||||
def __repr__(self):
|
||||
return "ClientRemoteFunc(%s, %s)" % (self._name, self.id)
|
||||
|
||||
def set_remote_func(self, func):
|
||||
self._raylet_remote_func = func
|
||||
|
||||
|
||||
def convert_from_arg(pb) -> Any:
|
||||
if pb.local == ray_client_pb2.Arg.Locality.REFERENCE:
|
||||
return ClientObjectRef(pb.reference_id)
|
||||
elif pb.local == ray_client_pb2.Arg.Locality.INTERNED:
|
||||
return cloudpickle.loads(pb.data)
|
||||
|
||||
raise Exception("convert_from_arg: Uncovered locality enum")
|
||||
|
||||
|
||||
def convert_to_arg(val):
|
||||
out = ray_client_pb2.Arg()
|
||||
if isinstance(val, ClientObjectRef):
|
||||
out.local = ray_client_pb2.Arg.Locality.REFERENCE
|
||||
out.reference_id = val.id
|
||||
else:
|
||||
out.local = ray_client_pb2.Arg.Locality.INTERNED
|
||||
out.data = cloudpickle.dumps(val)
|
||||
return out
|
||||
@@ -0,0 +1,34 @@
|
||||
# Along with `api.py` this is the stub that interfaces with
|
||||
# the real (C-binding, raylet) ray core.
|
||||
#
|
||||
# Ideally, the first import line is the only time we actually
|
||||
# import ray in this library (excluding the main function for the server)
|
||||
#
|
||||
# While the stub is trivial, it allows us to check that the calls we're
|
||||
# making into the core-ray module are contained and well-defined.
|
||||
|
||||
import ray
|
||||
|
||||
from ray.experimental.client.api import APIImpl
|
||||
|
||||
|
||||
class CoreRayAPI(APIImpl):
|
||||
def get(self, *args, **kwargs):
|
||||
return ray.get(*args, **kwargs)
|
||||
|
||||
def put(self, *args, **kwargs):
|
||||
return ray.put(*args, **kwargs)
|
||||
|
||||
def remote(self, *args, **kwargs):
|
||||
return ray.remote(*args, **kwargs)
|
||||
|
||||
def call_remote(self, f, *args, **kwargs):
|
||||
return f.remote(*args, **kwargs)
|
||||
|
||||
def close(self, *args, **kwargs):
|
||||
return None
|
||||
|
||||
|
||||
def set_client_api_as_ray():
|
||||
ray_api = CoreRayAPI()
|
||||
ray.experimental.client._set_client_api(ray_api)
|
||||
@@ -6,32 +6,62 @@ import ray
|
||||
import ray.core.generated.ray_client_pb2 as ray_client_pb2
|
||||
import ray.core.generated.ray_client_pb2_grpc as ray_client_pb2_grpc
|
||||
import time
|
||||
from ray.experimental.client.core_ray_api import set_client_api_as_ray
|
||||
from ray.experimental.client.common import convert_from_arg
|
||||
from ray.experimental.client.common import ClientObjectRef
|
||||
from ray.experimental.client.common import ClientRemoteFunc
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class RayletServicer(ray_client_pb2_grpc.RayletDriverServicer):
|
||||
def __init__(self):
|
||||
self.realref = {}
|
||||
self.object_refs = {}
|
||||
self.function_refs = {}
|
||||
|
||||
def GetObject(self, request, context=None):
|
||||
objectref = self.realref[request.id]
|
||||
print("get: %s" % objectref)
|
||||
item = ray.get(objectref)
|
||||
if item is None:
|
||||
if request.id not in self.object_refs:
|
||||
return ray_client_pb2.GetResponse(valid=False)
|
||||
data = cloudpickle.loads(item)
|
||||
return ray_client_pb2.GetResponse(valid=True, data=data)
|
||||
objectref = self.object_refs[request.id]
|
||||
logger.info("get: %s" % objectref)
|
||||
item = ray.get(objectref)
|
||||
item_ser = cloudpickle.dumps(item)
|
||||
return ray_client_pb2.GetResponse(valid=True, data=item_ser)
|
||||
|
||||
def PutObject(self, request, context=None):
|
||||
data = cloudpickle.dumps(request.data)
|
||||
objectref = ray.put(data)
|
||||
self.realref[objectref.binary()] = objectref
|
||||
print("put: %s" % objectref)
|
||||
obj = cloudpickle.loads(request.data)
|
||||
objectref = ray.put(obj)
|
||||
self.object_refs[objectref.binary()] = objectref
|
||||
logger.info("put: %s" % objectref)
|
||||
return ray_client_pb2.PutResponse(id=objectref.binary())
|
||||
|
||||
def Schedule(self, task, context=None):
|
||||
context.set_code(grpc.StatusCode.UNIMPLEMENTED)
|
||||
context.set_details("Unimplemented")
|
||||
return ray_client_pb2.TaskTicket()
|
||||
logger.info("schedule: %s" % task)
|
||||
if task.payload_id not in self.function_refs:
|
||||
funcref = self.object_refs[task.payload_id]
|
||||
func = ray.get(funcref)
|
||||
if not isinstance(func, ClientRemoteFunc):
|
||||
raise Exception("Attempting to schedule something that "
|
||||
"isn't a ClientRemoteFunc")
|
||||
ray_remote = ray.remote(func._func)
|
||||
func.set_remote_func(ray_remote)
|
||||
self.function_refs[task.payload_id] = func
|
||||
remote_func = self.function_refs[task.payload_id]
|
||||
arglist = _convert_args(task.args)
|
||||
output = remote_func.remote(*arglist)
|
||||
self.object_refs[output.binary()] = output
|
||||
return ray_client_pb2.ClientTaskTicket(return_id=output.binary())
|
||||
|
||||
|
||||
def _convert_args(arg_list):
|
||||
out = []
|
||||
for arg in arg_list:
|
||||
t = convert_from_arg(arg)
|
||||
if isinstance(t, ClientObjectRef):
|
||||
out.append(ray.ObjectRef(t.id))
|
||||
else:
|
||||
out.append(t)
|
||||
return out
|
||||
|
||||
|
||||
def serve(connection_str):
|
||||
@@ -45,8 +75,10 @@ def serve(connection_str):
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
logging.basicConfig()
|
||||
logging.basicConfig(level="INFO")
|
||||
# TODO(barakmich): Perhaps wrap ray init
|
||||
ray.init()
|
||||
set_client_api_as_ray()
|
||||
server = serve("0.0.0.0:50051")
|
||||
try:
|
||||
while True:
|
||||
|
||||
@@ -0,0 +1,26 @@
|
||||
from ray import cloudpickle
|
||||
from ray.experimental.client.worker import ClientObjectRef
|
||||
|
||||
import ray
|
||||
import ray.core.generated.ray_client_pb2 as ray_client_pb2
|
||||
|
||||
|
||||
def dump_args_proto(arg):
|
||||
if arg.local == ray_client_pb2.Arg.Locality.INTERNED:
|
||||
return cloudpickle.loads(arg.data)
|
||||
else:
|
||||
# TODO(barakmich): This is a dirty hack that assumes the
|
||||
# server maintains a reference to the ID we've been given
|
||||
ref = ray.ObjectRef(arg.reference_id)
|
||||
return ray.get(ref)
|
||||
|
||||
|
||||
def load_args_proto(thing):
|
||||
arg = ray_client_pb2.Arg()
|
||||
if isinstance(thing, ClientObjectRef):
|
||||
arg.local = ray_client_pb2.Arg.Locality.REFERENCE
|
||||
arg.reference_id = thing.id
|
||||
else:
|
||||
arg.local = ray_client_pb2.Arg.Locality.INTERNED
|
||||
arg.data = cloudpickle.dumps(thing)
|
||||
return arg
|
||||
@@ -1,15 +1,10 @@
|
||||
import cloudpickle
|
||||
from ray import cloudpickle
|
||||
import grpc
|
||||
import ray.core.generated.ray_client_pb2 as ray_client_pb2
|
||||
import ray.core.generated.ray_client_pb2_grpc as ray_client_pb2_grpc
|
||||
|
||||
|
||||
class ObjectID:
|
||||
def __init__(self, id):
|
||||
self.id = id
|
||||
|
||||
def __repr__(self):
|
||||
return "ObjectID(%s)" % self.id.hex()
|
||||
from ray.experimental.client.common import convert_to_arg
|
||||
from ray.experimental.client.common import ClientObjectRef
|
||||
from ray.experimental.client.common import ClientRemoteFunc
|
||||
|
||||
|
||||
class Worker:
|
||||
@@ -25,13 +20,12 @@ class Worker:
|
||||
single = False
|
||||
if isinstance(ids, list):
|
||||
to_get = [x.id for x in ids]
|
||||
elif isinstance(ids, ObjectID):
|
||||
elif isinstance(ids, ClientObjectRef):
|
||||
to_get = [ids.id]
|
||||
single = True
|
||||
else:
|
||||
raise Exception(
|
||||
"Can't get something that's not a list of IDs or just an ID")
|
||||
|
||||
raise Exception("Can't get something that's not a "
|
||||
"list of IDs or just an ID: %s" % type(ids))
|
||||
out = [self._get(x) for x in to_get]
|
||||
if single:
|
||||
out = out[0]
|
||||
@@ -40,6 +34,9 @@ class Worker:
|
||||
def _get(self, id: bytes):
|
||||
req = ray_client_pb2.GetRequest(id=id)
|
||||
data = self.server.GetObject(req)
|
||||
if not data.valid:
|
||||
raise Exception(
|
||||
"Client GetObject returned invalid data: id invalid?")
|
||||
return cloudpickle.loads(data.data)
|
||||
|
||||
def put(self, vals):
|
||||
@@ -60,29 +57,23 @@ class Worker:
|
||||
data = cloudpickle.dumps(val)
|
||||
req = ray_client_pb2.PutRequest(data=data)
|
||||
resp = self.server.PutObject(req)
|
||||
return ObjectID(resp.id)
|
||||
return ClientObjectRef(resp.id)
|
||||
|
||||
def remote(self, func):
|
||||
return RemoteFunc(self, func)
|
||||
return ClientRemoteFunc(func)
|
||||
|
||||
def schedule(self, task):
|
||||
return self.server.Schedule(task)
|
||||
def call_remote(self, func, *args, **kwargs):
|
||||
if not isinstance(func, ClientRemoteFunc):
|
||||
raise TypeError("Client not passing a ClientRemoteFunc stub")
|
||||
func_ref = self._put(func)
|
||||
task = ray_client_pb2.ClientTask()
|
||||
task.name = func._name
|
||||
task.payload_id = func_ref.id
|
||||
for arg in args:
|
||||
pb_arg = convert_to_arg(arg)
|
||||
task.args.append(pb_arg)
|
||||
ticket = self.server.Schedule(task)
|
||||
return ClientObjectRef(ticket.return_id)
|
||||
|
||||
def close(self):
|
||||
self.channel.close()
|
||||
|
||||
|
||||
class RemoteFunc:
|
||||
def __init__(self, worker, f):
|
||||
self._func = f
|
||||
self._name = f.__name__
|
||||
self.id = None
|
||||
|
||||
def __call__(self, *args, **kwargs):
|
||||
raise Exception("Matching the old API")
|
||||
|
||||
def remote(self, *args):
|
||||
pass
|
||||
|
||||
def __repr__(self):
|
||||
return "RemoteFunc(%s, %s)" % (self._name, self.id)
|
||||
|
||||
@@ -16,6 +16,39 @@ def test_put_get(ray_start_regular_shared):
|
||||
server.stop(0)
|
||||
|
||||
|
||||
def test_remote_functions(ray_start_regular_shared):
|
||||
server = ray_client_server.serve("localhost:50051")
|
||||
|
||||
@ray.remote
|
||||
def plus2(x):
|
||||
return x + 2
|
||||
|
||||
@ray.remote
|
||||
def fact(x):
|
||||
print(x, type(fact))
|
||||
if x <= 0:
|
||||
return 1
|
||||
# This hits the "nested tasks" issue
|
||||
# https://github.com/ray-project/ray/issues/3644
|
||||
# So we're on the right track!
|
||||
return ray.get(fact.remote(x - 1)) * x
|
||||
|
||||
ref2 = plus2.remote(234)
|
||||
# `236`
|
||||
assert ray.get(ref2) == 236
|
||||
|
||||
ref3 = fact.remote(20)
|
||||
# `2432902008176640000`
|
||||
assert ray.get(ref3) == 2_432_902_008_176_640_000
|
||||
|
||||
# Reuse the cached ClientRemoteFunc object
|
||||
ref4 = fact.remote(5)
|
||||
assert ray.get(ref4) == 120
|
||||
|
||||
ray.disconnect()
|
||||
server.stop(0)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
import sys
|
||||
sys.exit(pytest.main(["-v", __file__]))
|
||||
|
||||
Reference in New Issue
Block a user