mirror of
https://github.com/wassname/ray.git
synced 2026-06-28 15:56:55 +08:00
[ray_client] Add metadata to gRPC requests (#13167)
This commit is contained in:
@@ -20,12 +20,14 @@ INT32_MAX = (2**31) - 1
|
||||
|
||||
|
||||
class DataClient:
|
||||
def __init__(self, channel: "grpc._channel.Channel", client_id: str):
|
||||
def __init__(self, channel: "grpc._channel.Channel", client_id: str,
|
||||
metadata: list):
|
||||
"""Initializes a thread-safe datapath over a Ray Client gRPC channel.
|
||||
|
||||
Args:
|
||||
channel: connected gRPC channel
|
||||
client_id: the generated ID representing this client
|
||||
metadata: metadata to pass to gRPC requests
|
||||
"""
|
||||
self.channel = channel
|
||||
self.request_queue = queue.Queue()
|
||||
@@ -34,6 +36,7 @@ class DataClient:
|
||||
self.cv = threading.Condition()
|
||||
self._req_id = 0
|
||||
self._client_id = client_id
|
||||
self._metadata = metadata
|
||||
self.data_thread.start()
|
||||
|
||||
def _next_id(self) -> int:
|
||||
@@ -52,7 +55,7 @@ class DataClient:
|
||||
stub = ray_client_pb2_grpc.RayletDataStreamerStub(self.channel)
|
||||
resp_stream = stub.Datapath(
|
||||
iter(self.request_queue.get, None),
|
||||
metadata=(("client_id", self._client_id), ),
|
||||
metadata=[("client_id", self._client_id)] + self._metadata,
|
||||
wait_for_ready=True)
|
||||
try:
|
||||
for response in resp_stream:
|
||||
|
||||
@@ -18,13 +18,15 @@ logger.propagate = False
|
||||
|
||||
|
||||
class LogstreamClient:
|
||||
def __init__(self, channel: "grpc._channel.Channel"):
|
||||
def __init__(self, channel: "grpc._channel.Channel", metadata: list):
|
||||
"""Initializes a thread-safe log stream over a Ray Client gRPC channel.
|
||||
|
||||
Args:
|
||||
channel: connected gRPC channel
|
||||
metadata: metadata to pass to gRPC requests
|
||||
"""
|
||||
self.channel = channel
|
||||
self._metadata = metadata
|
||||
self.request_queue = queue.Queue()
|
||||
self.log_thread = self._start_logthread()
|
||||
self.log_thread.start()
|
||||
@@ -34,7 +36,8 @@ class LogstreamClient:
|
||||
|
||||
def _log_main(self) -> None:
|
||||
stub = ray_client_pb2_grpc.RayletLogStreamerStub(self.channel)
|
||||
log_stream = stub.Logstream(iter(self.request_queue.get, None))
|
||||
log_stream = stub.Logstream(
|
||||
iter(self.request_queue.get, None), metadata=self._metadata)
|
||||
try:
|
||||
for record in log_stream:
|
||||
if record.level < 0:
|
||||
|
||||
@@ -41,7 +41,7 @@ class Worker:
|
||||
secure: whether to use SSL secure channel or not.
|
||||
metadata: additional metadata passed in the grpc request headers.
|
||||
"""
|
||||
self.metadata = metadata
|
||||
self.metadata = metadata if metadata else []
|
||||
self.channel = None
|
||||
self._client_id = make_client_id()
|
||||
if secure:
|
||||
@@ -51,10 +51,11 @@ class Worker:
|
||||
self.channel = grpc.insecure_channel(conn_str)
|
||||
self.server = ray_client_pb2_grpc.RayletDriverStub(self.channel)
|
||||
|
||||
self.data_client = DataClient(self.channel, self._client_id)
|
||||
self.data_client = DataClient(self.channel, self._client_id,
|
||||
self.metadata)
|
||||
self.reference_count: Dict[bytes, int] = defaultdict(int)
|
||||
|
||||
self.log_client = LogstreamClient(self.channel)
|
||||
self.log_client = LogstreamClient(self.channel, self.metadata)
|
||||
self.log_client.set_logstream_level(logging.INFO)
|
||||
self.closed = False
|
||||
|
||||
@@ -240,7 +241,7 @@ class Worker:
|
||||
def get_cluster_info(self, type: ray_client_pb2.ClusterInfoType.TypeEnum):
|
||||
req = ray_client_pb2.ClusterInfoRequest()
|
||||
req.type = type
|
||||
resp = self.server.ClusterInfo(req)
|
||||
resp = self.server.ClusterInfo(req, metadata=self.metadata)
|
||||
if resp.WhichOneof("response_type") == "resource_table":
|
||||
# translate from a proto map to a python dict
|
||||
output_dict = {k: v for k, v in resp.resource_table.table.items()}
|
||||
|
||||
Reference in New Issue
Block a user