diff --git a/python/ray/experimental/client/client_pickler.py b/python/ray/experimental/client/client_pickler.py index 2496199ea..7ba83b3ac 100644 --- a/python/ray/experimental/client/client_pickler.py +++ b/python/ray/experimental/client/client_pickler.py @@ -28,6 +28,7 @@ import sys from typing import NamedTuple from typing import Any +from typing import Dict from typing import Optional from ray.experimental.client import RayAPIStub @@ -37,6 +38,7 @@ from ray.experimental.client.common import ClientActorRef from ray.experimental.client.common import ClientActorClass from ray.experimental.client.common import ClientRemoteFunc from ray.experimental.client.common import ClientRemoteMethod +from ray.experimental.client.common import OptionWrapper from ray.experimental.client.common import SelfReferenceSentinel import ray.core.generated.ray_client_pb2 as ray_client_pb2 @@ -52,7 +54,8 @@ else: # the data for an exectuion, with no arguments. Combine the two? PickleStub = NamedTuple("PickleStub", [("type", str), ("client_id", str), ("ref_id", bytes), - ("name", Optional[str])]) + ("name", Optional[str]), + ("baseline_options", Optional[Dict])]) class ClientPickler(cloudpickle.CloudPickler): @@ -67,6 +70,7 @@ class ClientPickler(cloudpickle.CloudPickler): client_id=self.client_id, ref_id=b"", name=None, + baseline_options=None, ) elif isinstance(obj, ClientObjectRef): return PickleStub( @@ -74,6 +78,7 @@ class ClientPickler(cloudpickle.CloudPickler): client_id=self.client_id, ref_id=obj.id, name=None, + baseline_options=None, ) elif isinstance(obj, ClientActorHandle): return PickleStub( @@ -81,6 +86,7 @@ class ClientPickler(cloudpickle.CloudPickler): client_id=self.client_id, ref_id=obj._actor_id, name=None, + baseline_options=None, ) elif isinstance(obj, ClientRemoteFunc): # TODO(barakmich): This is going to have trouble with mutually @@ -95,12 +101,14 @@ class ClientPickler(cloudpickle.CloudPickler): client_id=self.client_id, ref_id=b"", name=None, + baseline_options=None, ) return PickleStub( type="RemoteFunc", client_id=self.client_id, ref_id=obj._ref.id, name=None, + baseline_options=obj._options, ) elif isinstance(obj, ClientActorClass): # TODO(barakmich): Mutual recursion, as above. @@ -112,12 +120,14 @@ class ClientPickler(cloudpickle.CloudPickler): client_id=self.client_id, ref_id=b"", name=None, + baseline_options=None, ) return PickleStub( type="RemoteActor", client_id=self.client_id, ref_id=obj._ref.id, name=None, + baseline_options=obj._options, ) elif isinstance(obj, ClientRemoteMethod): return PickleStub( @@ -125,7 +135,11 @@ class ClientPickler(cloudpickle.CloudPickler): client_id=self.client_id, ref_id=obj.actor_handle.actor_ref.id, name=obj.method_name, + baseline_options=None, ) + elif isinstance(obj, OptionWrapper): + raise NotImplementedError( + "Sending a partial option is unimplemented") return None diff --git a/python/ray/experimental/client/common.py b/python/ray/experimental/client/common.py index 60901c661..49eee05d6 100644 --- a/python/ray/experimental/client/common.py +++ b/python/ray/experimental/client/common.py @@ -1,9 +1,21 @@ import ray.core.generated.ray_client_pb2 as ray_client_pb2 from ray.experimental.client import ray +from ray.experimental.client.options import validate_options + +import json +import threading +from typing import Any +from typing import List +from typing import Dict +from typing import Optional +from typing import Union class ClientBaseRef: def __init__(self, id: bytes): + self.id = None + if not isinstance(id, bytes): + raise TypeError("ClientRefs must be created with bytes IDs") self.id: bytes = id ray.call_retain(id) @@ -23,7 +35,7 @@ class ClientBaseRef: return hash(self.id) def __del__(self): - if ray.is_connected(): + if ray.is_connected() and self.id is not None: ray.call_release(self.id) @@ -52,33 +64,42 @@ class ClientRemoteFunc(ClientStub): _ref: The ClientObjectRef of the pickled code of the function, _func """ - def __init__(self, f): + def __init__(self, f, options=None): + self._lock = threading.Lock() self._func = f self._name = f.__name__ self._ref = None + self._options = validate_options(options) def __call__(self, *args, **kwargs): raise TypeError(f"Remote function cannot be called directly. " "Use {self._name}.remote method instead") def remote(self, *args, **kwargs): - return ClientObjectRef(ray.call_remote(self, *args, **kwargs)) + return return_refs(ray.call_remote(self, *args, **kwargs)) + + def options(self, **kwargs): + return OptionWrapper(self, kwargs) + + def _remote(self, args=[], kwargs={}, **option_args): + return self.options(**option_args).remote(*args, **kwargs) def __repr__(self): return "ClientRemoteFunc(%s, %s)" % (self._name, self._ref) def _ensure_ref(self): - if self._ref is None: - # While calling ray.put() on our function, if - # our function is recursive, it will attempt to - # encode the ClientRemoteFunc -- itself -- and - # infinitely recurse on _ensure_ref. - # - # So we set the state of the reference to be an - # in-progress self reference value, which - # the encoding can detect and handle correctly. - self._ref = SelfReferenceSentinel() - self._ref = ray.put(self._func) + with self._lock: + if self._ref is None: + # While calling ray.put() on our function, if + # our function is recursive, it will attempt to + # encode the ClientRemoteFunc -- itself -- and + # infinitely recurse on _ensure_ref. + # + # So we set the state of the reference to be an + # in-progress self reference value, which + # the encoding can detect and handle correctly. + self._ref = SelfReferenceSentinel() + self._ref = ray.put(self._func) def _prepare_client_task(self) -> ray_client_pb2.ClientTask: self._ensure_ref() @@ -86,6 +107,7 @@ class ClientRemoteFunc(ClientStub): task.type = ray_client_pb2.ClientTask.FUNCTION task.name = self._name task.payload_id = self._ref.id + set_task_options(task, self._options, "baseline_options") return task @@ -100,10 +122,11 @@ class ClientActorClass(ClientStub): _ref: The ClientObjectRef of the pickled `actor_cls` """ - def __init__(self, actor_cls): + def __init__(self, actor_cls, options=None): self.actor_cls = actor_cls self._name = actor_cls.__name__ self._ref = None + self._options = validate_options(options) def __call__(self, *args, **kwargs): raise TypeError(f"Remote actor cannot be instantiated directly. " @@ -119,8 +142,15 @@ class ClientActorClass(ClientStub): def remote(self, *args, **kwargs) -> "ClientActorHandle": # Actually instantiate the actor - ref_id = ray.call_remote(self, *args, **kwargs) - return ClientActorHandle(ClientActorRef(ref_id), self) + ref_ids = ray.call_remote(self, *args, **kwargs) + assert len(ref_ids) == 1 + return ClientActorHandle(ClientActorRef(ref_ids[0]), self) + + def options(self, **kwargs): + return ActorOptionWrapper(self, kwargs) + + def _remote(self, args=[], kwargs={}, **option_args): + return self.options(**option_args).remote(*args, **kwargs) def __repr__(self): return "ClientActorClass(%s, %s)" % (self._name, self._ref) @@ -136,6 +166,7 @@ class ClientActorClass(ClientStub): task.type = ray_client_pb2.ClientTask.ACTOR task.name = self._name task.payload_id = self._ref.id + set_task_options(task, self._options, "baseline_options") return task @@ -160,7 +191,8 @@ class ClientActorHandle(ClientStub): self.actor_ref = actor_ref def __del__(self) -> None: - ray.call_release(self.actor_ref.id) + if ray.is_connected(): + ray.call_release(self.actor_ref.id) @property def _actor_id(self): @@ -193,12 +225,18 @@ class ClientRemoteMethod(ClientStub): f"Use {self._name}.remote() instead") def remote(self, *args, **kwargs): - return ClientObjectRef(ray.call_remote(self, *args, **kwargs)) + return return_refs(ray.call_remote(self, *args, **kwargs)) def __repr__(self): return "ClientRemoteMethod(%s, %s)" % (self.method_name, self.actor_handle) + def options(self, **kwargs): + return OptionWrapper(self, kwargs) + + def _remote(self, args=[], kwargs={}, **option_args): + return self.options(**option_args).remote(*args, **kwargs) + def _prepare_client_task(self) -> ray_client_pb2.ClientTask: task = ray_client_pb2.ClientTask() task.type = ray_client_pb2.ClientTask.METHOD @@ -207,6 +245,49 @@ class ClientRemoteMethod(ClientStub): return task +class OptionWrapper: + def __init__(self, stub: ClientStub, options: Optional[Dict[str, Any]]): + self.remote_stub = stub + self.options = validate_options(options) + + def remote(self, *args, **kwargs): + return return_refs(ray.call_remote(self, *args, **kwargs)) + + def __getattr__(self, key): + return getattr(self.remote_stub, key) + + def _prepare_client_task(self): + task = self.remote_stub._prepare_client_task() + set_task_options(task, self.options) + return task + + +class ActorOptionWrapper(OptionWrapper): + def remote(self, *args, **kwargs): + ref_ids = ray.call_remote(self, *args, **kwargs) + assert len(ref_ids) == 1 + return ClientActorHandle(ClientActorRef(ref_ids[0]), self) + + +def set_task_options(task: ray_client_pb2.ClientTask, + options: Optional[Dict[str, Any]], + field: str = "options") -> None: + if options is None: + task.ClearField(field) + return + options_str = json.dumps(options) + getattr(task, field).json_options = options_str + + +def return_refs(ids: List[bytes] + ) -> Union[None, ClientObjectRef, List[ClientObjectRef]]: + if len(ids) == 1: + return ClientObjectRef(ids[0]) + if len(ids) == 0: + return None + return [ClientObjectRef(id) for id in ids] + + class DataEncodingSentinel: def __repr__(self) -> str: return self.__class__.__name__ diff --git a/python/ray/experimental/client/options.py b/python/ray/experimental/client/options.py new file mode 100644 index 000000000..79727b126 --- /dev/null +++ b/python/ray/experimental/client/options.py @@ -0,0 +1,54 @@ +from typing import Any +from typing import Dict +from typing import Optional + +options = { + "num_returns": (int, lambda x: x >= 0, + "The keyword 'num_returns' only accepts 0 " + "or a positive integer"), + "num_cpus": (), + "num_gpus": (), + "resources": (), + "accelerator_type": (), + "max_calls": (int, lambda x: x >= 0, + "The keyword 'max_calls' only accepts 0 " + "or a positive integer"), + "max_restarts": (int, lambda x: x >= -1, + "The keyword 'max_restarts' only accepts -1, 0 " + "or a positive integer"), + "max_task_retries": (int, lambda x: x >= -1, + "The keyword 'max_task_retries' only accepts -1, 0 " + "or a positive integer"), + "max_retries": (int, lambda x: x >= -1, + "The keyword 'max_retries' only accepts 0, -1 " + "or a positive integer"), + "max_concurrency": (), + "name": (), + "lifetime": (), + "memory": (), + "object_store_memory": (), + "placement_group": (), + "placement_group_bundle_index": (), + "placement_group_capture_child_tasks": (), + "override_environment_variables": (), +} + + +def validate_options( + kwargs_dict: Optional[Dict[str, Any]]) -> Optional[Dict[str, Any]]: + if kwargs_dict is None: + return None + if len(kwargs_dict) == 0: + return None + out = {} + for k, v in kwargs_dict.items(): + if k not in options.keys(): + raise TypeError(f"Invalid option passed to remote(): {k}") + validator = options[k] + if len(validator) != 0: + if not isinstance(v, validator[0]): + raise ValueError(validator[2]) + if not validator[1](v): + raise ValueError(validator[2]) + out[k] = v + return out diff --git a/python/ray/experimental/client/server/server.py b/python/ray/experimental/client/server/server.py index 2841384d8..5f86ddee2 100644 --- a/python/ray/experimental/client/server/server.py +++ b/python/ray/experimental/client/server/server.py @@ -4,8 +4,10 @@ import grpc import base64 from collections import defaultdict +from typing import Any from typing import Dict from typing import Set +from typing import Optional from ray import cloudpickle import ray @@ -187,9 +189,11 @@ class RayletServicer(ray_client_pb2_grpc.RayletDriverServicer): ready_object_refs, remaining_object_refs = ray.wait( object_refs, num_returns=num_returns, - timeout=timeout if timeout != -1 else None) - except Exception: + timeout=timeout if timeout != -1 else None, + ) + except Exception as e: # TODO(ameer): improve exception messages. + logger.error(f"Exception {e}") return ray_client_pb2.WaitResponse(valid=False) logger.debug("wait: %s %s" % (str(ready_object_refs), str(remaining_object_refs))) @@ -206,9 +210,10 @@ class RayletServicer(ray_client_pb2_grpc.RayletDriverServicer): remaining_object_ids=remaining_object_ids) def Schedule(self, task, context=None) -> ray_client_pb2.ClientTaskTicket: - logger.info("schedule: %s %s" % - (task.name, - ray_client_pb2.ClientTask.RemoteExecType.Name(task.type))) + logger.debug( + "schedule: %s %s" % (task.name, + ray_client_pb2.ClientTask.RemoteExecType.Name( + task.type))) with stash_api_for_tests(self._test_mode): try: if task.type == ray_client_pb2.ClientTask.FUNCTION: @@ -226,6 +231,7 @@ class RayletServicer(ray_client_pb2_grpc.RayletDriverServicer): return result except Exception as e: logger.error(f"Caught schedule exception {e}") + raise e return ray_client_pb2.ClientTaskTicket( valid=False, error=cloudpickle.dumps(e)) @@ -236,34 +242,44 @@ class RayletServicer(ray_client_pb2_grpc.RayletDriverServicer): raise Exception( "Can't run an actor the server doesn't have a handle for") arglist, kwargs = self._convert_args(task.args, task.kwargs) - output = getattr(actor_handle, task.name).remote(*arglist, **kwargs) - self.object_refs[task.client_id][output.binary()] = output - return ray_client_pb2.ClientTaskTicket(return_id=output.binary()) + method = getattr(actor_handle, task.name) + opts = decode_options(task.options) + if opts is not None: + method = method.options(**opts) + output = method.remote(*arglist, **kwargs) + ids = self.unify_and_track_outputs(output, task.client_id) + return ray_client_pb2.ClientTaskTicket(return_ids=ids) def _schedule_actor(self, task: ray_client_pb2.ClientTask, context=None) -> ray_client_pb2.ClientTaskTicket: - remote_class = self.lookup_or_register_actor(task.payload_id, - task.client_id) + remote_class = self.lookup_or_register_actor( + task.payload_id, task.client_id, + decode_options(task.baseline_options)) arglist, kwargs = self._convert_args(task.args, task.kwargs) + opts = decode_options(task.options) + if opts is not None: + remote_class = remote_class.options(**opts) with current_remote(remote_class): actor = remote_class.remote(*arglist, **kwargs) self.actor_refs[actor._actor_id.binary()] = actor self.actor_owners[task.client_id].add(actor._actor_id.binary()) return ray_client_pb2.ClientTaskTicket( - return_id=actor._actor_id.binary()) + return_ids=[actor._actor_id.binary()]) def _schedule_function(self, task: ray_client_pb2.ClientTask, context=None) -> ray_client_pb2.ClientTaskTicket: - remote_func = self.lookup_or_register_func(task.payload_id, - task.client_id) + remote_func = self.lookup_or_register_func( + task.payload_id, task.client_id, + decode_options(task.baseline_options)) arglist, kwargs = self._convert_args(task.args, task.kwargs) + opts = decode_options(task.options) + if opts is not None: + remote_func = remote_func.options(**opts) with current_remote(remote_func): output = remote_func.remote(*arglist, **kwargs) - if output.binary() in self.object_refs[task.client_id]: - raise Exception("already found it") - self.object_refs[task.client_id][output.binary()] = output - return ray_client_pb2.ClientTaskTicket(return_id=output.binary()) + ids = self.unify_and_track_outputs(output, task.client_id) + return ray_client_pb2.ClientTaskTicket(return_ids=ids) def _convert_args(self, arg_list, kwarg_map): argout = [] @@ -275,28 +291,50 @@ class RayletServicer(ray_client_pb2_grpc.RayletDriverServicer): kwargout[k] = convert_from_arg(kwarg_map[k], self) return argout, kwargout - def lookup_or_register_func(self, id: bytes, client_id: str - ) -> ray.remote_function.RemoteFunction: + def lookup_or_register_func( + self, id: bytes, client_id: str, + options: Optional[Dict]) -> ray.remote_function.RemoteFunction: if id not in self.function_refs: funcref = self.object_refs[client_id][id] func = ray.get(funcref) if not inspect.isfunction(func): raise Exception("Attempting to register function that " "isn't a function.") - self.function_refs[id] = ray.remote(func) + if options is None or len(options) == 0: + self.function_refs[id] = ray.remote(func) + else: + self.function_refs[id] = ray.remote(**options)(func) return self.function_refs[id] - def lookup_or_register_actor(self, id: bytes, client_id: str): + def lookup_or_register_actor(self, id: bytes, client_id: str, + options: Optional[Dict]): if id not in self.registered_actor_classes: actor_class_ref = self.object_refs[client_id][id] actor_class = ray.get(actor_class_ref) if not inspect.isclass(actor_class): raise Exception("Attempting to schedule actor that " "isn't a class.") - reg_class = ray.remote(actor_class) + if options is None or len(options) == 0: + reg_class = ray.remote(actor_class) + else: + reg_class = ray.remote(**options)(actor_class) self.registered_actor_classes[id] = reg_class + return self.registered_actor_classes[id] + def unify_and_track_outputs(self, output, client_id): + if output is None: + outputs = [] + elif isinstance(output, list): + outputs = output + else: + outputs = [output] + for out in outputs: + if out.binary() in self.object_refs[client_id]: + logger.warning(f"Already saw object_ref {out}") + self.object_refs[client_id][out.binary()] = out + return [out.binary() for out in outputs] + def return_exception_in_context(err, context): if context is not None: @@ -309,6 +347,15 @@ def encode_exception(exception) -> str: return base64.standard_b64encode(data).decode() +def decode_options( + options: ray_client_pb2.TaskOptions) -> Optional[Dict[str, Any]]: + if options.json_options == "": + return None + opts = json.loads(options.json_options) + assert isinstance(opts, dict) + return opts + + def serve(connection_str, test_mode=False): server = grpc.server(futures.ThreadPoolExecutor(max_workers=10)) task_servicer = RayletServicer(test_mode=test_mode) diff --git a/python/ray/experimental/client/server/server_pickler.py b/python/ray/experimental/client/server/server_pickler.py index c3cd161bd..10da70cc1 100644 --- a/python/ray/experimental/client/server/server_pickler.py +++ b/python/ray/experimental/client/server/server_pickler.py @@ -56,6 +56,7 @@ class ServerPickler(cloudpickle.CloudPickler): client_id=self.client_id, ref_id=obj_id, name=None, + baseline_options=None, ) elif isinstance(obj, ray.actor.ActorHandle): actor_id = obj._actor_id.binary() @@ -69,6 +70,7 @@ class ServerPickler(cloudpickle.CloudPickler): client_id=self.client_id, ref_id=obj._actor_id.binary(), name=None, + baseline_options=None, ) return None @@ -89,13 +91,13 @@ class ClientUnpickler(pickle.Unpickler): elif pid.type == "RemoteFuncSelfReference": return ServerSelfReferenceSentinel() elif pid.type == "RemoteFunc": - return self.server.lookup_or_register_func(pid.ref_id, - pid.client_id) + return self.server.lookup_or_register_func( + pid.ref_id, pid.client_id, pid.baseline_options) elif pid.type == "RemoteActorSelfReference": return ServerSelfReferenceSentinel() elif pid.type == "RemoteActor": return self.server.lookup_or_register_actor( - pid.ref_id, pid.client_id) + pid.ref_id, pid.client_id, pid.baseline_options) elif pid.type == "RemoteMethod": actor = self.server.actor_refs[pid.ref_id] return getattr(actor, pid.name) diff --git a/python/ray/experimental/client/worker.py b/python/ray/experimental/client/worker.py index 6bfab6b75..d2ba52d62 100644 --- a/python/ray/experimental/client/worker.py +++ b/python/ray/experimental/client/worker.py @@ -21,12 +21,13 @@ import ray.cloudpickle as cloudpickle import ray.core.generated.ray_client_pb2 as ray_client_pb2 import ray.core.generated.ray_client_pb2_grpc as ray_client_pb2_grpc from ray.experimental.client.client_pickler import convert_to_arg -from ray.experimental.client.client_pickler import loads_from_server from ray.experimental.client.client_pickler import dumps_from_client -from ray.experimental.client.common import ClientObjectRef +from ray.experimental.client.client_pickler import loads_from_server from ray.experimental.client.common import ClientActorClass from ray.experimental.client.common import ClientActorHandle +from ray.experimental.client.common import ClientObjectRef from ray.experimental.client.common import ClientRemoteFunc +from ray.experimental.client.common import ClientStub from ray.experimental.client.dataclient import DataClient logger = logging.getLogger(__name__) @@ -80,7 +81,9 @@ class Worker: except grpc.RpcError as e: raise e.details() if not data.valid: - raise cloudpickle.loads(data.error) + err = cloudpickle.loads(data.error) + logger.error(err) + raise err return loads_from_server(data.data) def put(self, vals): @@ -98,6 +101,13 @@ class Worker: return out def _put(self, val): + if isinstance(val, ClientObjectRef): + raise TypeError( + "Calling 'put' on an ObjectRef is not allowed " + "(similarly, returning an ObjectRef from a remote " + "function is not allowed). If you really want to " + "do this, you can wrap the ObjectRef in a list and " + "call 'put' on it (or return it).") data = dumps_from_client(val, self._client_id) req = ray_client_pb2.PutRequest(data=data) resp = self.data_client.PutObject(req) @@ -107,7 +117,8 @@ class Worker: object_refs: List[ClientObjectRef], *, num_returns: int = 1, - timeout: float = None + timeout: float = None, + fetch_local: bool = True ) -> Tuple[List[ClientObjectRef], List[ClientObjectRef]]: if not isinstance(object_refs, list): raise TypeError("wait() expected a list of ClientObjectRef, " @@ -136,19 +147,22 @@ class Worker: return (client_ready_object_ids, client_remaining_object_ids) - def remote(self, function_or_class, *args, **kwargs): - # TODO(barakmich): Arguments to ray.remote - # get captured here. - if (inspect.isfunction(function_or_class) - or is_cython(function_or_class)): - return ClientRemoteFunc(function_or_class) - elif inspect.isclass(function_or_class): - return ClientActorClass(function_or_class) - else: - raise TypeError("The @ray.remote decorator must be applied to " - "either a function or to a class.") + def remote(self, *args, **kwargs): + if len(args) == 1 and len(kwargs) == 0 and callable(args[0]): + # This is the case where the decorator is just @ray.remote. + return remote_decorator(options=None)(args[0]) + error_string = ("The @ray.remote decorator must be applied either " + "with no arguments and no parentheses, for example " + "'@ray.remote', or it must be applied using some of " + "the arguments 'num_returns', 'num_cpus', 'num_gpus', " + "'memory', 'object_store_memory', 'resources', " + "'max_calls', or 'max_restarts', like " + "'@ray.remote(num_returns=2, " + "resources={\"CustomResource\": 1})'.") + assert len(args) == 0 and len(kwargs) > 0, error_string + return remote_decorator(options=kwargs) - def call_remote(self, instance, *args, **kwargs) -> bytes: + def call_remote(self, instance, *args, **kwargs) -> List[bytes]: task = instance._prepare_client_task() for arg in args: pb_arg = convert_to_arg(arg, self._client_id) @@ -160,10 +174,10 @@ class Worker: try: ticket = self.server.Schedule(task, metadata=self.metadata) except grpc.RpcError as e: - raise e.details() + raise decode_exception(e.details) if not ticket.valid: raise cloudpickle.loads(ticket.error) - return ticket.return_id + return ticket.return_ids def call_release(self, id: bytes) -> None: self.reference_count[id] -= 1 @@ -234,6 +248,20 @@ class Worker: return False +def remote_decorator(options: Optional[Dict[str, Any]]): + def decorator(function_or_class) -> ClientStub: + if (inspect.isfunction(function_or_class) + or is_cython(function_or_class)): + return ClientRemoteFunc(function_or_class, options=options) + elif inspect.isclass(function_or_class): + return ClientActorClass(function_or_class, options=options) + else: + raise TypeError("The @ray.remote decorator must be applied to " + "either a function or to a class.") + + return decorator + + def make_client_id() -> str: id = uuid.uuid4() return id.hex diff --git a/python/ray/tests/BUILD b/python/ray/tests/BUILD index e88986475..7e552e616 100644 --- a/python/ray/tests/BUILD +++ b/python/ray/tests/BUILD @@ -158,6 +158,7 @@ py_test( py_test_module_list( files = [ "test_actor.py", + "test_advanced.py", "test_basic.py", "test_basic_2.py", ], diff --git a/python/ray/tests/test_actor.py b/python/ray/tests/test_actor.py index 1e761762e..3ba2ed7eb 100644 --- a/python/ray/tests/test_actor.py +++ b/python/ray/tests/test_actor.py @@ -25,7 +25,9 @@ else: import setproctitle # noqa -@pytest.mark.skipif(client_test_enabled(), reason="test setup order") +@pytest.mark.skipif( + client_test_enabled(), + reason="defining early, no ray package injection yet") def test_caching_actors(shutdown_only): # Test defining actors before ray.init() has been called. @@ -564,7 +566,6 @@ def test_actor_static_attributes(ray_start_regular_shared): assert ray.get(t.g.remote()) == 3 -@pytest.mark.skipif(client_test_enabled(), reason="remote args") def test_decorator_args(ray_start_regular_shared): # This is an invalid way of using the actor decorator. with pytest.raises(Exception): @@ -655,7 +656,7 @@ def test_actor_inheritance(ray_start_regular_shared): pass -@pytest.mark.skipif(client_test_enabled(), reason="remote args") +@pytest.mark.skipif(client_test_enabled(), reason="ray.method unimplemented") def test_multiple_return_values(ray_start_regular_shared): @ray.remote class Foo: @@ -689,7 +690,6 @@ def test_multiple_return_values(ray_start_regular_shared): assert ray.get([id3a, id3b, id3c]) == [1, 2, 3] -@pytest.mark.skipif(client_test_enabled(), reason="remote args") def test_options_num_returns(ray_start_regular_shared): @ray.remote class Foo: diff --git a/python/ray/tests/test_advanced.py b/python/ray/tests/test_advanced.py index 08dd168fa..ea2a6c693 100644 --- a/python/ray/tests/test_advanced.py +++ b/python/ray/tests/test_advanced.py @@ -10,16 +10,22 @@ import time import numpy as np import pytest -import ray import ray.cluster_utils import ray.test_utils +from ray.test_utils import client_test_enabled from ray.test_utils import RayTestTimeoutException +if client_test_enabled(): + from ray.experimental.client import ray +else: + import ray + logger = logging.getLogger(__name__) # issue https://github.com/ray-project/ray/issues/7105 +@pytest.mark.skipif(client_test_enabled(), reason="message size") def test_internal_free(shutdown_only): ray.init(num_cpus=1) @@ -60,14 +66,14 @@ def test_multiple_waits_and_gets(shutdown_only): return 1 @ray.remote - def g(l): - # The argument l should be a list containing one object ref. - ray.wait([l[0]]) + def g(input_list): + # The argument input_list should be a list containing one object ref. + ray.wait([input_list[0]]) @ray.remote - def h(l): - # The argument l should be a list containing one object ref. - ray.get(l[0]) + def h(input_list): + # The argument input_list should be a list containing one object ref. + ray.get(input_list[0]) # Make sure that multiple wait requests involving the same object ref # all return. @@ -80,6 +86,7 @@ def test_multiple_waits_and_gets(shutdown_only): ray.get([h.remote([x]), h.remote([x])]) +@pytest.mark.skipif(client_test_enabled(), reason="internal api") def test_caching_functions_to_run(shutdown_only): # Test that we export functions to run on all workers before the driver # is connected. @@ -125,6 +132,7 @@ def test_caching_functions_to_run(shutdown_only): ray.worker.global_worker.run_function_on_all_workers(f) +@pytest.mark.skipif(client_test_enabled(), reason="internal api") def test_running_function_on_all_workers(ray_start_regular): def f(worker_info): sys.path.append("fake_directory") @@ -152,6 +160,7 @@ def test_running_function_on_all_workers(ray_start_regular): assert "fake_directory" not in ray.get(get_path2.remote()) +@pytest.mark.skipif(client_test_enabled(), reason="ray.timeline") def test_profiling_api(ray_start_2_cpus): @ray.remote def f(): @@ -482,6 +491,7 @@ def test_multithreading(ray_start_2_cpus): ray.get(actor.join.remote()) == "ok" +@pytest.mark.skipif(client_test_enabled(), reason="message size") def test_wait_makes_object_local(ray_start_cluster): cluster = ray_start_cluster cluster.add_node(num_cpus=0) diff --git a/python/ray/tests/test_basic.py b/python/ray/tests/test_basic.py index 709b467e6..38330645b 100644 --- a/python/ray/tests/test_basic.py +++ b/python/ray/tests/test_basic.py @@ -24,7 +24,7 @@ logger = logging.getLogger(__name__) # https://github.com/ray-project/ray/issues/6662 -@pytest.mark.skipif(client_test_enabled(), reason="internal api") +@pytest.mark.skipif(client_test_enabled(), reason="interferes with grpc") def test_ignore_http_proxy(shutdown_only): ray.init(num_cpus=1) os.environ["http_proxy"] = "http://example.com" @@ -55,14 +55,12 @@ def test_grpc_message_size(shutdown_only): # https://github.com/ray-project/ray/issues/7287 -@pytest.mark.skipif(client_test_enabled(), reason="internal api") def test_omp_threads_set(shutdown_only): ray.init(num_cpus=1) # Should have been auto set by ray init. assert os.environ["OMP_NUM_THREADS"] == "1" -@pytest.mark.skipif(client_test_enabled(), reason="remote args") def test_submit_api(shutdown_only): ray.init(num_cpus=2, num_gpus=1, resources={"Custom": 1}) @@ -121,7 +119,6 @@ def test_submit_api(shutdown_only): assert ray.get([id1, id2, id3, id4]) == [0, 1, "test", 2] -@pytest.mark.skipif(client_test_enabled(), reason="remote args") def test_invalid_arguments(shutdown_only): ray.init(num_cpus=2) @@ -176,7 +173,6 @@ def test_invalid_arguments(shutdown_only): x = 1 -@pytest.mark.skipif(client_test_enabled(), reason="remote args") def test_many_fractional_resources(shutdown_only): ray.init(num_cpus=2, num_gpus=2, resources={"Custom": 2}) @@ -244,7 +240,6 @@ def test_many_fractional_resources(shutdown_only): assert False, "Did not get correct available resources." -@pytest.mark.skipif(client_test_enabled(), reason="remote args") def test_background_tasks_with_max_calls(shutdown_only): ray.init(num_cpus=2) @@ -360,8 +355,9 @@ def test_function_descriptor(): assert d.get(python_descriptor2) == 123 -@pytest.mark.skipif(client_test_enabled(), reason="remote args") def test_ray_options(shutdown_only): + ray.init(num_cpus=10, num_gpus=10, resources={"custom1": 2}) + @ray.remote( num_cpus=2, num_gpus=3, memory=150 * 2**20, resources={"custom1": 1}) def foo(): @@ -370,8 +366,6 @@ def test_ray_options(shutdown_only): time.sleep(0.1) return ray.available_resources() - ray.init(num_cpus=10, num_gpus=10, resources={"custom1": 2}) - without_options = ray.get(foo.remote()) with_options = ray.get( foo.options( diff --git a/python/ray/tests/test_basic_2.py b/python/ray/tests/test_basic_2.py index 25688a6f7..cd8114aa8 100644 --- a/python/ray/tests/test_basic_2.py +++ b/python/ray/tests/test_basic_2.py @@ -537,7 +537,6 @@ def test_actor_recursive(ray_start_regular_shared): assert result == [x * 2 for x in range(100)] -@pytest.mark.skipif(client_test_enabled(), reason="remote args") def test_actor_concurrent(ray_start_regular_shared): @ray.remote class Batcher: diff --git a/src/ray/protobuf/ray_client.proto b/src/ray/protobuf/ray_client.proto index ea4939738..cbd6679dd 100644 --- a/src/ray/protobuf/ray_client.proto +++ b/src/ray/protobuf/ray_client.proto @@ -35,6 +35,18 @@ message Arg { Type type = 4; } +// A message representing the valid options to modify a task exectution +// +// TODO(barakmich): In the longer term, if everything were a client, +// this message could be the actual standard for which options are +// allowed in the API. Today, however, it's a bit flexible and defined in the +// Python code. So for now, it's a stand-in message with a json field, but +// this is forwards-compatible with deprecating that field and instituting +// strongly defined and typed fields, without migrating the original ClientTask. +message TaskOptions { + string json_options = 1; +} + // Represents one unit of work to be executed by the server. message ClientTask { enum RemoteExecType { @@ -45,8 +57,8 @@ message ClientTask { } // Which type of work this request represents. RemoteExecType type = 1; - // A name parameter, if the payload can be called in more than one way (like a method on - // a payload object). + // A name parameter, if the payload can be called in more than one way + // (like a method on a payload object). string name = 2; // A reference to the payload. bytes payload_id = 3; @@ -54,16 +66,20 @@ message ClientTask { repeated Arg args = 4; // Keyword parameters to pass to this call. map kwargs = 5; - // The ID of the client namespace associated with the Datapath stream making this - // request. + // The ID of the client namespace associated with the Datapath stream + // making this request. string client_id = 6; + // Options for modifying the remote task execution environment. + TaskOptions options = 7; + // Options passed to create the default remote task excution environment. + TaskOptions baseline_options = 8; } message ClientTaskTicket { // Was the task successful? bool valid = 1; - // A reference to the returned value from the execution. - bytes return_id = 2; + // A reference to the returned values from the execution. + repeated bytes return_ids = 2; // If unsuccessful, an encoding of the error. bytes error = 3; }