mirror of
https://github.com/wassname/ray.git
synced 2026-06-29 13:15:35 +08:00
[ray_client] Include multiple facets of the Ray API (#12736)
This commit is contained in:
@@ -3,9 +3,8 @@ from traceback import format_exception
|
||||
|
||||
import colorama
|
||||
|
||||
import ray
|
||||
import ray.cloudpickle as pickle
|
||||
from ray.core.generated.common_pb2 import RayException, Language
|
||||
from ray.core.generated.common_pb2 import RayException, Language, PYTHON
|
||||
import setproctitle
|
||||
|
||||
|
||||
@@ -17,7 +16,7 @@ class RayError(Exception):
|
||||
exc_info = (type(self), self, self.__traceback__)
|
||||
formatted_exception_string = "\n".join(format_exception(*exc_info))
|
||||
return RayException(
|
||||
language=ray.Language.PYTHON.value(),
|
||||
language=PYTHON,
|
||||
serialized_exception=pickle.dumps(self),
|
||||
formatted_exception_string=formatted_exception_string
|
||||
).SerializeToString()
|
||||
@@ -26,7 +25,7 @@ class RayError(Exception):
|
||||
def from_bytes(b):
|
||||
ray_exception = RayException()
|
||||
ray_exception.ParseFromString(b)
|
||||
if ray_exception.language == ray.Language.PYTHON.value():
|
||||
if ray_exception.language == PYTHON:
|
||||
return pickle.loads(ray_exception.serialized_exception)
|
||||
else:
|
||||
return CrossLanguageError(ray_exception)
|
||||
@@ -81,6 +80,7 @@ class RayTaskError(RayError):
|
||||
pid=None,
|
||||
ip=None):
|
||||
"""Initialize a RayTaskError."""
|
||||
import ray
|
||||
if proctitle:
|
||||
self.proctitle = proctitle
|
||||
else:
|
||||
|
||||
@@ -11,8 +11,10 @@
|
||||
|
||||
from abc import ABC
|
||||
from abc import abstractmethod
|
||||
from typing import TYPE_CHECKING, Any, Union
|
||||
from typing import TYPE_CHECKING, Any, Union, Optional
|
||||
import ray.core.generated.ray_client_pb2 as ray_client_pb2
|
||||
if TYPE_CHECKING:
|
||||
from ray.experimental.client.common import ClientActorHandle
|
||||
from ray.experimental.client.common import ClientStub
|
||||
from ray.experimental.client.common import ClientObjectRef
|
||||
from ray._raylet import ObjectRef
|
||||
@@ -29,13 +31,13 @@ class APIImpl(ABC):
|
||||
"""
|
||||
|
||||
@abstractmethod
|
||||
def get(self, *args, **kwargs) -> Any:
|
||||
def get(self, vals, *, timeout: Optional[float] = None) -> Any:
|
||||
"""
|
||||
get is the hook stub passed on to replace `ray.get`
|
||||
|
||||
Args:
|
||||
args: opaque arguments
|
||||
kwargs: opaque keyword arguments
|
||||
vals: [Client]ObjectRef or list of these refs to retrieve.
|
||||
timeout: Optional timeout in milliseconds
|
||||
"""
|
||||
pass
|
||||
|
||||
@@ -103,6 +105,39 @@ class APIImpl(ABC):
|
||||
"""
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
def kill(self, actor, *, no_restart=True):
|
||||
"""
|
||||
kill forcibly stops an actor running in the cluster
|
||||
|
||||
Args:
|
||||
no_restart: Whether this actor should be restarted if it's a
|
||||
restartable actor.
|
||||
"""
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
def cancel(self, obj, *, force=False, recursive=True):
|
||||
"""
|
||||
Cancels a task on the cluster.
|
||||
|
||||
If the specified task is pending execution, it will not be executed. If
|
||||
the task is currently executing, the behavior depends on the ``force``
|
||||
flag, as per `ray.cancel()`
|
||||
|
||||
Only non-actor tasks can be canceled. Canceled tasks will not be
|
||||
retried (max_retries will not be respected).
|
||||
|
||||
Args:
|
||||
object_ref (ObjectRef): ObjectRef returned by the task
|
||||
that should be canceled.
|
||||
force (boolean): Whether to force-kill a running task by killing
|
||||
the worker that is running the task.
|
||||
recursive (boolean): Whether to try to cancel tasks submitted by the
|
||||
task specified.
|
||||
"""
|
||||
pass
|
||||
|
||||
|
||||
class ClientAPI(APIImpl):
|
||||
"""
|
||||
@@ -113,8 +148,8 @@ class ClientAPI(APIImpl):
|
||||
def __init__(self, worker):
|
||||
self.worker = worker
|
||||
|
||||
def get(self, *args, **kwargs):
|
||||
return self.worker.get(*args, **kwargs)
|
||||
def get(self, vals, *, timeout=None):
|
||||
return self.worker.get(vals, timeout=timeout)
|
||||
|
||||
def put(self, *args, **kwargs):
|
||||
return self.worker.put(*args, **kwargs)
|
||||
@@ -131,6 +166,59 @@ class ClientAPI(APIImpl):
|
||||
def close(self) -> None:
|
||||
return self.worker.close()
|
||||
|
||||
def kill(self, actor: "ClientActorHandle", *, no_restart=True):
|
||||
return self.worker.terminate_actor(actor, no_restart)
|
||||
|
||||
def cancel(self, obj: "ClientObjectRef", *, force=False, recursive=True):
|
||||
return self.worker.terminate_task(obj, force, recursive)
|
||||
|
||||
# Various metadata methods for the client that are defined in the protocol.
|
||||
def is_initialized(self) -> bool:
|
||||
""" True if our client is connected, and if the server is initialized.
|
||||
|
||||
Returns:
|
||||
A boolean determining if the client is connected and
|
||||
server initialized.
|
||||
"""
|
||||
return self.worker.is_initialized()
|
||||
|
||||
def nodes(self):
|
||||
"""Get a list of the nodes in the cluster (for debugging only).
|
||||
|
||||
Returns:
|
||||
Information about the Ray clients in the cluster.
|
||||
"""
|
||||
return self.worker.get_cluster_info(
|
||||
ray_client_pb2.ClusterInfoType.NODES)
|
||||
|
||||
def cluster_resources(self):
|
||||
"""Get the current total cluster resources.
|
||||
|
||||
Note that this information can grow stale as nodes are added to or
|
||||
removed from the cluster.
|
||||
|
||||
Returns:
|
||||
A dictionary mapping resource name to the total quantity of that
|
||||
resource in the cluster.
|
||||
"""
|
||||
return self.worker.get_cluster_info(
|
||||
ray_client_pb2.ClusterInfoType.CLUSTER_RESOURCES)
|
||||
|
||||
def available_resources(self):
|
||||
"""Get the current available cluster resources.
|
||||
|
||||
This is different from `cluster_resources` in that this will return idle
|
||||
(available) resources rather than total resources.
|
||||
|
||||
Note that this information can grow stale as tasks start and finish.
|
||||
|
||||
Returns:
|
||||
A dictionary mapping resource name to the total quantity of that
|
||||
resource in the cluster.
|
||||
"""
|
||||
return self.worker.get_cluster_info(
|
||||
ray_client_pb2.ClusterInfoType.AVAILABLE_RESOURCES)
|
||||
|
||||
def __getattr__(self, key: str):
|
||||
if not key.startswith("_"):
|
||||
raise NotImplementedError(
|
||||
|
||||
@@ -4,10 +4,13 @@ from typing import Any
|
||||
from typing import Dict
|
||||
from ray import cloudpickle
|
||||
|
||||
import base64
|
||||
|
||||
|
||||
class ClientBaseRef:
|
||||
def __init__(self, id):
|
||||
def __init__(self, id, handle=None):
|
||||
self.id = id
|
||||
self.handle = handle
|
||||
|
||||
def __repr__(self):
|
||||
return "%s(%s)" % (
|
||||
@@ -21,9 +24,14 @@ class ClientBaseRef:
|
||||
def binary(self):
|
||||
return self.id
|
||||
|
||||
@classmethod
|
||||
def from_remote_ref(cls, ref: ray_client_pb2.RemoteRef):
|
||||
return cls(id=ref.id, handle=ref.handle)
|
||||
|
||||
|
||||
class ClientObjectRef(ClientBaseRef):
|
||||
pass
|
||||
def _unpack_ref(self):
|
||||
return cloudpickle.loads(self.handle)
|
||||
|
||||
|
||||
class ClientActorRef(ClientBaseRef):
|
||||
@@ -88,7 +96,7 @@ class ClientRemoteFunc(ClientStub):
|
||||
task = ray_client_pb2.ClientTask()
|
||||
task.type = ray_client_pb2.ClientTask.FUNCTION
|
||||
task.name = self._name
|
||||
task.payload_id = self._ref.id
|
||||
task.payload_id = self._ref.handle
|
||||
return task
|
||||
|
||||
|
||||
@@ -130,7 +138,7 @@ class ClientActorClass(ClientStub):
|
||||
def remote(self, *args, **kwargs):
|
||||
# Actually instantiate the actor
|
||||
ref = ray.call_remote(self, *args, **kwargs)
|
||||
return ClientActorHandle(ClientActorRef(ref.id), self)
|
||||
return ClientActorHandle(ClientActorRef(ref.id, ref.handle), self)
|
||||
|
||||
def __repr__(self):
|
||||
return "ClientRemoteActor(%s, %s)" % (self._name, self._ref)
|
||||
@@ -146,7 +154,7 @@ class ClientActorClass(ClientStub):
|
||||
task = ray_client_pb2.ClientTask()
|
||||
task.type = ray_client_pb2.ClientTask.ACTOR
|
||||
task.name = self._name
|
||||
task.payload_id = self._ref.id
|
||||
task.payload_id = self._ref.handle
|
||||
return task
|
||||
|
||||
|
||||
@@ -174,7 +182,7 @@ class ClientActorHandle(ClientStub):
|
||||
|
||||
def _get_ray_remote_impl(self):
|
||||
if self._real_actor_handle is None:
|
||||
self._real_actor_handle = cloudpickle.loads(self.actor_ref.id)
|
||||
self._real_actor_handle = cloudpickle.loads(self.actor_ref.handle)
|
||||
return self._real_actor_handle
|
||||
|
||||
def __getstate__(self) -> Dict:
|
||||
@@ -190,6 +198,10 @@ class ClientActorHandle(ClientStub):
|
||||
self.actor_class = state["actor_class"]
|
||||
self._real_actor_handle = state["_real_actor_handle"]
|
||||
|
||||
@property
|
||||
def _actor_id(self):
|
||||
return self.actor_ref.id
|
||||
|
||||
def __getattr__(self, key):
|
||||
return ClientRemoteMethod(self, key)
|
||||
|
||||
@@ -244,7 +256,7 @@ class ClientRemoteMethod(ClientStub):
|
||||
task = ray_client_pb2.ClientTask()
|
||||
task.type = ray_client_pb2.ClientTask.METHOD
|
||||
task.name = self.method_name
|
||||
task.payload_id = self.actor_handle.actor_ref.id
|
||||
task.payload_id = self.actor_handle.actor_ref.handle
|
||||
return task
|
||||
|
||||
|
||||
@@ -266,3 +278,13 @@ def convert_to_arg(val):
|
||||
out.local = ray_client_pb2.Arg.Locality.INTERNED
|
||||
out.data = cloudpickle.dumps(val)
|
||||
return out
|
||||
|
||||
|
||||
def encode_exception(exception) -> str:
|
||||
data = cloudpickle.dumps(exception)
|
||||
return base64.standard_b64encode(data).decode()
|
||||
|
||||
|
||||
def decode_exception(data) -> Exception:
|
||||
data = base64.standard_b64decode(data)
|
||||
return cloudpickle.loads(data)
|
||||
|
||||
@@ -8,6 +8,7 @@
|
||||
# making into the core-ray module are contained and well-defined.
|
||||
|
||||
from typing import Any
|
||||
from typing import Optional
|
||||
from typing import Union
|
||||
|
||||
import ray
|
||||
@@ -24,8 +25,14 @@ class CoreRayAPI(APIImpl):
|
||||
to core ray when passed client stubs.
|
||||
"""
|
||||
|
||||
def get(self, *args, **kwargs):
|
||||
return ray.get(*args, **kwargs)
|
||||
def get(self, vals, *, timeout: Optional[float] = None) -> Any:
|
||||
if isinstance(vals, list):
|
||||
if isinstance(vals[0], ClientObjectRef):
|
||||
return ray.get(
|
||||
[val._unpack_ref() for val in vals], timeout=timeout)
|
||||
elif isinstance(vals, ClientObjectRef):
|
||||
return ray.get(vals._unpack_ref(), timeout=timeout)
|
||||
return ray.get(vals, timeout=timeout)
|
||||
|
||||
def put(self, vals: Any, *args,
|
||||
**kwargs) -> Union[ClientObjectRef, ray._raylet.ObjectRef]:
|
||||
@@ -43,6 +50,15 @@ class CoreRayAPI(APIImpl):
|
||||
def close(self) -> None:
|
||||
return None
|
||||
|
||||
def kill(self, actor, *, no_restart=True):
|
||||
return ray.kill(actor, no_restart=no_restart)
|
||||
|
||||
def cancel(self, obj, *, force=False, recursive=True):
|
||||
return ray.cancel(obj, force=force, recursive=recursive)
|
||||
|
||||
def is_initialized(self) -> bool:
|
||||
return ray.is_initialized()
|
||||
|
||||
# Allow for generic fallback to ray.* in remote methods. This allows calls
|
||||
# like ray.nodes() to be run in remote functions even though the client
|
||||
# doesn't currently support them.
|
||||
|
||||
@@ -3,12 +3,15 @@ from concurrent import futures
|
||||
import grpc
|
||||
from ray import cloudpickle
|
||||
import ray
|
||||
import ray.state
|
||||
import ray.core.generated.ray_client_pb2 as ray_client_pb2
|
||||
import ray.core.generated.ray_client_pb2_grpc as ray_client_pb2_grpc
|
||||
import time
|
||||
import inspect
|
||||
import json
|
||||
from ray.experimental.client import stash_api_for_tests, _set_server_api
|
||||
from ray.experimental.client.common import convert_from_arg
|
||||
from ray.experimental.client.common import encode_exception
|
||||
from ray.experimental.client.common import ClientObjectRef
|
||||
from ray.experimental.client.server.core_ray_api import RayServerAPI
|
||||
|
||||
@@ -23,19 +26,81 @@ class RayletServicer(ray_client_pb2_grpc.RayletDriverServicer):
|
||||
self.registered_actor_classes = {}
|
||||
self._test_mode = test_mode
|
||||
|
||||
def ClusterInfo(self, request,
|
||||
context=None) -> ray_client_pb2.ClusterInfoResponse:
|
||||
resp = ray_client_pb2.ClusterInfoResponse()
|
||||
resp.type = request.type
|
||||
if request.type == ray_client_pb2.ClusterInfoType.CLUSTER_RESOURCES:
|
||||
resources = ray.cluster_resources()
|
||||
# Normalize resources into floats
|
||||
# (the function may return values that are ints)
|
||||
float_resources = {k: float(v) for k, v in resources.items()}
|
||||
resp.resource_table.CopyFrom(
|
||||
ray_client_pb2.ClusterInfoResponse.ResourceTable(
|
||||
table=float_resources))
|
||||
elif request.type == ray_client_pb2.ClusterInfoType.AVAILABLE_RESOURCES:
|
||||
resources = ray.available_resources()
|
||||
# Normalize resources into floats
|
||||
# (the function may return values that are ints)
|
||||
float_resources = {k: float(v) for k, v in resources.items()}
|
||||
resp.resource_table.CopyFrom(
|
||||
ray_client_pb2.ClusterInfoResponse.ResourceTable(
|
||||
table=float_resources))
|
||||
else:
|
||||
resp.json = self._return_debug_cluster_info(request, context)
|
||||
return resp
|
||||
|
||||
def _return_debug_cluster_info(self, request, context=None) -> str:
|
||||
data = None
|
||||
if request.type == ray_client_pb2.ClusterInfoType.NODES:
|
||||
data = ray.nodes()
|
||||
elif request.type == ray_client_pb2.ClusterInfoType.IS_INITIALIZED:
|
||||
data = ray.is_initialized()
|
||||
else:
|
||||
raise TypeError("Unsupported cluster info type")
|
||||
return json.dumps(data)
|
||||
|
||||
def Terminate(self, request, context=None):
|
||||
if request.WhichOneof("terminate_type") == "task_object":
|
||||
try:
|
||||
object_ref = cloudpickle.loads(request.task_object.handle)
|
||||
ray.cancel(
|
||||
object_ref,
|
||||
force=request.task_object.force,
|
||||
recursive=request.task_object.recursive)
|
||||
except Exception as e:
|
||||
return_exception_in_context(e, context)
|
||||
elif request.WhichOneof("terminate_type") == "actor":
|
||||
try:
|
||||
actor_ref = cloudpickle.loads(request.actor.handle)
|
||||
ray.kill(actor_ref, no_restart=request.actor.no_restart)
|
||||
except Exception as e:
|
||||
return_exception_in_context(e, context)
|
||||
else:
|
||||
raise RuntimeError(
|
||||
"Client requested termination without providing a valid terminate_type"
|
||||
)
|
||||
return ray_client_pb2.TerminateResponse(ok=True)
|
||||
|
||||
def GetObject(self, request, context=None):
|
||||
if request.id not in self.object_refs:
|
||||
request_ref = cloudpickle.loads(request.handle)
|
||||
if request_ref.binary() not in self.object_refs:
|
||||
return ray_client_pb2.GetResponse(valid=False)
|
||||
objectref = self.object_refs[request.id]
|
||||
objectref = self.object_refs[request_ref.binary()]
|
||||
logger.info("get: %s" % objectref)
|
||||
item = ray.get(objectref)
|
||||
try:
|
||||
item = ray.get(objectref, timeout=request.timeout)
|
||||
except Exception as e:
|
||||
return_exception_in_context(e, context)
|
||||
item_ser = cloudpickle.dumps(item)
|
||||
return ray_client_pb2.GetResponse(valid=True, data=item_ser)
|
||||
|
||||
def PutObject(self, request, context=None) -> ray_client_pb2.PutResponse:
|
||||
obj = cloudpickle.loads(request.data)
|
||||
objectref = self._put_and_retain_obj(obj)
|
||||
return ray_client_pb2.PutResponse(id=objectref.binary())
|
||||
pickled_ref = cloudpickle.dumps(objectref)
|
||||
return ray_client_pb2.PutResponse(
|
||||
ref=make_remote_ref(objectref.binary(), pickled_ref))
|
||||
|
||||
def _put_and_retain_obj(self, obj) -> ray.ObjectRef:
|
||||
objectref = ray.put(obj)
|
||||
@@ -44,14 +109,14 @@ class RayletServicer(ray_client_pb2_grpc.RayletDriverServicer):
|
||||
return objectref
|
||||
|
||||
def WaitObject(self, request, context=None) -> ray_client_pb2.WaitResponse:
|
||||
object_refs = [cloudpickle.loads(o) for o in request.object_refs]
|
||||
object_refs = [cloudpickle.loads(o) for o in request.object_handles]
|
||||
num_returns = request.num_returns
|
||||
timeout = request.timeout
|
||||
object_refs_ids = []
|
||||
for object_ref in object_refs:
|
||||
if object_ref.id not in self.object_refs:
|
||||
if object_ref.binary() not in self.object_refs:
|
||||
return ray_client_pb2.WaitResponse(valid=False)
|
||||
object_refs_ids.append(self.object_refs[object_ref.id])
|
||||
object_refs_ids.append(self.object_refs[object_ref.binary()])
|
||||
try:
|
||||
ready_object_refs, remaining_object_refs = ray.wait(
|
||||
object_refs_ids,
|
||||
@@ -63,11 +128,16 @@ class RayletServicer(ray_client_pb2_grpc.RayletDriverServicer):
|
||||
logger.info("wait: %s %s" % (str(ready_object_refs),
|
||||
str(remaining_object_refs)))
|
||||
ready_object_ids = [
|
||||
ready_object_ref.binary() for ready_object_ref in ready_object_refs
|
||||
make_remote_ref(
|
||||
id=ready_object_ref.binary(),
|
||||
handle=cloudpickle.dumps(ready_object_ref),
|
||||
) for ready_object_ref in ready_object_refs
|
||||
]
|
||||
remaining_object_ids = [
|
||||
remaining_object_ref.binary()
|
||||
for remaining_object_ref in remaining_object_refs
|
||||
make_remote_ref(
|
||||
id=remaining_object_ref.binary(),
|
||||
handle=cloudpickle.dumps(remaining_object_ref),
|
||||
) for remaining_object_ref in remaining_object_refs
|
||||
]
|
||||
return ray_client_pb2.WaitResponse(
|
||||
valid=True,
|
||||
@@ -103,41 +173,46 @@ class RayletServicer(ray_client_pb2_grpc.RayletDriverServicer):
|
||||
with stash_api_for_tests(self._test_mode):
|
||||
output = getattr(actor_handle, task.name).remote(*arglist)
|
||||
self.object_refs[output.binary()] = output
|
||||
return ray_client_pb2.ClientTaskTicket(return_id=output.binary())
|
||||
pickled_ref = cloudpickle.dumps(output)
|
||||
return ray_client_pb2.ClientTaskTicket(
|
||||
return_ref=make_remote_ref(output.binary(), pickled_ref))
|
||||
|
||||
def _schedule_actor(self,
|
||||
task: ray_client_pb2.ClientTask,
|
||||
context=None,
|
||||
prepared_args=None) -> ray_client_pb2.ClientTaskTicket:
|
||||
with stash_api_for_tests(self._test_mode):
|
||||
if task.payload_id not in self.registered_actor_classes:
|
||||
actor_class_ref = self.object_refs[task.payload_id]
|
||||
payload_ref = cloudpickle.loads(task.payload_id)
|
||||
if payload_ref.binary() not in self.registered_actor_classes:
|
||||
actor_class_ref = self.object_refs[payload_ref.binary()]
|
||||
actor_class = ray.get(actor_class_ref)
|
||||
if not inspect.isclass(actor_class):
|
||||
raise Exception("Attempting to schedule actor that "
|
||||
"isn't a class.")
|
||||
reg_class = ray.remote(actor_class)
|
||||
self.registered_actor_classes[task.payload_id] = reg_class
|
||||
remote_class = self.registered_actor_classes[task.payload_id]
|
||||
self.registered_actor_classes[payload_ref.binary()] = reg_class
|
||||
remote_class = self.registered_actor_classes[payload_ref.binary()]
|
||||
arglist = _convert_args(task.args, prepared_args)
|
||||
actor = remote_class.remote(*arglist)
|
||||
actorhandle = cloudpickle.dumps(actor)
|
||||
self.actor_refs[actorhandle] = actor
|
||||
return ray_client_pb2.ClientTaskTicket(return_id=actorhandle)
|
||||
return ray_client_pb2.ClientTaskTicket(
|
||||
return_ref=make_remote_ref(actor._actor_id.binary(), actorhandle))
|
||||
|
||||
def _schedule_function(
|
||||
self,
|
||||
task: ray_client_pb2.ClientTask,
|
||||
context=None,
|
||||
prepared_args=None) -> ray_client_pb2.ClientTaskTicket:
|
||||
if task.payload_id not in self.function_refs:
|
||||
funcref = self.object_refs[task.payload_id]
|
||||
payload_ref = cloudpickle.loads(task.payload_id)
|
||||
if payload_ref.binary() not in self.function_refs:
|
||||
funcref = self.object_refs[payload_ref.binary()]
|
||||
func = ray.get(funcref)
|
||||
if not inspect.isfunction(func):
|
||||
raise Exception("Attempting to schedule function that "
|
||||
"isn't a function.")
|
||||
self.function_refs[task.payload_id] = ray.remote(func)
|
||||
remote_func = self.function_refs[task.payload_id]
|
||||
self.function_refs[payload_ref.binary()] = ray.remote(func)
|
||||
remote_func = self.function_refs[payload_ref.binary()]
|
||||
arglist = _convert_args(task.args, prepared_args)
|
||||
# Prepare call if we're in a test
|
||||
with stash_api_for_tests(self._test_mode):
|
||||
@@ -145,7 +220,9 @@ class RayletServicer(ray_client_pb2_grpc.RayletDriverServicer):
|
||||
if output.binary() in self.object_refs:
|
||||
raise Exception("already found it")
|
||||
self.object_refs[output.binary()] = output
|
||||
return ray_client_pb2.ClientTaskTicket(return_id=output.binary())
|
||||
pickled_output = cloudpickle.dumps(output)
|
||||
return ray_client_pb2.ClientTaskTicket(
|
||||
return_ref=make_remote_ref(output.binary(), pickled_output))
|
||||
|
||||
|
||||
def _convert_args(arg_list, prepared_args=None):
|
||||
@@ -155,12 +232,25 @@ def _convert_args(arg_list, prepared_args=None):
|
||||
for arg in arg_list:
|
||||
t = convert_from_arg(arg)
|
||||
if isinstance(t, ClientObjectRef):
|
||||
out.append(ray.ObjectRef(t.id))
|
||||
out.append(t._unpack_ref())
|
||||
else:
|
||||
out.append(t)
|
||||
return out
|
||||
|
||||
|
||||
def make_remote_ref(id: bytes, handle: bytes) -> ray_client_pb2.RemoteRef:
|
||||
return ray_client_pb2.RemoteRef(
|
||||
id=id,
|
||||
handle=handle,
|
||||
)
|
||||
|
||||
|
||||
def return_exception_in_context(err, context):
|
||||
if context is not None:
|
||||
context.set_details(encode_exception(err))
|
||||
context.set_code(grpc.StatusCode.INTERNAL)
|
||||
|
||||
|
||||
def serve(connection_str, test_mode=False):
|
||||
server = grpc.server(futures.ThreadPoolExecutor(max_workers=10))
|
||||
task_servicer = RayletServicer(test_mode=test_mode)
|
||||
|
||||
@@ -3,19 +3,25 @@ It implements the Ray API functions that are forwarded through grpc calls
|
||||
to the server.
|
||||
"""
|
||||
import inspect
|
||||
import json
|
||||
import logging
|
||||
from typing import Any
|
||||
from typing import List
|
||||
from typing import Tuple
|
||||
from typing import Optional
|
||||
|
||||
import ray.cloudpickle as cloudpickle
|
||||
from ray.util.inspect import is_cython
|
||||
import grpc
|
||||
|
||||
from ray.exceptions import TaskCancelledError
|
||||
import ray.core.generated.ray_client_pb2 as ray_client_pb2
|
||||
import ray.core.generated.ray_client_pb2_grpc as ray_client_pb2_grpc
|
||||
from ray.experimental.client.common import convert_to_arg
|
||||
from ray.experimental.client.common import decode_exception
|
||||
from ray.experimental.client.common import ClientObjectRef
|
||||
from ray.experimental.client.common import ClientActorClass
|
||||
from ray.experimental.client.common import ClientActorHandle
|
||||
from ray.experimental.client.common import ClientRemoteFunc
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
@@ -35,6 +41,7 @@ class Worker:
|
||||
metadata: additional metadata passed in the grpc request headers.
|
||||
"""
|
||||
self.metadata = metadata
|
||||
self.channel = None
|
||||
if stub is None:
|
||||
if secure:
|
||||
credentials = grpc.ssl_channel_credentials()
|
||||
@@ -45,28 +52,32 @@ class Worker:
|
||||
else:
|
||||
self.server = stub
|
||||
|
||||
def get(self, ids):
|
||||
def get(self, vals, *, timeout: Optional[float] = None) -> Any:
|
||||
to_get = []
|
||||
single = False
|
||||
if isinstance(ids, list):
|
||||
to_get = [x.id for x in ids]
|
||||
elif isinstance(ids, ClientObjectRef):
|
||||
to_get = [ids.id]
|
||||
if isinstance(vals, list):
|
||||
to_get = [x.handle for x in vals]
|
||||
elif isinstance(vals, ClientObjectRef):
|
||||
to_get = [vals.handle]
|
||||
single = True
|
||||
else:
|
||||
raise Exception("Can't get something that's not a "
|
||||
"list of IDs or just an ID: %s" % type(ids))
|
||||
out = [self._get(x) for x in to_get]
|
||||
"list of IDs or just an ID: %s" % type(vals))
|
||||
if timeout is None:
|
||||
timeout = 0
|
||||
out = [self._get(x, timeout) for x in to_get]
|
||||
if single:
|
||||
out = out[0]
|
||||
return out
|
||||
|
||||
def _get(self, id: bytes):
|
||||
req = ray_client_pb2.GetRequest(id=id)
|
||||
data = self.server.GetObject(req, metadata=self.metadata)
|
||||
def _get(self, handle: bytes, timeout: float):
|
||||
req = ray_client_pb2.GetRequest(handle=handle, timeout=timeout)
|
||||
try:
|
||||
data = self.server.GetObject(req, metadata=self.metadata)
|
||||
except grpc.RpcError as e:
|
||||
raise decode_exception(e.details())
|
||||
if not data.valid:
|
||||
raise Exception(
|
||||
"Client GetObject returned invalid data: id invalid?")
|
||||
raise TaskCancelledError(handle)
|
||||
return cloudpickle.loads(data.data)
|
||||
|
||||
def put(self, vals):
|
||||
@@ -87,7 +98,7 @@ class Worker:
|
||||
data = cloudpickle.dumps(val)
|
||||
req = ray_client_pb2.PutRequest(data=data)
|
||||
resp = self.server.PutObject(req, metadata=self.metadata)
|
||||
return ClientObjectRef(resp.id)
|
||||
return ClientObjectRef.from_remote_ref(resp.ref)
|
||||
|
||||
def wait(self,
|
||||
object_refs: List[ClientObjectRef],
|
||||
@@ -99,8 +110,8 @@ class Worker:
|
||||
for ref in object_refs:
|
||||
assert isinstance(ref, ClientObjectRef)
|
||||
data = {
|
||||
"object_refs": [
|
||||
cloudpickle.dumps(object_ref) for object_ref in object_refs
|
||||
"object_handles": [
|
||||
object_ref.handle for object_ref in object_refs
|
||||
],
|
||||
"num_returns": num_returns,
|
||||
"timeout": timeout if timeout else -1
|
||||
@@ -111,10 +122,12 @@ class Worker:
|
||||
# TODO(ameer): improve error/exceptions messages.
|
||||
raise Exception("Client Wait request failed. Reference invalid?")
|
||||
client_ready_object_ids = [
|
||||
ClientObjectRef(id) for id in resp.ready_object_ids
|
||||
ClientObjectRef.from_remote_ref(ref)
|
||||
for ref in resp.ready_object_ids
|
||||
]
|
||||
client_remaining_object_ids = [
|
||||
ClientObjectRef(id) for id in resp.remaining_object_ids
|
||||
ClientObjectRef.from_remote_ref(ref)
|
||||
for ref in resp.remaining_object_ids
|
||||
]
|
||||
|
||||
return (client_ready_object_ids, client_remaining_object_ids)
|
||||
@@ -138,7 +151,53 @@ class Worker:
|
||||
task.args.append(pb_arg)
|
||||
logging.debug("Scheduling %s" % task)
|
||||
ticket = self.server.Schedule(task, metadata=self.metadata)
|
||||
return ClientObjectRef(ticket.return_id)
|
||||
return ClientObjectRef.from_remote_ref(ticket.return_ref)
|
||||
|
||||
def close(self):
|
||||
self.channel.close()
|
||||
self.server = None
|
||||
if self.channel:
|
||||
self.channel.close()
|
||||
|
||||
def terminate_actor(self, actor: ClientActorHandle,
|
||||
no_restart: bool) -> None:
|
||||
if not isinstance(actor, ClientActorHandle):
|
||||
raise ValueError("ray.kill() only supported for actors. "
|
||||
"Got: {}.".format(type(actor)))
|
||||
term_actor = ray_client_pb2.TerminateRequest.ActorTerminate()
|
||||
term_actor.handle = actor.actor_ref.handle
|
||||
term_actor.no_restart = no_restart
|
||||
try:
|
||||
term = ray_client_pb2.TerminateRequest(actor=term_actor)
|
||||
self.server.Terminate(term)
|
||||
except grpc.RpcError as e:
|
||||
raise decode_exception(e.details())
|
||||
|
||||
def terminate_task(self, obj: ClientObjectRef, force: bool,
|
||||
recursive: bool) -> None:
|
||||
if not isinstance(obj, ClientObjectRef):
|
||||
raise TypeError(
|
||||
"ray.cancel() only supported for non-actor object refs. "
|
||||
f"Got: {type(obj)}.")
|
||||
term_object = ray_client_pb2.TerminateRequest.TaskObjectTerminate()
|
||||
term_object.handle = obj.handle
|
||||
term_object.force = force
|
||||
term_object.recursive = recursive
|
||||
try:
|
||||
term = ray_client_pb2.TerminateRequest(task_object=term_object)
|
||||
self.server.Terminate(term)
|
||||
except grpc.RpcError as e:
|
||||
raise decode_exception(e.details())
|
||||
|
||||
def get_cluster_info(self, type: ray_client_pb2.ClusterInfoType.TypeEnum):
|
||||
req = ray_client_pb2.ClusterInfoRequest()
|
||||
req.type = type
|
||||
resp = self.server.ClusterInfo(req)
|
||||
if resp.WhichOneof("response_type") == "resource_table":
|
||||
return resp.resource_table.table
|
||||
return json.loads(resp.json)
|
||||
|
||||
def is_initialized(self) -> bool:
|
||||
if self.server is not None:
|
||||
return self.get_cluster_info(
|
||||
ray_client_pb2.ClusterInfoType.IS_INITIALIZED)
|
||||
return False
|
||||
|
||||
@@ -95,6 +95,8 @@ py_test_module_list(
|
||||
"test_dask_callback.py",
|
||||
"test_debug_tools.py",
|
||||
"test_experimental_client.py",
|
||||
"test_experimental_client_metadata.py",
|
||||
"test_experimental_client_terminate.py",
|
||||
"test_job.py",
|
||||
"test_memstat.py",
|
||||
"test_metrics_agent.py",
|
||||
|
||||
@@ -10,10 +10,12 @@ from ray.experimental.client.common import ClientObjectRef
|
||||
def ray_start_client_server():
|
||||
server = ray_client_server.serve("localhost:50051", test_mode=True)
|
||||
ray.connect("localhost:50051")
|
||||
yield ray
|
||||
ray.disconnect()
|
||||
server.stop(0)
|
||||
reset_api()
|
||||
try:
|
||||
yield ray
|
||||
finally:
|
||||
ray.disconnect()
|
||||
server.stop(0)
|
||||
reset_api()
|
||||
|
||||
|
||||
def test_real_ray_fallback(ray_start_regular_shared):
|
||||
@@ -35,9 +37,6 @@ def test_real_ray_fallback(ray_start_regular_shared):
|
||||
nodes = ray.get(get_nodes.remote())
|
||||
assert len(nodes) == 1, nodes
|
||||
|
||||
with pytest.raises(NotImplementedError):
|
||||
print(ray.nodes())
|
||||
|
||||
|
||||
def test_nested_function(ray_start_regular_shared):
|
||||
with ray_start_client_server() as ray:
|
||||
|
||||
@@ -0,0 +1,25 @@
|
||||
from ray.tests.test_experimental_client import ray_start_client_server
|
||||
|
||||
|
||||
def test_get_ray_metadata(ray_start_regular_shared):
|
||||
"""
|
||||
Test the ClusterInfo client data pathway and API surface
|
||||
"""
|
||||
with ray_start_client_server() as ray:
|
||||
ip_address = ray_start_regular_shared["node_ip_address"]
|
||||
|
||||
initialized = ray.is_initialized()
|
||||
assert initialized
|
||||
|
||||
nodes = ray.nodes()
|
||||
assert len(nodes) == 1, nodes
|
||||
assert nodes[0]["NodeManagerAddress"] == ip_address
|
||||
|
||||
current_node_id = "node:" + ip_address
|
||||
|
||||
cluster_resources = ray.cluster_resources()
|
||||
available_resources = ray.available_resources()
|
||||
|
||||
assert cluster_resources["CPU"] == 1.0
|
||||
assert current_node_id in cluster_resources
|
||||
assert current_node_id in available_resources
|
||||
@@ -0,0 +1,99 @@
|
||||
import pytest
|
||||
import asyncio
|
||||
from ray.tests.test_experimental_client import ray_start_client_server
|
||||
from ray.test_utils import wait_for_condition
|
||||
from ray.exceptions import TaskCancelledError
|
||||
from ray.exceptions import RayTaskError
|
||||
from ray.exceptions import WorkerCrashedError
|
||||
from ray.exceptions import ObjectLostError
|
||||
from ray.exceptions import GetTimeoutError
|
||||
|
||||
|
||||
def valid_exceptions(use_force):
|
||||
if use_force:
|
||||
return (RayTaskError, TaskCancelledError, WorkerCrashedError,
|
||||
ObjectLostError)
|
||||
else:
|
||||
return (RayTaskError, TaskCancelledError)
|
||||
|
||||
|
||||
def _all_actors_dead(ray):
|
||||
import ray as real_ray
|
||||
|
||||
def _all_actors_dead_internal():
|
||||
return all(actor["State"] == real_ray.gcs_utils.ActorTableData.DEAD
|
||||
for actor in list(real_ray.actors().values()))
|
||||
|
||||
return _all_actors_dead_internal
|
||||
|
||||
|
||||
def test_kill_actor_immediately_after_creation(ray_start_regular):
|
||||
with ray_start_client_server() as ray:
|
||||
|
||||
@ray.remote
|
||||
class A:
|
||||
pass
|
||||
|
||||
a = A.remote()
|
||||
b = A.remote()
|
||||
|
||||
ray.kill(a)
|
||||
ray.kill(b)
|
||||
wait_for_condition(_all_actors_dead(ray), timeout=10)
|
||||
|
||||
|
||||
@pytest.mark.parametrize("use_force", [True, False])
|
||||
def test_cancel_chain(ray_start_regular, use_force):
|
||||
with ray_start_client_server() as ray:
|
||||
|
||||
@ray.remote
|
||||
class SignalActor:
|
||||
def __init__(self):
|
||||
self.ready_event = asyncio.Event()
|
||||
|
||||
def send(self, clear=False):
|
||||
self.ready_event.set()
|
||||
if clear:
|
||||
self.ready_event.clear()
|
||||
|
||||
async def wait(self, should_wait=True):
|
||||
if should_wait:
|
||||
await self.ready_event.wait()
|
||||
|
||||
signaler = SignalActor.remote()
|
||||
|
||||
@ray.remote
|
||||
def wait_for(t):
|
||||
return ray.get(t[0])
|
||||
|
||||
obj1 = wait_for.remote([signaler.wait.remote()])
|
||||
obj2 = wait_for.remote([obj1])
|
||||
obj3 = wait_for.remote([obj2])
|
||||
obj4 = wait_for.remote([obj3])
|
||||
|
||||
assert len(ray.wait([obj1], timeout=.1)[0]) == 0
|
||||
ray.cancel(obj1, force=use_force)
|
||||
for ob in [obj1, obj2, obj3, obj4]:
|
||||
with pytest.raises(valid_exceptions(use_force)):
|
||||
ray.get(ob)
|
||||
|
||||
signaler2 = SignalActor.remote()
|
||||
obj1 = wait_for.remote([signaler2.wait.remote()])
|
||||
obj2 = wait_for.remote([obj1])
|
||||
obj3 = wait_for.remote([obj2])
|
||||
obj4 = wait_for.remote([obj3])
|
||||
|
||||
assert len(ray.wait([obj3], timeout=.1)[0]) == 0
|
||||
ray.cancel(obj3, force=use_force)
|
||||
for ob in [obj3, obj4]:
|
||||
with pytest.raises(valid_exceptions(use_force)):
|
||||
ray.get(ob)
|
||||
|
||||
with pytest.raises(GetTimeoutError):
|
||||
ray.get(obj1, timeout=.1)
|
||||
|
||||
with pytest.raises(GetTimeoutError):
|
||||
ray.get(obj2, timeout=.1)
|
||||
|
||||
signaler2.send.remote()
|
||||
ray.get(obj1)
|
||||
Reference in New Issue
Block a user