Cross language serialization for primitive types (#7711)

* Cross language serialization for Java and Python

* Use strict types when Python serializing

* Handle recursive objects in Python; Pin msgpack >= 0.6.0, < 1.0.0

* Disable gc for optimizing msgpack loads

* Fix merge bug

* Java call Python use returnType; Fix ClassLoaderTest

* Fix RayMethodsTest

* Fix checkstyle

* Fix lint

* prepare_args raises exception if try to transfer a non-deserializable object to another language

* Fix CrossLanguageInvocationTest.java, Python msgpack treat float as double

* Minor fixes

* Fix compile error on linux

* Fix lint in java/BUILD.bazel

* Fix test_failure

* Fix lint

* Class<?> to Class<T>; Refine metadata bytes.

* Rename FST to Fst; sort java dependencies

* Change Class<?>[] to Optional<Class<?>>; sort requirements in setup.py

* Improve CrossLanguageInvocationTest

* Refactor MessagePackSerializer.java

* Refactor MessagePackSerializer.java; Refine CrossLanguageInvocationTest.java

* Remove unnecessary dependencies for Java; Add getReturnType() for RayFunction in Java

* Fix bug

* Remove custom cross language type support

* Replace Serializer.Meta with MutableBoolean

* Remove @SuppressWarnings support from checkstyle.xml; Add null test in CrossLanguageInvocationTest.java

* Refine MessagePackSerializer.pack

* Ray.get support RayObject as input

* Improve comments and error info

* Remove classLoader argument from serializer

* Separate msgpack from pickle5 in Python

* Pair<byte[], MutableBoolean> to Pair<byte[], Boolean>

* Remove public static <T> T get(RayObject<T> object), use RayObject.get() instead

* Refine test

* small fixes

Co-authored-by: 刘宝 <po.lb@antfin.com>
Co-authored-by: Hao Chen <chenh1024@gmail.com>
This commit is contained in:
fyrestone
2020-04-08 21:10:57 +08:00
committed by GitHub
parent e8c19aba41
commit fc6259a656
42 changed files with 1057 additions and 313 deletions
+81 -67
View File
@@ -15,7 +15,15 @@ from ray.exceptions import (
RayWorkerError,
UnreconstructableError,
)
from ray._raylet import Pickle5Writer, unpack_pickle5_buffers
from ray._raylet import (
split_buffer,
unpack_pickle5_buffers,
Pickle5Writer,
Pickle5SerializedObject,
MessagePackSerializer,
MessagePackSerializedObject,
RawSerializedObject,
)
logger = logging.getLogger(__name__)
@@ -34,51 +42,6 @@ class DeserializationError(Exception):
pass
class SerializedObject:
def __init__(self, metadata, contained_object_ids=None):
self._metadata = metadata
self._contained_object_ids = contained_object_ids or []
@property
def total_bytes(self):
raise NotImplementedError
@property
def metadata(self):
return self._metadata
@property
def contained_object_ids(self):
return self._contained_object_ids
class Pickle5SerializedObject(SerializedObject):
def __init__(self, metadata, inband, writer, contained_object_ids):
super(Pickle5SerializedObject, self).__init__(metadata,
contained_object_ids)
self.inband = inband
self.writer = writer
# cached total bytes
self._total_bytes = None
@property
def total_bytes(self):
if self._total_bytes is None:
self._total_bytes = self.writer.get_total_bytes(self.inband)
return self._total_bytes
class RawSerializedObject(SerializedObject):
def __init__(self, value):
super(RawSerializedObject,
self).__init__(ray_constants.RAW_BUFFER_METADATA)
self.value = value
@property
def total_bytes(self):
return len(self.value)
def _try_to_compute_deterministic_class_id(cls, depth=5):
"""Attempt to produce a deterministic class ID for a given class.
@@ -265,23 +228,51 @@ class SerializationContext:
raise DeserializationError()
return obj
def _deserialize_msgpack_data(self, data, metadata):
msgpack_data, pickle5_data = split_buffer(data)
if metadata == ray_constants.OBJECT_METADATA_TYPE_CROSS_LANGUAGE:
python_objects = []
else:
python_objects = self._deserialize_pickle5_data(pickle5_data)
try:
def _python_deserializer(index):
return python_objects[index]
obj = MessagePackSerializer.loads(msgpack_data,
_python_deserializer)
except Exception:
raise DeserializationError()
return obj
def _deserialize_object(self, data, metadata, object_id):
if metadata:
if metadata == ray_constants.PICKLE5_BUFFER_METADATA:
return self._deserialize_pickle5_data(data)
if metadata in [
ray_constants.OBJECT_METADATA_TYPE_CROSS_LANGUAGE,
ray_constants.OBJECT_METADATA_TYPE_PYTHON
]:
return self._deserialize_msgpack_data(data, metadata)
# Check if the object should be returned as raw bytes.
if metadata == ray_constants.RAW_BUFFER_METADATA:
if metadata == ray_constants.OBJECT_METADATA_TYPE_RAW:
if data is None:
return b""
return data.to_pybytes()
# Otherwise, return an exception object based on
# the error type.
error_type = int(metadata)
try:
error_type = int(metadata)
except Exception:
raise Exception(
"Can't deserialize object: {}, metadata: {}".format(
object_id, metadata))
# RayTaskError is serialized with pickle5 in the data field.
# TODO (kfstorm): exception serialization should be language
# independent.
if error_type == ErrorType.Value("TASK_EXECUTION_EXCEPTION"):
obj = self._deserialize_pickle5_data(data)
obj = self._deserialize_msgpack_data(data, metadata)
assert isinstance(obj, RayTaskError)
return obj
elif error_type == ErrorType.Value("WORKER_DIED"):
@@ -347,6 +338,43 @@ class SerializationContext:
return results
def _serialize_to_pickle5(self, metadata, value):
writer = Pickle5Writer()
# TODO(swang): Check that contained_object_ids is empty.
try:
self.set_in_band_serialization()
inband = pickle.dumps(
value, protocol=5, buffer_callback=writer.buffer_callback)
except Exception as e:
self.get_and_clear_contained_object_ids()
raise e
finally:
self.set_out_of_band_serialization()
return Pickle5SerializedObject(
metadata, inband, writer,
self.get_and_clear_contained_object_ids())
def _serialize_to_msgpack(self, metadata, value):
python_objects = []
def _python_serializer(o):
index = len(python_objects)
python_objects.append(o)
return index
msgpack_data = MessagePackSerializer.dumps(value, _python_serializer)
if python_objects:
pickle5_serialized_object = \
self._serialize_to_pickle5(metadata, python_objects)
else:
metadata = ray_constants.OBJECT_METADATA_TYPE_CROSS_LANGUAGE
pickle5_serialized_object = None
return MessagePackSerializedObject(metadata, msgpack_data,
pickle5_serialized_object)
def serialize(self, value):
"""Serialize an object.
@@ -365,23 +393,9 @@ class SerializationContext:
metadata = str(ErrorType.Value(
"TASK_EXECUTION_EXCEPTION")).encode("ascii")
else:
metadata = ray_constants.PICKLE5_BUFFER_METADATA
metadata = ray_constants.OBJECT_METADATA_TYPE_PYTHON
writer = Pickle5Writer()
# TODO(swang): Check that contained_object_ids is empty.
try:
self.set_in_band_serialization()
inband = pickle.dumps(
value, protocol=5, buffer_callback=writer.buffer_callback)
except Exception as e:
self.get_and_clear_contained_object_ids()
raise e
finally:
self.set_out_of_band_serialization()
return Pickle5SerializedObject(
metadata, inband, writer,
self.get_and_clear_contained_object_ids())
return self._serialize_to_msgpack(metadata, value)
def register_custom_serializer(self,
cls,