From 27c810a97e182fd53d3c6a72f7840406b571aa79 Mon Sep 17 00:00:00 2001 From: Barak Michener Date: Thu, 5 Nov 2020 16:23:54 -0800 Subject: [PATCH] Basic protos for ray client (#11762) --- BUILD.bazel | 1 + python/ray/experimental/client/__init__.py | 44 ++++++++++ python/ray/experimental/client/client_app.py | 8 ++ python/ray/experimental/client/server.py | 55 ++++++++++++ python/ray/experimental/client/worker.py | 88 ++++++++++++++++++++ python/ray/tests/BUILD | 1 + python/ray/tests/test_experimental_client.py | 21 +++++ src/ray/protobuf/BUILD | 12 +++ src/ray/protobuf/ray_client.proto | 67 +++++++++++++++ 9 files changed, 297 insertions(+) create mode 100644 python/ray/experimental/client/__init__.py create mode 100644 python/ray/experimental/client/client_app.py create mode 100644 python/ray/experimental/client/server.py create mode 100644 python/ray/experimental/client/worker.py create mode 100644 python/ray/tests/test_experimental_client.py create mode 100644 src/ray/protobuf/ray_client.proto diff --git a/BUILD.bazel b/BUILD.bazel index 41ab0a35c..6a224cafe 100644 --- a/BUILD.bazel +++ b/BUILD.bazel @@ -1861,6 +1861,7 @@ filegroup( "//src/ray/protobuf:gcs_py_proto", "//src/ray/protobuf:gcs_service_py_proto", "//src/ray/protobuf:node_manager_py_proto", + "//src/ray/protobuf:ray_client_py_proto", "//src/ray/protobuf:reporter_py_proto", ], ) diff --git a/python/ray/experimental/client/__init__.py b/python/ray/experimental/client/__init__.py new file mode 100644 index 000000000..12b7322a6 --- /dev/null +++ b/python/ray/experimental/client/__init__.py @@ -0,0 +1,44 @@ +import ray +from ray.experimental.client.worker import Worker +from typing import Optional + +_client_worker: Optional[Worker] = None +_in_cluster: bool = True + + +def get(*args, **kwargs): + global _client_worker + global _in_cluster + if _in_cluster: + return ray.get(*args, **kwargs) + else: + return _client_worker.get(*args, **kwargs) + + +def put(*args, **kwargs): + global _client_worker + global _in_cluster + if _in_cluster: + return ray.put(*args, **kwargs) + else: + return _client_worker.put(*args, **kwargs) + + +def remote(*args, **kwargs): + pass + + +def connect(conn_str): + global _in_cluster + global _client_worker + _in_cluster = False + _client_worker = Worker(conn_str) + + +def disconnect(): + global _in_cluster + global _client_worker + if _client_worker is not None: + _client_worker.close() + _in_cluster = True + _client_worker = None diff --git a/python/ray/experimental/client/client_app.py b/python/ray/experimental/client/client_app.py new file mode 100644 index 000000000..f78317493 --- /dev/null +++ b/python/ray/experimental/client/client_app.py @@ -0,0 +1,8 @@ +import ray.experimental.client as ray + +ray.connect("localhost:50051") + +objectref = ray.put("hello world") +print(objectref) + +print(ray.get(objectref)) diff --git a/python/ray/experimental/client/server.py b/python/ray/experimental/client/server.py new file mode 100644 index 000000000..804fc834d --- /dev/null +++ b/python/ray/experimental/client/server.py @@ -0,0 +1,55 @@ +import logging +from concurrent import futures +import grpc +from ray import cloudpickle +import ray +import ray.core.generated.ray_client_pb2 as ray_client_pb2 +import ray.core.generated.ray_client_pb2_grpc as ray_client_pb2_grpc +import time + + +class RayletServicer(ray_client_pb2_grpc.RayletDriverServicer): + def __init__(self): + self.realref = {} + + def GetObject(self, request, context=None): + objectref = self.realref[request.id] + print("get: %s" % objectref) + item = ray.get(objectref) + if item is None: + return ray_client_pb2.GetResponse(valid=False) + data = cloudpickle.loads(item) + return ray_client_pb2.GetResponse(valid=True, data=data) + + def PutObject(self, request, context=None): + data = cloudpickle.dumps(request.data) + objectref = ray.put(data) + self.realref[objectref.binary()] = objectref + print("put: %s" % objectref) + return ray_client_pb2.PutResponse(id=objectref.binary()) + + def Schedule(self, task, context=None): + context.set_code(grpc.StatusCode.UNIMPLEMENTED) + context.set_details("Unimplemented") + return ray_client_pb2.TaskTicket() + + +def serve(connection_str): + server = grpc.server(futures.ThreadPoolExecutor(max_workers=10)) + task_servicer = RayletServicer() + ray_client_pb2_grpc.add_RayletDriverServicer_to_server( + task_servicer, server) + server.add_insecure_port(connection_str) + server.start() + return server + + +if __name__ == "__main__": + logging.basicConfig() + ray.init() + server = serve("0.0.0.0:50051") + try: + while True: + time.sleep(1000) + except KeyboardInterrupt: + server.stop(0) diff --git a/python/ray/experimental/client/worker.py b/python/ray/experimental/client/worker.py new file mode 100644 index 000000000..c69fd4185 --- /dev/null +++ b/python/ray/experimental/client/worker.py @@ -0,0 +1,88 @@ +import cloudpickle +import grpc +import ray.core.generated.ray_client_pb2 as ray_client_pb2 +import ray.core.generated.ray_client_pb2_grpc as ray_client_pb2_grpc + + +class ObjectID: + def __init__(self, id): + self.id = id + + def __repr__(self): + return "ObjectID(%s)" % self.id.hex() + + +class Worker: + def __init__(self, conn_str="", stub=None): + if stub is None: + self.channel = grpc.insecure_channel(conn_str) + self.server = ray_client_pb2_grpc.RayletDriverStub(self.channel) + else: + self.server = stub + + def get(self, ids): + to_get = [] + single = False + if isinstance(ids, list): + to_get = [x.id for x in ids] + elif isinstance(ids, ObjectID): + to_get = [ids.id] + single = True + else: + raise Exception( + "Can't get something that's not a list of IDs or just an ID") + + out = [self._get(x) for x in to_get] + if single: + out = out[0] + return out + + def _get(self, id: bytes): + req = ray_client_pb2.GetRequest(id=id) + data = self.server.GetObject(req) + return cloudpickle.loads(data.data) + + def put(self, vals): + to_put = [] + single = False + if isinstance(vals, list): + to_put = vals + else: + single = True + to_put.append(vals) + + out = [self._put(x) for x in to_put] + if single: + out = out[0] + return out + + def _put(self, val): + data = cloudpickle.dumps(val) + req = ray_client_pb2.PutRequest(data=data) + resp = self.server.PutObject(req) + return ObjectID(resp.id) + + def remote(self, func): + return RemoteFunc(self, func) + + def schedule(self, task): + return self.server.Schedule(task) + + def close(self): + self.channel.close() + + +class RemoteFunc: + def __init__(self, worker, f): + self._func = f + self._name = f.__name__ + self.id = None + + def __call__(self, *args, **kwargs): + raise Exception("Matching the old API") + + def remote(self, *args): + pass + + def __repr__(self): + return "RemoteFunc(%s, %s)" % (self._name, self.id) diff --git a/python/ray/tests/BUILD b/python/ray/tests/BUILD index 3f64254cc..94630021f 100644 --- a/python/ray/tests/BUILD +++ b/python/ray/tests/BUILD @@ -102,6 +102,7 @@ py_test_module_list( "test_dask_scheduler.py", "test_dask_callback.py", "test_debug_tools.py", + "test_experimental_client.py", "test_job.py", "test_memstat.py", "test_metrics_agent.py", diff --git a/python/ray/tests/test_experimental_client.py b/python/ray/tests/test_experimental_client.py new file mode 100644 index 000000000..7877b57e1 --- /dev/null +++ b/python/ray/tests/test_experimental_client.py @@ -0,0 +1,21 @@ +import pytest +import ray.experimental.client.server as ray_client_server +import ray.experimental.client as ray + + +def test_put_get(ray_start_regular_shared): + server = ray_client_server.serve("localhost:50051") + ray.connect("localhost:50051") + + objectref = ray.put("hello world") + print(objectref) + + retval = ray.get(objectref) + assert retval == "hello world" + ray.disconnect() + server.stop(0) + + +if __name__ == "__main__": + import sys + sys.exit(pytest.main(["-v", __file__])) diff --git a/src/ray/protobuf/BUILD b/src/ray/protobuf/BUILD index 9efff0f4c..79a4ef15e 100644 --- a/src/ray/protobuf/BUILD +++ b/src/ray/protobuf/BUILD @@ -155,3 +155,15 @@ cc_proto_library( deps = [":agent_manager_proto"], ) +# Ray Client gRPC lib +proto_library( + name = "ray_client_proto", + srcs = ["ray_client.proto"], + deps = [], +) + +python_grpc_compile( + name = "ray_client_py_proto", + deps = [":ray_client_proto"] +) + diff --git a/src/ray/protobuf/ray_client.proto b/src/ray/protobuf/ray_client.proto new file mode 100644 index 000000000..dc04a5927 --- /dev/null +++ b/src/ray/protobuf/ray_client.proto @@ -0,0 +1,67 @@ +// Copyright 2020 The Ray Authors. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +syntax = "proto3"; + +package ray.rpc; + +enum Type { DEFAULT = 0; } + +message Arg { + enum Locality { + INTERNED = 0; + REFERENCE = 1; + } + Locality local = 1; + bytes reference_id = 2; + bytes data = 3; + Type type = 4; +} + +message ClientTask { + // Optionally Provided Task Name + string name = 1; + bytes payload_id = 2; + repeated Arg args = 3; +} + +message ClientTaskTicket { + bytes return_id = 1; +} + +message PutRequest { + bytes data = 1; +} + +message PutResponse { + bytes id = 1; +} + +message GetRequest { + bytes id = 1; +} + +message GetResponse { + bool valid = 1; + bytes data = 2; +} + +service RayletDriver { + rpc GetObject(GetRequest) returns (GetResponse) { + } + rpc PutObject(PutRequest) returns (PutResponse) { + } + rpc Schedule(ClientTask) returns (ClientTaskTicket) { + } +}