mirror of
https://github.com/wassname/ray.git
synced 2026-06-27 20:06:31 +08:00
[Serialization] New custom serialization API (#13291)
* new serialization API with doc & test * add more notes * refine notes * doc
This commit is contained in:
committed by
GitHub
parent
07e97fe4c2
commit
d1e9887be2
+8
-170
@@ -1,12 +1,9 @@
|
||||
import hashlib
|
||||
import logging
|
||||
import time
|
||||
import threading
|
||||
|
||||
import ray.cloudpickle as pickle
|
||||
from ray import ray_constants, JobID
|
||||
from ray import ray_constants
|
||||
import ray.utils
|
||||
from ray.utils import _random_string
|
||||
from ray.gcs_utils import ErrorType
|
||||
from ray.exceptions import (
|
||||
RayError,
|
||||
@@ -30,62 +27,10 @@ from ray._raylet import (
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class RayNotDictionarySerializable(Exception):
|
||||
pass
|
||||
|
||||
|
||||
# This exception is used to represent situations where cloudpickle fails to
|
||||
# pickle an object (cloudpickle can fail in many different ways).
|
||||
class CloudPickleError(Exception):
|
||||
pass
|
||||
|
||||
|
||||
class DeserializationError(Exception):
|
||||
pass
|
||||
|
||||
|
||||
def _try_to_compute_deterministic_class_id(cls, depth=5):
|
||||
"""Attempt to produce a deterministic class ID for a given class.
|
||||
|
||||
The goal here is for the class ID to be the same when this is run on
|
||||
different worker processes. Pickling, loading, and pickling again seems to
|
||||
produce more consistent results than simply pickling. This is a bit crazy
|
||||
and could cause problems, in which case we should revert it and figure out
|
||||
something better.
|
||||
|
||||
Args:
|
||||
cls: The class to produce an ID for.
|
||||
depth: The number of times to repeatedly try to load and dump the
|
||||
string while trying to reach a fixed point.
|
||||
|
||||
Returns:
|
||||
A class ID for this class. We attempt to make the class ID the same
|
||||
when this function is run on different workers, but that is not
|
||||
guaranteed.
|
||||
|
||||
Raises:
|
||||
Exception: This could raise an exception if cloudpickle raises an
|
||||
exception.
|
||||
"""
|
||||
# Pickling, loading, and pickling again seems to produce more consistent
|
||||
# results than simply pickling. This is a bit
|
||||
class_id = pickle.dumps(cls)
|
||||
for _ in range(depth):
|
||||
new_class_id = pickle.dumps(pickle.loads(class_id))
|
||||
if new_class_id == class_id:
|
||||
# We appear to have reached a fix point, so use this as the ID.
|
||||
return hashlib.shake_128(new_class_id).digest(
|
||||
ray_constants.ID_SIZE)
|
||||
class_id = new_class_id
|
||||
|
||||
# We have not reached a fixed point, so we may end up with a different
|
||||
# class ID for this custom class on each worker, which could lead to the
|
||||
# same class definition being exported many many times.
|
||||
logger.warning(
|
||||
f"WARNING: Could not produce a deterministic class ID for class {cls}")
|
||||
return hashlib.shake_128(new_class_id).digest(ray_constants.ID_SIZE)
|
||||
|
||||
|
||||
def object_ref_deserializer(reduced_obj_ref, owner_address):
|
||||
# NOTE(suquark): This function should be a global function so
|
||||
# cloudpickle can access it directly. Otherwise couldpickle
|
||||
@@ -153,8 +98,6 @@ class SerializationContext:
|
||||
worker.core_worker.serialize_and_promote_object_ref(obj))
|
||||
return object_ref_deserializer, (obj.__reduce__(), owner_address)
|
||||
|
||||
# Because objects have default __reduce__ method, we only need to
|
||||
# treat ObjectRef specifically.
|
||||
self._register_cloudpickle_reducer(ray.ObjectRef, object_ref_reducer)
|
||||
|
||||
def _register_cloudpickle_reducer(self, cls, reducer):
|
||||
@@ -291,48 +234,17 @@ class SerializationContext:
|
||||
# throws an exception.
|
||||
return PlasmaObjectNotAvailable
|
||||
|
||||
def deserialize_objects(self,
|
||||
data_metadata_pairs,
|
||||
object_refs,
|
||||
error_timeout=10):
|
||||
def deserialize_objects(self, data_metadata_pairs, object_refs):
|
||||
assert len(data_metadata_pairs) == len(object_refs)
|
||||
|
||||
start_time = time.time()
|
||||
results = []
|
||||
warning_sent = False
|
||||
i = 0
|
||||
while i < len(object_refs):
|
||||
object_ref = object_refs[i]
|
||||
data, metadata = data_metadata_pairs[i]
|
||||
for object_ref, (data, metadata) in zip(object_refs,
|
||||
data_metadata_pairs):
|
||||
assert self.get_outer_object_ref() is None
|
||||
self.set_outer_object_ref(object_ref)
|
||||
try:
|
||||
results.append(
|
||||
self._deserialize_object(data, metadata, object_ref))
|
||||
i += 1
|
||||
except DeserializationError:
|
||||
# Wait a little bit for the import thread to import the class.
|
||||
# If we currently have the worker lock, we need to release it
|
||||
# so that the import thread can acquire it.
|
||||
time.sleep(0.01)
|
||||
|
||||
if time.time() - start_time > error_timeout:
|
||||
warning_message = ("This worker or driver is waiting to "
|
||||
"receive a class definition so that it "
|
||||
"can deserialize an object from the "
|
||||
"object store. This may be fine, or it "
|
||||
"may be a bug.")
|
||||
if not warning_sent:
|
||||
ray.utils.push_error_to_driver(
|
||||
self,
|
||||
ray_constants.WAIT_FOR_CLASS_PUSH_ERROR,
|
||||
warning_message,
|
||||
job_id=self.worker.current_job_id)
|
||||
warning_sent = True
|
||||
finally:
|
||||
# Must clear ObjectRef to not hold a reference.
|
||||
self.set_outer_object_ref(None)
|
||||
|
||||
results.append(
|
||||
self._deserialize_object(data, metadata, object_ref))
|
||||
# Must clear ObjectRef to not hold a reference.
|
||||
self.set_outer_object_ref(None)
|
||||
return results
|
||||
|
||||
def _serialize_to_pickle5(self, metadata, value):
|
||||
@@ -405,77 +317,3 @@ class SerializationContext:
|
||||
return RawSerializedObject(value)
|
||||
else:
|
||||
return self._serialize_to_msgpack(value)
|
||||
|
||||
def register_custom_serializer(self,
|
||||
cls,
|
||||
serializer,
|
||||
deserializer,
|
||||
local=False,
|
||||
job_id=None,
|
||||
class_id=None):
|
||||
"""Enable serialization and deserialization for a particular class.
|
||||
|
||||
This method runs the register_class function defined below on
|
||||
every worker, which will enable ray to properly serialize and
|
||||
deserialize objects of this class.
|
||||
|
||||
Args:
|
||||
cls (type): The class that ray should use this custom serializer
|
||||
for.
|
||||
serializer: The custom serializer to use.
|
||||
deserializer: The custom deserializer to use.
|
||||
local: True if the serializers should only be registered on the
|
||||
current worker. This should usually be False.
|
||||
job_id: ID of the job that we want to register the class for.
|
||||
class_id (str): Unique ID of the class. Autogenerated if None.
|
||||
|
||||
Raises:
|
||||
RayNotDictionarySerializable: Raised if use_dict is true and cls
|
||||
cannot be efficiently serialized by Ray.
|
||||
ValueError: Raised if ray could not autogenerate a class_id.
|
||||
"""
|
||||
assert serializer is not None and deserializer is not None, (
|
||||
"Must provide serializer and deserializer.")
|
||||
|
||||
if class_id is None:
|
||||
if not local:
|
||||
# In this case, the class ID will be used to deduplicate the
|
||||
# class across workers. Note that cloudpickle unfortunately
|
||||
# does not produce deterministic strings, so these IDs could
|
||||
# be different on different workers. We could use something
|
||||
# weaker like cls.__name__, however that would run the risk
|
||||
# of having collisions.
|
||||
# TODO(rkn): We should improve this.
|
||||
try:
|
||||
# Attempt to produce a class ID that will be the same on
|
||||
# each worker. However, determinism is not guaranteed,
|
||||
# and the result may be different on different workers.
|
||||
class_id = _try_to_compute_deterministic_class_id(cls)
|
||||
except Exception:
|
||||
raise ValueError(
|
||||
"Failed to use pickle in generating a unique id"
|
||||
f"for '{cls}'. Provide a unique class_id.")
|
||||
else:
|
||||
# In this case, the class ID only needs to be meaningful on
|
||||
# this worker and not across workers.
|
||||
class_id = _random_string()
|
||||
|
||||
# Make sure class_id is a string.
|
||||
class_id = ray.utils.binary_to_hex(class_id)
|
||||
|
||||
if job_id is None:
|
||||
job_id = self.worker.current_job_id
|
||||
assert isinstance(job_id, JobID)
|
||||
|
||||
def register_class_for_serialization(worker_info):
|
||||
context = worker_info["worker"].get_serialization_context(job_id)
|
||||
context._register_cloudpickle_serializer(cls, serializer,
|
||||
deserializer)
|
||||
|
||||
if not local:
|
||||
self.worker.run_function_on_all_workers(
|
||||
register_class_for_serialization)
|
||||
else:
|
||||
# Since we are pickling objects of this class, we don't actually
|
||||
# need to ship the class definition.
|
||||
register_class_for_serialization({"worker": self.worker})
|
||||
|
||||
@@ -598,6 +598,25 @@ def test_buffer_alignment(ray_start_shared_local_modes):
|
||||
assert y.ctypes.data % 8 == 0
|
||||
|
||||
|
||||
def test_custom_serializer(ray_start_shared_local_modes):
|
||||
import threading
|
||||
|
||||
class A:
|
||||
def __init__(self, x):
|
||||
self.x = x
|
||||
self.lock = threading.Lock()
|
||||
|
||||
def custom_serializer(a):
|
||||
return a.x
|
||||
|
||||
def custom_deserializer(x):
|
||||
return A(x)
|
||||
|
||||
ray.util.register_serializer(
|
||||
A, serializer=custom_serializer, deserializer=custom_deserializer)
|
||||
ray.get(ray.put(A(1)))
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
import pytest
|
||||
sys.exit(pytest.main(["-v", __file__]))
|
||||
|
||||
@@ -6,6 +6,7 @@ from ray.util.debug import log_once, disable_log_once_globally, \
|
||||
from ray.util.placement_group import (placement_group, placement_group_table,
|
||||
remove_placement_group)
|
||||
from ray.util import rpdb as pdb
|
||||
from ray.util.serialization import register_serializer
|
||||
|
||||
from ray.util.client_connect import connect, disconnect
|
||||
|
||||
@@ -23,4 +24,5 @@ __all__ = [
|
||||
"collective",
|
||||
"connect",
|
||||
"disconnect",
|
||||
"register_serializer",
|
||||
]
|
||||
|
||||
@@ -0,0 +1,18 @@
|
||||
import ray
|
||||
|
||||
|
||||
def register_serializer(cls, *, serializer, deserializer):
|
||||
"""Use the given serializer to serialize instances of type ``cls``,
|
||||
and use the deserializer to deserialize the serialized object.
|
||||
|
||||
Args:
|
||||
cls: A Python class/type.
|
||||
serializer (callable): A function that converts an instances of
|
||||
type ``cls`` into a serializable object (e.g. python dict
|
||||
of basic objects).
|
||||
deserializer (callable): A function that constructs the
|
||||
instance of type ``cls`` from the serialized object.
|
||||
This function itself must be serializable.
|
||||
"""
|
||||
context = ray.worker.global_worker.get_serialization_context()
|
||||
context._register_cloudpickle_serializer(cls, serializer, deserializer)
|
||||
Reference in New Issue
Block a user