Collect contained ObjectIDs during deserialization (#7029)

This commit is contained in:
Edward Oakes
2020-02-03 22:49:14 -08:00
committed by GitHub
parent 5e8ded344a
commit 844f607c93
5 changed files with 44 additions and 2 deletions
+9
View File
@@ -943,6 +943,15 @@ cdef class CoreWorker:
c_owner_id,
c_owner_address)
def add_contained_object_ids(
self, ObjectID object_id, contained_object_ids):
cdef:
CObjectID c_object_id = object_id.native()
c_vector[CObjectID] c_contained_ids
c_contained_ids = ObjectIDsToVector(contained_object_ids)
self.core_worker.get().AddContainedObjectIDs(
c_object_id, c_contained_ids)
# TODO: handle noreturn better
cdef store_task_outputs(
self, worker, outputs, const c_vector[CObjectID] return_ids,
+3
View File
@@ -128,6 +128,9 @@ cdef extern from "ray/core_worker/core_worker.h" nogil:
void RegisterOwnershipInfoAndResolveFuture(
const CObjectID &object_id, const CTaskID &owner_id, const
CAddress &owner_address)
void AddContainedObjectIDs(
const CObjectID &object_id,
const c_vector[CObjectID] &contained_object_ids)
CRayStatus SetClientOptions(c_string client_name, int64_t limit)
CRayStatus Put(const CRayObject &object,
+19 -2
View File
@@ -173,6 +173,11 @@ class SerializationContext:
# that the ref count for the ObjectID is greater than 0 by the
# time the core worker resolves the value of the object.
deserialized_object_id = id_deserializer(obj_id)
# TODO(edoakes): we should be able to just capture a reference
# to 'self' here instead, but this function is itself pickled
# somewhere, which causes an error.
context = ray.worker.global_worker.get_serialization_context()
context.add_contained_object_id(deserialized_object_id)
if owner_id:
worker = ray.worker.get_global_worker()
worker.check_connected()
@@ -225,12 +230,24 @@ class SerializationContext:
try:
in_band, buffers = unpack_pickle5_buffers(data)
if len(buffers) > 0:
return pickle.loads(in_band, buffers=buffers)
obj = pickle.loads(in_band, buffers=buffers)
else:
return pickle.loads(in_band)
obj = pickle.loads(in_band)
# cloudpickle does not provide error types
except pickle.pickle.PicklingError:
raise DeserializationError()
# Check that there are no ObjectIDs serialized in arguments
# that are inlined.
if object_id.is_nil():
assert len(self.get_and_clear_contained_object_ids()) == 0
else:
worker = ray.worker.global_worker
worker.core_worker.add_contained_object_ids(
object_id,
self.get_and_clear_contained_object_ids(),
)
return obj
# Check if the object should be returned as raw bytes.
if metadata == ray_constants.RAW_BUFFER_METADATA:
if data is None: