mirror of
https://github.com/wassname/ray.git
synced 2026-06-27 20:06:31 +08:00
[ray_client]: Support runtime_context as metadata (#13428)
This commit is contained in:
@@ -1,9 +1,13 @@
|
||||
import pytest
|
||||
|
||||
from ray.util.client.ray_client_helpers import ray_start_client_server
|
||||
from ray._raylet import NodeID
|
||||
|
||||
from ray.runtime_context import RuntimeContext
|
||||
|
||||
|
||||
def test_get_ray_metadata(ray_start_regular_shared):
|
||||
"""Test the ClusterInfo client data pathway and API surface
|
||||
"""
|
||||
"""Test the ClusterInfo client data pathway and API surface"""
|
||||
with ray_start_client_server() as ray:
|
||||
ip_address = ray_start_regular_shared["node_ip_address"]
|
||||
|
||||
@@ -22,3 +26,15 @@ def test_get_ray_metadata(ray_start_regular_shared):
|
||||
assert cluster_resources["CPU"] == 1.0
|
||||
assert current_node_id in cluster_resources
|
||||
assert current_node_id in available_resources
|
||||
|
||||
|
||||
def test_get_runtime_context(ray_start_regular_shared):
|
||||
"""Test the get_runtime_context data through the metadata API"""
|
||||
with ray_start_client_server() as ray:
|
||||
rtc = ray.get_runtime_context()
|
||||
assert isinstance(rtc, RuntimeContext)
|
||||
assert isinstance(rtc.node_id, NodeID)
|
||||
assert len(rtc.node_id.hex()) == 56
|
||||
|
||||
with pytest.raises(Exception):
|
||||
_ = rtc.task_id
|
||||
|
||||
@@ -1,6 +1,7 @@
|
||||
"""This file defines the interface between the ray client worker
|
||||
and the overall ray module API.
|
||||
"""
|
||||
from ray.util.client.runtime_context import ClientWorkerPropertyAPI
|
||||
from typing import TYPE_CHECKING
|
||||
if TYPE_CHECKING:
|
||||
from ray.util.client.common import ClientStub
|
||||
@@ -232,6 +233,14 @@ class ClientAPI:
|
||||
return self.worker.get_cluster_info(
|
||||
ray_client_pb2.ClusterInfoType.AVAILABLE_RESOURCES)
|
||||
|
||||
def get_runtime_context(self):
|
||||
"""Return a Ray RuntimeContext describing the state on the server
|
||||
|
||||
Returns:
|
||||
A RuntimeContext wrapping a client making get_cluster_info calls.
|
||||
"""
|
||||
return ClientWorkerPropertyAPI(self.worker).build_runtime_context()
|
||||
|
||||
def _internal_kv_initialized(self) -> bool:
|
||||
"""Hook for internal_kv._internal_kv_initialized."""
|
||||
return self.is_initialized()
|
||||
|
||||
@@ -0,0 +1,43 @@
|
||||
from typing import TYPE_CHECKING
|
||||
if TYPE_CHECKING:
|
||||
from ray.runtime_context import RuntimeContext
|
||||
from ray import JobID
|
||||
from ray import NodeID
|
||||
|
||||
|
||||
class ClientWorkerPropertyAPI:
|
||||
"""Emulates the properties of the ray.worker object for the client"""
|
||||
|
||||
def __init__(self, worker):
|
||||
assert worker is not None
|
||||
self.worker = worker
|
||||
|
||||
def build_runtime_context(self) -> "RuntimeContext":
|
||||
"""Creates a RuntimeContext backed by the properites of this API"""
|
||||
# Defer the import of RuntimeContext until needed to avoid cycles
|
||||
from ray.runtime_context import RuntimeContext
|
||||
return RuntimeContext(self)
|
||||
|
||||
def _fetch_runtime_context(self):
|
||||
import ray.core.generated.ray_client_pb2 as ray_client_pb2
|
||||
return self.worker.get_cluster_info(
|
||||
ray_client_pb2.ClusterInfoType.RUNTIME_CONTEXT)
|
||||
|
||||
@property
|
||||
def mode(self):
|
||||
from ray.worker import SCRIPT_MODE
|
||||
return SCRIPT_MODE
|
||||
|
||||
@property
|
||||
def current_job_id(self) -> "JobID":
|
||||
from ray import JobID
|
||||
return JobID(self._fetch_runtime_context().job_id)
|
||||
|
||||
@property
|
||||
def current_node_id(self) -> "NodeID":
|
||||
from ray import NodeID
|
||||
return NodeID(self._fetch_runtime_context().node_id)
|
||||
|
||||
@property
|
||||
def should_capture_child_tasks_in_placement_group(self) -> bool:
|
||||
return self._fetch_runtime_context().capture_client_tasks
|
||||
@@ -83,6 +83,15 @@ class RayletServicer(ray_client_pb2_grpc.RayletDriverServicer):
|
||||
resp.resource_table.CopyFrom(
|
||||
ray_client_pb2.ClusterInfoResponse.ResourceTable(
|
||||
table=float_resources))
|
||||
elif request.type == ray_client_pb2.ClusterInfoType.RUNTIME_CONTEXT:
|
||||
ctx = ray_client_pb2.ClusterInfoResponse.RuntimeContext()
|
||||
with disable_client_hook():
|
||||
rtc = ray.get_runtime_context()
|
||||
ctx.job_id = rtc.job_id.binary()
|
||||
ctx.node_id = rtc.node_id.binary()
|
||||
ctx.capture_client_tasks = \
|
||||
rtc.should_capture_child_tasks_in_placement_group
|
||||
resp.runtime_context.CopyFrom(ctx)
|
||||
else:
|
||||
with disable_client_hook():
|
||||
resp.json = self._return_debug_cluster_info(request, context)
|
||||
|
||||
@@ -278,6 +278,8 @@ class Worker:
|
||||
# translate from a proto map to a python dict
|
||||
output_dict = {k: v for k, v in resp.resource_table.table.items()}
|
||||
return output_dict
|
||||
elif resp.WhichOneof("response_type") == "runtime_context":
|
||||
return resp.runtime_context
|
||||
return json.loads(resp.json)
|
||||
|
||||
def internal_kv_get(self, key: bytes) -> bytes:
|
||||
|
||||
Reference in New Issue
Block a user