diff --git a/python/ray/util/client/__init__.py b/python/ray/util/client/__init__.py index 22cf0af5b..02aab93ff 100644 --- a/python/ray/util/client/__init__.py +++ b/python/ray/util/client/__init__.py @@ -24,7 +24,8 @@ class RayAPIStub: def connect(self, conn_str: str, secure: bool = False, - metadata: List[Tuple[str, str]] = None) -> Dict[str, Any]: + metadata: List[Tuple[str, str]] = None, + connection_retries: int = 3) -> Dict[str, Any]: """Connect the Ray Client to a server. Args: @@ -50,7 +51,10 @@ class RayAPIStub: try: self.client_worker = Worker( - conn_str, secure=secure, metadata=metadata) + conn_str, + secure=secure, + metadata=metadata, + connection_retries=connection_retries) self.api.worker = self.client_worker return self.client_worker.connection_info() except Exception: diff --git a/python/ray/util/client/worker.py b/python/ray/util/client/worker.py index 2c6c50fe4..94e6d1cd7 100644 --- a/python/ray/util/client/worker.py +++ b/python/ray/util/client/worker.py @@ -38,7 +38,7 @@ class Worker: conn_str: str = "", secure: bool = False, metadata: List[Tuple[str, str]] = None, - connection_retries=3): + connection_retries: int = 3): """Initializes the worker side grpc client. Args: diff --git a/python/ray/util/client_connect.py b/python/ray/util/client_connect.py index 3f33933c4..0ff88408d 100644 --- a/python/ray/util/client_connect.py +++ b/python/ray/util/client_connect.py @@ -8,7 +8,8 @@ from typing import List, Tuple, Dict, Any def connect(conn_str: str, secure: bool = False, - metadata: List[Tuple[str, str]] = None) -> Dict[str, Any]: + metadata: List[Tuple[str, str]] = None, + connection_retries: int = 3) -> Dict[str, Any]: if ray.is_connected(): raise RuntimeError("Ray Client is already connected. " "Maybe you called ray.util.connect twice by " @@ -21,7 +22,8 @@ def connect(conn_str: str, # TODO(barakmich): https://github.com/ray-project/ray/issues/13274 # for supporting things like cert_path, ca_path, etc and creating # the correct metadata - return ray.connect(conn_str, secure=secure, metadata=metadata) + return ray.connect( + conn_str, secure=secure, metadata=metadata, connection_retries=3) def disconnect():