mirror of
https://github.com/wassname/ray.git
synced 2026-06-28 04:23:03 +08:00
[ray client] add metadata and secure options to Worker. (#12409)
This commit is contained in:
@@ -1,6 +1,6 @@
|
||||
from ray.experimental.client.api import ClientAPI
|
||||
from ray.experimental.client.api import APIImpl
|
||||
from typing import Optional
|
||||
from typing import Optional, List, Tuple
|
||||
|
||||
import logging
|
||||
|
||||
@@ -27,10 +27,15 @@ def restore_api(api: Optional[APIImpl]):
|
||||
|
||||
|
||||
class RayAPIStub:
|
||||
def connect(self, conn_str):
|
||||
def connect(self,
|
||||
conn_str: str,
|
||||
secure: bool = False,
|
||||
metadata: List[Tuple[str, str]] = None,
|
||||
stub=None):
|
||||
global _client_api
|
||||
from ray.experimental.client.worker import Worker
|
||||
_client_worker = Worker(conn_str)
|
||||
_client_worker = Worker(
|
||||
conn_str, secure=secure, metadata=metadata, stub=stub)
|
||||
_client_api = ClientAPI(_client_worker)
|
||||
|
||||
def disconnect(self):
|
||||
|
||||
@@ -2,7 +2,7 @@
|
||||
It implements the Ray API functions that are forwarded through grpc calls
|
||||
to the server.
|
||||
"""
|
||||
from typing import List
|
||||
from typing import List, Tuple
|
||||
|
||||
import ray.cloudpickle as cloudpickle
|
||||
import grpc
|
||||
@@ -15,9 +15,25 @@ from ray.experimental.client.common import ClientRemoteFunc
|
||||
|
||||
|
||||
class Worker:
|
||||
def __init__(self, conn_str="", stub=None):
|
||||
def __init__(self,
|
||||
conn_str: str = "",
|
||||
secure: bool = False,
|
||||
metadata: List[Tuple[str, str]] = None,
|
||||
stub=None):
|
||||
"""Initializes the worker side grpc client.
|
||||
|
||||
Args:
|
||||
stub: custom grpc stub.
|
||||
secure: whether to use SSL secure channel or not.
|
||||
metadata: additional metadata passed in the grpc request headers.
|
||||
"""
|
||||
self.metadata = metadata
|
||||
if stub is None:
|
||||
self.channel = grpc.insecure_channel(conn_str)
|
||||
if secure:
|
||||
credentials = grpc.ssl_channel_credentials()
|
||||
self.channel = grpc.secure_channel(conn_str, credentials)
|
||||
else:
|
||||
self.channel = grpc.insecure_channel(conn_str)
|
||||
self.server = ray_client_pb2_grpc.RayletDriverStub(self.channel)
|
||||
else:
|
||||
self.server = stub
|
||||
@@ -40,7 +56,7 @@ class Worker:
|
||||
|
||||
def _get(self, id: bytes):
|
||||
req = ray_client_pb2.GetRequest(id=id)
|
||||
data = self.server.GetObject(req)
|
||||
data = self.server.GetObject(req, metadata=self.metadata)
|
||||
if not data.valid:
|
||||
raise Exception(
|
||||
"Client GetObject returned invalid data: id invalid?")
|
||||
@@ -63,7 +79,7 @@ class Worker:
|
||||
def _put(self, val):
|
||||
data = cloudpickle.dumps(val)
|
||||
req = ray_client_pb2.PutRequest(data=data)
|
||||
resp = self.server.PutObject(req)
|
||||
resp = self.server.PutObject(req, metadata=self.metadata)
|
||||
return ClientObjectRef(resp.id)
|
||||
|
||||
def wait(self,
|
||||
@@ -83,7 +99,7 @@ class Worker:
|
||||
"timeout": timeout if timeout else -1
|
||||
}
|
||||
req = ray_client_pb2.WaitRequest(**data)
|
||||
resp = self.server.WaitObject(req)
|
||||
resp = self.server.WaitObject(req, metadata=self.metadata)
|
||||
if not resp.valid:
|
||||
# TODO(ameer): improve error/exceptions messages.
|
||||
raise Exception("Client Wait request failed. Reference invalid?")
|
||||
@@ -109,7 +125,7 @@ class Worker:
|
||||
for arg in args:
|
||||
pb_arg = convert_to_arg(arg)
|
||||
task.args.append(pb_arg)
|
||||
ticket = self.server.Schedule(task)
|
||||
ticket = self.server.Schedule(task, metadata=self.metadata)
|
||||
return ClientObjectRef(ticket.return_id)
|
||||
|
||||
def close(self):
|
||||
|
||||
Reference in New Issue
Block a user