Handle exchange of direct call objects between tasks and actors (#6147)

This commit is contained in:
Eric Liang
2019-11-14 17:32:04 -08:00
committed by GitHub
parent 385783fcec
commit 8ff393a7bd
23 changed files with 426 additions and 202 deletions
+6
View File
@@ -1015,6 +1015,12 @@ cdef class CoreWorker:
# Note: faster to not release GIL for short-running op.
self.core_worker.get().RemoveObjectIDReference(c_object_id)
def promote_object_to_plasma(self, ObjectID object_id):
cdef:
CObjectID c_object_id = object_id.native()
self.core_worker.get().PromoteObjectToPlasma(c_object_id)
return object_id.with_plasma_transport_type()
# TODO: handle noreturn better
cdef store_task_outputs(
self, worker, outputs, const c_vector[CObjectID] return_ids,
+1
View File
@@ -104,6 +104,7 @@ cdef extern from "ray/core_worker/core_worker.h" nogil:
*bytes)
void AddObjectIDReference(const CObjectID &object_id)
void RemoveObjectIDReference(const CObjectID &object_id)
void PromoteObjectToPlasma(const CObjectID &object_id)
CRayStatus SetClientOptions(c_string client_name, int64_t limit)
CRayStatus Put(const CRayObject &object, CObjectID *object_id)
+3 -1
View File
@@ -95,8 +95,10 @@ cdef class TaskSpec:
:self.task_spec.get().ArgMetadataSize(i)]
if metadata == RAW_BUFFER_METADATA:
obj = data
else:
elif metadata == PICKLE_BUFFER_METADATA:
obj = pickle.loads(data)
else:
obj = data
arg_list.append(obj)
elif lang == <int32_t>LANGUAGE_JAVA:
arg_list = num_args * ["<java-argument>"]
+2
View File
@@ -150,6 +150,8 @@ cdef extern from "ray/common/id.h" namespace "ray" nogil:
c_bool IsDirectCallType()
CObjectID WithPlasmaTransportType()
int64_t ObjectIndex() const
CTaskID TaskId() const
+3
View File
@@ -179,6 +179,9 @@ cdef class ObjectID(BaseID):
def is_direct_call_type(self):
return self.data.IsDirectCallType()
def with_plasma_transport_type(self):
return ObjectID(self.data.WithPlasmaTransportType().Binary())
def is_nil(self):
return self.data.IsNil()
+2 -6
View File
@@ -158,9 +158,7 @@ class SerializationContext(object):
def id_serializer(obj):
if isinstance(obj, ray.ObjectID) and obj.is_direct_call_type():
raise NotImplementedError(
"Objects produced by direct actor calls cannot be "
"passed to other tasks as arguments.")
obj = self.worker.core_worker.promote_object_to_plasma(obj)
return pickle.dumps(obj)
def id_deserializer(serialized_obj):
@@ -191,9 +189,7 @@ class SerializationContext(object):
def id_serializer(obj):
if isinstance(obj, ray.ObjectID) and obj.is_direct_call_type():
raise NotImplementedError(
"Objects produced by direct actor calls cannot be "
"passed to other tasks as arguments.")
obj = self.worker.core_worker.promote_object_to_plasma(obj)
return obj.__reduce__()
def id_deserializer(serialized_obj):
+65 -20
View File
@@ -1218,6 +1218,71 @@ def test_direct_call_simple(ray_start_regular):
range(1, 101))
def test_direct_call_matrix(shutdown_only):
ray.init(object_store_memory=1000 * 1024 * 1024)
@ray.remote
class Actor(object):
def small_value(self):
return 0
def large_value(self):
return np.zeros(10 * 1024 * 1024)
def echo(self, x):
if isinstance(x, list):
x = ray.get(x[0])
return x
@ray.remote
def small_value():
return 0
@ray.remote
def large_value():
return np.zeros(10 * 1024 * 1024)
@ray.remote
def echo(x):
if isinstance(x, list):
x = ray.get(x[0])
return x
def check(source_actor, dest_actor, is_large, out_of_band):
print("CHECKING", "actor" if source_actor else "task", "to", "actor"
if dest_actor else "task", "large_object"
if is_large else "small_object", "out_of_band"
if out_of_band else "in_band")
if source_actor:
a = Actor.options(is_direct_call=True).remote()
if is_large:
x_id = a.large_value.remote()
else:
x_id = a.small_value.remote()
else:
if is_large:
x_id = large_value.options(is_direct_call=True).remote()
else:
x_id = small_value.options(is_direct_call=True).remote()
if out_of_band:
x_id = [x_id]
if dest_actor:
b = Actor.options(is_direct_call=True).remote()
x = ray.get(b.echo.remote(x_id))
else:
x = ray.get(echo.options(is_direct_call=True).remote(x_id))
if is_large:
assert isinstance(x, np.ndarray)
else:
assert isinstance(x, int)
for is_large in [False, True]:
for source_actor in [False, True]:
for dest_actor in [False, True]:
for out_of_band in [False, True]:
check(source_actor, dest_actor, is_large, out_of_band)
def test_direct_call_chain(ray_start_regular):
@ray.remote
def g(x):
@@ -1265,26 +1330,6 @@ def test_direct_actor_large_objects(ray_start_regular):
assert isinstance(ray.get(obj_id), np.ndarray)
def test_direct_actor_errors(ray_start_regular):
@ray.remote
class Actor(object):
def __init__(self):
pass
def f(self, x):
return x * 2
@ray.remote
def f(x):
return 1
a = Actor._remote(is_direct_call=True)
# cannot pass returns to other methods even in a list
with pytest.raises(Exception):
ray.get(f.remote([a.f.remote(2)]))
def test_direct_actor_pass_by_ref(ray_start_regular):
@ray.remote
class Actor(object):