diff --git a/doc/source/serialization.rst b/doc/source/serialization.rst
index 6bdada7d0..a5e58a339 100644
--- a/doc/source/serialization.rst
+++ b/doc/source/serialization.rst
@@ -17,9 +17,9 @@ Each node has its own object store. When data is put into the object store, it d
Overview
--------
-Ray has decided to use a customed `Pickle protocol version 5 `_ backport to replace the original PyArrow serializer. This gets rid of several previous limitations (e.g. cannot serialize recursive objects).
+Ray has decided to use a customized `Pickle protocol version 5 `_ backport to replace the original PyArrow serializer. This gets rid of several previous limitations (e.g. cannot serialize recursive objects).
-Ray is currently compatible with Pickle protocol version 5, while Ray supports serialization of a wilder range of objects (e.g. lambda & nested functions, dynamic classes) with the support of cloudpickle.
+Ray is currently compatible with Pickle protocol version 5, while Ray supports serialization of a wider range of objects (e.g. lambda & nested functions, dynamic classes) with the help of cloudpickle.
Numpy Arrays
------------
@@ -34,23 +34,6 @@ Serialization notes
- Ray is currently using Pickle protocol version 5. The default pickle protocol used by most python distributions is protocol 3. Protocol 4 & 5 are more efficient than protocol 3 for larger objects.
-- Ray may create extra copies of simple native objects (e.g. list, and this is also the default behavior of Pickle Protocol 4 & 5), but recursive objects are treated carefully without any issues:
-
- .. code-block:: python
-
- l1 = [0]
- l2 = [l1, l1]
- l3 = ray.get(ray.put(l2))
-
- assert l2[0] is l2[1]
- assert l3[0] is l3[1] # will raise AssertionError for protocol 4 & 5, but not protocol 3
-
- l = []
- l.append(l)
-
- # Try to put this list that recursively contains itself in the object store.
- ray.put(l) # ok
-
- For non-native objects, Ray will always keep a single copy even it is referred multiple times in an object:
.. code-block:: python
@@ -64,6 +47,105 @@ Serialization notes
- Lock objects are mostly unserializable, because copying a lock is meaningless and could cause serious concurrency problems. You may have to come up with a workaround if your object contains a lock.
+Customized Serialization
+________________________
+
+Sometimes you may want to customize your serialization process because
+the default serializer used by Ray (pickle5 + cloudpickle) does
+not work for you (fail to serialize some objects, too slow for certain objects, etc.).
+
+There are at least 3 ways to define your custom serialization process:
+
+1. If you want to customize the serialization of a type of objects,
+ and you have access to the code, you can define ``__reduce__``
+ function inside the corresponding class. This is commonly done
+ by most Python libraries. Example code:
+
+.. code-block:: python
+
+ import ray
+ import sqlite3
+
+ ray.init()
+
+ class DBConnection:
+ def __init__(self, path):
+ self.path = path
+ self.conn = sqlite3.connect(path)
+
+ # without '__reduce__', the instance is unserializable.
+ def __reduce__(self):
+ deserializer = DBConnection
+ serialized_data = (self.path,)
+ return deserializer, serialized_data
+
+ original = DBConnection("/tmp/db")
+ print(original.conn)
+
+ copied = ray.get(ray.put(original))
+ print(copied.conn)
+
+2. If you want to customize the serialization of a type of objects,
+ but you cannot access or modify the corresponding class, you can
+ register the class with the serializer you use:
+
+ .. code-block:: python
+
+ import ray
+ import threading
+
+ class A:
+ def __init__(self, x):
+ self.x = x
+ self.lock = threading.Lock() # could not be serialized!
+
+ ray.get(ray.put(A(1))) # fail!
+
+ def custom_serializer(a):
+ return a.x
+
+ def custom_deserializer(b):
+ return A(b)
+
+ # Register serializer and deserializer for class A:
+ ray.util.register_serializer(
+ A, serializer=custom_serializer, deserializer=custom_deserializer)
+ ray.get(ray.put(A(1))) # success!
+
+ NOTE: Serializers are managed locally for each Ray worker. So for every Ray worker,
+ if you want to use the serializer, you need to register the serializer.
+ If you register a new serializer for a class, the new serializer would replace
+ the old serializer immediately in the worker. This API is also idempotent, there are
+ no side effects caused by re-registering the same serializer.
+
+3. We also provide you an example, if you want to customize the serialization
+ of a specific object:
+
+.. code-block:: python
+
+ import threading
+
+ class A:
+ def __init__(self, x):
+ self.x = x
+ self.lock = threading.Lock() # could not serialize!
+
+ ray.get(ray.put(A(1))) # fail!
+
+ class SerializationHelperForA:
+ """A helper class for serialization."""
+ def __init__(self, a):
+ self.a = a
+
+ def __reduce__(self):
+ return A, (self.a.x,)
+
+ ray.get(ray.put(SerializationHelperForA(A(1)))) # success!
+ # the serializer only works for a specific object, not all A
+ # instances, so we still expect failure here.
+ ray.get(ray.put(A(1))) # still fail!
+
+
Troubleshooting
---------------
diff --git a/python/ray/serialization.py b/python/ray/serialization.py
index 9a24f3ccc..724cf477e 100644
--- a/python/ray/serialization.py
+++ b/python/ray/serialization.py
@@ -1,12 +1,9 @@
-import hashlib
import logging
-import time
import threading
import ray.cloudpickle as pickle
-from ray import ray_constants, JobID
+from ray import ray_constants
import ray.utils
-from ray.utils import _random_string
from ray.gcs_utils import ErrorType
from ray.exceptions import (
RayError,
@@ -30,62 +27,10 @@ from ray._raylet import (
logger = logging.getLogger(__name__)
-class RayNotDictionarySerializable(Exception):
- pass
-
-
-# This exception is used to represent situations where cloudpickle fails to
-# pickle an object (cloudpickle can fail in many different ways).
-class CloudPickleError(Exception):
- pass
-
-
class DeserializationError(Exception):
pass
-def _try_to_compute_deterministic_class_id(cls, depth=5):
- """Attempt to produce a deterministic class ID for a given class.
-
- The goal here is for the class ID to be the same when this is run on
- different worker processes. Pickling, loading, and pickling again seems to
- produce more consistent results than simply pickling. This is a bit crazy
- and could cause problems, in which case we should revert it and figure out
- something better.
-
- Args:
- cls: The class to produce an ID for.
- depth: The number of times to repeatedly try to load and dump the
- string while trying to reach a fixed point.
-
- Returns:
- A class ID for this class. We attempt to make the class ID the same
- when this function is run on different workers, but that is not
- guaranteed.
-
- Raises:
- Exception: This could raise an exception if cloudpickle raises an
- exception.
- """
- # Pickling, loading, and pickling again seems to produce more consistent
- # results than simply pickling. This is a bit
- class_id = pickle.dumps(cls)
- for _ in range(depth):
- new_class_id = pickle.dumps(pickle.loads(class_id))
- if new_class_id == class_id:
- # We appear to have reached a fix point, so use this as the ID.
- return hashlib.shake_128(new_class_id).digest(
- ray_constants.ID_SIZE)
- class_id = new_class_id
-
- # We have not reached a fixed point, so we may end up with a different
- # class ID for this custom class on each worker, which could lead to the
- # same class definition being exported many many times.
- logger.warning(
- f"WARNING: Could not produce a deterministic class ID for class {cls}")
- return hashlib.shake_128(new_class_id).digest(ray_constants.ID_SIZE)
-
-
def object_ref_deserializer(reduced_obj_ref, owner_address):
# NOTE(suquark): This function should be a global function so
# cloudpickle can access it directly. Otherwise couldpickle
@@ -153,8 +98,6 @@ class SerializationContext:
worker.core_worker.serialize_and_promote_object_ref(obj))
return object_ref_deserializer, (obj.__reduce__(), owner_address)
- # Because objects have default __reduce__ method, we only need to
- # treat ObjectRef specifically.
self._register_cloudpickle_reducer(ray.ObjectRef, object_ref_reducer)
def _register_cloudpickle_reducer(self, cls, reducer):
@@ -291,48 +234,17 @@ class SerializationContext:
# throws an exception.
return PlasmaObjectNotAvailable
- def deserialize_objects(self,
- data_metadata_pairs,
- object_refs,
- error_timeout=10):
+ def deserialize_objects(self, data_metadata_pairs, object_refs):
assert len(data_metadata_pairs) == len(object_refs)
-
- start_time = time.time()
results = []
- warning_sent = False
- i = 0
- while i < len(object_refs):
- object_ref = object_refs[i]
- data, metadata = data_metadata_pairs[i]
+ for object_ref, (data, metadata) in zip(object_refs,
+ data_metadata_pairs):
assert self.get_outer_object_ref() is None
self.set_outer_object_ref(object_ref)
- try:
- results.append(
- self._deserialize_object(data, metadata, object_ref))
- i += 1
- except DeserializationError:
- # Wait a little bit for the import thread to import the class.
- # If we currently have the worker lock, we need to release it
- # so that the import thread can acquire it.
- time.sleep(0.01)
-
- if time.time() - start_time > error_timeout:
- warning_message = ("This worker or driver is waiting to "
- "receive a class definition so that it "
- "can deserialize an object from the "
- "object store. This may be fine, or it "
- "may be a bug.")
- if not warning_sent:
- ray.utils.push_error_to_driver(
- self,
- ray_constants.WAIT_FOR_CLASS_PUSH_ERROR,
- warning_message,
- job_id=self.worker.current_job_id)
- warning_sent = True
- finally:
- # Must clear ObjectRef to not hold a reference.
- self.set_outer_object_ref(None)
-
+ results.append(
+ self._deserialize_object(data, metadata, object_ref))
+ # Must clear ObjectRef to not hold a reference.
+ self.set_outer_object_ref(None)
return results
def _serialize_to_pickle5(self, metadata, value):
@@ -405,77 +317,3 @@ class SerializationContext:
return RawSerializedObject(value)
else:
return self._serialize_to_msgpack(value)
-
- def register_custom_serializer(self,
- cls,
- serializer,
- deserializer,
- local=False,
- job_id=None,
- class_id=None):
- """Enable serialization and deserialization for a particular class.
-
- This method runs the register_class function defined below on
- every worker, which will enable ray to properly serialize and
- deserialize objects of this class.
-
- Args:
- cls (type): The class that ray should use this custom serializer
- for.
- serializer: The custom serializer to use.
- deserializer: The custom deserializer to use.
- local: True if the serializers should only be registered on the
- current worker. This should usually be False.
- job_id: ID of the job that we want to register the class for.
- class_id (str): Unique ID of the class. Autogenerated if None.
-
- Raises:
- RayNotDictionarySerializable: Raised if use_dict is true and cls
- cannot be efficiently serialized by Ray.
- ValueError: Raised if ray could not autogenerate a class_id.
- """
- assert serializer is not None and deserializer is not None, (
- "Must provide serializer and deserializer.")
-
- if class_id is None:
- if not local:
- # In this case, the class ID will be used to deduplicate the
- # class across workers. Note that cloudpickle unfortunately
- # does not produce deterministic strings, so these IDs could
- # be different on different workers. We could use something
- # weaker like cls.__name__, however that would run the risk
- # of having collisions.
- # TODO(rkn): We should improve this.
- try:
- # Attempt to produce a class ID that will be the same on
- # each worker. However, determinism is not guaranteed,
- # and the result may be different on different workers.
- class_id = _try_to_compute_deterministic_class_id(cls)
- except Exception:
- raise ValueError(
- "Failed to use pickle in generating a unique id"
- f"for '{cls}'. Provide a unique class_id.")
- else:
- # In this case, the class ID only needs to be meaningful on
- # this worker and not across workers.
- class_id = _random_string()
-
- # Make sure class_id is a string.
- class_id = ray.utils.binary_to_hex(class_id)
-
- if job_id is None:
- job_id = self.worker.current_job_id
- assert isinstance(job_id, JobID)
-
- def register_class_for_serialization(worker_info):
- context = worker_info["worker"].get_serialization_context(job_id)
- context._register_cloudpickle_serializer(cls, serializer,
- deserializer)
-
- if not local:
- self.worker.run_function_on_all_workers(
- register_class_for_serialization)
- else:
- # Since we are pickling objects of this class, we don't actually
- # need to ship the class definition.
- register_class_for_serialization({"worker": self.worker})
diff --git a/python/ray/tests/test_serialization.py b/python/ray/tests/test_serialization.py
index 0c88ebd22..8c72ba209 100644
--- a/python/ray/tests/test_serialization.py
+++ b/python/ray/tests/test_serialization.py
@@ -598,6 +598,25 @@ def test_buffer_alignment(ray_start_shared_local_modes):
assert y.ctypes.data % 8 == 0
+def test_custom_serializer(ray_start_shared_local_modes):
+ import threading
+
+ class A:
+ def __init__(self, x):
+ self.x = x
+ self.lock = threading.Lock()
+
+ def custom_serializer(a):
+ return a.x
+
+ def custom_deserializer(x):
+ return A(x)
+
+ ray.util.register_serializer(
+ A, serializer=custom_serializer, deserializer=custom_deserializer)
+ ray.get(ray.put(A(1)))
+
+
if __name__ == "__main__":
import pytest
sys.exit(pytest.main(["-v", __file__]))
diff --git a/python/ray/util/__init__.py b/python/ray/util/__init__.py
index 252432ad8..b2dc97bbd 100644
--- a/python/ray/util/__init__.py
+++ b/python/ray/util/__init__.py
@@ -6,6 +6,7 @@ from ray.util.debug import log_once, disable_log_once_globally, \
from ray.util.placement_group import (placement_group, placement_group_table,
remove_placement_group)
from ray.util import rpdb as pdb
+from ray.util.serialization import register_serializer
from ray.util.client_connect import connect, disconnect
@@ -23,4 +24,5 @@ __all__ = [
"collective",
"connect",
"disconnect",
+ "register_serializer",
]
diff --git a/python/ray/util/serialization.py b/python/ray/util/serialization.py
new file mode 100644
index 000000000..a93bbab55
--- /dev/null
+++ b/python/ray/util/serialization.py
@@ -0,0 +1,18 @@
+import ray
+
+
+def register_serializer(cls, *, serializer, deserializer):
+ """Use the given serializer to serialize instances of type ``cls``,
+ and use the deserializer to deserialize the serialized object.
+
+ Args:
+ cls: A Python class/type.
+ serializer (callable): A function that converts an instances of
+ type ``cls`` into a serializable object (e.g. python dict
+ of basic objects).
+ deserializer (callable): A function that constructs the
+ instance of type ``cls`` from the serialized object.
+ This function itself must be serializable.
+ """
+ context = ray.worker.global_worker.get_serialization_context()
+ context._register_cloudpickle_serializer(cls, serializer, deserializer)