diff --git a/python/ray/_raylet.pyx b/python/ray/_raylet.pyx index 4c4fe7519..b95f993ac 100644 --- a/python/ray/_raylet.pyx +++ b/python/ray/_raylet.pyx @@ -943,6 +943,15 @@ cdef class CoreWorker: c_owner_id, c_owner_address) + def add_contained_object_ids( + self, ObjectID object_id, contained_object_ids): + cdef: + CObjectID c_object_id = object_id.native() + c_vector[CObjectID] c_contained_ids + c_contained_ids = ObjectIDsToVector(contained_object_ids) + self.core_worker.get().AddContainedObjectIDs( + c_object_id, c_contained_ids) + # TODO: handle noreturn better cdef store_task_outputs( self, worker, outputs, const c_vector[CObjectID] return_ids, diff --git a/python/ray/includes/libcoreworker.pxd b/python/ray/includes/libcoreworker.pxd index e1df8e4c5..360552fb0 100644 --- a/python/ray/includes/libcoreworker.pxd +++ b/python/ray/includes/libcoreworker.pxd @@ -128,6 +128,9 @@ cdef extern from "ray/core_worker/core_worker.h" nogil: void RegisterOwnershipInfoAndResolveFuture( const CObjectID &object_id, const CTaskID &owner_id, const CAddress &owner_address) + void AddContainedObjectIDs( + const CObjectID &object_id, + const c_vector[CObjectID] &contained_object_ids) CRayStatus SetClientOptions(c_string client_name, int64_t limit) CRayStatus Put(const CRayObject &object, diff --git a/python/ray/serialization.py b/python/ray/serialization.py index 5a9e056d8..d13e16fd7 100644 --- a/python/ray/serialization.py +++ b/python/ray/serialization.py @@ -173,6 +173,11 @@ class SerializationContext: # that the ref count for the ObjectID is greater than 0 by the # time the core worker resolves the value of the object. deserialized_object_id = id_deserializer(obj_id) + # TODO(edoakes): we should be able to just capture a reference + # to 'self' here instead, but this function is itself pickled + # somewhere, which causes an error. + context = ray.worker.global_worker.get_serialization_context() + context.add_contained_object_id(deserialized_object_id) if owner_id: worker = ray.worker.get_global_worker() worker.check_connected() @@ -225,12 +230,24 @@ class SerializationContext: try: in_band, buffers = unpack_pickle5_buffers(data) if len(buffers) > 0: - return pickle.loads(in_band, buffers=buffers) + obj = pickle.loads(in_band, buffers=buffers) else: - return pickle.loads(in_band) + obj = pickle.loads(in_band) # cloudpickle does not provide error types except pickle.pickle.PicklingError: raise DeserializationError() + + # Check that there are no ObjectIDs serialized in arguments + # that are inlined. + if object_id.is_nil(): + assert len(self.get_and_clear_contained_object_ids()) == 0 + else: + worker = ray.worker.global_worker + worker.core_worker.add_contained_object_ids( + object_id, + self.get_and_clear_contained_object_ids(), + ) + return obj # Check if the object should be returned as raw bytes. if metadata == ray_constants.RAW_BUFFER_METADATA: if data is None: diff --git a/src/ray/core_worker/core_worker.cc b/src/ray/core_worker/core_worker.cc index 71b23358d..9278e358a 100644 --- a/src/ray/core_worker/core_worker.cc +++ b/src/ray/core_worker/core_worker.cc @@ -361,6 +361,11 @@ void CoreWorker::RegisterOwnershipInfoAndResolveFuture( future_resolver_->ResolveFutureAsync(object_id, owner_id, owner_address); } +void CoreWorker::AddContainedObjectIDs( + const ObjectID &object_id, const std::vector &contained_object_ids) { + // TODO(edoakes,swang): integrate with the reference counting logic. +} + Status CoreWorker::SetClientOptions(std::string name, int64_t limit_bytes) { // Currently only the Plasma store supports client options. return plasma_store_provider_->SetClientOptions(name, limit_bytes); diff --git a/src/ray/core_worker/core_worker.h b/src/ray/core_worker/core_worker.h index 5428bf08c..8142cb441 100644 --- a/src/ray/core_worker/core_worker.h +++ b/src/ray/core_worker/core_worker.h @@ -163,6 +163,14 @@ class CoreWorker : public rpc::CoreWorkerServiceHandler { const TaskID &owner_id, const rpc::Address &owner_address); + /// Add metadata about the object IDs contained within another object ID. + /// This should be called during deserialization of the outer object ID. + /// + /// \param[in] object_id The object containing IDs. + /// \param[in] contained_object_ids The IDs contained in the object. + void AddContainedObjectIDs(const ObjectID &object_id, + const std::vector &contained_object_ids); + /// /// Public methods related to storing and retrieving objects. ///