diff --git a/python/ray/experimental/client/__init__.py b/python/ray/experimental/client/__init__.py index 8b135f60e..8d1267d24 100644 --- a/python/ray/experimental/client/__init__.py +++ b/python/ray/experimental/client/__init__.py @@ -1,6 +1,6 @@ from ray.experimental.client.api import ClientAPI from ray.experimental.client.api import APIImpl -from typing import Optional +from typing import Optional, List, Tuple import logging @@ -27,10 +27,15 @@ def restore_api(api: Optional[APIImpl]): class RayAPIStub: - def connect(self, conn_str): + def connect(self, + conn_str: str, + secure: bool = False, + metadata: List[Tuple[str, str]] = None, + stub=None): global _client_api from ray.experimental.client.worker import Worker - _client_worker = Worker(conn_str) + _client_worker = Worker( + conn_str, secure=secure, metadata=metadata, stub=stub) _client_api = ClientAPI(_client_worker) def disconnect(self): diff --git a/python/ray/experimental/client/worker.py b/python/ray/experimental/client/worker.py index f17959d3f..f63171bdc 100644 --- a/python/ray/experimental/client/worker.py +++ b/python/ray/experimental/client/worker.py @@ -2,7 +2,7 @@ It implements the Ray API functions that are forwarded through grpc calls to the server. """ -from typing import List +from typing import List, Tuple import ray.cloudpickle as cloudpickle import grpc @@ -15,9 +15,25 @@ from ray.experimental.client.common import ClientRemoteFunc class Worker: - def __init__(self, conn_str="", stub=None): + def __init__(self, + conn_str: str = "", + secure: bool = False, + metadata: List[Tuple[str, str]] = None, + stub=None): + """Initializes the worker side grpc client. + + Args: + stub: custom grpc stub. + secure: whether to use SSL secure channel or not. + metadata: additional metadata passed in the grpc request headers. + """ + self.metadata = metadata if stub is None: - self.channel = grpc.insecure_channel(conn_str) + if secure: + credentials = grpc.ssl_channel_credentials() + self.channel = grpc.secure_channel(conn_str, credentials) + else: + self.channel = grpc.insecure_channel(conn_str) self.server = ray_client_pb2_grpc.RayletDriverStub(self.channel) else: self.server = stub @@ -40,7 +56,7 @@ class Worker: def _get(self, id: bytes): req = ray_client_pb2.GetRequest(id=id) - data = self.server.GetObject(req) + data = self.server.GetObject(req, metadata=self.metadata) if not data.valid: raise Exception( "Client GetObject returned invalid data: id invalid?") @@ -63,7 +79,7 @@ class Worker: def _put(self, val): data = cloudpickle.dumps(val) req = ray_client_pb2.PutRequest(data=data) - resp = self.server.PutObject(req) + resp = self.server.PutObject(req, metadata=self.metadata) return ClientObjectRef(resp.id) def wait(self, @@ -83,7 +99,7 @@ class Worker: "timeout": timeout if timeout else -1 } req = ray_client_pb2.WaitRequest(**data) - resp = self.server.WaitObject(req) + resp = self.server.WaitObject(req, metadata=self.metadata) if not resp.valid: # TODO(ameer): improve error/exceptions messages. raise Exception("Client Wait request failed. Reference invalid?") @@ -109,7 +125,7 @@ class Worker: for arg in args: pb_arg = convert_to_arg(arg) task.args.append(pb_arg) - ticket = self.server.Schedule(task) + ticket = self.server.Schedule(task, metadata=self.metadata) return ClientObjectRef(ticket.return_id) def close(self):