mirror of
https://github.com/wassname/ray.git
synced 2026-06-29 03:21:06 +08:00
Cross language serialization for primitive types (#7711)
* Cross language serialization for Java and Python * Use strict types when Python serializing * Handle recursive objects in Python; Pin msgpack >= 0.6.0, < 1.0.0 * Disable gc for optimizing msgpack loads * Fix merge bug * Java call Python use returnType; Fix ClassLoaderTest * Fix RayMethodsTest * Fix checkstyle * Fix lint * prepare_args raises exception if try to transfer a non-deserializable object to another language * Fix CrossLanguageInvocationTest.java, Python msgpack treat float as double * Minor fixes * Fix compile error on linux * Fix lint in java/BUILD.bazel * Fix test_failure * Fix lint * Class<?> to Class<T>; Refine metadata bytes. * Rename FST to Fst; sort java dependencies * Change Class<?>[] to Optional<Class<?>>; sort requirements in setup.py * Improve CrossLanguageInvocationTest * Refactor MessagePackSerializer.java * Refactor MessagePackSerializer.java; Refine CrossLanguageInvocationTest.java * Remove unnecessary dependencies for Java; Add getReturnType() for RayFunction in Java * Fix bug * Remove custom cross language type support * Replace Serializer.Meta with MutableBoolean * Remove @SuppressWarnings support from checkstyle.xml; Add null test in CrossLanguageInvocationTest.java * Refine MessagePackSerializer.pack * Ray.get support RayObject as input * Improve comments and error info * Remove classLoader argument from serializer * Separate msgpack from pickle5 in Python * Pair<byte[], MutableBoolean> to Pair<byte[], Boolean> * Remove public static <T> T get(RayObject<T> object), use RayObject.get() instead * Refine test * small fixes Co-authored-by: 刘宝 <po.lb@antfin.com> Co-authored-by: Hao Chen <chenh1024@gmail.com>
This commit is contained in:
+81
-67
@@ -15,7 +15,15 @@ from ray.exceptions import (
|
||||
RayWorkerError,
|
||||
UnreconstructableError,
|
||||
)
|
||||
from ray._raylet import Pickle5Writer, unpack_pickle5_buffers
|
||||
from ray._raylet import (
|
||||
split_buffer,
|
||||
unpack_pickle5_buffers,
|
||||
Pickle5Writer,
|
||||
Pickle5SerializedObject,
|
||||
MessagePackSerializer,
|
||||
MessagePackSerializedObject,
|
||||
RawSerializedObject,
|
||||
)
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
@@ -34,51 +42,6 @@ class DeserializationError(Exception):
|
||||
pass
|
||||
|
||||
|
||||
class SerializedObject:
|
||||
def __init__(self, metadata, contained_object_ids=None):
|
||||
self._metadata = metadata
|
||||
self._contained_object_ids = contained_object_ids or []
|
||||
|
||||
@property
|
||||
def total_bytes(self):
|
||||
raise NotImplementedError
|
||||
|
||||
@property
|
||||
def metadata(self):
|
||||
return self._metadata
|
||||
|
||||
@property
|
||||
def contained_object_ids(self):
|
||||
return self._contained_object_ids
|
||||
|
||||
|
||||
class Pickle5SerializedObject(SerializedObject):
|
||||
def __init__(self, metadata, inband, writer, contained_object_ids):
|
||||
super(Pickle5SerializedObject, self).__init__(metadata,
|
||||
contained_object_ids)
|
||||
self.inband = inband
|
||||
self.writer = writer
|
||||
# cached total bytes
|
||||
self._total_bytes = None
|
||||
|
||||
@property
|
||||
def total_bytes(self):
|
||||
if self._total_bytes is None:
|
||||
self._total_bytes = self.writer.get_total_bytes(self.inband)
|
||||
return self._total_bytes
|
||||
|
||||
|
||||
class RawSerializedObject(SerializedObject):
|
||||
def __init__(self, value):
|
||||
super(RawSerializedObject,
|
||||
self).__init__(ray_constants.RAW_BUFFER_METADATA)
|
||||
self.value = value
|
||||
|
||||
@property
|
||||
def total_bytes(self):
|
||||
return len(self.value)
|
||||
|
||||
|
||||
def _try_to_compute_deterministic_class_id(cls, depth=5):
|
||||
"""Attempt to produce a deterministic class ID for a given class.
|
||||
|
||||
@@ -265,23 +228,51 @@ class SerializationContext:
|
||||
raise DeserializationError()
|
||||
return obj
|
||||
|
||||
def _deserialize_msgpack_data(self, data, metadata):
|
||||
msgpack_data, pickle5_data = split_buffer(data)
|
||||
|
||||
if metadata == ray_constants.OBJECT_METADATA_TYPE_CROSS_LANGUAGE:
|
||||
python_objects = []
|
||||
else:
|
||||
python_objects = self._deserialize_pickle5_data(pickle5_data)
|
||||
|
||||
try:
|
||||
|
||||
def _python_deserializer(index):
|
||||
return python_objects[index]
|
||||
|
||||
obj = MessagePackSerializer.loads(msgpack_data,
|
||||
_python_deserializer)
|
||||
except Exception:
|
||||
raise DeserializationError()
|
||||
return obj
|
||||
|
||||
def _deserialize_object(self, data, metadata, object_id):
|
||||
if metadata:
|
||||
if metadata == ray_constants.PICKLE5_BUFFER_METADATA:
|
||||
return self._deserialize_pickle5_data(data)
|
||||
if metadata in [
|
||||
ray_constants.OBJECT_METADATA_TYPE_CROSS_LANGUAGE,
|
||||
ray_constants.OBJECT_METADATA_TYPE_PYTHON
|
||||
]:
|
||||
return self._deserialize_msgpack_data(data, metadata)
|
||||
# Check if the object should be returned as raw bytes.
|
||||
if metadata == ray_constants.RAW_BUFFER_METADATA:
|
||||
if metadata == ray_constants.OBJECT_METADATA_TYPE_RAW:
|
||||
if data is None:
|
||||
return b""
|
||||
return data.to_pybytes()
|
||||
# Otherwise, return an exception object based on
|
||||
# the error type.
|
||||
error_type = int(metadata)
|
||||
try:
|
||||
error_type = int(metadata)
|
||||
except Exception:
|
||||
raise Exception(
|
||||
"Can't deserialize object: {}, metadata: {}".format(
|
||||
object_id, metadata))
|
||||
|
||||
# RayTaskError is serialized with pickle5 in the data field.
|
||||
# TODO (kfstorm): exception serialization should be language
|
||||
# independent.
|
||||
if error_type == ErrorType.Value("TASK_EXECUTION_EXCEPTION"):
|
||||
obj = self._deserialize_pickle5_data(data)
|
||||
obj = self._deserialize_msgpack_data(data, metadata)
|
||||
assert isinstance(obj, RayTaskError)
|
||||
return obj
|
||||
elif error_type == ErrorType.Value("WORKER_DIED"):
|
||||
@@ -347,6 +338,43 @@ class SerializationContext:
|
||||
|
||||
return results
|
||||
|
||||
def _serialize_to_pickle5(self, metadata, value):
|
||||
writer = Pickle5Writer()
|
||||
# TODO(swang): Check that contained_object_ids is empty.
|
||||
try:
|
||||
self.set_in_band_serialization()
|
||||
inband = pickle.dumps(
|
||||
value, protocol=5, buffer_callback=writer.buffer_callback)
|
||||
except Exception as e:
|
||||
self.get_and_clear_contained_object_ids()
|
||||
raise e
|
||||
finally:
|
||||
self.set_out_of_band_serialization()
|
||||
|
||||
return Pickle5SerializedObject(
|
||||
metadata, inband, writer,
|
||||
self.get_and_clear_contained_object_ids())
|
||||
|
||||
def _serialize_to_msgpack(self, metadata, value):
|
||||
python_objects = []
|
||||
|
||||
def _python_serializer(o):
|
||||
index = len(python_objects)
|
||||
python_objects.append(o)
|
||||
return index
|
||||
|
||||
msgpack_data = MessagePackSerializer.dumps(value, _python_serializer)
|
||||
|
||||
if python_objects:
|
||||
pickle5_serialized_object = \
|
||||
self._serialize_to_pickle5(metadata, python_objects)
|
||||
else:
|
||||
metadata = ray_constants.OBJECT_METADATA_TYPE_CROSS_LANGUAGE
|
||||
pickle5_serialized_object = None
|
||||
|
||||
return MessagePackSerializedObject(metadata, msgpack_data,
|
||||
pickle5_serialized_object)
|
||||
|
||||
def serialize(self, value):
|
||||
"""Serialize an object.
|
||||
|
||||
@@ -365,23 +393,9 @@ class SerializationContext:
|
||||
metadata = str(ErrorType.Value(
|
||||
"TASK_EXECUTION_EXCEPTION")).encode("ascii")
|
||||
else:
|
||||
metadata = ray_constants.PICKLE5_BUFFER_METADATA
|
||||
metadata = ray_constants.OBJECT_METADATA_TYPE_PYTHON
|
||||
|
||||
writer = Pickle5Writer()
|
||||
# TODO(swang): Check that contained_object_ids is empty.
|
||||
try:
|
||||
self.set_in_band_serialization()
|
||||
inband = pickle.dumps(
|
||||
value, protocol=5, buffer_callback=writer.buffer_callback)
|
||||
except Exception as e:
|
||||
self.get_and_clear_contained_object_ids()
|
||||
raise e
|
||||
finally:
|
||||
self.set_out_of_band_serialization()
|
||||
|
||||
return Pickle5SerializedObject(
|
||||
metadata, inband, writer,
|
||||
self.get_and_clear_contained_object_ids())
|
||||
return self._serialize_to_msgpack(metadata, value)
|
||||
|
||||
def register_custom_serializer(self,
|
||||
cls,
|
||||
|
||||
Reference in New Issue
Block a user