diff --git a/doc/source/serialization.rst b/doc/source/serialization.rst index 6bdada7d0..a5e58a339 100644 --- a/doc/source/serialization.rst +++ b/doc/source/serialization.rst @@ -17,9 +17,9 @@ Each node has its own object store. When data is put into the object store, it d Overview -------- -Ray has decided to use a customed `Pickle protocol version 5 `_ backport to replace the original PyArrow serializer. This gets rid of several previous limitations (e.g. cannot serialize recursive objects). +Ray has decided to use a customized `Pickle protocol version 5 `_ backport to replace the original PyArrow serializer. This gets rid of several previous limitations (e.g. cannot serialize recursive objects). -Ray is currently compatible with Pickle protocol version 5, while Ray supports serialization of a wilder range of objects (e.g. lambda & nested functions, dynamic classes) with the support of cloudpickle. +Ray is currently compatible with Pickle protocol version 5, while Ray supports serialization of a wider range of objects (e.g. lambda & nested functions, dynamic classes) with the help of cloudpickle. Numpy Arrays ------------ @@ -34,23 +34,6 @@ Serialization notes - Ray is currently using Pickle protocol version 5. The default pickle protocol used by most python distributions is protocol 3. Protocol 4 & 5 are more efficient than protocol 3 for larger objects. -- Ray may create extra copies of simple native objects (e.g. list, and this is also the default behavior of Pickle Protocol 4 & 5), but recursive objects are treated carefully without any issues: - - .. code-block:: python - - l1 = [0] - l2 = [l1, l1] - l3 = ray.get(ray.put(l2)) - - assert l2[0] is l2[1] - assert l3[0] is l3[1] # will raise AssertionError for protocol 4 & 5, but not protocol 3 - - l = [] - l.append(l) - - # Try to put this list that recursively contains itself in the object store. - ray.put(l) # ok - - For non-native objects, Ray will always keep a single copy even it is referred multiple times in an object: .. code-block:: python @@ -64,6 +47,105 @@ Serialization notes - Lock objects are mostly unserializable, because copying a lock is meaningless and could cause serious concurrency problems. You may have to come up with a workaround if your object contains a lock. +Customized Serialization +________________________ + +Sometimes you may want to customize your serialization process because +the default serializer used by Ray (pickle5 + cloudpickle) does +not work for you (fail to serialize some objects, too slow for certain objects, etc.). + +There are at least 3 ways to define your custom serialization process: + +1. If you want to customize the serialization of a type of objects, + and you have access to the code, you can define ``__reduce__`` + function inside the corresponding class. This is commonly done + by most Python libraries. Example code: + +.. code-block:: python + + import ray + import sqlite3 + + ray.init() + + class DBConnection: + def __init__(self, path): + self.path = path + self.conn = sqlite3.connect(path) + + # without '__reduce__', the instance is unserializable. + def __reduce__(self): + deserializer = DBConnection + serialized_data = (self.path,) + return deserializer, serialized_data + + original = DBConnection("/tmp/db") + print(original.conn) + + copied = ray.get(ray.put(original)) + print(copied.conn) + +2. If you want to customize the serialization of a type of objects, + but you cannot access or modify the corresponding class, you can + register the class with the serializer you use: + + .. code-block:: python + + import ray + import threading + + class A: + def __init__(self, x): + self.x = x + self.lock = threading.Lock() # could not be serialized! + + ray.get(ray.put(A(1))) # fail! + + def custom_serializer(a): + return a.x + + def custom_deserializer(b): + return A(b) + + # Register serializer and deserializer for class A: + ray.util.register_serializer( + A, serializer=custom_serializer, deserializer=custom_deserializer) + ray.get(ray.put(A(1))) # success! + + NOTE: Serializers are managed locally for each Ray worker. So for every Ray worker, + if you want to use the serializer, you need to register the serializer. + If you register a new serializer for a class, the new serializer would replace + the old serializer immediately in the worker. This API is also idempotent, there are + no side effects caused by re-registering the same serializer. + +3. We also provide you an example, if you want to customize the serialization + of a specific object: + +.. code-block:: python + + import threading + + class A: + def __init__(self, x): + self.x = x + self.lock = threading.Lock() # could not serialize! + + ray.get(ray.put(A(1))) # fail! + + class SerializationHelperForA: + """A helper class for serialization.""" + def __init__(self, a): + self.a = a + + def __reduce__(self): + return A, (self.a.x,) + + ray.get(ray.put(SerializationHelperForA(A(1)))) # success! + # the serializer only works for a specific object, not all A + # instances, so we still expect failure here. + ray.get(ray.put(A(1))) # still fail! + + Troubleshooting --------------- diff --git a/python/ray/serialization.py b/python/ray/serialization.py index 9a24f3ccc..724cf477e 100644 --- a/python/ray/serialization.py +++ b/python/ray/serialization.py @@ -1,12 +1,9 @@ -import hashlib import logging -import time import threading import ray.cloudpickle as pickle -from ray import ray_constants, JobID +from ray import ray_constants import ray.utils -from ray.utils import _random_string from ray.gcs_utils import ErrorType from ray.exceptions import ( RayError, @@ -30,62 +27,10 @@ from ray._raylet import ( logger = logging.getLogger(__name__) -class RayNotDictionarySerializable(Exception): - pass - - -# This exception is used to represent situations where cloudpickle fails to -# pickle an object (cloudpickle can fail in many different ways). -class CloudPickleError(Exception): - pass - - class DeserializationError(Exception): pass -def _try_to_compute_deterministic_class_id(cls, depth=5): - """Attempt to produce a deterministic class ID for a given class. - - The goal here is for the class ID to be the same when this is run on - different worker processes. Pickling, loading, and pickling again seems to - produce more consistent results than simply pickling. This is a bit crazy - and could cause problems, in which case we should revert it and figure out - something better. - - Args: - cls: The class to produce an ID for. - depth: The number of times to repeatedly try to load and dump the - string while trying to reach a fixed point. - - Returns: - A class ID for this class. We attempt to make the class ID the same - when this function is run on different workers, but that is not - guaranteed. - - Raises: - Exception: This could raise an exception if cloudpickle raises an - exception. - """ - # Pickling, loading, and pickling again seems to produce more consistent - # results than simply pickling. This is a bit - class_id = pickle.dumps(cls) - for _ in range(depth): - new_class_id = pickle.dumps(pickle.loads(class_id)) - if new_class_id == class_id: - # We appear to have reached a fix point, so use this as the ID. - return hashlib.shake_128(new_class_id).digest( - ray_constants.ID_SIZE) - class_id = new_class_id - - # We have not reached a fixed point, so we may end up with a different - # class ID for this custom class on each worker, which could lead to the - # same class definition being exported many many times. - logger.warning( - f"WARNING: Could not produce a deterministic class ID for class {cls}") - return hashlib.shake_128(new_class_id).digest(ray_constants.ID_SIZE) - - def object_ref_deserializer(reduced_obj_ref, owner_address): # NOTE(suquark): This function should be a global function so # cloudpickle can access it directly. Otherwise couldpickle @@ -153,8 +98,6 @@ class SerializationContext: worker.core_worker.serialize_and_promote_object_ref(obj)) return object_ref_deserializer, (obj.__reduce__(), owner_address) - # Because objects have default __reduce__ method, we only need to - # treat ObjectRef specifically. self._register_cloudpickle_reducer(ray.ObjectRef, object_ref_reducer) def _register_cloudpickle_reducer(self, cls, reducer): @@ -291,48 +234,17 @@ class SerializationContext: # throws an exception. return PlasmaObjectNotAvailable - def deserialize_objects(self, - data_metadata_pairs, - object_refs, - error_timeout=10): + def deserialize_objects(self, data_metadata_pairs, object_refs): assert len(data_metadata_pairs) == len(object_refs) - - start_time = time.time() results = [] - warning_sent = False - i = 0 - while i < len(object_refs): - object_ref = object_refs[i] - data, metadata = data_metadata_pairs[i] + for object_ref, (data, metadata) in zip(object_refs, + data_metadata_pairs): assert self.get_outer_object_ref() is None self.set_outer_object_ref(object_ref) - try: - results.append( - self._deserialize_object(data, metadata, object_ref)) - i += 1 - except DeserializationError: - # Wait a little bit for the import thread to import the class. - # If we currently have the worker lock, we need to release it - # so that the import thread can acquire it. - time.sleep(0.01) - - if time.time() - start_time > error_timeout: - warning_message = ("This worker or driver is waiting to " - "receive a class definition so that it " - "can deserialize an object from the " - "object store. This may be fine, or it " - "may be a bug.") - if not warning_sent: - ray.utils.push_error_to_driver( - self, - ray_constants.WAIT_FOR_CLASS_PUSH_ERROR, - warning_message, - job_id=self.worker.current_job_id) - warning_sent = True - finally: - # Must clear ObjectRef to not hold a reference. - self.set_outer_object_ref(None) - + results.append( + self._deserialize_object(data, metadata, object_ref)) + # Must clear ObjectRef to not hold a reference. + self.set_outer_object_ref(None) return results def _serialize_to_pickle5(self, metadata, value): @@ -405,77 +317,3 @@ class SerializationContext: return RawSerializedObject(value) else: return self._serialize_to_msgpack(value) - - def register_custom_serializer(self, - cls, - serializer, - deserializer, - local=False, - job_id=None, - class_id=None): - """Enable serialization and deserialization for a particular class. - - This method runs the register_class function defined below on - every worker, which will enable ray to properly serialize and - deserialize objects of this class. - - Args: - cls (type): The class that ray should use this custom serializer - for. - serializer: The custom serializer to use. - deserializer: The custom deserializer to use. - local: True if the serializers should only be registered on the - current worker. This should usually be False. - job_id: ID of the job that we want to register the class for. - class_id (str): Unique ID of the class. Autogenerated if None. - - Raises: - RayNotDictionarySerializable: Raised if use_dict is true and cls - cannot be efficiently serialized by Ray. - ValueError: Raised if ray could not autogenerate a class_id. - """ - assert serializer is not None and deserializer is not None, ( - "Must provide serializer and deserializer.") - - if class_id is None: - if not local: - # In this case, the class ID will be used to deduplicate the - # class across workers. Note that cloudpickle unfortunately - # does not produce deterministic strings, so these IDs could - # be different on different workers. We could use something - # weaker like cls.__name__, however that would run the risk - # of having collisions. - # TODO(rkn): We should improve this. - try: - # Attempt to produce a class ID that will be the same on - # each worker. However, determinism is not guaranteed, - # and the result may be different on different workers. - class_id = _try_to_compute_deterministic_class_id(cls) - except Exception: - raise ValueError( - "Failed to use pickle in generating a unique id" - f"for '{cls}'. Provide a unique class_id.") - else: - # In this case, the class ID only needs to be meaningful on - # this worker and not across workers. - class_id = _random_string() - - # Make sure class_id is a string. - class_id = ray.utils.binary_to_hex(class_id) - - if job_id is None: - job_id = self.worker.current_job_id - assert isinstance(job_id, JobID) - - def register_class_for_serialization(worker_info): - context = worker_info["worker"].get_serialization_context(job_id) - context._register_cloudpickle_serializer(cls, serializer, - deserializer) - - if not local: - self.worker.run_function_on_all_workers( - register_class_for_serialization) - else: - # Since we are pickling objects of this class, we don't actually - # need to ship the class definition. - register_class_for_serialization({"worker": self.worker}) diff --git a/python/ray/tests/test_serialization.py b/python/ray/tests/test_serialization.py index 0c88ebd22..8c72ba209 100644 --- a/python/ray/tests/test_serialization.py +++ b/python/ray/tests/test_serialization.py @@ -598,6 +598,25 @@ def test_buffer_alignment(ray_start_shared_local_modes): assert y.ctypes.data % 8 == 0 +def test_custom_serializer(ray_start_shared_local_modes): + import threading + + class A: + def __init__(self, x): + self.x = x + self.lock = threading.Lock() + + def custom_serializer(a): + return a.x + + def custom_deserializer(x): + return A(x) + + ray.util.register_serializer( + A, serializer=custom_serializer, deserializer=custom_deserializer) + ray.get(ray.put(A(1))) + + if __name__ == "__main__": import pytest sys.exit(pytest.main(["-v", __file__])) diff --git a/python/ray/util/__init__.py b/python/ray/util/__init__.py index 252432ad8..b2dc97bbd 100644 --- a/python/ray/util/__init__.py +++ b/python/ray/util/__init__.py @@ -6,6 +6,7 @@ from ray.util.debug import log_once, disable_log_once_globally, \ from ray.util.placement_group import (placement_group, placement_group_table, remove_placement_group) from ray.util import rpdb as pdb +from ray.util.serialization import register_serializer from ray.util.client_connect import connect, disconnect @@ -23,4 +24,5 @@ __all__ = [ "collective", "connect", "disconnect", + "register_serializer", ] diff --git a/python/ray/util/serialization.py b/python/ray/util/serialization.py new file mode 100644 index 000000000..a93bbab55 --- /dev/null +++ b/python/ray/util/serialization.py @@ -0,0 +1,18 @@ +import ray + + +def register_serializer(cls, *, serializer, deserializer): + """Use the given serializer to serialize instances of type ``cls``, + and use the deserializer to deserialize the serialized object. + + Args: + cls: A Python class/type. + serializer (callable): A function that converts an instances of + type ``cls`` into a serializable object (e.g. python dict + of basic objects). + deserializer (callable): A function that constructs the + instance of type ``cls`` from the serialized object. + This function itself must be serializable. + """ + context = ray.worker.global_worker.get_serialization_context() + context._register_cloudpickle_serializer(cls, serializer, deserializer)