mirror of
https://github.com/wassname/ray.git
synced 2026-06-28 03:02:56 +08:00
Fix passing object ids in local mode (#6170)
This commit is contained in:
@@ -5,6 +5,7 @@ from __future__ import print_function
|
||||
import copy
|
||||
import traceback
|
||||
|
||||
import ray
|
||||
from ray import ObjectID
|
||||
from ray.utils import format_error_message
|
||||
from ray.exceptions import RayTaskError
|
||||
@@ -20,7 +21,18 @@ class LocalModeObjectID(ObjectID):
|
||||
it equates to the object not existing in the object store. This is
|
||||
necessary because None is a valid object value.
|
||||
"""
|
||||
pass
|
||||
|
||||
def __copy__(self):
|
||||
new = LocalModeObjectID(self.binary())
|
||||
if hasattr(self, "value"):
|
||||
new.value = self.value
|
||||
return new
|
||||
|
||||
def __deepcopy__(self, memo=None):
|
||||
new = LocalModeObjectID(self.binary())
|
||||
if hasattr(self, "value"):
|
||||
new.value = self.value
|
||||
return new
|
||||
|
||||
|
||||
class LocalModeManager(object):
|
||||
@@ -49,23 +61,37 @@ class LocalModeManager(object):
|
||||
Returns:
|
||||
LocalModeObjectIDs corresponding to the function return values.
|
||||
"""
|
||||
object_ids = [
|
||||
return_ids = [
|
||||
LocalModeObjectID.from_random() for _ in range(num_return_vals)
|
||||
]
|
||||
try:
|
||||
results = function(*copy.deepcopy(args), **copy.deepcopy(kwargs))
|
||||
if num_return_vals == 1:
|
||||
object_ids[0].value = results
|
||||
new_args = []
|
||||
for i, arg in enumerate(args):
|
||||
if isinstance(arg, ObjectID):
|
||||
new_args.append(ray.get(arg))
|
||||
else:
|
||||
for object_id, result in zip(object_ids, results):
|
||||
new_args.append(copy.deepcopy(arg))
|
||||
|
||||
new_kwargs = {}
|
||||
for k, v in kwargs.items():
|
||||
if isinstance(v, ObjectID):
|
||||
new_kwargs[k] = ray.get(v)
|
||||
else:
|
||||
new_kwargs[k] = copy.deepcopy(v)
|
||||
|
||||
try:
|
||||
results = function(*new_args, **new_kwargs)
|
||||
if num_return_vals == 1:
|
||||
return_ids[0].value = results
|
||||
else:
|
||||
for object_id, result in zip(return_ids, results):
|
||||
object_id.value = result
|
||||
except Exception as e:
|
||||
backtrace = format_error_message(traceback.format_exc())
|
||||
task_error = RayTaskError(function_name, backtrace, e.__class__)
|
||||
for object_id in object_ids:
|
||||
for object_id in return_ids:
|
||||
object_id.value = task_error
|
||||
|
||||
return object_ids
|
||||
return return_ids
|
||||
|
||||
def put_object(self, value):
|
||||
"""Store an object in the emulated object store.
|
||||
|
||||
@@ -2227,6 +2227,17 @@ def test_local_mode(shutdown_only):
|
||||
_ = RemoteActor2.remote()
|
||||
assert ray.get(actor1.function1.remote()) == 0
|
||||
|
||||
# Test passing ObjectIDs.
|
||||
@ray.remote
|
||||
def direct_dep(input):
|
||||
return input
|
||||
|
||||
@ray.remote
|
||||
def indirect_dep(input):
|
||||
return ray.get(direct_dep.remote(input[0]))
|
||||
|
||||
assert ray.get(indirect_dep.remote(["hello"])) == "hello"
|
||||
|
||||
|
||||
def test_resource_constraints(shutdown_only):
|
||||
num_workers = 20
|
||||
|
||||
Reference in New Issue
Block a user