diff --git a/python/ray/actor.py b/python/ray/actor.py index 3917042f2..7c2420802 100644 --- a/python/ray/actor.py +++ b/python/ray/actor.py @@ -109,10 +109,36 @@ def method(*args, **kwargs): # Create objects to wrap method invocations. This is done so that we can # invoke methods with actor.method.remote() instead of actor.method(). class ActorMethod(object): - def __init__(self, actor, method_name, num_return_vals): + """A class used to invoke an actor method. + + Note: This class is instantiated only while the actor method is being + invoked (so that it doesn't keep a reference to the actor handle and + prevent it from going out of scope). + + Attributes: + _actor: A handle to the actor. + _method_name: The name of the actor method. + _num_return_vals: The default number of return values that the method + invocation should return. + _decorator: An optional decorator that should be applied to the actor + method invocation (as opposed to the actor method execution) before + invoking the method. The decorator must return a function that + takes in two arguments ("args" and "kwargs"). In most cases, it + should call the function that was passed into the decorator and + return the resulting ObjectIDs. For an example, see + "test_decorated_method" in "python/ray/tests/test_actor.py". + """ + + def __init__(self, actor, method_name, num_return_vals, decorator=None): self._actor = actor self._method_name = method_name self._num_return_vals = num_return_vals + # This is a decorator that is used to wrap the function invocation (as + # opposed to the function execution). The decorator must return a + # function that takes in two arguments ("args" and "kwargs"). In most + # cases, it should call the function that was passed into the decorator + # and return the resulting ObjectIDs. + self._decorator = decorator def __call__(self, *args, **kwargs): raise Exception("Actor methods cannot be called directly. Instead " @@ -131,11 +157,18 @@ class ActorMethod(object): if num_return_vals is None: num_return_vals = self._num_return_vals - return self._actor._actor_method_call( - self._method_name, - args=args, - kwargs=kwargs, - num_return_vals=num_return_vals) + def invocation(args, kwargs): + return self._actor._actor_method_call( + self._method_name, + args=args, + kwargs=kwargs, + num_return_vals=num_return_vals) + + # Apply the decorator if there is one. + if self._decorator is not None: + invocation = self._decorator(invocation) + + return invocation(args, kwargs) class ActorClass(object): @@ -157,6 +190,10 @@ class ActorClass(object): _exported: True if the actor class has been exported and false otherwise. _actor_methods: The actor methods. + _method_decorators: Optional decorators that should be applied to the + method invocation function before invoking the actor methods. These + can be set by attaching the attribute + "__ray_invocation_decorator__" to the actor method. _method_signatures: The signatures of the methods. _actor_method_names: The names of the actor methods. _actor_method_num_return_vals: The default number of return values for @@ -196,6 +233,7 @@ class ActorClass(object): # Extract the signatures of each of the methods. This will be used # to catch some errors if the methods are called with inappropriate # arguments. + self._method_decorators = {} self._method_signatures = {} self._actor_method_num_return_vals = {} for method_name, method in self._actor_methods: @@ -214,6 +252,10 @@ class ActorClass(object): self._actor_method_num_return_vals[method_name] = ( ray_constants.DEFAULT_ACTOR_METHOD_NUM_RETURN_VALS) + if hasattr(method, "__ray_invocation_decorator__"): + self._method_decorators[method_name] = ( + method.__ray_invocation_decorator__) + def __call__(self, *args, **kwargs): raise Exception("Actors methods cannot be instantiated directly. " "Instead of running '{}()', try '{}.remote()'.".format( @@ -337,9 +379,9 @@ class ActorClass(object): actor_handle = ActorHandle( actor_id, self._modified_class.__module__, self._class_name, - actor_cursor, self._actor_method_names, self._method_signatures, - self._actor_method_num_return_vals, actor_cursor, actor_method_cpu, - worker.task_driver_id) + actor_cursor, self._actor_method_names, self._method_decorators, + self._method_signatures, self._actor_method_num_return_vals, + actor_cursor, actor_method_cpu, worker.task_driver_id) # We increment the actor counter by 1 to account for the actor creation # task. actor_handle._ray_actor_counter += 1 @@ -381,6 +423,10 @@ class ActorHandle(object): _ray_actor_counter: The number of actor method invocations that we've called so far. _ray_actor_method_names: The names of the actor methods. + _ray_method_decorators: Optional decorators for the function + invocation. This can be used to change the behavior on the + invocation side, whereas a regular decorator can be used to change + the behavior on the execution side. _ray_method_signatures: The signatures of the actor methods. _ray_method_num_return_vals: The default number of return values for each method. @@ -407,6 +453,7 @@ class ActorHandle(object): class_name, actor_cursor, actor_method_names, + method_decorators, method_signatures, method_num_return_vals, actor_creation_dummy_object_id, @@ -428,6 +475,7 @@ class ActorHandle(object): self._ray_actor_cursor = actor_cursor self._ray_actor_counter = 0 self._ray_actor_method_names = actor_method_names + self._ray_method_decorators = method_decorators self._ray_method_signatures = method_signatures self._ray_method_num_return_vals = method_num_return_vals self._ray_class_name = class_name @@ -530,8 +578,11 @@ class ActorHandle(object): # this was causing cyclic references which were prevent # object deallocation from behaving in a predictable # manner. - return ActorMethod(self, attr, - self._ray_method_num_return_vals[attr]) + return ActorMethod( + self, + attr, + self._ray_method_num_return_vals[attr], + decorator=self._ray_method_decorators.get(attr)) except AttributeError: pass @@ -600,6 +651,7 @@ class ActorHandle(object): "class_name": self._ray_class_name, "actor_cursor": self._ray_actor_cursor, "actor_method_names": self._ray_actor_method_names, + "method_decorators": self._ray_method_decorators, "method_signatures": self._ray_method_signatures, "method_num_return_vals": self._ray_method_num_return_vals, # Actors in local mode don't have dummy objects. @@ -662,6 +714,7 @@ class ActorHandle(object): state["class_name"], state["actor_cursor"], state["actor_method_names"], + state["method_decorators"], state["method_signatures"], state["method_num_return_vals"], state["actor_creation_dummy_object_id"], diff --git a/python/ray/remote_function.py b/python/ray/remote_function.py index c26dd5884..3bc3fc2bd 100644 --- a/python/ray/remote_function.py +++ b/python/ray/remote_function.py @@ -35,6 +35,13 @@ class RemoteFunction(object): of this remote function. _max_calls: The number of times a worker can execute this function before executing. + _decorator: An optional decorator that should be applied to the remote + function invocation (as opposed to the function execution) before + invoking the function. The decorator must return a function that + takes in two arguments ("args" and "kwargs"). In most cases, it + should call the function that was passed into the decorator and + return the resulting ObjectIDs. For an example, see + "test_decorated_function" in "python/ray/tests/test_basic.py". _function_signature: The function signature. """ @@ -52,6 +59,8 @@ class RemoteFunction(object): num_return_vals is None else num_return_vals) self._max_calls = (DEFAULT_REMOTE_FUNCTION_MAX_CALLS if max_calls is None else max_calls) + self._decorator = getattr(function, "__ray_invocation_decorator__", + None) ray.signature.check_signature_supported(self._function) self._function_signature = ray.signature.extract_signature( @@ -108,8 +117,6 @@ class RemoteFunction(object): kwargs = {} if kwargs is None else kwargs args = [] if args is None else args - args = ray.signature.extend_args(self._function_signature, args, - kwargs) if num_return_vals is None: num_return_vals = self._num_return_vals @@ -117,19 +124,29 @@ class RemoteFunction(object): resources = ray.utils.resources_from_resource_arguments( self._num_cpus, self._num_gpus, self._resources, num_cpus, num_gpus, resources) - if worker.mode == ray.worker.LOCAL_MODE: - # In LOCAL_MODE, remote calls simply execute the function. - # We copy the arguments to prevent the function call from - # mutating them and to match the usual behavior of - # immutable remote objects. - result = self._function(*copy.deepcopy(args)) - return result - object_ids = worker.submit_task( - self._function_descriptor, - args, - num_return_vals=num_return_vals, - resources=resources) - if len(object_ids) == 1: - return object_ids[0] - elif len(object_ids) > 1: - return object_ids + + def invocation(args, kwargs): + args = ray.signature.extend_args(self._function_signature, args, + kwargs) + + if worker.mode == ray.worker.LOCAL_MODE: + # In LOCAL_MODE, remote calls simply execute the function. + # We copy the arguments to prevent the function call from + # mutating them and to match the usual behavior of + # immutable remote objects. + result = self._function(*copy.deepcopy(args)) + return result + object_ids = worker.submit_task( + self._function_descriptor, + args, + num_return_vals=num_return_vals, + resources=resources) + if len(object_ids) == 1: + return object_ids[0] + elif len(object_ids) > 1: + return object_ids + + if self._decorator is not None: + invocation = self._decorator(invocation) + + return invocation(args, kwargs) diff --git a/python/ray/tests/test_actor.py b/python/ray/tests/test_actor.py index b072d6926..d7da081fd 100644 --- a/python/ray/tests/test_actor.py +++ b/python/ray/tests/test_actor.py @@ -2576,3 +2576,35 @@ def test_init_exception_in_checkpointable_actor(ray_start_regular, errors = relevant_errors(ray_constants.TASK_PUSH_ERROR) assert len(errors) == 2 assert error_message1 in errors[1]["message"] + + +def test_decorated_method(ray_start_regular): + def method_invocation_decorator(f): + def new_f_invocation(args, kwargs): + # Split one argument into two. Return th kwargs without passing + # them into the actor. + return f([args[0], args[0]], {}), kwargs + + return new_f_invocation + + def method_execution_decorator(f): + def new_f_execution(self, b, c): + # Turn two arguments into one. + return f(self, b + c) + + new_f_execution.__ray_invocation_decorator__ = ( + method_invocation_decorator) + return new_f_execution + + @ray.remote + class Actor(object): + @method_execution_decorator + def decorated_method(self, x): + return x + 1 + + a = Actor.remote() + + object_id, extra = a.decorated_method.remote(3, kwarg=3) + assert isinstance(object_id, ray.ObjectID) + assert extra == {"kwarg": 3} + assert ray.get(object_id) == 7 # 2 * 3 + 1 diff --git a/python/ray/tests/test_basic.py b/python/ray/tests/test_basic.py index 5fa7f3dbf..3f8c7cb2b 100644 --- a/python/ray/tests/test_basic.py +++ b/python/ray/tests/test_basic.py @@ -2892,3 +2892,32 @@ def test_redis_lru_with_set(ray_start_object_store_memory): # Now evict the object from the object store. ray.put(x) # This should not crash. + + +def test_decorated_function(ray_start_regular): + def function_invocation_decorator(f): + def new_f(args, kwargs): + # Reverse the arguments. + return f(args[::-1], {"d": 5}), kwargs + + return new_f + + def f(a, b, c, d=None): + return a, b, c, d + + f.__ray_invocation_decorator__ = function_invocation_decorator + f = ray.remote(f) + + result_id, kwargs = f.remote(1, 2, 3, d=4) + assert kwargs == {"d": 4} + assert ray.get(result_id) == (3, 2, 1, 5) + + +def test_get_postprocess(ray_start_regular): + def get_postprocessor(object_ids, values): + return [value for value in values if value > 0] + + ray.worker.global_worker._post_get_hooks.append(get_postprocessor) + + assert ray.get( + [ray.put(i) for i in [0, 1, 3, 5, -1, -3, 4]]) == [1, 3, 5, 4] diff --git a/python/ray/worker.py b/python/ray/worker.py index b8fd28b30..c3a9f53ce 100644 --- a/python/ray/worker.py +++ b/python/ray/worker.py @@ -156,6 +156,9 @@ class Worker(object): # increment every time when `ray.shutdown` is called. self._session_index = 0 self._current_task = None + # Functions to run to process the values returned by ray.get. Each + # postprocessor must take two arguments ("object_ids", and "values"). + self._post_get_hooks = [] @property def connected(self): @@ -1455,7 +1458,7 @@ def init(redis_address=None, return _global_node.address_info -# Functions to run as callback after a successful ray init +# Functions to run as callback after a successful ray init. _post_init_hooks = [] @@ -1493,7 +1496,10 @@ def shutdown(exiting_interpreter=False): _global_node.kill_all_processes(check_alive=False, allow_graceful=True) _global_node = None + # TODO(rkn): Instead of manually reseting some of the worker fields, we + # should simply set "global_worker" to equal "None" or something like that. global_worker.set_mode(None) + global_worker._post_get_hooks = [] atexit.register(shutdown, True) @@ -2175,23 +2181,29 @@ def get(object_ids): # In LOCAL_MODE, ray.get is the identity operation (the input will # actually be a value not an objectid). return object_ids + + is_individual_id = isinstance(object_ids, ray.ObjectID) + if is_individual_id: + object_ids = [object_ids] + + if not isinstance(object_ids, list): + raise ValueError("'object_ids' must either by an object ID " + "or a list of object IDs.") + global last_task_error_raise_time - if isinstance(object_ids, list): - values = worker.get_object(object_ids) - for i, value in enumerate(values): - if isinstance(value, RayError): - last_task_error_raise_time = time.time() - raise value - return values - else: - value = worker.get_object([object_ids])[0] + values = worker.get_object(object_ids) + for i, value in enumerate(values): if isinstance(value, RayError): - # If the result is a RayError, then the task that created - # this object failed, and we should propagate the error message - # here. last_task_error_raise_time = time.time() raise value - return value + + # Run post processors. + for post_processor in worker._post_get_hooks: + values = post_processor(object_ids, values) + + if is_individual_id: + values = values[0] + return values def put(value):