mirror of
https://github.com/wassname/ray.git
synced 2026-07-05 21:07:06 +08:00
Change Python's ObjectID to ObjectRef (#9353)
This commit is contained in:
+51
-50
@@ -88,7 +88,7 @@ def _try_to_compute_deterministic_class_id(cls, depth=5):
|
||||
class SerializationContext:
|
||||
"""Initialize the serialization library.
|
||||
|
||||
This defines a custom serializer for object IDs and also tells ray to
|
||||
This defines a custom serializer for object refs and also tells ray to
|
||||
serialize several exception classes that we define for error handling.
|
||||
"""
|
||||
|
||||
@@ -99,14 +99,14 @@ class SerializationContext:
|
||||
def actor_handle_serializer(obj):
|
||||
serialized, actor_handle_id = obj._serialization_helper()
|
||||
# Update ref counting for the actor handle
|
||||
self.add_contained_object_id(actor_handle_id)
|
||||
self.add_contained_object_ref(actor_handle_id)
|
||||
return serialized
|
||||
|
||||
def actor_handle_deserializer(serialized_obj):
|
||||
# If this actor handle was stored in another object, then tell the
|
||||
# core worker.
|
||||
context = ray.worker.global_worker.get_serialization_context()
|
||||
outer_id = context.get_outer_object_id()
|
||||
outer_id = context.get_outer_object_ref()
|
||||
return ray.actor.ActorHandle._deserialization_helper(
|
||||
serialized_obj, outer_id)
|
||||
|
||||
@@ -121,22 +121,22 @@ class SerializationContext:
|
||||
def id_deserializer(serialized_obj):
|
||||
return serialized_obj[0](*serialized_obj[1])
|
||||
|
||||
def object_id_serializer(obj):
|
||||
self.add_contained_object_id(obj)
|
||||
def object_ref_serializer(obj):
|
||||
self.add_contained_object_ref(obj)
|
||||
worker = ray.worker.global_worker
|
||||
worker.check_connected()
|
||||
obj, owner_address = (
|
||||
worker.core_worker.serialize_and_promote_object_id(obj))
|
||||
worker.core_worker.serialize_and_promote_object_ref(obj))
|
||||
obj = id_serializer(obj)
|
||||
return obj, owner_address
|
||||
|
||||
def object_id_deserializer(serialized_obj):
|
||||
obj_id, owner_address = serialized_obj
|
||||
def object_ref_deserializer(serialized_obj):
|
||||
obj_ref, owner_address = serialized_obj
|
||||
# NOTE(swang): Must deserialize the object first before asking
|
||||
# the core worker to resolve the value. This is to make sure
|
||||
# that the ref count for the ObjectID is greater than 0 by the
|
||||
# that the ref count for the ObjectRef is greater than 0 by the
|
||||
# time the core worker resolves the value of the object.
|
||||
deserialized_object_id = id_deserializer(obj_id)
|
||||
deserialized_object_ref = id_deserializer(obj_ref)
|
||||
# TODO(edoakes): we should be able to just capture a reference
|
||||
# to 'self' here instead, but this function is itself pickled
|
||||
# somewhere, which causes an error.
|
||||
@@ -146,19 +146,19 @@ class SerializationContext:
|
||||
worker.check_connected()
|
||||
# UniqueIDs are serialized as
|
||||
# (class name, (unique bytes,)).
|
||||
outer_id = context.get_outer_object_id()
|
||||
# outer_id is None in the case that this ObjectID was closed
|
||||
outer_id = context.get_outer_object_ref()
|
||||
# outer_id is None in the case that this ObjectRef was closed
|
||||
# over in a function or pickled directly using pickle.dumps().
|
||||
if outer_id is None:
|
||||
outer_id = ray.ObjectID.nil()
|
||||
worker.core_worker.deserialize_and_register_object_id(
|
||||
obj_id[1][0], outer_id, owner_address)
|
||||
return deserialized_object_id
|
||||
outer_id = ray.ObjectRef.nil()
|
||||
worker.core_worker.deserialize_and_register_object_ref(
|
||||
obj_ref[1][0], outer_id, owner_address)
|
||||
return deserialized_object_ref
|
||||
|
||||
for id_type in ray._raylet._ID_TYPES:
|
||||
if id_type == ray._raylet.ObjectID:
|
||||
if id_type == ray._raylet.ObjectRef:
|
||||
self._register_cloudpickle_serializer(
|
||||
id_type, object_id_serializer, object_id_deserializer)
|
||||
id_type, object_ref_serializer, object_ref_deserializer)
|
||||
else:
|
||||
self._register_cloudpickle_serializer(id_type, id_serializer,
|
||||
id_deserializer)
|
||||
@@ -180,36 +180,36 @@ class SerializationContext:
|
||||
def set_out_of_band_serialization(self):
|
||||
self._thread_local.in_band = False
|
||||
|
||||
def set_outer_object_id(self, outer_object_id):
|
||||
self._thread_local.outer_object_id = outer_object_id
|
||||
def set_outer_object_ref(self, outer_object_ref):
|
||||
self._thread_local.outer_object_ref = outer_object_ref
|
||||
|
||||
def get_outer_object_id(self):
|
||||
return getattr(self._thread_local, "outer_object_id", None)
|
||||
def get_outer_object_ref(self):
|
||||
return getattr(self._thread_local, "outer_object_ref", None)
|
||||
|
||||
def get_and_clear_contained_object_ids(self):
|
||||
if not hasattr(self._thread_local, "object_ids"):
|
||||
self._thread_local.object_ids = set()
|
||||
def get_and_clear_contained_object_refs(self):
|
||||
if not hasattr(self._thread_local, "object_refs"):
|
||||
self._thread_local.object_refs = set()
|
||||
return set()
|
||||
|
||||
object_ids = self._thread_local.object_ids
|
||||
self._thread_local.object_ids = set()
|
||||
return object_ids
|
||||
object_refs = self._thread_local.object_refs
|
||||
self._thread_local.object_refs = set()
|
||||
return object_refs
|
||||
|
||||
def add_contained_object_id(self, object_id):
|
||||
def add_contained_object_ref(self, object_ref):
|
||||
if self.is_in_band_serialization():
|
||||
# This object ID is being stored in an object. Add the ID to the
|
||||
# This object ref is being stored in an object. Add the ID to the
|
||||
# list of IDs contained in the object so that we keep the inner
|
||||
# object value alive as long as the outer object is in scope.
|
||||
if not hasattr(self._thread_local, "object_ids"):
|
||||
self._thread_local.object_ids = set()
|
||||
self._thread_local.object_ids.add(object_id)
|
||||
if not hasattr(self._thread_local, "object_refs"):
|
||||
self._thread_local.object_refs = set()
|
||||
self._thread_local.object_refs.add(object_ref)
|
||||
else:
|
||||
# If this serialization is out-of-band (e.g., from a call to
|
||||
# cloudpickle directly or captured in a remote function/actor),
|
||||
# then pin the object for the lifetime of this worker by adding
|
||||
# a local reference that won't ever be removed.
|
||||
ray.worker.global_worker.core_worker.add_object_id_reference(
|
||||
object_id)
|
||||
ray.worker.global_worker.core_worker.add_object_ref_reference(
|
||||
object_ref)
|
||||
|
||||
def _deserialize_pickle5_data(self, data):
|
||||
try:
|
||||
@@ -242,7 +242,7 @@ class SerializationContext:
|
||||
raise DeserializationError()
|
||||
return obj
|
||||
|
||||
def _deserialize_object(self, data, metadata, object_id):
|
||||
def _deserialize_object(self, data, metadata, object_ref):
|
||||
if metadata:
|
||||
if metadata in [
|
||||
ray_constants.OBJECT_METADATA_TYPE_CROSS_LANGUAGE,
|
||||
@@ -261,7 +261,7 @@ class SerializationContext:
|
||||
except Exception:
|
||||
raise Exception(
|
||||
"Can't deserialize object: {}, metadata: {}".format(
|
||||
object_id, metadata))
|
||||
object_ref, metadata))
|
||||
|
||||
# RayTaskError is serialized with pickle5 in the data field.
|
||||
# TODO (kfstorm): exception serialization should be language
|
||||
@@ -277,7 +277,8 @@ class SerializationContext:
|
||||
elif error_type == ErrorType.Value("TASK_CANCELLED"):
|
||||
return RayCancellationError()
|
||||
elif error_type == ErrorType.Value("OBJECT_UNRECONSTRUCTABLE"):
|
||||
return UnreconstructableError(ray.ObjectID(object_id.binary()))
|
||||
return UnreconstructableError(
|
||||
ray.ObjectRef(object_ref.binary()))
|
||||
else:
|
||||
assert error_type != ErrorType.Value("OBJECT_IN_PLASMA"), \
|
||||
"Tried to get object that has been promoted to plasma."
|
||||
@@ -293,22 +294,22 @@ class SerializationContext:
|
||||
|
||||
def deserialize_objects(self,
|
||||
data_metadata_pairs,
|
||||
object_ids,
|
||||
object_refs,
|
||||
error_timeout=10):
|
||||
assert len(data_metadata_pairs) == len(object_ids)
|
||||
assert len(data_metadata_pairs) == len(object_refs)
|
||||
|
||||
start_time = time.time()
|
||||
results = []
|
||||
warning_sent = False
|
||||
i = 0
|
||||
while i < len(object_ids):
|
||||
object_id = object_ids[i]
|
||||
while i < len(object_refs):
|
||||
object_ref = object_refs[i]
|
||||
data, metadata = data_metadata_pairs[i]
|
||||
assert self.get_outer_object_id() is None
|
||||
self.set_outer_object_id(object_id)
|
||||
assert self.get_outer_object_ref() is None
|
||||
self.set_outer_object_ref(object_ref)
|
||||
try:
|
||||
results.append(
|
||||
self._deserialize_object(data, metadata, object_id))
|
||||
self._deserialize_object(data, metadata, object_ref))
|
||||
i += 1
|
||||
except DeserializationError:
|
||||
# Wait a little bit for the import thread to import the class.
|
||||
@@ -330,27 +331,27 @@ class SerializationContext:
|
||||
job_id=self.worker.current_job_id)
|
||||
warning_sent = True
|
||||
finally:
|
||||
# Must clear ObjectID to not hold a reference.
|
||||
self.set_outer_object_id(None)
|
||||
# Must clear ObjectRef to not hold a reference.
|
||||
self.set_outer_object_ref(None)
|
||||
|
||||
return results
|
||||
|
||||
def _serialize_to_pickle5(self, metadata, value):
|
||||
writer = Pickle5Writer()
|
||||
# TODO(swang): Check that contained_object_ids is empty.
|
||||
# TODO(swang): Check that contained_object_refs is empty.
|
||||
try:
|
||||
self.set_in_band_serialization()
|
||||
inband = pickle.dumps(
|
||||
value, protocol=5, buffer_callback=writer.buffer_callback)
|
||||
except Exception as e:
|
||||
self.get_and_clear_contained_object_ids()
|
||||
self.get_and_clear_contained_object_refs()
|
||||
raise e
|
||||
finally:
|
||||
self.set_out_of_band_serialization()
|
||||
|
||||
return Pickle5SerializedObject(
|
||||
metadata, inband, writer,
|
||||
self.get_and_clear_contained_object_ids())
|
||||
self.get_and_clear_contained_object_refs())
|
||||
|
||||
def _serialize_to_msgpack(self, metadata, value):
|
||||
python_objects = []
|
||||
|
||||
Reference in New Issue
Block a user