diff --git a/python/ray/local_mode_manager.py b/python/ray/local_mode_manager.py index 3d9d70a85..11d74784a 100644 --- a/python/ray/local_mode_manager.py +++ b/python/ray/local_mode_manager.py @@ -5,6 +5,7 @@ from __future__ import print_function import copy import traceback +import ray from ray import ObjectID from ray.utils import format_error_message from ray.exceptions import RayTaskError @@ -20,7 +21,18 @@ class LocalModeObjectID(ObjectID): it equates to the object not existing in the object store. This is necessary because None is a valid object value. """ - pass + + def __copy__(self): + new = LocalModeObjectID(self.binary()) + if hasattr(self, "value"): + new.value = self.value + return new + + def __deepcopy__(self, memo=None): + new = LocalModeObjectID(self.binary()) + if hasattr(self, "value"): + new.value = self.value + return new class LocalModeManager(object): @@ -49,23 +61,37 @@ class LocalModeManager(object): Returns: LocalModeObjectIDs corresponding to the function return values. """ - object_ids = [ + return_ids = [ LocalModeObjectID.from_random() for _ in range(num_return_vals) ] - try: - results = function(*copy.deepcopy(args), **copy.deepcopy(kwargs)) - if num_return_vals == 1: - object_ids[0].value = results + new_args = [] + for i, arg in enumerate(args): + if isinstance(arg, ObjectID): + new_args.append(ray.get(arg)) else: - for object_id, result in zip(object_ids, results): + new_args.append(copy.deepcopy(arg)) + + new_kwargs = {} + for k, v in kwargs.items(): + if isinstance(v, ObjectID): + new_kwargs[k] = ray.get(v) + else: + new_kwargs[k] = copy.deepcopy(v) + + try: + results = function(*new_args, **new_kwargs) + if num_return_vals == 1: + return_ids[0].value = results + else: + for object_id, result in zip(return_ids, results): object_id.value = result except Exception as e: backtrace = format_error_message(traceback.format_exc()) task_error = RayTaskError(function_name, backtrace, e.__class__) - for object_id in object_ids: + for object_id in return_ids: object_id.value = task_error - return object_ids + return return_ids def put_object(self, value): """Store an object in the emulated object store. diff --git a/python/ray/tests/test_basic.py b/python/ray/tests/test_basic.py index 6e8cadd5f..d229d1767 100644 --- a/python/ray/tests/test_basic.py +++ b/python/ray/tests/test_basic.py @@ -2227,6 +2227,17 @@ def test_local_mode(shutdown_only): _ = RemoteActor2.remote() assert ray.get(actor1.function1.remote()) == 0 + # Test passing ObjectIDs. + @ray.remote + def direct_dep(input): + return input + + @ray.remote + def indirect_dep(input): + return ray.get(direct_dep.remote(input[0])) + + assert ray.get(indirect_dep.remote(["hello"])) == "hello" + def test_resource_constraints(shutdown_only): num_workers = 20