mirror of
https://github.com/wassname/ray.git
synced 2026-06-27 17:49:47 +08:00
[Serialization] New custom serialization API (#13291)
* new serialization API with doc & test * add more notes * refine notes * doc
This commit is contained in:
committed by
GitHub
parent
07e97fe4c2
commit
d1e9887be2
+101
-19
@@ -17,9 +17,9 @@ Each node has its own object store. When data is put into the object store, it d
|
||||
Overview
|
||||
--------
|
||||
|
||||
Ray has decided to use a customed `Pickle protocol version 5 <https://www.python.org/dev/peps/pep-0574/>`_ backport to replace the original PyArrow serializer. This gets rid of several previous limitations (e.g. cannot serialize recursive objects).
|
||||
Ray has decided to use a customized `Pickle protocol version 5 <https://www.python.org/dev/peps/pep-0574/>`_ backport to replace the original PyArrow serializer. This gets rid of several previous limitations (e.g. cannot serialize recursive objects).
|
||||
|
||||
Ray is currently compatible with Pickle protocol version 5, while Ray supports serialization of a wilder range of objects (e.g. lambda & nested functions, dynamic classes) with the support of cloudpickle.
|
||||
Ray is currently compatible with Pickle protocol version 5, while Ray supports serialization of a wider range of objects (e.g. lambda & nested functions, dynamic classes) with the help of cloudpickle.
|
||||
|
||||
Numpy Arrays
|
||||
------------
|
||||
@@ -34,23 +34,6 @@ Serialization notes
|
||||
|
||||
- Ray is currently using Pickle protocol version 5. The default pickle protocol used by most python distributions is protocol 3. Protocol 4 & 5 are more efficient than protocol 3 for larger objects.
|
||||
|
||||
- Ray may create extra copies of simple native objects (e.g. list, and this is also the default behavior of Pickle Protocol 4 & 5), but recursive objects are treated carefully without any issues:
|
||||
|
||||
.. code-block:: python
|
||||
|
||||
l1 = [0]
|
||||
l2 = [l1, l1]
|
||||
l3 = ray.get(ray.put(l2))
|
||||
|
||||
assert l2[0] is l2[1]
|
||||
assert l3[0] is l3[1] # will raise AssertionError for protocol 4 & 5, but not protocol 3
|
||||
|
||||
l = []
|
||||
l.append(l)
|
||||
|
||||
# Try to put this list that recursively contains itself in the object store.
|
||||
ray.put(l) # ok
|
||||
|
||||
- For non-native objects, Ray will always keep a single copy even it is referred multiple times in an object:
|
||||
|
||||
.. code-block:: python
|
||||
@@ -64,6 +47,105 @@ Serialization notes
|
||||
|
||||
- Lock objects are mostly unserializable, because copying a lock is meaningless and could cause serious concurrency problems. You may have to come up with a workaround if your object contains a lock.
|
||||
|
||||
Customized Serialization
|
||||
________________________
|
||||
|
||||
Sometimes you may want to customize your serialization process because
|
||||
the default serializer used by Ray (pickle5 + cloudpickle) does
|
||||
not work for you (fail to serialize some objects, too slow for certain objects, etc.).
|
||||
|
||||
There are at least 3 ways to define your custom serialization process:
|
||||
|
||||
1. If you want to customize the serialization of a type of objects,
|
||||
and you have access to the code, you can define ``__reduce__``
|
||||
function inside the corresponding class. This is commonly done
|
||||
by most Python libraries. Example code:
|
||||
|
||||
.. code-block:: python
|
||||
|
||||
import ray
|
||||
import sqlite3
|
||||
|
||||
ray.init()
|
||||
|
||||
class DBConnection:
|
||||
def __init__(self, path):
|
||||
self.path = path
|
||||
self.conn = sqlite3.connect(path)
|
||||
|
||||
# without '__reduce__', the instance is unserializable.
|
||||
def __reduce__(self):
|
||||
deserializer = DBConnection
|
||||
serialized_data = (self.path,)
|
||||
return deserializer, serialized_data
|
||||
|
||||
original = DBConnection("/tmp/db")
|
||||
print(original.conn)
|
||||
|
||||
copied = ray.get(ray.put(original))
|
||||
print(copied.conn)
|
||||
|
||||
2. If you want to customize the serialization of a type of objects,
|
||||
but you cannot access or modify the corresponding class, you can
|
||||
register the class with the serializer you use:
|
||||
|
||||
.. code-block:: python
|
||||
|
||||
import ray
|
||||
import threading
|
||||
|
||||
class A:
|
||||
def __init__(self, x):
|
||||
self.x = x
|
||||
self.lock = threading.Lock() # could not be serialized!
|
||||
|
||||
ray.get(ray.put(A(1))) # fail!
|
||||
|
||||
def custom_serializer(a):
|
||||
return a.x
|
||||
|
||||
def custom_deserializer(b):
|
||||
return A(b)
|
||||
|
||||
# Register serializer and deserializer for class A:
|
||||
ray.util.register_serializer(
|
||||
A, serializer=custom_serializer, deserializer=custom_deserializer)
|
||||
ray.get(ray.put(A(1))) # success!
|
||||
|
||||
NOTE: Serializers are managed locally for each Ray worker. So for every Ray worker,
|
||||
if you want to use the serializer, you need to register the serializer.
|
||||
If you register a new serializer for a class, the new serializer would replace
|
||||
the old serializer immediately in the worker. This API is also idempotent, there are
|
||||
no side effects caused by re-registering the same serializer.
|
||||
|
||||
3. We also provide you an example, if you want to customize the serialization
|
||||
of a specific object:
|
||||
|
||||
.. code-block:: python
|
||||
|
||||
import threading
|
||||
|
||||
class A:
|
||||
def __init__(self, x):
|
||||
self.x = x
|
||||
self.lock = threading.Lock() # could not serialize!
|
||||
|
||||
ray.get(ray.put(A(1))) # fail!
|
||||
|
||||
class SerializationHelperForA:
|
||||
"""A helper class for serialization."""
|
||||
def __init__(self, a):
|
||||
self.a = a
|
||||
|
||||
def __reduce__(self):
|
||||
return A, (self.a.x,)
|
||||
|
||||
ray.get(ray.put(SerializationHelperForA(A(1)))) # success!
|
||||
# the serializer only works for a specific object, not all A
|
||||
# instances, so we still expect failure here.
|
||||
ray.get(ray.put(A(1))) # still fail!
|
||||
|
||||
|
||||
Troubleshooting
|
||||
---------------
|
||||
|
||||
|
||||
+8
-170
@@ -1,12 +1,9 @@
|
||||
import hashlib
|
||||
import logging
|
||||
import time
|
||||
import threading
|
||||
|
||||
import ray.cloudpickle as pickle
|
||||
from ray import ray_constants, JobID
|
||||
from ray import ray_constants
|
||||
import ray.utils
|
||||
from ray.utils import _random_string
|
||||
from ray.gcs_utils import ErrorType
|
||||
from ray.exceptions import (
|
||||
RayError,
|
||||
@@ -30,62 +27,10 @@ from ray._raylet import (
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class RayNotDictionarySerializable(Exception):
|
||||
pass
|
||||
|
||||
|
||||
# This exception is used to represent situations where cloudpickle fails to
|
||||
# pickle an object (cloudpickle can fail in many different ways).
|
||||
class CloudPickleError(Exception):
|
||||
pass
|
||||
|
||||
|
||||
class DeserializationError(Exception):
|
||||
pass
|
||||
|
||||
|
||||
def _try_to_compute_deterministic_class_id(cls, depth=5):
|
||||
"""Attempt to produce a deterministic class ID for a given class.
|
||||
|
||||
The goal here is for the class ID to be the same when this is run on
|
||||
different worker processes. Pickling, loading, and pickling again seems to
|
||||
produce more consistent results than simply pickling. This is a bit crazy
|
||||
and could cause problems, in which case we should revert it and figure out
|
||||
something better.
|
||||
|
||||
Args:
|
||||
cls: The class to produce an ID for.
|
||||
depth: The number of times to repeatedly try to load and dump the
|
||||
string while trying to reach a fixed point.
|
||||
|
||||
Returns:
|
||||
A class ID for this class. We attempt to make the class ID the same
|
||||
when this function is run on different workers, but that is not
|
||||
guaranteed.
|
||||
|
||||
Raises:
|
||||
Exception: This could raise an exception if cloudpickle raises an
|
||||
exception.
|
||||
"""
|
||||
# Pickling, loading, and pickling again seems to produce more consistent
|
||||
# results than simply pickling. This is a bit
|
||||
class_id = pickle.dumps(cls)
|
||||
for _ in range(depth):
|
||||
new_class_id = pickle.dumps(pickle.loads(class_id))
|
||||
if new_class_id == class_id:
|
||||
# We appear to have reached a fix point, so use this as the ID.
|
||||
return hashlib.shake_128(new_class_id).digest(
|
||||
ray_constants.ID_SIZE)
|
||||
class_id = new_class_id
|
||||
|
||||
# We have not reached a fixed point, so we may end up with a different
|
||||
# class ID for this custom class on each worker, which could lead to the
|
||||
# same class definition being exported many many times.
|
||||
logger.warning(
|
||||
f"WARNING: Could not produce a deterministic class ID for class {cls}")
|
||||
return hashlib.shake_128(new_class_id).digest(ray_constants.ID_SIZE)
|
||||
|
||||
|
||||
def object_ref_deserializer(reduced_obj_ref, owner_address):
|
||||
# NOTE(suquark): This function should be a global function so
|
||||
# cloudpickle can access it directly. Otherwise couldpickle
|
||||
@@ -153,8 +98,6 @@ class SerializationContext:
|
||||
worker.core_worker.serialize_and_promote_object_ref(obj))
|
||||
return object_ref_deserializer, (obj.__reduce__(), owner_address)
|
||||
|
||||
# Because objects have default __reduce__ method, we only need to
|
||||
# treat ObjectRef specifically.
|
||||
self._register_cloudpickle_reducer(ray.ObjectRef, object_ref_reducer)
|
||||
|
||||
def _register_cloudpickle_reducer(self, cls, reducer):
|
||||
@@ -291,48 +234,17 @@ class SerializationContext:
|
||||
# throws an exception.
|
||||
return PlasmaObjectNotAvailable
|
||||
|
||||
def deserialize_objects(self,
|
||||
data_metadata_pairs,
|
||||
object_refs,
|
||||
error_timeout=10):
|
||||
def deserialize_objects(self, data_metadata_pairs, object_refs):
|
||||
assert len(data_metadata_pairs) == len(object_refs)
|
||||
|
||||
start_time = time.time()
|
||||
results = []
|
||||
warning_sent = False
|
||||
i = 0
|
||||
while i < len(object_refs):
|
||||
object_ref = object_refs[i]
|
||||
data, metadata = data_metadata_pairs[i]
|
||||
for object_ref, (data, metadata) in zip(object_refs,
|
||||
data_metadata_pairs):
|
||||
assert self.get_outer_object_ref() is None
|
||||
self.set_outer_object_ref(object_ref)
|
||||
try:
|
||||
results.append(
|
||||
self._deserialize_object(data, metadata, object_ref))
|
||||
i += 1
|
||||
except DeserializationError:
|
||||
# Wait a little bit for the import thread to import the class.
|
||||
# If we currently have the worker lock, we need to release it
|
||||
# so that the import thread can acquire it.
|
||||
time.sleep(0.01)
|
||||
|
||||
if time.time() - start_time > error_timeout:
|
||||
warning_message = ("This worker or driver is waiting to "
|
||||
"receive a class definition so that it "
|
||||
"can deserialize an object from the "
|
||||
"object store. This may be fine, or it "
|
||||
"may be a bug.")
|
||||
if not warning_sent:
|
||||
ray.utils.push_error_to_driver(
|
||||
self,
|
||||
ray_constants.WAIT_FOR_CLASS_PUSH_ERROR,
|
||||
warning_message,
|
||||
job_id=self.worker.current_job_id)
|
||||
warning_sent = True
|
||||
finally:
|
||||
# Must clear ObjectRef to not hold a reference.
|
||||
self.set_outer_object_ref(None)
|
||||
|
||||
results.append(
|
||||
self._deserialize_object(data, metadata, object_ref))
|
||||
# Must clear ObjectRef to not hold a reference.
|
||||
self.set_outer_object_ref(None)
|
||||
return results
|
||||
|
||||
def _serialize_to_pickle5(self, metadata, value):
|
||||
@@ -405,77 +317,3 @@ class SerializationContext:
|
||||
return RawSerializedObject(value)
|
||||
else:
|
||||
return self._serialize_to_msgpack(value)
|
||||
|
||||
def register_custom_serializer(self,
|
||||
cls,
|
||||
serializer,
|
||||
deserializer,
|
||||
local=False,
|
||||
job_id=None,
|
||||
class_id=None):
|
||||
"""Enable serialization and deserialization for a particular class.
|
||||
|
||||
This method runs the register_class function defined below on
|
||||
every worker, which will enable ray to properly serialize and
|
||||
deserialize objects of this class.
|
||||
|
||||
Args:
|
||||
cls (type): The class that ray should use this custom serializer
|
||||
for.
|
||||
serializer: The custom serializer to use.
|
||||
deserializer: The custom deserializer to use.
|
||||
local: True if the serializers should only be registered on the
|
||||
current worker. This should usually be False.
|
||||
job_id: ID of the job that we want to register the class for.
|
||||
class_id (str): Unique ID of the class. Autogenerated if None.
|
||||
|
||||
Raises:
|
||||
RayNotDictionarySerializable: Raised if use_dict is true and cls
|
||||
cannot be efficiently serialized by Ray.
|
||||
ValueError: Raised if ray could not autogenerate a class_id.
|
||||
"""
|
||||
assert serializer is not None and deserializer is not None, (
|
||||
"Must provide serializer and deserializer.")
|
||||
|
||||
if class_id is None:
|
||||
if not local:
|
||||
# In this case, the class ID will be used to deduplicate the
|
||||
# class across workers. Note that cloudpickle unfortunately
|
||||
# does not produce deterministic strings, so these IDs could
|
||||
# be different on different workers. We could use something
|
||||
# weaker like cls.__name__, however that would run the risk
|
||||
# of having collisions.
|
||||
# TODO(rkn): We should improve this.
|
||||
try:
|
||||
# Attempt to produce a class ID that will be the same on
|
||||
# each worker. However, determinism is not guaranteed,
|
||||
# and the result may be different on different workers.
|
||||
class_id = _try_to_compute_deterministic_class_id(cls)
|
||||
except Exception:
|
||||
raise ValueError(
|
||||
"Failed to use pickle in generating a unique id"
|
||||
f"for '{cls}'. Provide a unique class_id.")
|
||||
else:
|
||||
# In this case, the class ID only needs to be meaningful on
|
||||
# this worker and not across workers.
|
||||
class_id = _random_string()
|
||||
|
||||
# Make sure class_id is a string.
|
||||
class_id = ray.utils.binary_to_hex(class_id)
|
||||
|
||||
if job_id is None:
|
||||
job_id = self.worker.current_job_id
|
||||
assert isinstance(job_id, JobID)
|
||||
|
||||
def register_class_for_serialization(worker_info):
|
||||
context = worker_info["worker"].get_serialization_context(job_id)
|
||||
context._register_cloudpickle_serializer(cls, serializer,
|
||||
deserializer)
|
||||
|
||||
if not local:
|
||||
self.worker.run_function_on_all_workers(
|
||||
register_class_for_serialization)
|
||||
else:
|
||||
# Since we are pickling objects of this class, we don't actually
|
||||
# need to ship the class definition.
|
||||
register_class_for_serialization({"worker": self.worker})
|
||||
|
||||
@@ -598,6 +598,25 @@ def test_buffer_alignment(ray_start_shared_local_modes):
|
||||
assert y.ctypes.data % 8 == 0
|
||||
|
||||
|
||||
def test_custom_serializer(ray_start_shared_local_modes):
|
||||
import threading
|
||||
|
||||
class A:
|
||||
def __init__(self, x):
|
||||
self.x = x
|
||||
self.lock = threading.Lock()
|
||||
|
||||
def custom_serializer(a):
|
||||
return a.x
|
||||
|
||||
def custom_deserializer(x):
|
||||
return A(x)
|
||||
|
||||
ray.util.register_serializer(
|
||||
A, serializer=custom_serializer, deserializer=custom_deserializer)
|
||||
ray.get(ray.put(A(1)))
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
import pytest
|
||||
sys.exit(pytest.main(["-v", __file__]))
|
||||
|
||||
@@ -6,6 +6,7 @@ from ray.util.debug import log_once, disable_log_once_globally, \
|
||||
from ray.util.placement_group import (placement_group, placement_group_table,
|
||||
remove_placement_group)
|
||||
from ray.util import rpdb as pdb
|
||||
from ray.util.serialization import register_serializer
|
||||
|
||||
from ray.util.client_connect import connect, disconnect
|
||||
|
||||
@@ -23,4 +24,5 @@ __all__ = [
|
||||
"collective",
|
||||
"connect",
|
||||
"disconnect",
|
||||
"register_serializer",
|
||||
]
|
||||
|
||||
@@ -0,0 +1,18 @@
|
||||
import ray
|
||||
|
||||
|
||||
def register_serializer(cls, *, serializer, deserializer):
|
||||
"""Use the given serializer to serialize instances of type ``cls``,
|
||||
and use the deserializer to deserialize the serialized object.
|
||||
|
||||
Args:
|
||||
cls: A Python class/type.
|
||||
serializer (callable): A function that converts an instances of
|
||||
type ``cls`` into a serializable object (e.g. python dict
|
||||
of basic objects).
|
||||
deserializer (callable): A function that constructs the
|
||||
instance of type ``cls`` from the serialized object.
|
||||
This function itself must be serializable.
|
||||
"""
|
||||
context = ray.worker.global_worker.get_serialization_context()
|
||||
context._register_cloudpickle_serializer(cls, serializer, deserializer)
|
||||
Reference in New Issue
Block a user