Implement internal kv in ray client (#13344)

* kv internal

* fix
This commit is contained in:
Eric Liang
2021-01-11 14:54:52 -08:00
committed by GitHub
parent fbb9795374
commit de5bc24c60
6 changed files with 131 additions and 4 deletions
+14 -4
View File
@@ -1,18 +1,26 @@
from typing import List, Union
import ray
from ray._private.client_mode_hook import client_mode_hook
@client_mode_hook
def _internal_kv_initialized():
worker = ray.worker.global_worker
return hasattr(worker, "mode") and worker.mode is not None
def _internal_kv_get(key):
@client_mode_hook
def _internal_kv_get(key: Union[str, bytes]) -> bytes:
"""Fetch the value of a binary key."""
return ray.worker.global_worker.redis_client.hget(key, "value")
def _internal_kv_put(key, value, overwrite=False):
@client_mode_hook
def _internal_kv_put(key: Union[str, bytes],
value: Union[str, bytes],
overwrite: bool = False) -> bool:
"""Globally associates a value with a given binary key.
This only has an effect if the key does not already have a value.
@@ -30,11 +38,13 @@ def _internal_kv_put(key, value, overwrite=False):
return updated == 0 # already exists
def _internal_kv_del(key):
@client_mode_hook
def _internal_kv_del(key: Union[str, bytes]):
return ray.worker.global_worker.redis_client.delete(key)
def _internal_kv_list(prefix):
@client_mode_hook
def _internal_kv_list(prefix: Union[str, bytes]) -> List[bytes]:
"""List all keys in the internal KV store that start with the prefix."""
if isinstance(prefix, bytes):
pattern = prefix + b"*"
+11
View File
@@ -341,5 +341,16 @@ def test_basic_named_actor(ray_start_regular_shared):
assert ray.get(new_actor.get.remote()) == 3
def test_internal_kv(ray_start_regular_shared):
with ray_start_client_server() as ray:
assert ray._internal_kv_initialized()
assert not ray._internal_kv_put("apple", "b")
assert ray._internal_kv_put("apple", "b")
assert ray._internal_kv_get("apple") == b"b"
assert ray._internal_kv_list("a") == [b"apple"]
ray._internal_kv_del("apple")
assert ray._internal_kv_get("apple") == b""
if __name__ == "__main__":
sys.exit(pytest.main(["-v", __file__]))
+26
View File
@@ -8,6 +8,12 @@ if TYPE_CHECKING:
from ray.util.client.common import ClientObjectRef
def as_bytes(value):
if isinstance(value, str):
return value.encode("utf-8")
return value
class ClientAPI:
"""The Client-side methods corresponding to the ray API. Delegates
to the Client Worker that contains the connection to the ClientServer.
@@ -226,6 +232,26 @@ class ClientAPI:
return self.worker.get_cluster_info(
ray_client_pb2.ClusterInfoType.AVAILABLE_RESOURCES)
def _internal_kv_initialized(self) -> bool:
"""Hook for internal_kv._internal_kv_initialized."""
return self.is_initialized()
def _internal_kv_get(self, key: bytes) -> bytes:
"""Hook for internal_kv._internal_kv_get."""
return self.worker.internal_kv_get(as_bytes(key))
def _internal_kv_put(self, key: bytes, value: bytes) -> bool:
"""Hook for internal_kv._internal_kv_put."""
return self.worker.internal_kv_put(as_bytes(key), as_bytes(value))
def _internal_kv_del(self, key: bytes) -> None:
"""Hook for internal_kv._internal_kv_del."""
return self.worker.internal_kv_del(as_bytes(key))
def _internal_kv_list(self, prefix: bytes) -> bytes:
"""Hook for internal_kv._internal_kv_list."""
return self.worker.internal_kv_list(as_bytes(prefix))
def __getattr__(self, key: str):
if not key.startswith("_"):
raise NotImplementedError(
+22
View File
@@ -38,6 +38,28 @@ class RayletServicer(ray_client_pb2_grpc.RayletDriverServicer):
self.registered_actor_classes = {}
self._current_function_stub = None
def KVPut(self, request, context=None) -> ray_client_pb2.KVPutResponse:
with disable_client_hook():
already_exists = ray.experimental.internal_kv._internal_kv_put(
request.key, request.value)
return ray_client_pb2.KVPutResponse(already_exists=already_exists)
def KVGet(self, request, context=None) -> ray_client_pb2.KVGetResponse:
with disable_client_hook():
value = ray.experimental.internal_kv._internal_kv_get(request.key)
return ray_client_pb2.KVGetResponse(value=value)
def KVDel(self, request, context=None) -> ray_client_pb2.KVDelResponse:
with disable_client_hook():
ray.experimental.internal_kv._internal_kv_del(request.key)
return ray_client_pb2.KVDelResponse()
def KVList(self, request, context=None) -> ray_client_pb2.KVListResponse:
with disable_client_hook():
keys = ray.experimental.internal_kv._internal_kv_list(
request.prefix)
return ray_client_pb2.KVListResponse(keys=keys)
def ClusterInfo(self, request,
context=None) -> ray_client_pb2.ClusterInfoResponse:
resp = ray_client_pb2.ClusterInfoResponse()
+18
View File
@@ -255,6 +255,24 @@ class Worker:
return output_dict
return json.loads(resp.json)
def internal_kv_get(self, key: bytes) -> bytes:
req = ray_client_pb2.KVGetRequest(key=key)
resp = self.server.KVGet(req, metadata=self.metadata)
return resp.value
def internal_kv_put(self, key: bytes, value: bytes) -> bool:
req = ray_client_pb2.KVPutRequest(key=key, value=value)
resp = self.server.KVPut(req, metadata=self.metadata)
return resp.already_exists
def internal_kv_del(self, key: bytes) -> None:
req = ray_client_pb2.KVDelRequest(key=key)
self.server.KVDel(req, metadata=self.metadata)
def internal_kv_list(self, prefix: bytes) -> bytes:
req = ray_client_pb2.KVListRequest(prefix=prefix)
return self.server.KVList(req, metadata=self.metadata).keys
def is_initialized(self) -> bool:
if self.server is not None:
return self.get_cluster_info(