[ray_client]: Support runtime_context as metadata (#13428)

This commit is contained in:
Barak Michener
2021-01-14 14:37:00 -08:00
committed by GitHub
parent 9a658b568f
commit 84e110a949
6 changed files with 90 additions and 2 deletions
+18 -2
View File
@@ -1,9 +1,13 @@
import pytest
from ray.util.client.ray_client_helpers import ray_start_client_server
from ray._raylet import NodeID
from ray.runtime_context import RuntimeContext
def test_get_ray_metadata(ray_start_regular_shared):
"""Test the ClusterInfo client data pathway and API surface
"""
"""Test the ClusterInfo client data pathway and API surface"""
with ray_start_client_server() as ray:
ip_address = ray_start_regular_shared["node_ip_address"]
@@ -22,3 +26,15 @@ def test_get_ray_metadata(ray_start_regular_shared):
assert cluster_resources["CPU"] == 1.0
assert current_node_id in cluster_resources
assert current_node_id in available_resources
def test_get_runtime_context(ray_start_regular_shared):
"""Test the get_runtime_context data through the metadata API"""
with ray_start_client_server() as ray:
rtc = ray.get_runtime_context()
assert isinstance(rtc, RuntimeContext)
assert isinstance(rtc.node_id, NodeID)
assert len(rtc.node_id.hex()) == 56
with pytest.raises(Exception):
_ = rtc.task_id
+9
View File
@@ -1,6 +1,7 @@
"""This file defines the interface between the ray client worker
and the overall ray module API.
"""
from ray.util.client.runtime_context import ClientWorkerPropertyAPI
from typing import TYPE_CHECKING
if TYPE_CHECKING:
from ray.util.client.common import ClientStub
@@ -232,6 +233,14 @@ class ClientAPI:
return self.worker.get_cluster_info(
ray_client_pb2.ClusterInfoType.AVAILABLE_RESOURCES)
def get_runtime_context(self):
"""Return a Ray RuntimeContext describing the state on the server
Returns:
A RuntimeContext wrapping a client making get_cluster_info calls.
"""
return ClientWorkerPropertyAPI(self.worker).build_runtime_context()
def _internal_kv_initialized(self) -> bool:
"""Hook for internal_kv._internal_kv_initialized."""
return self.is_initialized()
+43
View File
@@ -0,0 +1,43 @@
from typing import TYPE_CHECKING
if TYPE_CHECKING:
from ray.runtime_context import RuntimeContext
from ray import JobID
from ray import NodeID
class ClientWorkerPropertyAPI:
"""Emulates the properties of the ray.worker object for the client"""
def __init__(self, worker):
assert worker is not None
self.worker = worker
def build_runtime_context(self) -> "RuntimeContext":
"""Creates a RuntimeContext backed by the properites of this API"""
# Defer the import of RuntimeContext until needed to avoid cycles
from ray.runtime_context import RuntimeContext
return RuntimeContext(self)
def _fetch_runtime_context(self):
import ray.core.generated.ray_client_pb2 as ray_client_pb2
return self.worker.get_cluster_info(
ray_client_pb2.ClusterInfoType.RUNTIME_CONTEXT)
@property
def mode(self):
from ray.worker import SCRIPT_MODE
return SCRIPT_MODE
@property
def current_job_id(self) -> "JobID":
from ray import JobID
return JobID(self._fetch_runtime_context().job_id)
@property
def current_node_id(self) -> "NodeID":
from ray import NodeID
return NodeID(self._fetch_runtime_context().node_id)
@property
def should_capture_child_tasks_in_placement_group(self) -> bool:
return self._fetch_runtime_context().capture_client_tasks
+9
View File
@@ -83,6 +83,15 @@ class RayletServicer(ray_client_pb2_grpc.RayletDriverServicer):
resp.resource_table.CopyFrom(
ray_client_pb2.ClusterInfoResponse.ResourceTable(
table=float_resources))
elif request.type == ray_client_pb2.ClusterInfoType.RUNTIME_CONTEXT:
ctx = ray_client_pb2.ClusterInfoResponse.RuntimeContext()
with disable_client_hook():
rtc = ray.get_runtime_context()
ctx.job_id = rtc.job_id.binary()
ctx.node_id = rtc.node_id.binary()
ctx.capture_client_tasks = \
rtc.should_capture_child_tasks_in_placement_group
resp.runtime_context.CopyFrom(ctx)
else:
with disable_client_hook():
resp.json = self._return_debug_cluster_info(request, context)
+2
View File
@@ -278,6 +278,8 @@ class Worker:
# translate from a proto map to a python dict
output_dict = {k: v for k, v in resp.resource_table.table.items()}
return output_dict
elif resp.WhichOneof("response_type") == "runtime_context":
return resp.runtime_context
return json.loads(resp.json)
def internal_kv_get(self, key: bytes) -> bytes: