mirror of
https://github.com/wassname/ray.git
synced 2026-07-05 09:42:22 +08:00
Collect contained ObjectIDs during deserialization (#7029)
This commit is contained in:
@@ -943,6 +943,15 @@ cdef class CoreWorker:
|
||||
c_owner_id,
|
||||
c_owner_address)
|
||||
|
||||
def add_contained_object_ids(
|
||||
self, ObjectID object_id, contained_object_ids):
|
||||
cdef:
|
||||
CObjectID c_object_id = object_id.native()
|
||||
c_vector[CObjectID] c_contained_ids
|
||||
c_contained_ids = ObjectIDsToVector(contained_object_ids)
|
||||
self.core_worker.get().AddContainedObjectIDs(
|
||||
c_object_id, c_contained_ids)
|
||||
|
||||
# TODO: handle noreturn better
|
||||
cdef store_task_outputs(
|
||||
self, worker, outputs, const c_vector[CObjectID] return_ids,
|
||||
|
||||
@@ -128,6 +128,9 @@ cdef extern from "ray/core_worker/core_worker.h" nogil:
|
||||
void RegisterOwnershipInfoAndResolveFuture(
|
||||
const CObjectID &object_id, const CTaskID &owner_id, const
|
||||
CAddress &owner_address)
|
||||
void AddContainedObjectIDs(
|
||||
const CObjectID &object_id,
|
||||
const c_vector[CObjectID] &contained_object_ids)
|
||||
|
||||
CRayStatus SetClientOptions(c_string client_name, int64_t limit)
|
||||
CRayStatus Put(const CRayObject &object,
|
||||
|
||||
@@ -173,6 +173,11 @@ class SerializationContext:
|
||||
# that the ref count for the ObjectID is greater than 0 by the
|
||||
# time the core worker resolves the value of the object.
|
||||
deserialized_object_id = id_deserializer(obj_id)
|
||||
# TODO(edoakes): we should be able to just capture a reference
|
||||
# to 'self' here instead, but this function is itself pickled
|
||||
# somewhere, which causes an error.
|
||||
context = ray.worker.global_worker.get_serialization_context()
|
||||
context.add_contained_object_id(deserialized_object_id)
|
||||
if owner_id:
|
||||
worker = ray.worker.get_global_worker()
|
||||
worker.check_connected()
|
||||
@@ -225,12 +230,24 @@ class SerializationContext:
|
||||
try:
|
||||
in_band, buffers = unpack_pickle5_buffers(data)
|
||||
if len(buffers) > 0:
|
||||
return pickle.loads(in_band, buffers=buffers)
|
||||
obj = pickle.loads(in_band, buffers=buffers)
|
||||
else:
|
||||
return pickle.loads(in_band)
|
||||
obj = pickle.loads(in_band)
|
||||
# cloudpickle does not provide error types
|
||||
except pickle.pickle.PicklingError:
|
||||
raise DeserializationError()
|
||||
|
||||
# Check that there are no ObjectIDs serialized in arguments
|
||||
# that are inlined.
|
||||
if object_id.is_nil():
|
||||
assert len(self.get_and_clear_contained_object_ids()) == 0
|
||||
else:
|
||||
worker = ray.worker.global_worker
|
||||
worker.core_worker.add_contained_object_ids(
|
||||
object_id,
|
||||
self.get_and_clear_contained_object_ids(),
|
||||
)
|
||||
return obj
|
||||
# Check if the object should be returned as raw bytes.
|
||||
if metadata == ray_constants.RAW_BUFFER_METADATA:
|
||||
if data is None:
|
||||
|
||||
Reference in New Issue
Block a user