Change Python's ObjectID to ObjectRef (#9353)

This commit is contained in:
Hao Chen
2020-07-10 17:49:04 +08:00
committed by GitHub
parent 6311e5a947
commit d49dadf891
91 changed files with 959 additions and 907 deletions
+51 -50
View File
@@ -88,7 +88,7 @@ def _try_to_compute_deterministic_class_id(cls, depth=5):
class SerializationContext:
"""Initialize the serialization library.
This defines a custom serializer for object IDs and also tells ray to
This defines a custom serializer for object refs and also tells ray to
serialize several exception classes that we define for error handling.
"""
@@ -99,14 +99,14 @@ class SerializationContext:
def actor_handle_serializer(obj):
serialized, actor_handle_id = obj._serialization_helper()
# Update ref counting for the actor handle
self.add_contained_object_id(actor_handle_id)
self.add_contained_object_ref(actor_handle_id)
return serialized
def actor_handle_deserializer(serialized_obj):
# If this actor handle was stored in another object, then tell the
# core worker.
context = ray.worker.global_worker.get_serialization_context()
outer_id = context.get_outer_object_id()
outer_id = context.get_outer_object_ref()
return ray.actor.ActorHandle._deserialization_helper(
serialized_obj, outer_id)
@@ -121,22 +121,22 @@ class SerializationContext:
def id_deserializer(serialized_obj):
return serialized_obj[0](*serialized_obj[1])
def object_id_serializer(obj):
self.add_contained_object_id(obj)
def object_ref_serializer(obj):
self.add_contained_object_ref(obj)
worker = ray.worker.global_worker
worker.check_connected()
obj, owner_address = (
worker.core_worker.serialize_and_promote_object_id(obj))
worker.core_worker.serialize_and_promote_object_ref(obj))
obj = id_serializer(obj)
return obj, owner_address
def object_id_deserializer(serialized_obj):
obj_id, owner_address = serialized_obj
def object_ref_deserializer(serialized_obj):
obj_ref, owner_address = serialized_obj
# NOTE(swang): Must deserialize the object first before asking
# the core worker to resolve the value. This is to make sure
# that the ref count for the ObjectID is greater than 0 by the
# that the ref count for the ObjectRef is greater than 0 by the
# time the core worker resolves the value of the object.
deserialized_object_id = id_deserializer(obj_id)
deserialized_object_ref = id_deserializer(obj_ref)
# TODO(edoakes): we should be able to just capture a reference
# to 'self' here instead, but this function is itself pickled
# somewhere, which causes an error.
@@ -146,19 +146,19 @@ class SerializationContext:
worker.check_connected()
# UniqueIDs are serialized as
# (class name, (unique bytes,)).
outer_id = context.get_outer_object_id()
# outer_id is None in the case that this ObjectID was closed
outer_id = context.get_outer_object_ref()
# outer_id is None in the case that this ObjectRef was closed
# over in a function or pickled directly using pickle.dumps().
if outer_id is None:
outer_id = ray.ObjectID.nil()
worker.core_worker.deserialize_and_register_object_id(
obj_id[1][0], outer_id, owner_address)
return deserialized_object_id
outer_id = ray.ObjectRef.nil()
worker.core_worker.deserialize_and_register_object_ref(
obj_ref[1][0], outer_id, owner_address)
return deserialized_object_ref
for id_type in ray._raylet._ID_TYPES:
if id_type == ray._raylet.ObjectID:
if id_type == ray._raylet.ObjectRef:
self._register_cloudpickle_serializer(
id_type, object_id_serializer, object_id_deserializer)
id_type, object_ref_serializer, object_ref_deserializer)
else:
self._register_cloudpickle_serializer(id_type, id_serializer,
id_deserializer)
@@ -180,36 +180,36 @@ class SerializationContext:
def set_out_of_band_serialization(self):
self._thread_local.in_band = False
def set_outer_object_id(self, outer_object_id):
self._thread_local.outer_object_id = outer_object_id
def set_outer_object_ref(self, outer_object_ref):
self._thread_local.outer_object_ref = outer_object_ref
def get_outer_object_id(self):
return getattr(self._thread_local, "outer_object_id", None)
def get_outer_object_ref(self):
return getattr(self._thread_local, "outer_object_ref", None)
def get_and_clear_contained_object_ids(self):
if not hasattr(self._thread_local, "object_ids"):
self._thread_local.object_ids = set()
def get_and_clear_contained_object_refs(self):
if not hasattr(self._thread_local, "object_refs"):
self._thread_local.object_refs = set()
return set()
object_ids = self._thread_local.object_ids
self._thread_local.object_ids = set()
return object_ids
object_refs = self._thread_local.object_refs
self._thread_local.object_refs = set()
return object_refs
def add_contained_object_id(self, object_id):
def add_contained_object_ref(self, object_ref):
if self.is_in_band_serialization():
# This object ID is being stored in an object. Add the ID to the
# This object ref is being stored in an object. Add the ID to the
# list of IDs contained in the object so that we keep the inner
# object value alive as long as the outer object is in scope.
if not hasattr(self._thread_local, "object_ids"):
self._thread_local.object_ids = set()
self._thread_local.object_ids.add(object_id)
if not hasattr(self._thread_local, "object_refs"):
self._thread_local.object_refs = set()
self._thread_local.object_refs.add(object_ref)
else:
# If this serialization is out-of-band (e.g., from a call to
# cloudpickle directly or captured in a remote function/actor),
# then pin the object for the lifetime of this worker by adding
# a local reference that won't ever be removed.
ray.worker.global_worker.core_worker.add_object_id_reference(
object_id)
ray.worker.global_worker.core_worker.add_object_ref_reference(
object_ref)
def _deserialize_pickle5_data(self, data):
try:
@@ -242,7 +242,7 @@ class SerializationContext:
raise DeserializationError()
return obj
def _deserialize_object(self, data, metadata, object_id):
def _deserialize_object(self, data, metadata, object_ref):
if metadata:
if metadata in [
ray_constants.OBJECT_METADATA_TYPE_CROSS_LANGUAGE,
@@ -261,7 +261,7 @@ class SerializationContext:
except Exception:
raise Exception(
"Can't deserialize object: {}, metadata: {}".format(
object_id, metadata))
object_ref, metadata))
# RayTaskError is serialized with pickle5 in the data field.
# TODO (kfstorm): exception serialization should be language
@@ -277,7 +277,8 @@ class SerializationContext:
elif error_type == ErrorType.Value("TASK_CANCELLED"):
return RayCancellationError()
elif error_type == ErrorType.Value("OBJECT_UNRECONSTRUCTABLE"):
return UnreconstructableError(ray.ObjectID(object_id.binary()))
return UnreconstructableError(
ray.ObjectRef(object_ref.binary()))
else:
assert error_type != ErrorType.Value("OBJECT_IN_PLASMA"), \
"Tried to get object that has been promoted to plasma."
@@ -293,22 +294,22 @@ class SerializationContext:
def deserialize_objects(self,
data_metadata_pairs,
object_ids,
object_refs,
error_timeout=10):
assert len(data_metadata_pairs) == len(object_ids)
assert len(data_metadata_pairs) == len(object_refs)
start_time = time.time()
results = []
warning_sent = False
i = 0
while i < len(object_ids):
object_id = object_ids[i]
while i < len(object_refs):
object_ref = object_refs[i]
data, metadata = data_metadata_pairs[i]
assert self.get_outer_object_id() is None
self.set_outer_object_id(object_id)
assert self.get_outer_object_ref() is None
self.set_outer_object_ref(object_ref)
try:
results.append(
self._deserialize_object(data, metadata, object_id))
self._deserialize_object(data, metadata, object_ref))
i += 1
except DeserializationError:
# Wait a little bit for the import thread to import the class.
@@ -330,27 +331,27 @@ class SerializationContext:
job_id=self.worker.current_job_id)
warning_sent = True
finally:
# Must clear ObjectID to not hold a reference.
self.set_outer_object_id(None)
# Must clear ObjectRef to not hold a reference.
self.set_outer_object_ref(None)
return results
def _serialize_to_pickle5(self, metadata, value):
writer = Pickle5Writer()
# TODO(swang): Check that contained_object_ids is empty.
# TODO(swang): Check that contained_object_refs is empty.
try:
self.set_in_band_serialization()
inband = pickle.dumps(
value, protocol=5, buffer_callback=writer.buffer_callback)
except Exception as e:
self.get_and_clear_contained_object_ids()
self.get_and_clear_contained_object_refs()
raise e
finally:
self.set_out_of_band_serialization()
return Pickle5SerializedObject(
metadata, inband, writer,
self.get_and_clear_contained_object_ids())
self.get_and_clear_contained_object_refs())
def _serialize_to_msgpack(self, metadata, value):
python_objects = []