mirror of
https://github.com/wassname/ray.git
synced 2026-06-29 00:55:31 +08:00
[Serialization] API for deregistering serializers; code & doc cleanup (#13471)
* make methods private, remove confusion brackets and usages * unregister serializer; fix doc * Cleanup doc * rename unregister -> deregister
This commit is contained in:
committed by
GitHub
parent
b20a38febb
commit
0b598c0f05
+1
-1
@@ -937,7 +937,7 @@ class ActorHandle:
|
||||
def __reduce__(self):
|
||||
"""This code path is used by pickling but not by Ray forking."""
|
||||
state = self._serialization_helper()
|
||||
return ActorHandle._deserialization_helper, (state)
|
||||
return ActorHandle._deserialization_helper, state
|
||||
|
||||
|
||||
def modify_class(cls):
|
||||
|
||||
@@ -31,7 +31,7 @@ class DeserializationError(Exception):
|
||||
pass
|
||||
|
||||
|
||||
def object_ref_deserializer(reduced_obj_ref, owner_address):
|
||||
def _object_ref_deserializer(binary, owner_address):
|
||||
# NOTE(suquark): This function should be a global function so
|
||||
# cloudpickle can access it directly. Otherwise couldpickle
|
||||
# has to dump the whole function definition, which is inefficient.
|
||||
@@ -40,9 +40,7 @@ def object_ref_deserializer(reduced_obj_ref, owner_address):
|
||||
# the core worker to resolve the value. This is to make sure
|
||||
# that the ref count for the ObjectRef is greater than 0 by the
|
||||
# time the core worker resolves the value of the object.
|
||||
|
||||
# UniqueIDs are serialized as (class name, (unique bytes,)).
|
||||
obj_ref = reduced_obj_ref[0](*reduced_obj_ref[1])
|
||||
obj_ref = ray.ObjectRef(binary)
|
||||
|
||||
# TODO(edoakes): we should be able to just capture a reference
|
||||
# to 'self' here instead, but this function is itself pickled
|
||||
@@ -61,7 +59,7 @@ def object_ref_deserializer(reduced_obj_ref, owner_address):
|
||||
return obj_ref
|
||||
|
||||
|
||||
def actor_handle_deserializer(serialized_obj):
|
||||
def _actor_handle_deserializer(serialized_obj):
|
||||
# If this actor handle was stored in another object, then tell the
|
||||
# core worker.
|
||||
context = ray.worker.global_worker.get_serialization_context()
|
||||
@@ -85,7 +83,7 @@ class SerializationContext:
|
||||
serialized, actor_handle_id = obj._serialization_helper()
|
||||
# Update ref counting for the actor handle
|
||||
self.add_contained_object_ref(actor_handle_id)
|
||||
return actor_handle_deserializer, (serialized, )
|
||||
return _actor_handle_deserializer, (serialized, )
|
||||
|
||||
self._register_cloudpickle_reducer(ray.actor.ActorHandle,
|
||||
actor_handle_reducer)
|
||||
@@ -96,13 +94,16 @@ class SerializationContext:
|
||||
worker.check_connected()
|
||||
obj, owner_address = (
|
||||
worker.core_worker.serialize_and_promote_object_ref(obj))
|
||||
return object_ref_deserializer, (obj.__reduce__(), owner_address)
|
||||
return _object_ref_deserializer, (obj.binary(), owner_address)
|
||||
|
||||
self._register_cloudpickle_reducer(ray.ObjectRef, object_ref_reducer)
|
||||
|
||||
def _register_cloudpickle_reducer(self, cls, reducer):
|
||||
pickle.CloudPickler.dispatch[cls] = reducer
|
||||
|
||||
def _unregister_cloudpickle_reducer(self, cls):
|
||||
pickle.CloudPickler.dispatch.pop(cls, None)
|
||||
|
||||
def _register_cloudpickle_serializer(self, cls, custom_serializer,
|
||||
custom_deserializer):
|
||||
def _CloudPicklerReducer(obj):
|
||||
@@ -198,7 +199,7 @@ class SerializationContext:
|
||||
elif metadata_fields[
|
||||
0] == ray_constants.OBJECT_METADATA_TYPE_ACTOR_HANDLE:
|
||||
obj = self._deserialize_msgpack_data(data, metadata_fields)
|
||||
return actor_handle_deserializer(obj)
|
||||
return _actor_handle_deserializer(obj)
|
||||
# Otherwise, return an exception object based on
|
||||
# the error type.
|
||||
try:
|
||||
|
||||
@@ -616,6 +616,13 @@ def test_custom_serializer(ray_start_shared_local_modes):
|
||||
A, serializer=custom_serializer, deserializer=custom_deserializer)
|
||||
ray.get(ray.put(A(1)))
|
||||
|
||||
ray.util.deregister_serializer(A)
|
||||
with pytest.raises(Exception):
|
||||
ray.get(ray.put(A(1)))
|
||||
|
||||
# deregister again takes no effects
|
||||
ray.util.deregister_serializer(A)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
import pytest
|
||||
|
||||
@@ -6,7 +6,7 @@ from ray.util.debug import log_once, disable_log_once_globally, \
|
||||
from ray.util.placement_group import (placement_group, placement_group_table,
|
||||
remove_placement_group)
|
||||
from ray.util import rpdb as pdb
|
||||
from ray.util.serialization import register_serializer
|
||||
from ray.util.serialization import register_serializer, deregister_serializer
|
||||
|
||||
from ray.util.client_connect import connect, disconnect
|
||||
|
||||
@@ -25,4 +25,5 @@ __all__ = [
|
||||
"connect",
|
||||
"disconnect",
|
||||
"register_serializer",
|
||||
"deregister_serializer",
|
||||
]
|
||||
|
||||
@@ -16,3 +16,14 @@ def register_serializer(cls, *, serializer, deserializer):
|
||||
"""
|
||||
context = ray.worker.global_worker.get_serialization_context()
|
||||
context._register_cloudpickle_serializer(cls, serializer, deserializer)
|
||||
|
||||
|
||||
def deregister_serializer(cls):
|
||||
"""Deregister the serializer associated with the type ``cls``.
|
||||
There is no effect if the serializer is unavailable.
|
||||
|
||||
Args:
|
||||
cls: A Python class/type.
|
||||
"""
|
||||
context = ray.worker.global_worker.get_serialization_context()
|
||||
context._unregister_cloudpickle_reducer(cls)
|
||||
|
||||
Reference in New Issue
Block a user