diff --git a/python/ray/serialization.py b/python/ray/serialization.py index c871c3e56..1064a9cbd 100644 --- a/python/ray/serialization.py +++ b/python/ray/serialization.py @@ -85,6 +85,45 @@ def _try_to_compute_deterministic_class_id(cls, depth=5): return hashlib.sha1(new_class_id).digest() +def object_ref_deserializer(reduced_obj_ref, owner_address): + # NOTE(suquark): This function should be a global function so + # cloudpickle can access it directly. Otherwise couldpickle + # has to dump the whole function definition, which is inefficient. + + # NOTE(swang): Must deserialize the object first before asking + # the core worker to resolve the value. This is to make sure + # that the ref count for the ObjectRef is greater than 0 by the + # time the core worker resolves the value of the object. + + # UniqueIDs are serialized as (class name, (unique bytes,)). + obj_ref = reduced_obj_ref[0](*reduced_obj_ref[1]) + + # TODO(edoakes): we should be able to just capture a reference + # to 'self' here instead, but this function is itself pickled + # somewhere, which causes an error. + if owner_address: + worker = ray.worker.global_worker + worker.check_connected() + context = worker.get_serialization_context() + outer_id = context.get_outer_object_ref() + # outer_id is None in the case that this ObjectRef was closed + # over in a function or pickled directly using pickle.dumps(). + if outer_id is None: + outer_id = ray.ObjectRef.nil() + worker.core_worker.deserialize_and_register_object_ref( + obj_ref.binary(), outer_id, owner_address) + return obj_ref + + +def actor_handle_deserializer(serialized_obj): + # If this actor handle was stored in another object, then tell the + # core worker. + context = ray.worker.global_worker.get_serialization_context() + outer_id = context.get_outer_object_ref() + return ray.actor.ActorHandle._deserialization_helper( + serialized_obj, outer_id) + + class SerializationContext: """Initialize the serialization library. @@ -96,72 +135,29 @@ class SerializationContext: self.worker = worker self._thread_local = threading.local() - def actor_handle_serializer(obj): + def actor_handle_reducer(obj): serialized, actor_handle_id = obj._serialization_helper() # Update ref counting for the actor handle self.add_contained_object_ref(actor_handle_id) - return serialized + return actor_handle_deserializer, (serialized, ) - def actor_handle_deserializer(serialized_obj): - # If this actor handle was stored in another object, then tell the - # core worker. - context = ray.worker.global_worker.get_serialization_context() - outer_id = context.get_outer_object_ref() - return ray.actor.ActorHandle._deserialization_helper( - serialized_obj, outer_id) + self._register_cloudpickle_reducer(ray.actor.ActorHandle, + actor_handle_reducer) - self._register_cloudpickle_serializer( - ray.actor.ActorHandle, - custom_serializer=actor_handle_serializer, - custom_deserializer=actor_handle_deserializer) - - def id_serializer(obj): - return obj.__reduce__() - - def id_deserializer(serialized_obj): - return serialized_obj[0](*serialized_obj[1]) - - def object_ref_serializer(obj): + def object_ref_reducer(obj): self.add_contained_object_ref(obj) worker = ray.worker.global_worker worker.check_connected() obj, owner_address = ( worker.core_worker.serialize_and_promote_object_ref(obj)) - obj = id_serializer(obj) - return obj, owner_address + return object_ref_deserializer, (obj.__reduce__(), owner_address) - def object_ref_deserializer(serialized_obj): - obj_ref, owner_address = serialized_obj - # NOTE(swang): Must deserialize the object first before asking - # the core worker to resolve the value. This is to make sure - # that the ref count for the ObjectRef is greater than 0 by the - # time the core worker resolves the value of the object. - deserialized_object_ref = id_deserializer(obj_ref) - # TODO(edoakes): we should be able to just capture a reference - # to 'self' here instead, but this function is itself pickled - # somewhere, which causes an error. - context = ray.worker.global_worker.get_serialization_context() - if owner_address: - worker = ray.worker.global_worker - worker.check_connected() - # UniqueIDs are serialized as - # (class name, (unique bytes,)). - outer_id = context.get_outer_object_ref() - # outer_id is None in the case that this ObjectRef was closed - # over in a function or pickled directly using pickle.dumps(). - if outer_id is None: - outer_id = ray.ObjectRef.nil() - worker.core_worker.deserialize_and_register_object_ref( - obj_ref[1][0], outer_id, owner_address) - return deserialized_object_ref + # Because objects have default __reduce__ method, we only need to + # treat ObjectRef specifically. + self._register_cloudpickle_reducer(ray.ObjectRef, object_ref_reducer) - for id_type in ray._raylet._ID_TYPES: - if id_type == ray._raylet.ObjectRef: - self._register_cloudpickle_serializer( - id_type, object_ref_serializer, object_ref_deserializer) - else: - self._register_cloudpickle_serializer(id_type, id_serializer, - id_deserializer) + def _register_cloudpickle_reducer(self, cls, reducer): + pickle.CloudPickler.dispatch[cls] = reducer def _register_cloudpickle_serializer(self, cls, custom_serializer, custom_deserializer):