Collect object IDs during serialization (#6946)

This commit is contained in:
Edward Oakes
2020-02-03 18:38:11 -08:00
committed by GitHub
parent ca5a9c6739
commit 984490d2be
10 changed files with 134 additions and 74 deletions
+29 -5
View File
@@ -1,6 +1,7 @@
import hashlib
import logging
import time
import threading
import pyarrow.plasma as plasma
@@ -34,8 +35,9 @@ class DeserializationError(Exception):
class SerializedObject:
def __init__(self, metadata):
def __init__(self, metadata, contained_object_ids=None):
self._metadata = metadata
self._contained_object_ids = contained_object_ids or []
@property
def total_bytes(self):
@@ -45,11 +47,15 @@ class SerializedObject:
def metadata(self):
return self._metadata
@property
def contained_object_ids(self):
return self._contained_object_ids
class Pickle5SerializedObject(SerializedObject):
def __init__(self, inband, writer):
super(Pickle5SerializedObject,
self).__init__(ray_constants.PICKLE5_BUFFER_METADATA)
def __init__(self, inband, writer, contained_object_ids):
super(Pickle5SerializedObject, self).__init__(
ray_constants.PICKLE5_BUFFER_METADATA, contained_object_ids)
self.inband = inband
self.writer = writer
# cached total bytes
@@ -126,6 +132,7 @@ class SerializationContext:
self.worker = worker
assert worker.use_pickle
self.use_pickle = worker.use_pickle
self._thread_local = threading.local()
def actor_handle_serializer(obj):
return obj._serialization_helper(True)
@@ -147,6 +154,7 @@ class SerializationContext:
return serialized_obj[0](*serialized_obj[1])
def object_id_serializer(obj):
self.add_contained_object_id(obj)
owner_id = ""
owner_address = ""
if obj.is_direct_call_type():
@@ -192,6 +200,21 @@ class SerializationContext:
# construct a reducer
pickle.CloudPickler.dispatch[cls] = _CloudPicklerReducer
def get_and_clear_contained_object_ids(self):
if not hasattr(self._thread_local, "object_ids"):
self._thread_local.object_ids = set()
return set()
object_ids = self._thread_local.object_ids
self._thread_local.object_ids = set()
return object_ids
def add_contained_object_id(self, object_id):
if not hasattr(self._thread_local, "object_ids"):
self._thread_local.object_ids = set()
self._thread_local.object_ids.add(object_id)
def _deserialize_object(self, data, metadata, object_id):
if metadata:
if metadata == ray_constants.PICKLE5_BUFFER_METADATA:
@@ -291,7 +314,8 @@ class SerializationContext:
writer = Pickle5Writer()
inband = pickle.dumps(
value, protocol=5, buffer_callback=writer.buffer_callback)
return Pickle5SerializedObject(inband, writer)
return Pickle5SerializedObject(
inband, writer, self.get_and_clear_contained_object_ids())
def register_custom_serializer(self,
cls,