mirror of
https://github.com/wassname/ray.git
synced 2026-07-01 18:57:37 +08:00
273 lines
10 KiB
Python
273 lines
10 KiB
Python
"""This file includes the Worker class which sits on the client side.
|
|
It implements the Ray API functions that are forwarded through grpc calls
|
|
to the server.
|
|
"""
|
|
import base64
|
|
import inspect
|
|
import json
|
|
import logging
|
|
import uuid
|
|
from collections import defaultdict
|
|
from typing import Any
|
|
from typing import Dict
|
|
from typing import List
|
|
from typing import Tuple
|
|
from typing import Optional
|
|
|
|
from ray.util.inspect import is_cython
|
|
import grpc
|
|
|
|
import ray.cloudpickle as cloudpickle
|
|
import ray.core.generated.ray_client_pb2 as ray_client_pb2
|
|
import ray.core.generated.ray_client_pb2_grpc as ray_client_pb2_grpc
|
|
from ray.experimental.client.client_pickler import convert_to_arg
|
|
from ray.experimental.client.client_pickler import dumps_from_client
|
|
from ray.experimental.client.client_pickler import loads_from_server
|
|
from ray.experimental.client.common import ClientActorClass
|
|
from ray.experimental.client.common import ClientActorHandle
|
|
from ray.experimental.client.common import ClientObjectRef
|
|
from ray.experimental.client.common import ClientRemoteFunc
|
|
from ray.experimental.client.common import ClientStub
|
|
from ray.experimental.client.dataclient import DataClient
|
|
|
|
logger = logging.getLogger(__name__)
|
|
|
|
|
|
class Worker:
|
|
def __init__(self,
|
|
conn_str: str = "",
|
|
secure: bool = False,
|
|
metadata: List[Tuple[str, str]] = None):
|
|
"""Initializes the worker side grpc client.
|
|
|
|
Args:
|
|
secure: whether to use SSL secure channel or not.
|
|
metadata: additional metadata passed in the grpc request headers.
|
|
"""
|
|
self.metadata = metadata
|
|
self.channel = None
|
|
self._client_id = make_client_id()
|
|
if secure:
|
|
credentials = grpc.ssl_channel_credentials()
|
|
self.channel = grpc.secure_channel(conn_str, credentials)
|
|
else:
|
|
self.channel = grpc.insecure_channel(conn_str)
|
|
self.server = ray_client_pb2_grpc.RayletDriverStub(self.channel)
|
|
self.data_client = DataClient(self.channel, self._client_id)
|
|
self.reference_count: Dict[bytes, int] = defaultdict(int)
|
|
|
|
def get(self, vals, *, timeout: Optional[float] = None) -> Any:
|
|
to_get = []
|
|
single = False
|
|
if isinstance(vals, list):
|
|
to_get = vals
|
|
elif isinstance(vals, ClientObjectRef):
|
|
to_get = [vals]
|
|
single = True
|
|
else:
|
|
raise Exception("Can't get something that's not a "
|
|
"list of IDs or just an ID: %s" % type(vals))
|
|
if timeout is None:
|
|
timeout = 0
|
|
out = [self._get(x, timeout) for x in to_get]
|
|
if single:
|
|
out = out[0]
|
|
return out
|
|
|
|
def _get(self, ref: ClientObjectRef, timeout: float):
|
|
req = ray_client_pb2.GetRequest(id=ref.id, timeout=timeout)
|
|
try:
|
|
data = self.data_client.GetObject(req)
|
|
except grpc.RpcError as e:
|
|
raise e.details()
|
|
if not data.valid:
|
|
err = cloudpickle.loads(data.error)
|
|
logger.error(err)
|
|
raise err
|
|
return loads_from_server(data.data)
|
|
|
|
def put(self, vals):
|
|
to_put = []
|
|
single = False
|
|
if isinstance(vals, list):
|
|
to_put = vals
|
|
else:
|
|
single = True
|
|
to_put.append(vals)
|
|
|
|
out = [self._put(x) for x in to_put]
|
|
if single:
|
|
out = out[0]
|
|
return out
|
|
|
|
def _put(self, val):
|
|
if isinstance(val, ClientObjectRef):
|
|
raise TypeError(
|
|
"Calling 'put' on an ObjectRef is not allowed "
|
|
"(similarly, returning an ObjectRef from a remote "
|
|
"function is not allowed). If you really want to "
|
|
"do this, you can wrap the ObjectRef in a list and "
|
|
"call 'put' on it (or return it).")
|
|
data = dumps_from_client(val, self._client_id)
|
|
req = ray_client_pb2.PutRequest(data=data)
|
|
resp = self.data_client.PutObject(req)
|
|
return ClientObjectRef(resp.id)
|
|
|
|
def wait(self,
|
|
object_refs: List[ClientObjectRef],
|
|
*,
|
|
num_returns: int = 1,
|
|
timeout: float = None,
|
|
fetch_local: bool = True
|
|
) -> Tuple[List[ClientObjectRef], List[ClientObjectRef]]:
|
|
if not isinstance(object_refs, list):
|
|
raise TypeError("wait() expected a list of ClientObjectRef, "
|
|
f"got {type(object_refs)}")
|
|
for ref in object_refs:
|
|
if not isinstance(ref, ClientObjectRef):
|
|
raise TypeError("wait() expected a list of ClientObjectRef, "
|
|
f"got list containing {type(ref)}")
|
|
data = {
|
|
"object_ids": [object_ref.id for object_ref in object_refs],
|
|
"num_returns": num_returns,
|
|
"timeout": timeout if timeout else -1,
|
|
"client_id": self._client_id,
|
|
}
|
|
req = ray_client_pb2.WaitRequest(**data)
|
|
resp = self.server.WaitObject(req, metadata=self.metadata)
|
|
if not resp.valid:
|
|
# TODO(ameer): improve error/exceptions messages.
|
|
raise Exception("Client Wait request failed. Reference invalid?")
|
|
client_ready_object_ids = [
|
|
ClientObjectRef(ref) for ref in resp.ready_object_ids
|
|
]
|
|
client_remaining_object_ids = [
|
|
ClientObjectRef(ref) for ref in resp.remaining_object_ids
|
|
]
|
|
|
|
return (client_ready_object_ids, client_remaining_object_ids)
|
|
|
|
def remote(self, *args, **kwargs):
|
|
if len(args) == 1 and len(kwargs) == 0 and callable(args[0]):
|
|
# This is the case where the decorator is just @ray.remote.
|
|
return remote_decorator(options=None)(args[0])
|
|
error_string = ("The @ray.remote decorator must be applied either "
|
|
"with no arguments and no parentheses, for example "
|
|
"'@ray.remote', or it must be applied using some of "
|
|
"the arguments 'num_returns', 'num_cpus', 'num_gpus', "
|
|
"'memory', 'object_store_memory', 'resources', "
|
|
"'max_calls', or 'max_restarts', like "
|
|
"'@ray.remote(num_returns=2, "
|
|
"resources={\"CustomResource\": 1})'.")
|
|
assert len(args) == 0 and len(kwargs) > 0, error_string
|
|
return remote_decorator(options=kwargs)
|
|
|
|
def call_remote(self, instance, *args, **kwargs) -> List[bytes]:
|
|
task = instance._prepare_client_task()
|
|
for arg in args:
|
|
pb_arg = convert_to_arg(arg, self._client_id)
|
|
task.args.append(pb_arg)
|
|
for k, v in kwargs.items():
|
|
task.kwargs[k].CopyFrom(convert_to_arg(v, self._client_id))
|
|
task.client_id = self._client_id
|
|
logger.debug("Scheduling %s" % task)
|
|
try:
|
|
ticket = self.server.Schedule(task, metadata=self.metadata)
|
|
except grpc.RpcError as e:
|
|
raise decode_exception(e.details)
|
|
if not ticket.valid:
|
|
raise cloudpickle.loads(ticket.error)
|
|
return ticket.return_ids
|
|
|
|
def call_release(self, id: bytes) -> None:
|
|
self.reference_count[id] -= 1
|
|
if self.reference_count[id] == 0:
|
|
self._release_server(id)
|
|
del self.reference_count[id]
|
|
|
|
def _release_server(self, id: bytes) -> None:
|
|
if self.data_client is not None:
|
|
logger.debug(f"Releasing {id}")
|
|
self.data_client.ReleaseObject(
|
|
ray_client_pb2.ReleaseRequest(ids=[id]))
|
|
|
|
def call_retain(self, id: bytes) -> None:
|
|
logger.debug(f"Retaining {id}")
|
|
self.reference_count[id] += 1
|
|
|
|
def close(self):
|
|
self.data_client.close(close_channel=True)
|
|
self.server = None
|
|
if self.channel:
|
|
self.channel = None
|
|
|
|
def terminate_actor(self, actor: ClientActorHandle,
|
|
no_restart: bool) -> None:
|
|
if not isinstance(actor, ClientActorHandle):
|
|
raise ValueError("ray.kill() only supported for actors. "
|
|
"Got: {}.".format(type(actor)))
|
|
term_actor = ray_client_pb2.TerminateRequest.ActorTerminate()
|
|
term_actor.id = actor.actor_ref.id
|
|
term_actor.no_restart = no_restart
|
|
try:
|
|
term = ray_client_pb2.TerminateRequest(actor=term_actor)
|
|
term.client_id = self._client_id
|
|
self.server.Terminate(term)
|
|
except grpc.RpcError as e:
|
|
raise decode_exception(e.details())
|
|
|
|
def terminate_task(self, obj: ClientObjectRef, force: bool,
|
|
recursive: bool) -> None:
|
|
if not isinstance(obj, ClientObjectRef):
|
|
raise TypeError(
|
|
"ray.cancel() only supported for non-actor object refs. "
|
|
f"Got: {type(obj)}.")
|
|
term_object = ray_client_pb2.TerminateRequest.TaskObjectTerminate()
|
|
term_object.id = obj.id
|
|
term_object.force = force
|
|
term_object.recursive = recursive
|
|
try:
|
|
term = ray_client_pb2.TerminateRequest(task_object=term_object)
|
|
term.client_id = self._client_id
|
|
self.server.Terminate(term)
|
|
except grpc.RpcError as e:
|
|
raise decode_exception(e.details())
|
|
|
|
def get_cluster_info(self, type: ray_client_pb2.ClusterInfoType.TypeEnum):
|
|
req = ray_client_pb2.ClusterInfoRequest()
|
|
req.type = type
|
|
resp = self.server.ClusterInfo(req)
|
|
if resp.WhichOneof("response_type") == "resource_table":
|
|
return resp.resource_table.table
|
|
return json.loads(resp.json)
|
|
|
|
def is_initialized(self) -> bool:
|
|
if self.server is not None:
|
|
return self.get_cluster_info(
|
|
ray_client_pb2.ClusterInfoType.IS_INITIALIZED)
|
|
return False
|
|
|
|
|
|
def remote_decorator(options: Optional[Dict[str, Any]]):
|
|
def decorator(function_or_class) -> ClientStub:
|
|
if (inspect.isfunction(function_or_class)
|
|
or is_cython(function_or_class)):
|
|
return ClientRemoteFunc(function_or_class, options=options)
|
|
elif inspect.isclass(function_or_class):
|
|
return ClientActorClass(function_or_class, options=options)
|
|
else:
|
|
raise TypeError("The @ray.remote decorator must be applied to "
|
|
"either a function or to a class.")
|
|
|
|
return decorator
|
|
|
|
|
|
def make_client_id() -> str:
|
|
id = uuid.uuid4()
|
|
return id.hex
|
|
|
|
|
|
def decode_exception(data) -> Exception:
|
|
data = base64.standard_b64decode(data)
|
|
return loads_from_server(data)
|