[Serialization] New custom serialization API (#13291)

* new serialization API with doc & test

* add more notes

* refine notes

* doc
This commit is contained in:
Siyuan (Ryans) Zhuang
2021-01-14 13:15:31 -08:00
committed by GitHub
parent 07e97fe4c2
commit d1e9887be2
5 changed files with 148 additions and 189 deletions
+8 -170
View File
@@ -1,12 +1,9 @@
import hashlib
import logging
import time
import threading
import ray.cloudpickle as pickle
from ray import ray_constants, JobID
from ray import ray_constants
import ray.utils
from ray.utils import _random_string
from ray.gcs_utils import ErrorType
from ray.exceptions import (
RayError,
@@ -30,62 +27,10 @@ from ray._raylet import (
logger = logging.getLogger(__name__)
class RayNotDictionarySerializable(Exception):
pass
# This exception is used to represent situations where cloudpickle fails to
# pickle an object (cloudpickle can fail in many different ways).
class CloudPickleError(Exception):
pass
class DeserializationError(Exception):
pass
def _try_to_compute_deterministic_class_id(cls, depth=5):
"""Attempt to produce a deterministic class ID for a given class.
The goal here is for the class ID to be the same when this is run on
different worker processes. Pickling, loading, and pickling again seems to
produce more consistent results than simply pickling. This is a bit crazy
and could cause problems, in which case we should revert it and figure out
something better.
Args:
cls: The class to produce an ID for.
depth: The number of times to repeatedly try to load and dump the
string while trying to reach a fixed point.
Returns:
A class ID for this class. We attempt to make the class ID the same
when this function is run on different workers, but that is not
guaranteed.
Raises:
Exception: This could raise an exception if cloudpickle raises an
exception.
"""
# Pickling, loading, and pickling again seems to produce more consistent
# results than simply pickling. This is a bit
class_id = pickle.dumps(cls)
for _ in range(depth):
new_class_id = pickle.dumps(pickle.loads(class_id))
if new_class_id == class_id:
# We appear to have reached a fix point, so use this as the ID.
return hashlib.shake_128(new_class_id).digest(
ray_constants.ID_SIZE)
class_id = new_class_id
# We have not reached a fixed point, so we may end up with a different
# class ID for this custom class on each worker, which could lead to the
# same class definition being exported many many times.
logger.warning(
f"WARNING: Could not produce a deterministic class ID for class {cls}")
return hashlib.shake_128(new_class_id).digest(ray_constants.ID_SIZE)
def object_ref_deserializer(reduced_obj_ref, owner_address):
# NOTE(suquark): This function should be a global function so
# cloudpickle can access it directly. Otherwise couldpickle
@@ -153,8 +98,6 @@ class SerializationContext:
worker.core_worker.serialize_and_promote_object_ref(obj))
return object_ref_deserializer, (obj.__reduce__(), owner_address)
# Because objects have default __reduce__ method, we only need to
# treat ObjectRef specifically.
self._register_cloudpickle_reducer(ray.ObjectRef, object_ref_reducer)
def _register_cloudpickle_reducer(self, cls, reducer):
@@ -291,48 +234,17 @@ class SerializationContext:
# throws an exception.
return PlasmaObjectNotAvailable
def deserialize_objects(self,
data_metadata_pairs,
object_refs,
error_timeout=10):
def deserialize_objects(self, data_metadata_pairs, object_refs):
assert len(data_metadata_pairs) == len(object_refs)
start_time = time.time()
results = []
warning_sent = False
i = 0
while i < len(object_refs):
object_ref = object_refs[i]
data, metadata = data_metadata_pairs[i]
for object_ref, (data, metadata) in zip(object_refs,
data_metadata_pairs):
assert self.get_outer_object_ref() is None
self.set_outer_object_ref(object_ref)
try:
results.append(
self._deserialize_object(data, metadata, object_ref))
i += 1
except DeserializationError:
# Wait a little bit for the import thread to import the class.
# If we currently have the worker lock, we need to release it
# so that the import thread can acquire it.
time.sleep(0.01)
if time.time() - start_time > error_timeout:
warning_message = ("This worker or driver is waiting to "
"receive a class definition so that it "
"can deserialize an object from the "
"object store. This may be fine, or it "
"may be a bug.")
if not warning_sent:
ray.utils.push_error_to_driver(
self,
ray_constants.WAIT_FOR_CLASS_PUSH_ERROR,
warning_message,
job_id=self.worker.current_job_id)
warning_sent = True
finally:
# Must clear ObjectRef to not hold a reference.
self.set_outer_object_ref(None)
results.append(
self._deserialize_object(data, metadata, object_ref))
# Must clear ObjectRef to not hold a reference.
self.set_outer_object_ref(None)
return results
def _serialize_to_pickle5(self, metadata, value):
@@ -405,77 +317,3 @@ class SerializationContext:
return RawSerializedObject(value)
else:
return self._serialize_to_msgpack(value)
def register_custom_serializer(self,
cls,
serializer,
deserializer,
local=False,
job_id=None,
class_id=None):
"""Enable serialization and deserialization for a particular class.
This method runs the register_class function defined below on
every worker, which will enable ray to properly serialize and
deserialize objects of this class.
Args:
cls (type): The class that ray should use this custom serializer
for.
serializer: The custom serializer to use.
deserializer: The custom deserializer to use.
local: True if the serializers should only be registered on the
current worker. This should usually be False.
job_id: ID of the job that we want to register the class for.
class_id (str): Unique ID of the class. Autogenerated if None.
Raises:
RayNotDictionarySerializable: Raised if use_dict is true and cls
cannot be efficiently serialized by Ray.
ValueError: Raised if ray could not autogenerate a class_id.
"""
assert serializer is not None and deserializer is not None, (
"Must provide serializer and deserializer.")
if class_id is None:
if not local:
# In this case, the class ID will be used to deduplicate the
# class across workers. Note that cloudpickle unfortunately
# does not produce deterministic strings, so these IDs could
# be different on different workers. We could use something
# weaker like cls.__name__, however that would run the risk
# of having collisions.
# TODO(rkn): We should improve this.
try:
# Attempt to produce a class ID that will be the same on
# each worker. However, determinism is not guaranteed,
# and the result may be different on different workers.
class_id = _try_to_compute_deterministic_class_id(cls)
except Exception:
raise ValueError(
"Failed to use pickle in generating a unique id"
f"for '{cls}'. Provide a unique class_id.")
else:
# In this case, the class ID only needs to be meaningful on
# this worker and not across workers.
class_id = _random_string()
# Make sure class_id is a string.
class_id = ray.utils.binary_to_hex(class_id)
if job_id is None:
job_id = self.worker.current_job_id
assert isinstance(job_id, JobID)
def register_class_for_serialization(worker_info):
context = worker_info["worker"].get_serialization_context(job_id)
context._register_cloudpickle_serializer(cls, serializer,
deserializer)
if not local:
self.worker.run_function_on_all_workers(
register_class_for_serialization)
else:
# Since we are pickling objects of this class, we don't actually
# need to ship the class definition.
register_class_for_serialization({"worker": self.worker})
+19
View File
@@ -598,6 +598,25 @@ def test_buffer_alignment(ray_start_shared_local_modes):
assert y.ctypes.data % 8 == 0
def test_custom_serializer(ray_start_shared_local_modes):
import threading
class A:
def __init__(self, x):
self.x = x
self.lock = threading.Lock()
def custom_serializer(a):
return a.x
def custom_deserializer(x):
return A(x)
ray.util.register_serializer(
A, serializer=custom_serializer, deserializer=custom_deserializer)
ray.get(ray.put(A(1)))
if __name__ == "__main__":
import pytest
sys.exit(pytest.main(["-v", __file__]))
+2
View File
@@ -6,6 +6,7 @@ from ray.util.debug import log_once, disable_log_once_globally, \
from ray.util.placement_group import (placement_group, placement_group_table,
remove_placement_group)
from ray.util import rpdb as pdb
from ray.util.serialization import register_serializer
from ray.util.client_connect import connect, disconnect
@@ -23,4 +24,5 @@ __all__ = [
"collective",
"connect",
"disconnect",
"register_serializer",
]
+18
View File
@@ -0,0 +1,18 @@
import ray
def register_serializer(cls, *, serializer, deserializer):
"""Use the given serializer to serialize instances of type ``cls``,
and use the deserializer to deserialize the serialized object.
Args:
cls: A Python class/type.
serializer (callable): A function that converts an instances of
type ``cls`` into a serializable object (e.g. python dict
of basic objects).
deserializer (callable): A function that constructs the
instance of type ``cls`` from the serialized object.
This function itself must be serializable.
"""
context = ray.worker.global_worker.get_serialization_context()
context._register_cloudpickle_serializer(cls, serializer, deserializer)