mirror of
https://github.com/wassname/ray.git
synced 2026-06-28 02:01:24 +08:00
479 lines
19 KiB
Python
479 lines
19 KiB
Python
import hashlib
|
|
import logging
|
|
import time
|
|
import threading
|
|
|
|
import ray.cloudpickle as pickle
|
|
from ray import ray_constants, JobID
|
|
import ray.utils
|
|
from ray.utils import _random_string
|
|
from ray.gcs_utils import ErrorType
|
|
from ray.exceptions import (
|
|
RayError,
|
|
PlasmaObjectNotAvailable,
|
|
RayTaskError,
|
|
RayActorError,
|
|
TaskCancelledError,
|
|
WorkerCrashedError,
|
|
ObjectLostError,
|
|
)
|
|
from ray._raylet import (
|
|
split_buffer,
|
|
unpack_pickle5_buffers,
|
|
Pickle5Writer,
|
|
Pickle5SerializedObject,
|
|
MessagePackSerializer,
|
|
MessagePackSerializedObject,
|
|
RawSerializedObject,
|
|
)
|
|
|
|
logger = logging.getLogger(__name__)
|
|
|
|
|
|
class RayNotDictionarySerializable(Exception):
|
|
pass
|
|
|
|
|
|
# This exception is used to represent situations where cloudpickle fails to
|
|
# pickle an object (cloudpickle can fail in many different ways).
|
|
class CloudPickleError(Exception):
|
|
pass
|
|
|
|
|
|
class DeserializationError(Exception):
|
|
pass
|
|
|
|
|
|
def _try_to_compute_deterministic_class_id(cls, depth=5):
|
|
"""Attempt to produce a deterministic class ID for a given class.
|
|
|
|
The goal here is for the class ID to be the same when this is run on
|
|
different worker processes. Pickling, loading, and pickling again seems to
|
|
produce more consistent results than simply pickling. This is a bit crazy
|
|
and could cause problems, in which case we should revert it and figure out
|
|
something better.
|
|
|
|
Args:
|
|
cls: The class to produce an ID for.
|
|
depth: The number of times to repeatedly try to load and dump the
|
|
string while trying to reach a fixed point.
|
|
|
|
Returns:
|
|
A class ID for this class. We attempt to make the class ID the same
|
|
when this function is run on different workers, but that is not
|
|
guaranteed.
|
|
|
|
Raises:
|
|
Exception: This could raise an exception if cloudpickle raises an
|
|
exception.
|
|
"""
|
|
# Pickling, loading, and pickling again seems to produce more consistent
|
|
# results than simply pickling. This is a bit
|
|
class_id = pickle.dumps(cls)
|
|
for _ in range(depth):
|
|
new_class_id = pickle.dumps(pickle.loads(class_id))
|
|
if new_class_id == class_id:
|
|
# We appear to have reached a fix point, so use this as the ID.
|
|
return hashlib.sha1(new_class_id).digest()
|
|
class_id = new_class_id
|
|
|
|
# We have not reached a fixed point, so we may end up with a different
|
|
# class ID for this custom class on each worker, which could lead to the
|
|
# same class definition being exported many many times.
|
|
logger.warning(
|
|
f"WARNING: Could not produce a deterministic class ID for class {cls}")
|
|
return hashlib.sha1(new_class_id).digest()
|
|
|
|
|
|
def object_ref_deserializer(reduced_obj_ref, owner_address):
|
|
# NOTE(suquark): This function should be a global function so
|
|
# cloudpickle can access it directly. Otherwise couldpickle
|
|
# has to dump the whole function definition, which is inefficient.
|
|
|
|
# NOTE(swang): Must deserialize the object first before asking
|
|
# the core worker to resolve the value. This is to make sure
|
|
# that the ref count for the ObjectRef is greater than 0 by the
|
|
# time the core worker resolves the value of the object.
|
|
|
|
# UniqueIDs are serialized as (class name, (unique bytes,)).
|
|
obj_ref = reduced_obj_ref[0](*reduced_obj_ref[1])
|
|
|
|
# TODO(edoakes): we should be able to just capture a reference
|
|
# to 'self' here instead, but this function is itself pickled
|
|
# somewhere, which causes an error.
|
|
if owner_address:
|
|
worker = ray.worker.global_worker
|
|
worker.check_connected()
|
|
context = worker.get_serialization_context()
|
|
outer_id = context.get_outer_object_ref()
|
|
# outer_id is None in the case that this ObjectRef was closed
|
|
# over in a function or pickled directly using pickle.dumps().
|
|
if outer_id is None:
|
|
outer_id = ray.ObjectRef.nil()
|
|
worker.core_worker.deserialize_and_register_object_ref(
|
|
obj_ref.binary(), outer_id, owner_address)
|
|
return obj_ref
|
|
|
|
|
|
def actor_handle_deserializer(serialized_obj):
|
|
# If this actor handle was stored in another object, then tell the
|
|
# core worker.
|
|
context = ray.worker.global_worker.get_serialization_context()
|
|
outer_id = context.get_outer_object_ref()
|
|
return ray.actor.ActorHandle._deserialization_helper(
|
|
serialized_obj, outer_id)
|
|
|
|
|
|
class SerializationContext:
|
|
"""Initialize the serialization library.
|
|
|
|
This defines a custom serializer for object refs and also tells ray to
|
|
serialize several exception classes that we define for error handling.
|
|
"""
|
|
|
|
def __init__(self, worker):
|
|
self.worker = worker
|
|
self._thread_local = threading.local()
|
|
|
|
def actor_handle_reducer(obj):
|
|
serialized, actor_handle_id = obj._serialization_helper()
|
|
# Update ref counting for the actor handle
|
|
self.add_contained_object_ref(actor_handle_id)
|
|
return actor_handle_deserializer, (serialized, )
|
|
|
|
self._register_cloudpickle_reducer(ray.actor.ActorHandle,
|
|
actor_handle_reducer)
|
|
|
|
def object_ref_reducer(obj):
|
|
self.add_contained_object_ref(obj)
|
|
worker = ray.worker.global_worker
|
|
worker.check_connected()
|
|
obj, owner_address = (
|
|
worker.core_worker.serialize_and_promote_object_ref(obj))
|
|
return object_ref_deserializer, (obj.__reduce__(), owner_address)
|
|
|
|
# Because objects have default __reduce__ method, we only need to
|
|
# treat ObjectRef specifically.
|
|
self._register_cloudpickle_reducer(ray.ObjectRef, object_ref_reducer)
|
|
|
|
def _register_cloudpickle_reducer(self, cls, reducer):
|
|
pickle.CloudPickler.dispatch[cls] = reducer
|
|
|
|
def _register_cloudpickle_serializer(self, cls, custom_serializer,
|
|
custom_deserializer):
|
|
def _CloudPicklerReducer(obj):
|
|
return custom_deserializer, (custom_serializer(obj), )
|
|
|
|
# construct a reducer
|
|
pickle.CloudPickler.dispatch[cls] = _CloudPicklerReducer
|
|
|
|
def is_in_band_serialization(self):
|
|
return getattr(self._thread_local, "in_band", False)
|
|
|
|
def set_in_band_serialization(self):
|
|
self._thread_local.in_band = True
|
|
|
|
def set_out_of_band_serialization(self):
|
|
self._thread_local.in_band = False
|
|
|
|
def set_outer_object_ref(self, outer_object_ref):
|
|
self._thread_local.outer_object_ref = outer_object_ref
|
|
|
|
def get_outer_object_ref(self):
|
|
return getattr(self._thread_local, "outer_object_ref", None)
|
|
|
|
def get_and_clear_contained_object_refs(self):
|
|
if not hasattr(self._thread_local, "object_refs"):
|
|
self._thread_local.object_refs = set()
|
|
return set()
|
|
|
|
object_refs = self._thread_local.object_refs
|
|
self._thread_local.object_refs = set()
|
|
return object_refs
|
|
|
|
def add_contained_object_ref(self, object_ref):
|
|
if self.is_in_band_serialization():
|
|
# This object ref is being stored in an object. Add the ID to the
|
|
# list of IDs contained in the object so that we keep the inner
|
|
# object value alive as long as the outer object is in scope.
|
|
if not hasattr(self._thread_local, "object_refs"):
|
|
self._thread_local.object_refs = set()
|
|
self._thread_local.object_refs.add(object_ref)
|
|
else:
|
|
# If this serialization is out-of-band (e.g., from a call to
|
|
# cloudpickle directly or captured in a remote function/actor),
|
|
# then pin the object for the lifetime of this worker by adding
|
|
# a local reference that won't ever be removed.
|
|
ray.worker.global_worker.core_worker.add_object_ref_reference(
|
|
object_ref)
|
|
|
|
def _deserialize_pickle5_data(self, data):
|
|
try:
|
|
in_band, buffers = unpack_pickle5_buffers(data)
|
|
if len(buffers) > 0:
|
|
obj = pickle.loads(in_band, buffers=buffers)
|
|
else:
|
|
obj = pickle.loads(in_band)
|
|
# cloudpickle does not provide error types
|
|
except pickle.pickle.PicklingError:
|
|
raise DeserializationError()
|
|
return obj
|
|
|
|
def _deserialize_msgpack_data(self, data, metadata):
|
|
msgpack_data, pickle5_data = split_buffer(data)
|
|
|
|
if metadata == ray_constants.OBJECT_METADATA_TYPE_PYTHON:
|
|
python_objects = self._deserialize_pickle5_data(pickle5_data)
|
|
else:
|
|
python_objects = []
|
|
|
|
try:
|
|
|
|
def _python_deserializer(index):
|
|
return python_objects[index]
|
|
|
|
obj = MessagePackSerializer.loads(msgpack_data,
|
|
_python_deserializer)
|
|
except Exception:
|
|
raise DeserializationError()
|
|
return obj
|
|
|
|
def _deserialize_object(self, data, metadata, object_ref):
|
|
if metadata:
|
|
if metadata in [
|
|
ray_constants.OBJECT_METADATA_TYPE_CROSS_LANGUAGE,
|
|
ray_constants.OBJECT_METADATA_TYPE_PYTHON
|
|
]:
|
|
return self._deserialize_msgpack_data(data, metadata)
|
|
# Check if the object should be returned as raw bytes.
|
|
if metadata == ray_constants.OBJECT_METADATA_TYPE_RAW:
|
|
if data is None:
|
|
return b""
|
|
return data.to_pybytes()
|
|
elif metadata == ray_constants.OBJECT_METADATA_TYPE_ACTOR_HANDLE:
|
|
obj = self._deserialize_msgpack_data(data, metadata)
|
|
return actor_handle_deserializer(obj)
|
|
# Otherwise, return an exception object based on
|
|
# the error type.
|
|
try:
|
|
error_type = int(metadata)
|
|
except Exception:
|
|
raise Exception(f"Can't deserialize object: {object_ref}, "
|
|
f"metadata: {metadata}")
|
|
|
|
# RayTaskError is serialized with pickle5 in the data field.
|
|
# TODO (kfstorm): exception serialization should be language
|
|
# independent.
|
|
if error_type == ErrorType.Value("TASK_EXECUTION_EXCEPTION"):
|
|
obj = self._deserialize_msgpack_data(data, metadata)
|
|
return RayError.from_bytes(obj)
|
|
elif error_type == ErrorType.Value("WORKER_DIED"):
|
|
return WorkerCrashedError()
|
|
elif error_type == ErrorType.Value("ACTOR_DIED"):
|
|
return RayActorError()
|
|
elif error_type == ErrorType.Value("TASK_CANCELLED"):
|
|
return TaskCancelledError()
|
|
elif error_type == ErrorType.Value("OBJECT_UNRECONSTRUCTABLE"):
|
|
return ObjectLostError(ray.ObjectRef(object_ref.binary()))
|
|
else:
|
|
assert error_type != ErrorType.Value("OBJECT_IN_PLASMA"), \
|
|
"Tried to get object that has been promoted to plasma."
|
|
assert False, "Unrecognized error type " + str(error_type)
|
|
elif data:
|
|
raise ValueError("non-null object should always have metadata")
|
|
else:
|
|
# Object isn't available in plasma. This should never be returned
|
|
# to the user. We should only reach this line if this object was
|
|
# deserialized as part of a list, and another object in the list
|
|
# throws an exception.
|
|
return PlasmaObjectNotAvailable
|
|
|
|
def deserialize_objects(self,
|
|
data_metadata_pairs,
|
|
object_refs,
|
|
error_timeout=10):
|
|
assert len(data_metadata_pairs) == len(object_refs)
|
|
|
|
start_time = time.time()
|
|
results = []
|
|
warning_sent = False
|
|
i = 0
|
|
while i < len(object_refs):
|
|
object_ref = object_refs[i]
|
|
data, metadata = data_metadata_pairs[i]
|
|
assert self.get_outer_object_ref() is None
|
|
self.set_outer_object_ref(object_ref)
|
|
try:
|
|
results.append(
|
|
self._deserialize_object(data, metadata, object_ref))
|
|
i += 1
|
|
except DeserializationError:
|
|
# Wait a little bit for the import thread to import the class.
|
|
# If we currently have the worker lock, we need to release it
|
|
# so that the import thread can acquire it.
|
|
time.sleep(0.01)
|
|
|
|
if time.time() - start_time > error_timeout:
|
|
warning_message = ("This worker or driver is waiting to "
|
|
"receive a class definition so that it "
|
|
"can deserialize an object from the "
|
|
"object store. This may be fine, or it "
|
|
"may be a bug.")
|
|
if not warning_sent:
|
|
ray.utils.push_error_to_driver(
|
|
self,
|
|
ray_constants.WAIT_FOR_CLASS_PUSH_ERROR,
|
|
warning_message,
|
|
job_id=self.worker.current_job_id)
|
|
warning_sent = True
|
|
finally:
|
|
# Must clear ObjectRef to not hold a reference.
|
|
self.set_outer_object_ref(None)
|
|
|
|
return results
|
|
|
|
def _serialize_to_pickle5(self, metadata, value):
|
|
writer = Pickle5Writer()
|
|
# TODO(swang): Check that contained_object_refs is empty.
|
|
try:
|
|
self.set_in_band_serialization()
|
|
inband = pickle.dumps(
|
|
value, protocol=5, buffer_callback=writer.buffer_callback)
|
|
except Exception as e:
|
|
self.get_and_clear_contained_object_refs()
|
|
raise e
|
|
finally:
|
|
self.set_out_of_band_serialization()
|
|
|
|
return Pickle5SerializedObject(
|
|
metadata, inband, writer,
|
|
self.get_and_clear_contained_object_refs())
|
|
|
|
def _serialize_to_msgpack(self, value):
|
|
# Only RayTaskError is possible to be serialized here. We don't
|
|
# need to deal with other exception types here.
|
|
contained_object_refs = []
|
|
|
|
if isinstance(value, RayTaskError):
|
|
metadata = str(
|
|
ErrorType.Value("TASK_EXECUTION_EXCEPTION")).encode("ascii")
|
|
value = value.to_bytes()
|
|
elif isinstance(value, ray.actor.ActorHandle):
|
|
# TODO(fyresone): ActorHandle should be serialized via the
|
|
# custom type feature of cross-language.
|
|
serialized, actor_handle_id = value._serialization_helper()
|
|
contained_object_refs.append(actor_handle_id)
|
|
# Update ref counting for the actor handle
|
|
metadata = ray_constants.OBJECT_METADATA_TYPE_ACTOR_HANDLE
|
|
value = serialized
|
|
else:
|
|
metadata = ray_constants.OBJECT_METADATA_TYPE_CROSS_LANGUAGE
|
|
|
|
python_objects = []
|
|
|
|
def _python_serializer(o):
|
|
index = len(python_objects)
|
|
python_objects.append(o)
|
|
return index
|
|
|
|
msgpack_data = MessagePackSerializer.dumps(value, _python_serializer)
|
|
|
|
if python_objects:
|
|
metadata = ray_constants.OBJECT_METADATA_TYPE_PYTHON
|
|
pickle5_serialized_object = \
|
|
self._serialize_to_pickle5(metadata, python_objects)
|
|
else:
|
|
pickle5_serialized_object = None
|
|
|
|
return MessagePackSerializedObject(metadata, msgpack_data,
|
|
contained_object_refs,
|
|
pickle5_serialized_object)
|
|
|
|
def serialize(self, value):
|
|
"""Serialize an object.
|
|
|
|
Args:
|
|
value: The value to serialize.
|
|
"""
|
|
if isinstance(value, bytes):
|
|
# If the object is a byte array, skip serializing it and
|
|
# use a special metadata to indicate it's raw binary. So
|
|
# that this object can also be read by Java.
|
|
return RawSerializedObject(value)
|
|
else:
|
|
return self._serialize_to_msgpack(value)
|
|
|
|
def register_custom_serializer(self,
|
|
cls,
|
|
serializer,
|
|
deserializer,
|
|
local=False,
|
|
job_id=None,
|
|
class_id=None):
|
|
"""Enable serialization and deserialization for a particular class.
|
|
|
|
This method runs the register_class function defined below on
|
|
every worker, which will enable ray to properly serialize and
|
|
deserialize objects of this class.
|
|
|
|
Args:
|
|
cls (type): The class that ray should use this custom serializer
|
|
for.
|
|
serializer: The custom serializer to use.
|
|
deserializer: The custom deserializer to use.
|
|
local: True if the serializers should only be registered on the
|
|
current worker. This should usually be False.
|
|
job_id: ID of the job that we want to register the class for.
|
|
class_id (str): Unique ID of the class. Autogenerated if None.
|
|
|
|
Raises:
|
|
RayNotDictionarySerializable: Raised if use_dict is true and cls
|
|
cannot be efficiently serialized by Ray.
|
|
ValueError: Raised if ray could not autogenerate a class_id.
|
|
"""
|
|
assert serializer is not None and deserializer is not None, (
|
|
"Must provide serializer and deserializer.")
|
|
|
|
if class_id is None:
|
|
if not local:
|
|
# In this case, the class ID will be used to deduplicate the
|
|
# class across workers. Note that cloudpickle unfortunately
|
|
# does not produce deterministic strings, so these IDs could
|
|
# be different on different workers. We could use something
|
|
# weaker like cls.__name__, however that would run the risk
|
|
# of having collisions.
|
|
# TODO(rkn): We should improve this.
|
|
try:
|
|
# Attempt to produce a class ID that will be the same on
|
|
# each worker. However, determinism is not guaranteed,
|
|
# and the result may be different on different workers.
|
|
class_id = _try_to_compute_deterministic_class_id(cls)
|
|
except Exception:
|
|
raise ValueError(
|
|
"Failed to use pickle in generating a unique id"
|
|
f"for '{cls}'. Provide a unique class_id.")
|
|
else:
|
|
# In this case, the class ID only needs to be meaningful on
|
|
# this worker and not across workers.
|
|
class_id = _random_string()
|
|
|
|
# Make sure class_id is a string.
|
|
class_id = ray.utils.binary_to_hex(class_id)
|
|
|
|
if job_id is None:
|
|
job_id = self.worker.current_job_id
|
|
assert isinstance(job_id, JobID)
|
|
|
|
def register_class_for_serialization(worker_info):
|
|
context = worker_info["worker"].get_serialization_context(job_id)
|
|
context._register_cloudpickle_serializer(cls, serializer,
|
|
deserializer)
|
|
|
|
if not local:
|
|
self.worker.run_function_on_all_workers(
|
|
register_class_for_serialization)
|
|
else:
|
|
# Since we are pickling objects of this class, we don't actually
|
|
# need to ship the class definition.
|
|
register_class_for_serialization({"worker": self.worker})
|