Fix passing object ids in local mode (#6170)

This commit is contained in:
Edward Oakes
2019-11-15 15:46:39 -08:00
committed by GitHub
parent 33040d734f
commit dee696577f
2 changed files with 46 additions and 9 deletions
+35 -9
View File
@@ -5,6 +5,7 @@ from __future__ import print_function
import copy
import traceback
import ray
from ray import ObjectID
from ray.utils import format_error_message
from ray.exceptions import RayTaskError
@@ -20,7 +21,18 @@ class LocalModeObjectID(ObjectID):
it equates to the object not existing in the object store. This is
necessary because None is a valid object value.
"""
pass
def __copy__(self):
new = LocalModeObjectID(self.binary())
if hasattr(self, "value"):
new.value = self.value
return new
def __deepcopy__(self, memo=None):
new = LocalModeObjectID(self.binary())
if hasattr(self, "value"):
new.value = self.value
return new
class LocalModeManager(object):
@@ -49,23 +61,37 @@ class LocalModeManager(object):
Returns:
LocalModeObjectIDs corresponding to the function return values.
"""
object_ids = [
return_ids = [
LocalModeObjectID.from_random() for _ in range(num_return_vals)
]
try:
results = function(*copy.deepcopy(args), **copy.deepcopy(kwargs))
if num_return_vals == 1:
object_ids[0].value = results
new_args = []
for i, arg in enumerate(args):
if isinstance(arg, ObjectID):
new_args.append(ray.get(arg))
else:
for object_id, result in zip(object_ids, results):
new_args.append(copy.deepcopy(arg))
new_kwargs = {}
for k, v in kwargs.items():
if isinstance(v, ObjectID):
new_kwargs[k] = ray.get(v)
else:
new_kwargs[k] = copy.deepcopy(v)
try:
results = function(*new_args, **new_kwargs)
if num_return_vals == 1:
return_ids[0].value = results
else:
for object_id, result in zip(return_ids, results):
object_id.value = result
except Exception as e:
backtrace = format_error_message(traceback.format_exc())
task_error = RayTaskError(function_name, backtrace, e.__class__)
for object_id in object_ids:
for object_id in return_ids:
object_id.value = task_error
return object_ids
return return_ids
def put_object(self, value):
"""Store an object in the emulated object store.
+11
View File
@@ -2227,6 +2227,17 @@ def test_local_mode(shutdown_only):
_ = RemoteActor2.remote()
assert ray.get(actor1.function1.remote()) == 0
# Test passing ObjectIDs.
@ray.remote
def direct_dep(input):
return input
@ray.remote
def indirect_dep(input):
return ray.get(direct_dep.remote(input[0]))
assert ray.get(indirect_dep.remote(["hello"])) == "hello"
def test_resource_constraints(shutdown_only):
num_workers = 20