mirror of
https://github.com/wassname/ray.git
synced 2026-06-28 09:29:05 +08:00
Collect object IDs during serialization (#6946)
This commit is contained in:
@@ -1,6 +1,7 @@
|
||||
import hashlib
|
||||
import logging
|
||||
import time
|
||||
import threading
|
||||
|
||||
import pyarrow.plasma as plasma
|
||||
|
||||
@@ -34,8 +35,9 @@ class DeserializationError(Exception):
|
||||
|
||||
|
||||
class SerializedObject:
|
||||
def __init__(self, metadata):
|
||||
def __init__(self, metadata, contained_object_ids=None):
|
||||
self._metadata = metadata
|
||||
self._contained_object_ids = contained_object_ids or []
|
||||
|
||||
@property
|
||||
def total_bytes(self):
|
||||
@@ -45,11 +47,15 @@ class SerializedObject:
|
||||
def metadata(self):
|
||||
return self._metadata
|
||||
|
||||
@property
|
||||
def contained_object_ids(self):
|
||||
return self._contained_object_ids
|
||||
|
||||
|
||||
class Pickle5SerializedObject(SerializedObject):
|
||||
def __init__(self, inband, writer):
|
||||
super(Pickle5SerializedObject,
|
||||
self).__init__(ray_constants.PICKLE5_BUFFER_METADATA)
|
||||
def __init__(self, inband, writer, contained_object_ids):
|
||||
super(Pickle5SerializedObject, self).__init__(
|
||||
ray_constants.PICKLE5_BUFFER_METADATA, contained_object_ids)
|
||||
self.inband = inband
|
||||
self.writer = writer
|
||||
# cached total bytes
|
||||
@@ -126,6 +132,7 @@ class SerializationContext:
|
||||
self.worker = worker
|
||||
assert worker.use_pickle
|
||||
self.use_pickle = worker.use_pickle
|
||||
self._thread_local = threading.local()
|
||||
|
||||
def actor_handle_serializer(obj):
|
||||
return obj._serialization_helper(True)
|
||||
@@ -147,6 +154,7 @@ class SerializationContext:
|
||||
return serialized_obj[0](*serialized_obj[1])
|
||||
|
||||
def object_id_serializer(obj):
|
||||
self.add_contained_object_id(obj)
|
||||
owner_id = ""
|
||||
owner_address = ""
|
||||
if obj.is_direct_call_type():
|
||||
@@ -192,6 +200,21 @@ class SerializationContext:
|
||||
# construct a reducer
|
||||
pickle.CloudPickler.dispatch[cls] = _CloudPicklerReducer
|
||||
|
||||
def get_and_clear_contained_object_ids(self):
|
||||
if not hasattr(self._thread_local, "object_ids"):
|
||||
self._thread_local.object_ids = set()
|
||||
return set()
|
||||
|
||||
object_ids = self._thread_local.object_ids
|
||||
self._thread_local.object_ids = set()
|
||||
return object_ids
|
||||
|
||||
def add_contained_object_id(self, object_id):
|
||||
if not hasattr(self._thread_local, "object_ids"):
|
||||
self._thread_local.object_ids = set()
|
||||
|
||||
self._thread_local.object_ids.add(object_id)
|
||||
|
||||
def _deserialize_object(self, data, metadata, object_id):
|
||||
if metadata:
|
||||
if metadata == ray_constants.PICKLE5_BUFFER_METADATA:
|
||||
@@ -291,7 +314,8 @@ class SerializationContext:
|
||||
writer = Pickle5Writer()
|
||||
inband = pickle.dumps(
|
||||
value, protocol=5, buffer_callback=writer.buffer_callback)
|
||||
return Pickle5SerializedObject(inband, writer)
|
||||
return Pickle5SerializedObject(
|
||||
inband, writer, self.get_and_clear_contained_object_ids())
|
||||
|
||||
def register_custom_serializer(self,
|
||||
cls,
|
||||
|
||||
Reference in New Issue
Block a user