[ray_client] Add metadata to gRPC requests (#13167)

This commit is contained in:
Philipp Moritz
2021-01-07 23:58:15 -08:00
committed by GitHub
parent 77cd0d5a21
commit a247c71e2e
3 changed files with 15 additions and 8 deletions
+5 -2
View File
@@ -20,12 +20,14 @@ INT32_MAX = (2**31) - 1
class DataClient:
def __init__(self, channel: "grpc._channel.Channel", client_id: str):
def __init__(self, channel: "grpc._channel.Channel", client_id: str,
metadata: list):
"""Initializes a thread-safe datapath over a Ray Client gRPC channel.
Args:
channel: connected gRPC channel
client_id: the generated ID representing this client
metadata: metadata to pass to gRPC requests
"""
self.channel = channel
self.request_queue = queue.Queue()
@@ -34,6 +36,7 @@ class DataClient:
self.cv = threading.Condition()
self._req_id = 0
self._client_id = client_id
self._metadata = metadata
self.data_thread.start()
def _next_id(self) -> int:
@@ -52,7 +55,7 @@ class DataClient:
stub = ray_client_pb2_grpc.RayletDataStreamerStub(self.channel)
resp_stream = stub.Datapath(
iter(self.request_queue.get, None),
metadata=(("client_id", self._client_id), ),
metadata=[("client_id", self._client_id)] + self._metadata,
wait_for_ready=True)
try:
for response in resp_stream:
+5 -2
View File
@@ -18,13 +18,15 @@ logger.propagate = False
class LogstreamClient:
def __init__(self, channel: "grpc._channel.Channel"):
def __init__(self, channel: "grpc._channel.Channel", metadata: list):
"""Initializes a thread-safe log stream over a Ray Client gRPC channel.
Args:
channel: connected gRPC channel
metadata: metadata to pass to gRPC requests
"""
self.channel = channel
self._metadata = metadata
self.request_queue = queue.Queue()
self.log_thread = self._start_logthread()
self.log_thread.start()
@@ -34,7 +36,8 @@ class LogstreamClient:
def _log_main(self) -> None:
stub = ray_client_pb2_grpc.RayletLogStreamerStub(self.channel)
log_stream = stub.Logstream(iter(self.request_queue.get, None))
log_stream = stub.Logstream(
iter(self.request_queue.get, None), metadata=self._metadata)
try:
for record in log_stream:
if record.level < 0:
+5 -4
View File
@@ -41,7 +41,7 @@ class Worker:
secure: whether to use SSL secure channel or not.
metadata: additional metadata passed in the grpc request headers.
"""
self.metadata = metadata
self.metadata = metadata if metadata else []
self.channel = None
self._client_id = make_client_id()
if secure:
@@ -51,10 +51,11 @@ class Worker:
self.channel = grpc.insecure_channel(conn_str)
self.server = ray_client_pb2_grpc.RayletDriverStub(self.channel)
self.data_client = DataClient(self.channel, self._client_id)
self.data_client = DataClient(self.channel, self._client_id,
self.metadata)
self.reference_count: Dict[bytes, int] = defaultdict(int)
self.log_client = LogstreamClient(self.channel)
self.log_client = LogstreamClient(self.channel, self.metadata)
self.log_client.set_logstream_level(logging.INFO)
self.closed = False
@@ -240,7 +241,7 @@ class Worker:
def get_cluster_info(self, type: ray_client_pb2.ClusterInfoType.TypeEnum):
req = ray_client_pb2.ClusterInfoRequest()
req.type = type
resp = self.server.ClusterInfo(req)
resp = self.server.ClusterInfo(req, metadata=self.metadata)
if resp.WhichOneof("response_type") == "resource_table":
# translate from a proto map to a python dict
output_dict = {k: v for k, v in resp.resource_table.table.items()}