From a247c71e2e02d59bf4f4fbed505c4b353e6855b7 Mon Sep 17 00:00:00 2001 From: Philipp Moritz Date: Thu, 7 Jan 2021 23:58:15 -0800 Subject: [PATCH] [ray_client] Add metadata to gRPC requests (#13167) --- python/ray/util/client/dataclient.py | 7 +++++-- python/ray/util/client/logsclient.py | 7 +++++-- python/ray/util/client/worker.py | 9 +++++---- 3 files changed, 15 insertions(+), 8 deletions(-) diff --git a/python/ray/util/client/dataclient.py b/python/ray/util/client/dataclient.py index 38f095f3f..bba9f53f3 100644 --- a/python/ray/util/client/dataclient.py +++ b/python/ray/util/client/dataclient.py @@ -20,12 +20,14 @@ INT32_MAX = (2**31) - 1 class DataClient: - def __init__(self, channel: "grpc._channel.Channel", client_id: str): + def __init__(self, channel: "grpc._channel.Channel", client_id: str, + metadata: list): """Initializes a thread-safe datapath over a Ray Client gRPC channel. Args: channel: connected gRPC channel client_id: the generated ID representing this client + metadata: metadata to pass to gRPC requests """ self.channel = channel self.request_queue = queue.Queue() @@ -34,6 +36,7 @@ class DataClient: self.cv = threading.Condition() self._req_id = 0 self._client_id = client_id + self._metadata = metadata self.data_thread.start() def _next_id(self) -> int: @@ -52,7 +55,7 @@ class DataClient: stub = ray_client_pb2_grpc.RayletDataStreamerStub(self.channel) resp_stream = stub.Datapath( iter(self.request_queue.get, None), - metadata=(("client_id", self._client_id), ), + metadata=[("client_id", self._client_id)] + self._metadata, wait_for_ready=True) try: for response in resp_stream: diff --git a/python/ray/util/client/logsclient.py b/python/ray/util/client/logsclient.py index acf2619c9..a09019039 100644 --- a/python/ray/util/client/logsclient.py +++ b/python/ray/util/client/logsclient.py @@ -18,13 +18,15 @@ logger.propagate = False class LogstreamClient: - def __init__(self, channel: "grpc._channel.Channel"): + def __init__(self, channel: "grpc._channel.Channel", metadata: list): """Initializes a thread-safe log stream over a Ray Client gRPC channel. Args: channel: connected gRPC channel + metadata: metadata to pass to gRPC requests """ self.channel = channel + self._metadata = metadata self.request_queue = queue.Queue() self.log_thread = self._start_logthread() self.log_thread.start() @@ -34,7 +36,8 @@ class LogstreamClient: def _log_main(self) -> None: stub = ray_client_pb2_grpc.RayletLogStreamerStub(self.channel) - log_stream = stub.Logstream(iter(self.request_queue.get, None)) + log_stream = stub.Logstream( + iter(self.request_queue.get, None), metadata=self._metadata) try: for record in log_stream: if record.level < 0: diff --git a/python/ray/util/client/worker.py b/python/ray/util/client/worker.py index 8e8d1d851..683b59082 100644 --- a/python/ray/util/client/worker.py +++ b/python/ray/util/client/worker.py @@ -41,7 +41,7 @@ class Worker: secure: whether to use SSL secure channel or not. metadata: additional metadata passed in the grpc request headers. """ - self.metadata = metadata + self.metadata = metadata if metadata else [] self.channel = None self._client_id = make_client_id() if secure: @@ -51,10 +51,11 @@ class Worker: self.channel = grpc.insecure_channel(conn_str) self.server = ray_client_pb2_grpc.RayletDriverStub(self.channel) - self.data_client = DataClient(self.channel, self._client_id) + self.data_client = DataClient(self.channel, self._client_id, + self.metadata) self.reference_count: Dict[bytes, int] = defaultdict(int) - self.log_client = LogstreamClient(self.channel) + self.log_client = LogstreamClient(self.channel, self.metadata) self.log_client.set_logstream_level(logging.INFO) self.closed = False @@ -240,7 +241,7 @@ class Worker: def get_cluster_info(self, type: ray_client_pb2.ClusterInfoType.TypeEnum): req = ray_client_pb2.ClusterInfoRequest() req.type = type - resp = self.server.ClusterInfo(req) + resp = self.server.ClusterInfo(req, metadata=self.metadata) if resp.WhichOneof("response_type") == "resource_table": # translate from a proto map to a python dict output_dict = {k: v for k, v in resp.resource_table.table.items()}