From de5bc24c602e3ac8580e5e822b79e54fe81e2920 Mon Sep 17 00:00:00 2001 From: Eric Liang Date: Mon, 11 Jan 2021 14:54:52 -0800 Subject: [PATCH] Implement internal kv in ray client (#13344) * kv internal * fix --- python/ray/experimental/internal_kv.py | 18 ++++++++--- python/ray/tests/test_client.py | 11 +++++++ python/ray/util/client/api.py | 26 ++++++++++++++++ python/ray/util/client/server/server.py | 22 ++++++++++++++ python/ray/util/client/worker.py | 18 +++++++++++ src/ray/protobuf/ray_client.proto | 40 +++++++++++++++++++++++++ 6 files changed, 131 insertions(+), 4 deletions(-) diff --git a/python/ray/experimental/internal_kv.py b/python/ray/experimental/internal_kv.py index 6ce2ad162..388d06037 100644 --- a/python/ray/experimental/internal_kv.py +++ b/python/ray/experimental/internal_kv.py @@ -1,18 +1,26 @@ +from typing import List, Union + import ray +from ray._private.client_mode_hook import client_mode_hook +@client_mode_hook def _internal_kv_initialized(): worker = ray.worker.global_worker return hasattr(worker, "mode") and worker.mode is not None -def _internal_kv_get(key): +@client_mode_hook +def _internal_kv_get(key: Union[str, bytes]) -> bytes: """Fetch the value of a binary key.""" return ray.worker.global_worker.redis_client.hget(key, "value") -def _internal_kv_put(key, value, overwrite=False): +@client_mode_hook +def _internal_kv_put(key: Union[str, bytes], + value: Union[str, bytes], + overwrite: bool = False) -> bool: """Globally associates a value with a given binary key. This only has an effect if the key does not already have a value. @@ -30,11 +38,13 @@ def _internal_kv_put(key, value, overwrite=False): return updated == 0 # already exists -def _internal_kv_del(key): +@client_mode_hook +def _internal_kv_del(key: Union[str, bytes]): return ray.worker.global_worker.redis_client.delete(key) -def _internal_kv_list(prefix): +@client_mode_hook +def _internal_kv_list(prefix: Union[str, bytes]) -> List[bytes]: """List all keys in the internal KV store that start with the prefix.""" if isinstance(prefix, bytes): pattern = prefix + b"*" diff --git a/python/ray/tests/test_client.py b/python/ray/tests/test_client.py index 0b5c34e99..39606ed98 100644 --- a/python/ray/tests/test_client.py +++ b/python/ray/tests/test_client.py @@ -341,5 +341,16 @@ def test_basic_named_actor(ray_start_regular_shared): assert ray.get(new_actor.get.remote()) == 3 +def test_internal_kv(ray_start_regular_shared): + with ray_start_client_server() as ray: + assert ray._internal_kv_initialized() + assert not ray._internal_kv_put("apple", "b") + assert ray._internal_kv_put("apple", "b") + assert ray._internal_kv_get("apple") == b"b" + assert ray._internal_kv_list("a") == [b"apple"] + ray._internal_kv_del("apple") + assert ray._internal_kv_get("apple") == b"" + + if __name__ == "__main__": sys.exit(pytest.main(["-v", __file__])) diff --git a/python/ray/util/client/api.py b/python/ray/util/client/api.py index e8bcb1a2b..c6a3a871b 100644 --- a/python/ray/util/client/api.py +++ b/python/ray/util/client/api.py @@ -8,6 +8,12 @@ if TYPE_CHECKING: from ray.util.client.common import ClientObjectRef +def as_bytes(value): + if isinstance(value, str): + return value.encode("utf-8") + return value + + class ClientAPI: """The Client-side methods corresponding to the ray API. Delegates to the Client Worker that contains the connection to the ClientServer. @@ -226,6 +232,26 @@ class ClientAPI: return self.worker.get_cluster_info( ray_client_pb2.ClusterInfoType.AVAILABLE_RESOURCES) + def _internal_kv_initialized(self) -> bool: + """Hook for internal_kv._internal_kv_initialized.""" + return self.is_initialized() + + def _internal_kv_get(self, key: bytes) -> bytes: + """Hook for internal_kv._internal_kv_get.""" + return self.worker.internal_kv_get(as_bytes(key)) + + def _internal_kv_put(self, key: bytes, value: bytes) -> bool: + """Hook for internal_kv._internal_kv_put.""" + return self.worker.internal_kv_put(as_bytes(key), as_bytes(value)) + + def _internal_kv_del(self, key: bytes) -> None: + """Hook for internal_kv._internal_kv_del.""" + return self.worker.internal_kv_del(as_bytes(key)) + + def _internal_kv_list(self, prefix: bytes) -> bytes: + """Hook for internal_kv._internal_kv_list.""" + return self.worker.internal_kv_list(as_bytes(prefix)) + def __getattr__(self, key: str): if not key.startswith("_"): raise NotImplementedError( diff --git a/python/ray/util/client/server/server.py b/python/ray/util/client/server/server.py index 4b59f1797..9c75b9404 100644 --- a/python/ray/util/client/server/server.py +++ b/python/ray/util/client/server/server.py @@ -38,6 +38,28 @@ class RayletServicer(ray_client_pb2_grpc.RayletDriverServicer): self.registered_actor_classes = {} self._current_function_stub = None + def KVPut(self, request, context=None) -> ray_client_pb2.KVPutResponse: + with disable_client_hook(): + already_exists = ray.experimental.internal_kv._internal_kv_put( + request.key, request.value) + return ray_client_pb2.KVPutResponse(already_exists=already_exists) + + def KVGet(self, request, context=None) -> ray_client_pb2.KVGetResponse: + with disable_client_hook(): + value = ray.experimental.internal_kv._internal_kv_get(request.key) + return ray_client_pb2.KVGetResponse(value=value) + + def KVDel(self, request, context=None) -> ray_client_pb2.KVDelResponse: + with disable_client_hook(): + ray.experimental.internal_kv._internal_kv_del(request.key) + return ray_client_pb2.KVDelResponse() + + def KVList(self, request, context=None) -> ray_client_pb2.KVListResponse: + with disable_client_hook(): + keys = ray.experimental.internal_kv._internal_kv_list( + request.prefix) + return ray_client_pb2.KVListResponse(keys=keys) + def ClusterInfo(self, request, context=None) -> ray_client_pb2.ClusterInfoResponse: resp = ray_client_pb2.ClusterInfoResponse() diff --git a/python/ray/util/client/worker.py b/python/ray/util/client/worker.py index 303d39c8d..e86690867 100644 --- a/python/ray/util/client/worker.py +++ b/python/ray/util/client/worker.py @@ -255,6 +255,24 @@ class Worker: return output_dict return json.loads(resp.json) + def internal_kv_get(self, key: bytes) -> bytes: + req = ray_client_pb2.KVGetRequest(key=key) + resp = self.server.KVGet(req, metadata=self.metadata) + return resp.value + + def internal_kv_put(self, key: bytes, value: bytes) -> bool: + req = ray_client_pb2.KVPutRequest(key=key, value=value) + resp = self.server.KVPut(req, metadata=self.metadata) + return resp.already_exists + + def internal_kv_del(self, key: bytes) -> None: + req = ray_client_pb2.KVDelRequest(key=key) + self.server.KVDel(req, metadata=self.metadata) + + def internal_kv_list(self, prefix: bytes) -> bytes: + req = ray_client_pb2.KVListRequest(prefix=prefix) + return self.server.KVList(req, metadata=self.metadata).keys + def is_initialized(self) -> bool: if self.server is not None: return self.get_cluster_info( diff --git a/src/ray/protobuf/ray_client.proto b/src/ray/protobuf/ray_client.proto index 49ec8475e..856815cbf 100644 --- a/src/ray/protobuf/ray_client.proto +++ b/src/ray/protobuf/ray_client.proto @@ -178,6 +178,38 @@ message TerminateResponse { bool ok = 1; } +message KVGetRequest { + bytes key = 1; +} + +message KVGetResponse { + bytes value = 1; +} + +message KVPutRequest { + bytes key = 1; + bytes value = 2; +} + +message KVPutResponse { + bool already_exists = 1; +} + +message KVDelRequest { + bytes key = 1; +} + +message KVDelResponse { +} + +message KVListRequest { + bytes prefix = 1; +} + +message KVListResponse { + repeated bytes keys = 1; +} + service RayletDriver { rpc GetObject(GetRequest) returns (GetResponse) { } @@ -191,6 +223,14 @@ service RayletDriver { } rpc ClusterInfo(ClusterInfoRequest) returns (ClusterInfoResponse) { } + rpc KVGet(KVGetRequest) returns (KVGetResponse) { + } + rpc KVPut(KVPutRequest) returns (KVPutResponse) { + } + rpc KVDel(KVDelRequest) returns (KVDelResponse) { + } + rpc KVList(KVListRequest) returns (KVListResponse) { + } } message ReleaseRequest {