mirror of
https://github.com/wassname/ray.git
synced 2026-06-28 15:22:56 +08:00
Handle exchange of direct call objects between tasks and actors (#6147)
This commit is contained in:
@@ -1015,6 +1015,12 @@ cdef class CoreWorker:
|
||||
# Note: faster to not release GIL for short-running op.
|
||||
self.core_worker.get().RemoveObjectIDReference(c_object_id)
|
||||
|
||||
def promote_object_to_plasma(self, ObjectID object_id):
|
||||
cdef:
|
||||
CObjectID c_object_id = object_id.native()
|
||||
self.core_worker.get().PromoteObjectToPlasma(c_object_id)
|
||||
return object_id.with_plasma_transport_type()
|
||||
|
||||
# TODO: handle noreturn better
|
||||
cdef store_task_outputs(
|
||||
self, worker, outputs, const c_vector[CObjectID] return_ids,
|
||||
|
||||
@@ -104,6 +104,7 @@ cdef extern from "ray/core_worker/core_worker.h" nogil:
|
||||
*bytes)
|
||||
void AddObjectIDReference(const CObjectID &object_id)
|
||||
void RemoveObjectIDReference(const CObjectID &object_id)
|
||||
void PromoteObjectToPlasma(const CObjectID &object_id)
|
||||
|
||||
CRayStatus SetClientOptions(c_string client_name, int64_t limit)
|
||||
CRayStatus Put(const CRayObject &object, CObjectID *object_id)
|
||||
|
||||
@@ -95,8 +95,10 @@ cdef class TaskSpec:
|
||||
:self.task_spec.get().ArgMetadataSize(i)]
|
||||
if metadata == RAW_BUFFER_METADATA:
|
||||
obj = data
|
||||
else:
|
||||
elif metadata == PICKLE_BUFFER_METADATA:
|
||||
obj = pickle.loads(data)
|
||||
else:
|
||||
obj = data
|
||||
arg_list.append(obj)
|
||||
elif lang == <int32_t>LANGUAGE_JAVA:
|
||||
arg_list = num_args * ["<java-argument>"]
|
||||
|
||||
@@ -150,6 +150,8 @@ cdef extern from "ray/common/id.h" namespace "ray" nogil:
|
||||
|
||||
c_bool IsDirectCallType()
|
||||
|
||||
CObjectID WithPlasmaTransportType()
|
||||
|
||||
int64_t ObjectIndex() const
|
||||
|
||||
CTaskID TaskId() const
|
||||
|
||||
@@ -179,6 +179,9 @@ cdef class ObjectID(BaseID):
|
||||
def is_direct_call_type(self):
|
||||
return self.data.IsDirectCallType()
|
||||
|
||||
def with_plasma_transport_type(self):
|
||||
return ObjectID(self.data.WithPlasmaTransportType().Binary())
|
||||
|
||||
def is_nil(self):
|
||||
return self.data.IsNil()
|
||||
|
||||
|
||||
@@ -158,9 +158,7 @@ class SerializationContext(object):
|
||||
|
||||
def id_serializer(obj):
|
||||
if isinstance(obj, ray.ObjectID) and obj.is_direct_call_type():
|
||||
raise NotImplementedError(
|
||||
"Objects produced by direct actor calls cannot be "
|
||||
"passed to other tasks as arguments.")
|
||||
obj = self.worker.core_worker.promote_object_to_plasma(obj)
|
||||
return pickle.dumps(obj)
|
||||
|
||||
def id_deserializer(serialized_obj):
|
||||
@@ -191,9 +189,7 @@ class SerializationContext(object):
|
||||
|
||||
def id_serializer(obj):
|
||||
if isinstance(obj, ray.ObjectID) and obj.is_direct_call_type():
|
||||
raise NotImplementedError(
|
||||
"Objects produced by direct actor calls cannot be "
|
||||
"passed to other tasks as arguments.")
|
||||
obj = self.worker.core_worker.promote_object_to_plasma(obj)
|
||||
return obj.__reduce__()
|
||||
|
||||
def id_deserializer(serialized_obj):
|
||||
|
||||
@@ -1218,6 +1218,71 @@ def test_direct_call_simple(ray_start_regular):
|
||||
range(1, 101))
|
||||
|
||||
|
||||
def test_direct_call_matrix(shutdown_only):
|
||||
ray.init(object_store_memory=1000 * 1024 * 1024)
|
||||
|
||||
@ray.remote
|
||||
class Actor(object):
|
||||
def small_value(self):
|
||||
return 0
|
||||
|
||||
def large_value(self):
|
||||
return np.zeros(10 * 1024 * 1024)
|
||||
|
||||
def echo(self, x):
|
||||
if isinstance(x, list):
|
||||
x = ray.get(x[0])
|
||||
return x
|
||||
|
||||
@ray.remote
|
||||
def small_value():
|
||||
return 0
|
||||
|
||||
@ray.remote
|
||||
def large_value():
|
||||
return np.zeros(10 * 1024 * 1024)
|
||||
|
||||
@ray.remote
|
||||
def echo(x):
|
||||
if isinstance(x, list):
|
||||
x = ray.get(x[0])
|
||||
return x
|
||||
|
||||
def check(source_actor, dest_actor, is_large, out_of_band):
|
||||
print("CHECKING", "actor" if source_actor else "task", "to", "actor"
|
||||
if dest_actor else "task", "large_object"
|
||||
if is_large else "small_object", "out_of_band"
|
||||
if out_of_band else "in_band")
|
||||
if source_actor:
|
||||
a = Actor.options(is_direct_call=True).remote()
|
||||
if is_large:
|
||||
x_id = a.large_value.remote()
|
||||
else:
|
||||
x_id = a.small_value.remote()
|
||||
else:
|
||||
if is_large:
|
||||
x_id = large_value.options(is_direct_call=True).remote()
|
||||
else:
|
||||
x_id = small_value.options(is_direct_call=True).remote()
|
||||
if out_of_band:
|
||||
x_id = [x_id]
|
||||
if dest_actor:
|
||||
b = Actor.options(is_direct_call=True).remote()
|
||||
x = ray.get(b.echo.remote(x_id))
|
||||
else:
|
||||
x = ray.get(echo.options(is_direct_call=True).remote(x_id))
|
||||
if is_large:
|
||||
assert isinstance(x, np.ndarray)
|
||||
else:
|
||||
assert isinstance(x, int)
|
||||
|
||||
for is_large in [False, True]:
|
||||
for source_actor in [False, True]:
|
||||
for dest_actor in [False, True]:
|
||||
for out_of_band in [False, True]:
|
||||
check(source_actor, dest_actor, is_large, out_of_band)
|
||||
|
||||
|
||||
def test_direct_call_chain(ray_start_regular):
|
||||
@ray.remote
|
||||
def g(x):
|
||||
@@ -1265,26 +1330,6 @@ def test_direct_actor_large_objects(ray_start_regular):
|
||||
assert isinstance(ray.get(obj_id), np.ndarray)
|
||||
|
||||
|
||||
def test_direct_actor_errors(ray_start_regular):
|
||||
@ray.remote
|
||||
class Actor(object):
|
||||
def __init__(self):
|
||||
pass
|
||||
|
||||
def f(self, x):
|
||||
return x * 2
|
||||
|
||||
@ray.remote
|
||||
def f(x):
|
||||
return 1
|
||||
|
||||
a = Actor._remote(is_direct_call=True)
|
||||
|
||||
# cannot pass returns to other methods even in a list
|
||||
with pytest.raises(Exception):
|
||||
ray.get(f.remote([a.f.remote(2)]))
|
||||
|
||||
|
||||
def test_direct_actor_pass_by_ref(ray_start_regular):
|
||||
@ray.remote
|
||||
class Actor(object):
|
||||
|
||||
Reference in New Issue
Block a user