[Serialization] New custom serialization API (#13291)

* new serialization API with doc & test

* add more notes

* refine notes

* doc
This commit is contained in:
Siyuan (Ryans) Zhuang
2021-01-14 13:15:31 -08:00
committed by GitHub
parent 07e97fe4c2
commit d1e9887be2
5 changed files with 148 additions and 189 deletions
+101 -19
View File
@@ -17,9 +17,9 @@ Each node has its own object store. When data is put into the object store, it d
Overview
--------
Ray has decided to use a customed `Pickle protocol version 5 <https://www.python.org/dev/peps/pep-0574/>`_ backport to replace the original PyArrow serializer. This gets rid of several previous limitations (e.g. cannot serialize recursive objects).
Ray has decided to use a customized `Pickle protocol version 5 <https://www.python.org/dev/peps/pep-0574/>`_ backport to replace the original PyArrow serializer. This gets rid of several previous limitations (e.g. cannot serialize recursive objects).
Ray is currently compatible with Pickle protocol version 5, while Ray supports serialization of a wilder range of objects (e.g. lambda & nested functions, dynamic classes) with the support of cloudpickle.
Ray is currently compatible with Pickle protocol version 5, while Ray supports serialization of a wider range of objects (e.g. lambda & nested functions, dynamic classes) with the help of cloudpickle.
Numpy Arrays
------------
@@ -34,23 +34,6 @@ Serialization notes
- Ray is currently using Pickle protocol version 5. The default pickle protocol used by most python distributions is protocol 3. Protocol 4 & 5 are more efficient than protocol 3 for larger objects.
- Ray may create extra copies of simple native objects (e.g. list, and this is also the default behavior of Pickle Protocol 4 & 5), but recursive objects are treated carefully without any issues:
.. code-block:: python
l1 = [0]
l2 = [l1, l1]
l3 = ray.get(ray.put(l2))
assert l2[0] is l2[1]
assert l3[0] is l3[1] # will raise AssertionError for protocol 4 & 5, but not protocol 3
l = []
l.append(l)
# Try to put this list that recursively contains itself in the object store.
ray.put(l) # ok
- For non-native objects, Ray will always keep a single copy even it is referred multiple times in an object:
.. code-block:: python
@@ -64,6 +47,105 @@ Serialization notes
- Lock objects are mostly unserializable, because copying a lock is meaningless and could cause serious concurrency problems. You may have to come up with a workaround if your object contains a lock.
Customized Serialization
________________________
Sometimes you may want to customize your serialization process because
the default serializer used by Ray (pickle5 + cloudpickle) does
not work for you (fail to serialize some objects, too slow for certain objects, etc.).
There are at least 3 ways to define your custom serialization process:
1. If you want to customize the serialization of a type of objects,
and you have access to the code, you can define ``__reduce__``
function inside the corresponding class. This is commonly done
by most Python libraries. Example code:
.. code-block:: python
import ray
import sqlite3
ray.init()
class DBConnection:
def __init__(self, path):
self.path = path
self.conn = sqlite3.connect(path)
# without '__reduce__', the instance is unserializable.
def __reduce__(self):
deserializer = DBConnection
serialized_data = (self.path,)
return deserializer, serialized_data
original = DBConnection("/tmp/db")
print(original.conn)
copied = ray.get(ray.put(original))
print(copied.conn)
2. If you want to customize the serialization of a type of objects,
but you cannot access or modify the corresponding class, you can
register the class with the serializer you use:
.. code-block:: python
import ray
import threading
class A:
def __init__(self, x):
self.x = x
self.lock = threading.Lock() # could not be serialized!
ray.get(ray.put(A(1))) # fail!
def custom_serializer(a):
return a.x
def custom_deserializer(b):
return A(b)
# Register serializer and deserializer for class A:
ray.util.register_serializer(
A, serializer=custom_serializer, deserializer=custom_deserializer)
ray.get(ray.put(A(1))) # success!
NOTE: Serializers are managed locally for each Ray worker. So for every Ray worker,
if you want to use the serializer, you need to register the serializer.
If you register a new serializer for a class, the new serializer would replace
the old serializer immediately in the worker. This API is also idempotent, there are
no side effects caused by re-registering the same serializer.
3. We also provide you an example, if you want to customize the serialization
of a specific object:
.. code-block:: python
import threading
class A:
def __init__(self, x):
self.x = x
self.lock = threading.Lock() # could not serialize!
ray.get(ray.put(A(1))) # fail!
class SerializationHelperForA:
"""A helper class for serialization."""
def __init__(self, a):
self.a = a
def __reduce__(self):
return A, (self.a.x,)
ray.get(ray.put(SerializationHelperForA(A(1)))) # success!
# the serializer only works for a specific object, not all A
# instances, so we still expect failure here.
ray.get(ray.put(A(1))) # still fail!
Troubleshooting
---------------
+8 -170
View File
@@ -1,12 +1,9 @@
import hashlib
import logging
import time
import threading
import ray.cloudpickle as pickle
from ray import ray_constants, JobID
from ray import ray_constants
import ray.utils
from ray.utils import _random_string
from ray.gcs_utils import ErrorType
from ray.exceptions import (
RayError,
@@ -30,62 +27,10 @@ from ray._raylet import (
logger = logging.getLogger(__name__)
class RayNotDictionarySerializable(Exception):
pass
# This exception is used to represent situations where cloudpickle fails to
# pickle an object (cloudpickle can fail in many different ways).
class CloudPickleError(Exception):
pass
class DeserializationError(Exception):
pass
def _try_to_compute_deterministic_class_id(cls, depth=5):
"""Attempt to produce a deterministic class ID for a given class.
The goal here is for the class ID to be the same when this is run on
different worker processes. Pickling, loading, and pickling again seems to
produce more consistent results than simply pickling. This is a bit crazy
and could cause problems, in which case we should revert it and figure out
something better.
Args:
cls: The class to produce an ID for.
depth: The number of times to repeatedly try to load and dump the
string while trying to reach a fixed point.
Returns:
A class ID for this class. We attempt to make the class ID the same
when this function is run on different workers, but that is not
guaranteed.
Raises:
Exception: This could raise an exception if cloudpickle raises an
exception.
"""
# Pickling, loading, and pickling again seems to produce more consistent
# results than simply pickling. This is a bit
class_id = pickle.dumps(cls)
for _ in range(depth):
new_class_id = pickle.dumps(pickle.loads(class_id))
if new_class_id == class_id:
# We appear to have reached a fix point, so use this as the ID.
return hashlib.shake_128(new_class_id).digest(
ray_constants.ID_SIZE)
class_id = new_class_id
# We have not reached a fixed point, so we may end up with a different
# class ID for this custom class on each worker, which could lead to the
# same class definition being exported many many times.
logger.warning(
f"WARNING: Could not produce a deterministic class ID for class {cls}")
return hashlib.shake_128(new_class_id).digest(ray_constants.ID_SIZE)
def object_ref_deserializer(reduced_obj_ref, owner_address):
# NOTE(suquark): This function should be a global function so
# cloudpickle can access it directly. Otherwise couldpickle
@@ -153,8 +98,6 @@ class SerializationContext:
worker.core_worker.serialize_and_promote_object_ref(obj))
return object_ref_deserializer, (obj.__reduce__(), owner_address)
# Because objects have default __reduce__ method, we only need to
# treat ObjectRef specifically.
self._register_cloudpickle_reducer(ray.ObjectRef, object_ref_reducer)
def _register_cloudpickle_reducer(self, cls, reducer):
@@ -291,48 +234,17 @@ class SerializationContext:
# throws an exception.
return PlasmaObjectNotAvailable
def deserialize_objects(self,
data_metadata_pairs,
object_refs,
error_timeout=10):
def deserialize_objects(self, data_metadata_pairs, object_refs):
assert len(data_metadata_pairs) == len(object_refs)
start_time = time.time()
results = []
warning_sent = False
i = 0
while i < len(object_refs):
object_ref = object_refs[i]
data, metadata = data_metadata_pairs[i]
for object_ref, (data, metadata) in zip(object_refs,
data_metadata_pairs):
assert self.get_outer_object_ref() is None
self.set_outer_object_ref(object_ref)
try:
results.append(
self._deserialize_object(data, metadata, object_ref))
i += 1
except DeserializationError:
# Wait a little bit for the import thread to import the class.
# If we currently have the worker lock, we need to release it
# so that the import thread can acquire it.
time.sleep(0.01)
if time.time() - start_time > error_timeout:
warning_message = ("This worker or driver is waiting to "
"receive a class definition so that it "
"can deserialize an object from the "
"object store. This may be fine, or it "
"may be a bug.")
if not warning_sent:
ray.utils.push_error_to_driver(
self,
ray_constants.WAIT_FOR_CLASS_PUSH_ERROR,
warning_message,
job_id=self.worker.current_job_id)
warning_sent = True
finally:
# Must clear ObjectRef to not hold a reference.
self.set_outer_object_ref(None)
results.append(
self._deserialize_object(data, metadata, object_ref))
# Must clear ObjectRef to not hold a reference.
self.set_outer_object_ref(None)
return results
def _serialize_to_pickle5(self, metadata, value):
@@ -405,77 +317,3 @@ class SerializationContext:
return RawSerializedObject(value)
else:
return self._serialize_to_msgpack(value)
def register_custom_serializer(self,
cls,
serializer,
deserializer,
local=False,
job_id=None,
class_id=None):
"""Enable serialization and deserialization for a particular class.
This method runs the register_class function defined below on
every worker, which will enable ray to properly serialize and
deserialize objects of this class.
Args:
cls (type): The class that ray should use this custom serializer
for.
serializer: The custom serializer to use.
deserializer: The custom deserializer to use.
local: True if the serializers should only be registered on the
current worker. This should usually be False.
job_id: ID of the job that we want to register the class for.
class_id (str): Unique ID of the class. Autogenerated if None.
Raises:
RayNotDictionarySerializable: Raised if use_dict is true and cls
cannot be efficiently serialized by Ray.
ValueError: Raised if ray could not autogenerate a class_id.
"""
assert serializer is not None and deserializer is not None, (
"Must provide serializer and deserializer.")
if class_id is None:
if not local:
# In this case, the class ID will be used to deduplicate the
# class across workers. Note that cloudpickle unfortunately
# does not produce deterministic strings, so these IDs could
# be different on different workers. We could use something
# weaker like cls.__name__, however that would run the risk
# of having collisions.
# TODO(rkn): We should improve this.
try:
# Attempt to produce a class ID that will be the same on
# each worker. However, determinism is not guaranteed,
# and the result may be different on different workers.
class_id = _try_to_compute_deterministic_class_id(cls)
except Exception:
raise ValueError(
"Failed to use pickle in generating a unique id"
f"for '{cls}'. Provide a unique class_id.")
else:
# In this case, the class ID only needs to be meaningful on
# this worker and not across workers.
class_id = _random_string()
# Make sure class_id is a string.
class_id = ray.utils.binary_to_hex(class_id)
if job_id is None:
job_id = self.worker.current_job_id
assert isinstance(job_id, JobID)
def register_class_for_serialization(worker_info):
context = worker_info["worker"].get_serialization_context(job_id)
context._register_cloudpickle_serializer(cls, serializer,
deserializer)
if not local:
self.worker.run_function_on_all_workers(
register_class_for_serialization)
else:
# Since we are pickling objects of this class, we don't actually
# need to ship the class definition.
register_class_for_serialization({"worker": self.worker})
+19
View File
@@ -598,6 +598,25 @@ def test_buffer_alignment(ray_start_shared_local_modes):
assert y.ctypes.data % 8 == 0
def test_custom_serializer(ray_start_shared_local_modes):
import threading
class A:
def __init__(self, x):
self.x = x
self.lock = threading.Lock()
def custom_serializer(a):
return a.x
def custom_deserializer(x):
return A(x)
ray.util.register_serializer(
A, serializer=custom_serializer, deserializer=custom_deserializer)
ray.get(ray.put(A(1)))
if __name__ == "__main__":
import pytest
sys.exit(pytest.main(["-v", __file__]))
+2
View File
@@ -6,6 +6,7 @@ from ray.util.debug import log_once, disable_log_once_globally, \
from ray.util.placement_group import (placement_group, placement_group_table,
remove_placement_group)
from ray.util import rpdb as pdb
from ray.util.serialization import register_serializer
from ray.util.client_connect import connect, disconnect
@@ -23,4 +24,5 @@ __all__ = [
"collective",
"connect",
"disconnect",
"register_serializer",
]
+18
View File
@@ -0,0 +1,18 @@
import ray
def register_serializer(cls, *, serializer, deserializer):
"""Use the given serializer to serialize instances of type ``cls``,
and use the deserializer to deserialize the serialized object.
Args:
cls: A Python class/type.
serializer (callable): A function that converts an instances of
type ``cls`` into a serializable object (e.g. python dict
of basic objects).
deserializer (callable): A function that constructs the
instance of type ``cls`` from the serialized object.
This function itself must be serializable.
"""
context = ray.worker.global_worker.get_serialization_context()
context._register_cloudpickle_serializer(cls, serializer, deserializer)