Basic protos for ray client (#11762)

This commit is contained in:
Barak Michener
2020-11-05 16:23:54 -08:00
committed by GitHub
parent f86c4f992c
commit 27c810a97e
9 changed files with 297 additions and 0 deletions
@@ -0,0 +1,44 @@
import ray
from ray.experimental.client.worker import Worker
from typing import Optional
_client_worker: Optional[Worker] = None
_in_cluster: bool = True
def get(*args, **kwargs):
global _client_worker
global _in_cluster
if _in_cluster:
return ray.get(*args, **kwargs)
else:
return _client_worker.get(*args, **kwargs)
def put(*args, **kwargs):
global _client_worker
global _in_cluster
if _in_cluster:
return ray.put(*args, **kwargs)
else:
return _client_worker.put(*args, **kwargs)
def remote(*args, **kwargs):
pass
def connect(conn_str):
global _in_cluster
global _client_worker
_in_cluster = False
_client_worker = Worker(conn_str)
def disconnect():
global _in_cluster
global _client_worker
if _client_worker is not None:
_client_worker.close()
_in_cluster = True
_client_worker = None
@@ -0,0 +1,8 @@
import ray.experimental.client as ray
ray.connect("localhost:50051")
objectref = ray.put("hello world")
print(objectref)
print(ray.get(objectref))
+55
View File
@@ -0,0 +1,55 @@
import logging
from concurrent import futures
import grpc
from ray import cloudpickle
import ray
import ray.core.generated.ray_client_pb2 as ray_client_pb2
import ray.core.generated.ray_client_pb2_grpc as ray_client_pb2_grpc
import time
class RayletServicer(ray_client_pb2_grpc.RayletDriverServicer):
def __init__(self):
self.realref = {}
def GetObject(self, request, context=None):
objectref = self.realref[request.id]
print("get: %s" % objectref)
item = ray.get(objectref)
if item is None:
return ray_client_pb2.GetResponse(valid=False)
data = cloudpickle.loads(item)
return ray_client_pb2.GetResponse(valid=True, data=data)
def PutObject(self, request, context=None):
data = cloudpickle.dumps(request.data)
objectref = ray.put(data)
self.realref[objectref.binary()] = objectref
print("put: %s" % objectref)
return ray_client_pb2.PutResponse(id=objectref.binary())
def Schedule(self, task, context=None):
context.set_code(grpc.StatusCode.UNIMPLEMENTED)
context.set_details("Unimplemented")
return ray_client_pb2.TaskTicket()
def serve(connection_str):
server = grpc.server(futures.ThreadPoolExecutor(max_workers=10))
task_servicer = RayletServicer()
ray_client_pb2_grpc.add_RayletDriverServicer_to_server(
task_servicer, server)
server.add_insecure_port(connection_str)
server.start()
return server
if __name__ == "__main__":
logging.basicConfig()
ray.init()
server = serve("0.0.0.0:50051")
try:
while True:
time.sleep(1000)
except KeyboardInterrupt:
server.stop(0)
+88
View File
@@ -0,0 +1,88 @@
import cloudpickle
import grpc
import ray.core.generated.ray_client_pb2 as ray_client_pb2
import ray.core.generated.ray_client_pb2_grpc as ray_client_pb2_grpc
class ObjectID:
def __init__(self, id):
self.id = id
def __repr__(self):
return "ObjectID(%s)" % self.id.hex()
class Worker:
def __init__(self, conn_str="", stub=None):
if stub is None:
self.channel = grpc.insecure_channel(conn_str)
self.server = ray_client_pb2_grpc.RayletDriverStub(self.channel)
else:
self.server = stub
def get(self, ids):
to_get = []
single = False
if isinstance(ids, list):
to_get = [x.id for x in ids]
elif isinstance(ids, ObjectID):
to_get = [ids.id]
single = True
else:
raise Exception(
"Can't get something that's not a list of IDs or just an ID")
out = [self._get(x) for x in to_get]
if single:
out = out[0]
return out
def _get(self, id: bytes):
req = ray_client_pb2.GetRequest(id=id)
data = self.server.GetObject(req)
return cloudpickle.loads(data.data)
def put(self, vals):
to_put = []
single = False
if isinstance(vals, list):
to_put = vals
else:
single = True
to_put.append(vals)
out = [self._put(x) for x in to_put]
if single:
out = out[0]
return out
def _put(self, val):
data = cloudpickle.dumps(val)
req = ray_client_pb2.PutRequest(data=data)
resp = self.server.PutObject(req)
return ObjectID(resp.id)
def remote(self, func):
return RemoteFunc(self, func)
def schedule(self, task):
return self.server.Schedule(task)
def close(self):
self.channel.close()
class RemoteFunc:
def __init__(self, worker, f):
self._func = f
self._name = f.__name__
self.id = None
def __call__(self, *args, **kwargs):
raise Exception("Matching the old API")
def remote(self, *args):
pass
def __repr__(self):
return "RemoteFunc(%s, %s)" % (self._name, self.id)
+1
View File
@@ -102,6 +102,7 @@ py_test_module_list(
"test_dask_scheduler.py",
"test_dask_callback.py",
"test_debug_tools.py",
"test_experimental_client.py",
"test_job.py",
"test_memstat.py",
"test_metrics_agent.py",
@@ -0,0 +1,21 @@
import pytest
import ray.experimental.client.server as ray_client_server
import ray.experimental.client as ray
def test_put_get(ray_start_regular_shared):
server = ray_client_server.serve("localhost:50051")
ray.connect("localhost:50051")
objectref = ray.put("hello world")
print(objectref)
retval = ray.get(objectref)
assert retval == "hello world"
ray.disconnect()
server.stop(0)
if __name__ == "__main__":
import sys
sys.exit(pytest.main(["-v", __file__]))