mirror of
https://github.com/wassname/ray.git
synced 2026-07-01 07:27:35 +08:00
Basic protos for ray client (#11762)
This commit is contained in:
@@ -0,0 +1,44 @@
|
||||
import ray
|
||||
from ray.experimental.client.worker import Worker
|
||||
from typing import Optional
|
||||
|
||||
_client_worker: Optional[Worker] = None
|
||||
_in_cluster: bool = True
|
||||
|
||||
|
||||
def get(*args, **kwargs):
|
||||
global _client_worker
|
||||
global _in_cluster
|
||||
if _in_cluster:
|
||||
return ray.get(*args, **kwargs)
|
||||
else:
|
||||
return _client_worker.get(*args, **kwargs)
|
||||
|
||||
|
||||
def put(*args, **kwargs):
|
||||
global _client_worker
|
||||
global _in_cluster
|
||||
if _in_cluster:
|
||||
return ray.put(*args, **kwargs)
|
||||
else:
|
||||
return _client_worker.put(*args, **kwargs)
|
||||
|
||||
|
||||
def remote(*args, **kwargs):
|
||||
pass
|
||||
|
||||
|
||||
def connect(conn_str):
|
||||
global _in_cluster
|
||||
global _client_worker
|
||||
_in_cluster = False
|
||||
_client_worker = Worker(conn_str)
|
||||
|
||||
|
||||
def disconnect():
|
||||
global _in_cluster
|
||||
global _client_worker
|
||||
if _client_worker is not None:
|
||||
_client_worker.close()
|
||||
_in_cluster = True
|
||||
_client_worker = None
|
||||
@@ -0,0 +1,8 @@
|
||||
import ray.experimental.client as ray
|
||||
|
||||
ray.connect("localhost:50051")
|
||||
|
||||
objectref = ray.put("hello world")
|
||||
print(objectref)
|
||||
|
||||
print(ray.get(objectref))
|
||||
@@ -0,0 +1,55 @@
|
||||
import logging
|
||||
from concurrent import futures
|
||||
import grpc
|
||||
from ray import cloudpickle
|
||||
import ray
|
||||
import ray.core.generated.ray_client_pb2 as ray_client_pb2
|
||||
import ray.core.generated.ray_client_pb2_grpc as ray_client_pb2_grpc
|
||||
import time
|
||||
|
||||
|
||||
class RayletServicer(ray_client_pb2_grpc.RayletDriverServicer):
|
||||
def __init__(self):
|
||||
self.realref = {}
|
||||
|
||||
def GetObject(self, request, context=None):
|
||||
objectref = self.realref[request.id]
|
||||
print("get: %s" % objectref)
|
||||
item = ray.get(objectref)
|
||||
if item is None:
|
||||
return ray_client_pb2.GetResponse(valid=False)
|
||||
data = cloudpickle.loads(item)
|
||||
return ray_client_pb2.GetResponse(valid=True, data=data)
|
||||
|
||||
def PutObject(self, request, context=None):
|
||||
data = cloudpickle.dumps(request.data)
|
||||
objectref = ray.put(data)
|
||||
self.realref[objectref.binary()] = objectref
|
||||
print("put: %s" % objectref)
|
||||
return ray_client_pb2.PutResponse(id=objectref.binary())
|
||||
|
||||
def Schedule(self, task, context=None):
|
||||
context.set_code(grpc.StatusCode.UNIMPLEMENTED)
|
||||
context.set_details("Unimplemented")
|
||||
return ray_client_pb2.TaskTicket()
|
||||
|
||||
|
||||
def serve(connection_str):
|
||||
server = grpc.server(futures.ThreadPoolExecutor(max_workers=10))
|
||||
task_servicer = RayletServicer()
|
||||
ray_client_pb2_grpc.add_RayletDriverServicer_to_server(
|
||||
task_servicer, server)
|
||||
server.add_insecure_port(connection_str)
|
||||
server.start()
|
||||
return server
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
logging.basicConfig()
|
||||
ray.init()
|
||||
server = serve("0.0.0.0:50051")
|
||||
try:
|
||||
while True:
|
||||
time.sleep(1000)
|
||||
except KeyboardInterrupt:
|
||||
server.stop(0)
|
||||
@@ -0,0 +1,88 @@
|
||||
import cloudpickle
|
||||
import grpc
|
||||
import ray.core.generated.ray_client_pb2 as ray_client_pb2
|
||||
import ray.core.generated.ray_client_pb2_grpc as ray_client_pb2_grpc
|
||||
|
||||
|
||||
class ObjectID:
|
||||
def __init__(self, id):
|
||||
self.id = id
|
||||
|
||||
def __repr__(self):
|
||||
return "ObjectID(%s)" % self.id.hex()
|
||||
|
||||
|
||||
class Worker:
|
||||
def __init__(self, conn_str="", stub=None):
|
||||
if stub is None:
|
||||
self.channel = grpc.insecure_channel(conn_str)
|
||||
self.server = ray_client_pb2_grpc.RayletDriverStub(self.channel)
|
||||
else:
|
||||
self.server = stub
|
||||
|
||||
def get(self, ids):
|
||||
to_get = []
|
||||
single = False
|
||||
if isinstance(ids, list):
|
||||
to_get = [x.id for x in ids]
|
||||
elif isinstance(ids, ObjectID):
|
||||
to_get = [ids.id]
|
||||
single = True
|
||||
else:
|
||||
raise Exception(
|
||||
"Can't get something that's not a list of IDs or just an ID")
|
||||
|
||||
out = [self._get(x) for x in to_get]
|
||||
if single:
|
||||
out = out[0]
|
||||
return out
|
||||
|
||||
def _get(self, id: bytes):
|
||||
req = ray_client_pb2.GetRequest(id=id)
|
||||
data = self.server.GetObject(req)
|
||||
return cloudpickle.loads(data.data)
|
||||
|
||||
def put(self, vals):
|
||||
to_put = []
|
||||
single = False
|
||||
if isinstance(vals, list):
|
||||
to_put = vals
|
||||
else:
|
||||
single = True
|
||||
to_put.append(vals)
|
||||
|
||||
out = [self._put(x) for x in to_put]
|
||||
if single:
|
||||
out = out[0]
|
||||
return out
|
||||
|
||||
def _put(self, val):
|
||||
data = cloudpickle.dumps(val)
|
||||
req = ray_client_pb2.PutRequest(data=data)
|
||||
resp = self.server.PutObject(req)
|
||||
return ObjectID(resp.id)
|
||||
|
||||
def remote(self, func):
|
||||
return RemoteFunc(self, func)
|
||||
|
||||
def schedule(self, task):
|
||||
return self.server.Schedule(task)
|
||||
|
||||
def close(self):
|
||||
self.channel.close()
|
||||
|
||||
|
||||
class RemoteFunc:
|
||||
def __init__(self, worker, f):
|
||||
self._func = f
|
||||
self._name = f.__name__
|
||||
self.id = None
|
||||
|
||||
def __call__(self, *args, **kwargs):
|
||||
raise Exception("Matching the old API")
|
||||
|
||||
def remote(self, *args):
|
||||
pass
|
||||
|
||||
def __repr__(self):
|
||||
return "RemoteFunc(%s, %s)" % (self._name, self.id)
|
||||
@@ -102,6 +102,7 @@ py_test_module_list(
|
||||
"test_dask_scheduler.py",
|
||||
"test_dask_callback.py",
|
||||
"test_debug_tools.py",
|
||||
"test_experimental_client.py",
|
||||
"test_job.py",
|
||||
"test_memstat.py",
|
||||
"test_metrics_agent.py",
|
||||
|
||||
@@ -0,0 +1,21 @@
|
||||
import pytest
|
||||
import ray.experimental.client.server as ray_client_server
|
||||
import ray.experimental.client as ray
|
||||
|
||||
|
||||
def test_put_get(ray_start_regular_shared):
|
||||
server = ray_client_server.serve("localhost:50051")
|
||||
ray.connect("localhost:50051")
|
||||
|
||||
objectref = ray.put("hello world")
|
||||
print(objectref)
|
||||
|
||||
retval = ray.get(objectref)
|
||||
assert retval == "hello world"
|
||||
ray.disconnect()
|
||||
server.stop(0)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
import sys
|
||||
sys.exit(pytest.main(["-v", __file__]))
|
||||
Reference in New Issue
Block a user