From 84e110a949e83c4956cca14ed59fd5a92e34f3fe Mon Sep 17 00:00:00 2001 From: Barak Michener Date: Thu, 14 Jan 2021 14:37:00 -0800 Subject: [PATCH] [ray_client]: Support runtime_context as metadata (#13428) --- python/ray/tests/test_client_metadata.py | 20 +++++++++-- python/ray/util/client/api.py | 9 +++++ python/ray/util/client/runtime_context.py | 43 +++++++++++++++++++++++ python/ray/util/client/server/server.py | 9 +++++ python/ray/util/client/worker.py | 2 ++ src/ray/protobuf/ray_client.proto | 9 +++++ 6 files changed, 90 insertions(+), 2 deletions(-) create mode 100644 python/ray/util/client/runtime_context.py diff --git a/python/ray/tests/test_client_metadata.py b/python/ray/tests/test_client_metadata.py index 47f46dec6..ffec75a77 100644 --- a/python/ray/tests/test_client_metadata.py +++ b/python/ray/tests/test_client_metadata.py @@ -1,9 +1,13 @@ +import pytest + from ray.util.client.ray_client_helpers import ray_start_client_server +from ray._raylet import NodeID + +from ray.runtime_context import RuntimeContext def test_get_ray_metadata(ray_start_regular_shared): - """Test the ClusterInfo client data pathway and API surface - """ + """Test the ClusterInfo client data pathway and API surface""" with ray_start_client_server() as ray: ip_address = ray_start_regular_shared["node_ip_address"] @@ -22,3 +26,15 @@ def test_get_ray_metadata(ray_start_regular_shared): assert cluster_resources["CPU"] == 1.0 assert current_node_id in cluster_resources assert current_node_id in available_resources + + +def test_get_runtime_context(ray_start_regular_shared): + """Test the get_runtime_context data through the metadata API""" + with ray_start_client_server() as ray: + rtc = ray.get_runtime_context() + assert isinstance(rtc, RuntimeContext) + assert isinstance(rtc.node_id, NodeID) + assert len(rtc.node_id.hex()) == 56 + + with pytest.raises(Exception): + _ = rtc.task_id diff --git a/python/ray/util/client/api.py b/python/ray/util/client/api.py index 22ae62cc2..7d8576d1f 100644 --- a/python/ray/util/client/api.py +++ b/python/ray/util/client/api.py @@ -1,6 +1,7 @@ """This file defines the interface between the ray client worker and the overall ray module API. """ +from ray.util.client.runtime_context import ClientWorkerPropertyAPI from typing import TYPE_CHECKING if TYPE_CHECKING: from ray.util.client.common import ClientStub @@ -232,6 +233,14 @@ class ClientAPI: return self.worker.get_cluster_info( ray_client_pb2.ClusterInfoType.AVAILABLE_RESOURCES) + def get_runtime_context(self): + """Return a Ray RuntimeContext describing the state on the server + + Returns: + A RuntimeContext wrapping a client making get_cluster_info calls. + """ + return ClientWorkerPropertyAPI(self.worker).build_runtime_context() + def _internal_kv_initialized(self) -> bool: """Hook for internal_kv._internal_kv_initialized.""" return self.is_initialized() diff --git a/python/ray/util/client/runtime_context.py b/python/ray/util/client/runtime_context.py new file mode 100644 index 000000000..1847dae6e --- /dev/null +++ b/python/ray/util/client/runtime_context.py @@ -0,0 +1,43 @@ +from typing import TYPE_CHECKING +if TYPE_CHECKING: + from ray.runtime_context import RuntimeContext + from ray import JobID + from ray import NodeID + + +class ClientWorkerPropertyAPI: + """Emulates the properties of the ray.worker object for the client""" + + def __init__(self, worker): + assert worker is not None + self.worker = worker + + def build_runtime_context(self) -> "RuntimeContext": + """Creates a RuntimeContext backed by the properites of this API""" + # Defer the import of RuntimeContext until needed to avoid cycles + from ray.runtime_context import RuntimeContext + return RuntimeContext(self) + + def _fetch_runtime_context(self): + import ray.core.generated.ray_client_pb2 as ray_client_pb2 + return self.worker.get_cluster_info( + ray_client_pb2.ClusterInfoType.RUNTIME_CONTEXT) + + @property + def mode(self): + from ray.worker import SCRIPT_MODE + return SCRIPT_MODE + + @property + def current_job_id(self) -> "JobID": + from ray import JobID + return JobID(self._fetch_runtime_context().job_id) + + @property + def current_node_id(self) -> "NodeID": + from ray import NodeID + return NodeID(self._fetch_runtime_context().node_id) + + @property + def should_capture_child_tasks_in_placement_group(self) -> bool: + return self._fetch_runtime_context().capture_client_tasks diff --git a/python/ray/util/client/server/server.py b/python/ray/util/client/server/server.py index cf095a137..19a192337 100644 --- a/python/ray/util/client/server/server.py +++ b/python/ray/util/client/server/server.py @@ -83,6 +83,15 @@ class RayletServicer(ray_client_pb2_grpc.RayletDriverServicer): resp.resource_table.CopyFrom( ray_client_pb2.ClusterInfoResponse.ResourceTable( table=float_resources)) + elif request.type == ray_client_pb2.ClusterInfoType.RUNTIME_CONTEXT: + ctx = ray_client_pb2.ClusterInfoResponse.RuntimeContext() + with disable_client_hook(): + rtc = ray.get_runtime_context() + ctx.job_id = rtc.job_id.binary() + ctx.node_id = rtc.node_id.binary() + ctx.capture_client_tasks = \ + rtc.should_capture_child_tasks_in_placement_group + resp.runtime_context.CopyFrom(ctx) else: with disable_client_hook(): resp.json = self._return_debug_cluster_info(request, context) diff --git a/python/ray/util/client/worker.py b/python/ray/util/client/worker.py index 94e6d1cd7..7515ecf04 100644 --- a/python/ray/util/client/worker.py +++ b/python/ray/util/client/worker.py @@ -278,6 +278,8 @@ class Worker: # translate from a proto map to a python dict output_dict = {k: v for k, v in resp.resource_table.table.items()} return output_dict + elif resp.WhichOneof("response_type") == "runtime_context": + return resp.runtime_context return json.loads(resp.json) def internal_kv_get(self, key: bytes) -> bytes: diff --git a/src/ray/protobuf/ray_client.proto b/src/ray/protobuf/ray_client.proto index 919e7a226..bbc86772c 100644 --- a/src/ray/protobuf/ray_client.proto +++ b/src/ray/protobuf/ray_client.proto @@ -138,6 +138,7 @@ message ClusterInfoType { NODES = 1; CLUSTER_RESOURCES = 2; AVAILABLE_RESOURCES = 3; + RUNTIME_CONTEXT = 4; } } @@ -149,10 +150,18 @@ message ClusterInfoResponse { message ResourceTable { map table = 1; } + + message RuntimeContext { + bytes job_id = 1; + bytes node_id = 2; + bool capture_client_tasks = 3; + } + ClusterInfoType.TypeEnum type = 1; oneof response_type { string json = 2; ResourceTable resource_table = 3; + RuntimeContext runtime_context = 4; } }