[Serialization] API for deregistering serializers; code & doc cleanup (#13471)

* make methods private, remove confusion brackets and usages

* unregister serializer; fix doc

* Cleanup doc

* rename unregister -> deregister
This commit is contained in:
Siyuan (Ryans) Zhuang
2021-01-29 10:27:05 -08:00
committed by GitHub
parent b20a38febb
commit 0b598c0f05
6 changed files with 85 additions and 56 deletions
+1 -1
View File
@@ -937,7 +937,7 @@ class ActorHandle:
def __reduce__(self):
"""This code path is used by pickling but not by Ray forking."""
state = self._serialization_helper()
return ActorHandle._deserialization_helper, (state)
return ActorHandle._deserialization_helper, state
def modify_class(cls):
+9 -8
View File
@@ -31,7 +31,7 @@ class DeserializationError(Exception):
pass
def object_ref_deserializer(reduced_obj_ref, owner_address):
def _object_ref_deserializer(binary, owner_address):
# NOTE(suquark): This function should be a global function so
# cloudpickle can access it directly. Otherwise couldpickle
# has to dump the whole function definition, which is inefficient.
@@ -40,9 +40,7 @@ def object_ref_deserializer(reduced_obj_ref, owner_address):
# the core worker to resolve the value. This is to make sure
# that the ref count for the ObjectRef is greater than 0 by the
# time the core worker resolves the value of the object.
# UniqueIDs are serialized as (class name, (unique bytes,)).
obj_ref = reduced_obj_ref[0](*reduced_obj_ref[1])
obj_ref = ray.ObjectRef(binary)
# TODO(edoakes): we should be able to just capture a reference
# to 'self' here instead, but this function is itself pickled
@@ -61,7 +59,7 @@ def object_ref_deserializer(reduced_obj_ref, owner_address):
return obj_ref
def actor_handle_deserializer(serialized_obj):
def _actor_handle_deserializer(serialized_obj):
# If this actor handle was stored in another object, then tell the
# core worker.
context = ray.worker.global_worker.get_serialization_context()
@@ -85,7 +83,7 @@ class SerializationContext:
serialized, actor_handle_id = obj._serialization_helper()
# Update ref counting for the actor handle
self.add_contained_object_ref(actor_handle_id)
return actor_handle_deserializer, (serialized, )
return _actor_handle_deserializer, (serialized, )
self._register_cloudpickle_reducer(ray.actor.ActorHandle,
actor_handle_reducer)
@@ -96,13 +94,16 @@ class SerializationContext:
worker.check_connected()
obj, owner_address = (
worker.core_worker.serialize_and_promote_object_ref(obj))
return object_ref_deserializer, (obj.__reduce__(), owner_address)
return _object_ref_deserializer, (obj.binary(), owner_address)
self._register_cloudpickle_reducer(ray.ObjectRef, object_ref_reducer)
def _register_cloudpickle_reducer(self, cls, reducer):
pickle.CloudPickler.dispatch[cls] = reducer
def _unregister_cloudpickle_reducer(self, cls):
pickle.CloudPickler.dispatch.pop(cls, None)
def _register_cloudpickle_serializer(self, cls, custom_serializer,
custom_deserializer):
def _CloudPicklerReducer(obj):
@@ -198,7 +199,7 @@ class SerializationContext:
elif metadata_fields[
0] == ray_constants.OBJECT_METADATA_TYPE_ACTOR_HANDLE:
obj = self._deserialize_msgpack_data(data, metadata_fields)
return actor_handle_deserializer(obj)
return _actor_handle_deserializer(obj)
# Otherwise, return an exception object based on
# the error type.
try:
+7
View File
@@ -616,6 +616,13 @@ def test_custom_serializer(ray_start_shared_local_modes):
A, serializer=custom_serializer, deserializer=custom_deserializer)
ray.get(ray.put(A(1)))
ray.util.deregister_serializer(A)
with pytest.raises(Exception):
ray.get(ray.put(A(1)))
# deregister again takes no effects
ray.util.deregister_serializer(A)
if __name__ == "__main__":
import pytest
+2 -1
View File
@@ -6,7 +6,7 @@ from ray.util.debug import log_once, disable_log_once_globally, \
from ray.util.placement_group import (placement_group, placement_group_table,
remove_placement_group)
from ray.util import rpdb as pdb
from ray.util.serialization import register_serializer
from ray.util.serialization import register_serializer, deregister_serializer
from ray.util.client_connect import connect, disconnect
@@ -25,4 +25,5 @@ __all__ = [
"connect",
"disconnect",
"register_serializer",
"deregister_serializer",
]
+11
View File
@@ -16,3 +16,14 @@ def register_serializer(cls, *, serializer, deserializer):
"""
context = ray.worker.global_worker.get_serialization_context()
context._register_cloudpickle_serializer(cls, serializer, deserializer)
def deregister_serializer(cls):
"""Deregister the serializer associated with the type ``cls``.
There is no effect if the serializer is unavailable.
Args:
cls: A Python class/type.
"""
context = ray.worker.global_worker.get_serialization_context()
context._unregister_cloudpickle_reducer(cls)