Cross language serialization for primitive types (#7711)

* Cross language serialization for Java and Python

* Use strict types when Python serializing

* Handle recursive objects in Python; Pin msgpack >= 0.6.0, < 1.0.0

* Disable gc for optimizing msgpack loads

* Fix merge bug

* Java call Python use returnType; Fix ClassLoaderTest

* Fix RayMethodsTest

* Fix checkstyle

* Fix lint

* prepare_args raises exception if try to transfer a non-deserializable object to another language

* Fix CrossLanguageInvocationTest.java, Python msgpack treat float as double

* Minor fixes

* Fix compile error on linux

* Fix lint in java/BUILD.bazel

* Fix test_failure

* Fix lint

* Class<?> to Class<T>; Refine metadata bytes.

* Rename FST to Fst; sort java dependencies

* Change Class<?>[] to Optional<Class<?>>; sort requirements in setup.py

* Improve CrossLanguageInvocationTest

* Refactor MessagePackSerializer.java

* Refactor MessagePackSerializer.java; Refine CrossLanguageInvocationTest.java

* Remove unnecessary dependencies for Java; Add getReturnType() for RayFunction in Java

* Fix bug

* Remove custom cross language type support

* Replace Serializer.Meta with MutableBoolean

* Remove @SuppressWarnings support from checkstyle.xml; Add null test in CrossLanguageInvocationTest.java

* Refine MessagePackSerializer.pack

* Ray.get support RayObject as input

* Improve comments and error info

* Remove classLoader argument from serializer

* Separate msgpack from pickle5 in Python

* Pair<byte[], MutableBoolean> to Pair<byte[], Boolean>

* Remove public static <T> T get(RayObject<T> object), use RayObject.get() instead

* Refine test

* small fixes

Co-authored-by: 刘宝 <po.lb@antfin.com>
Co-authored-by: Hao Chen <chenh1024@gmail.com>
This commit is contained in:
fyrestone
2020-04-08 21:10:57 +08:00
committed by GitHub
parent e8c19aba41
commit fc6259a656
42 changed files with 1057 additions and 313 deletions
+1 -1
View File
@@ -44,7 +44,7 @@ cdef class Buffer:
def __getbuffer__(self, Py_buffer* buffer, int flags):
buffer.readonly = 0
buffer.buf = <char *>self.buffer.get().Data()
buffer.format = 'b'
buffer.format = 'B'
buffer.internal = NULL
buffer.itemsize = 1
buffer.len = self.size
+222 -18
View File
@@ -1,5 +1,9 @@
from libc.string cimport memcpy
from libc.stdint cimport uintptr_t, uint64_t, INT32_MAX
from libcpp cimport nullptr
import cython
DEF MEMCOPY_THREADS = 6
# This is the default alignment value for len(buffer) < 2048.
DEF kMinorBufferAlign = 8
@@ -9,6 +13,8 @@ DEF kMajorBufferAlign = 64
DEF kMajorBufferSize = 2048
DEF kMemcopyDefaultBlocksize = 64
DEF kMemcopyDefaultThreshold = 1024 * 1024
DEF kLanguageSpecificTypeExtensionId = 101
DEF kMessagePackOffset = 9
cdef extern from "ray/util/memory.h" namespace "ray" nogil:
void parallel_memcopy(uint8_t* dst, const uint8_t* src, int64_t nbytes,
@@ -82,7 +88,7 @@ cdef class SubBuffer:
void *internal
object buffer
def __cinit__(self, Buffer buffer):
def __cinit__(self, object buffer):
# Increase ref count.
self.buffer = buffer
self.suboffsets = NULL
@@ -142,15 +148,68 @@ cdef class SubBuffer:
return self.size
# See 'serialization.proto' for the memory layout in the Plasma buffer.
def unpack_pickle5_buffers(Buffer buf):
cdef class MessagePackSerializer(object):
@staticmethod
def dumps(o, python_serializer=None):
def _default(obj):
if python_serializer is not None:
return msgpack.ExtType(kLanguageSpecificTypeExtensionId,
msgpack.dumps(python_serializer(obj)))
return obj
try:
# If we let strict_types is False, then whether list or tuple will
# be packed to a message pack array. So, they can't be
# distinguished when unpacking.
return msgpack.dumps(o, default=_default,
use_bin_type=True, strict_types=True)
except ValueError as ex:
# msgpack can't handle recursive objects, so we serialize them by
# python serializer, e.g. pickle.
return msgpack.dumps(_default(o), default=_default,
use_bin_type=True, strict_types=True)
@classmethod
def loads(cls, s, python_deserializer=None):
def _ext_hook(code, data):
if code == kLanguageSpecificTypeExtensionId:
if python_deserializer is not None:
return python_deserializer(msgpack.loads(data))
raise Exception('Unrecognized ext type id: {}'.format(code))
try:
gc.disable() # Performance optimization for msgpack.
return msgpack.loads(s, ext_hook=_ext_hook, raw=False)
finally:
gc.enable()
@cython.boundscheck(False)
@cython.wraparound(False)
def split_buffer(Buffer buf):
cdef:
shared_ptr[CBuffer] _buffer = buf.buffer
const uint8_t *data = buf.buffer.get().Data()
size_t size = _buffer.get().Size()
size_t size = buf.buffer.get().Size()
uint8_t[:] bufferview = buf
int64_t msgpack_bytes_length
assert kMessagePackOffset <= size
header_unpacker = msgpack.Unpacker()
header_unpacker.feed(bufferview[:kMessagePackOffset])
msgpack_bytes_length = header_unpacker.unpack()
assert kMessagePackOffset + msgpack_bytes_length <= <int64_t>size
return (bufferview[kMessagePackOffset:
kMessagePackOffset + msgpack_bytes_length],
bufferview[kMessagePackOffset + msgpack_bytes_length:])
# See 'serialization.proto' for the memory layout in the Plasma buffer.
@cython.boundscheck(False)
@cython.wraparound(False)
def unpack_pickle5_buffers(uint8_t[:] bufferview):
cdef:
const uint8_t *data = &bufferview[0]
size_t size = len(bufferview)
CPythonObject python_object
CPythonBuffer *buffer_meta
c_string inband_data
int64_t protobuf_offset
int64_t protobuf_size
int32_t i
@@ -167,14 +226,16 @@ def unpack_pickle5_buffers(Buffer buf):
if not python_object.ParseFromArray(
data + protobuf_offset, <int32_t>protobuf_size):
raise ValueError("Protobuf object is corrupted.")
inband_data.append(<char*>(data + python_object.inband_data_offset()),
<size_t>python_object.inband_data_size())
inband_data_offset = python_object.inband_data_offset()
inband_data = bufferview[
inband_data_offset:
inband_data_offset + python_object.inband_data_size()]
buffers_segment = data + python_object.raw_buffers_offset()
pickled_buffers = []
# Now read buffer meta
for i in range(python_object.buffer_size()):
buffer_meta = <CPythonBuffer *>&python_object.buffer(i)
buffer = SubBuffer(buf)
buffer = SubBuffer(bufferview)
buffer.buf = <void*>(buffers_segment + buffer_meta.address())
buffer.len = buffer_meta.length()
buffer.itemsize = buffer_meta.itemsize()
@@ -207,6 +268,11 @@ cdef class Pickle5Writer:
self._curr_buffer_addr = 0
self._total_bytes = -1
def __dealloc__(self):
# We must release the buffer, or we could experience memory leaks.
for i in range(self.buffers.size()):
cpython.PyBuffer_Release(&self.buffers[i])
def buffer_callback(self, pickle_buffer):
cdef:
Py_buffer view
@@ -240,14 +306,14 @@ cdef class Pickle5Writer:
self._curr_buffer_addr += view.len
self.buffers.push_back(view)
def get_total_bytes(self, const c_string &inband):
def get_total_bytes(self, const uint8_t[:] inband):
cdef:
size_t protobuf_bytes = 0
uint64_t inband_data_offset = sizeof(int64_t) * 2
uint64_t raw_buffers_offset = padded_length_u64(
inband_data_offset + inband.length(), kMajorBufferAlign)
inband_data_offset + len(inband), kMajorBufferAlign)
self.python_object.set_inband_data_offset(inband_data_offset)
self.python_object.set_inband_data_size(inband.length())
self.python_object.set_inband_data_size(len(inband))
self.python_object.set_raw_buffers_offset(raw_buffers_offset)
self.python_object.set_raw_buffers_size(self._curr_buffer_addr)
# Since calculating the output size is expensive, we will
@@ -265,9 +331,11 @@ cdef class Pickle5Writer:
self._total_bytes = self._protobuf_offset + protobuf_bytes
return self._total_bytes
cdef void write_to(self, const c_string &inband, shared_ptr[CBuffer] data,
int memcopy_threads):
cdef uint8_t *ptr = data.get().Data()
@cython.boundscheck(False)
@cython.wraparound(False)
cdef void write_to(self, const uint8_t[:] inband, uint8_t[:] data,
int memcopy_threads) nogil:
cdef uint8_t *ptr = &data[0]
cdef int32_t protobuf_size
cdef uint64_t buffer_addr
cdef uint64_t buffer_len
@@ -284,7 +352,7 @@ cdef class Pickle5Writer:
ptr + self._protobuf_offset)
# Write inband data.
memcpy(ptr + self.python_object.inband_data_offset(),
inband.data(), inband.length())
&inband[0], len(inband))
# Write buffer data.
ptr += self.python_object.raw_buffers_offset()
for i in range(self.python_object.buffer_size()):
@@ -298,5 +366,141 @@ cdef class Pickle5Writer:
kMemcopyDefaultBlocksize, memcopy_threads)
else:
memcpy(ptr + buffer_addr, self.buffers[i].buf, buffer_len)
# We must release the buffer, or we could experience memory leaks.
cpython.PyBuffer_Release(&self.buffers[i])
cdef class SerializedObject(object):
cdef:
object _metadata
object _contained_object_ids
def __init__(self, metadata, contained_object_ids=None):
self._metadata = metadata
self._contained_object_ids = contained_object_ids or []
@property
def total_bytes(self):
raise NotImplementedError("{}.total_bytes not implemented.".format(
type(self).__name__))
@property
def metadata(self):
return self._metadata
@property
def contained_object_ids(self):
return self._contained_object_ids
@cython.boundscheck(False)
@cython.wraparound(False)
cdef void write_to(self, uint8_t[:] buffer) nogil:
raise NotImplementedError("{}.write_to not implemented.".format(
type(self).__name__))
cdef class Pickle5SerializedObject(SerializedObject):
cdef:
const uint8_t[:] inband
Pickle5Writer writer
object _total_bytes
def __init__(self, metadata, inband, Pickle5Writer writer,
contained_object_ids):
super(Pickle5SerializedObject, self).__init__(metadata,
contained_object_ids)
self.inband = inband
self.writer = writer
# cached total bytes
self._total_bytes = None
@property
def total_bytes(self):
if self._total_bytes is None:
self._total_bytes = self.writer.get_total_bytes(self.inband)
return self._total_bytes
@cython.boundscheck(False)
@cython.wraparound(False)
cdef void write_to(self, uint8_t[:] buffer) nogil:
self.writer.write_to(self.inband, buffer, MEMCOPY_THREADS)
cdef class MessagePackSerializedObject(SerializedObject):
cdef:
SerializedObject nest_serialized_object
object msgpack_header
object msgpack_data
int64_t _msgpack_header_bytes
int64_t _msgpack_data_bytes
int64_t _total_bytes
const uint8_t *msgpack_header_ptr
const uint8_t *msgpack_data_ptr
def __init__(self, metadata, msgpack_data,
SerializedObject nest_serialized_object=None):
if nest_serialized_object:
contained_object_ids = nest_serialized_object.contained_object_ids
total_bytes = nest_serialized_object.total_bytes
else:
contained_object_ids = []
total_bytes = 0
super(MessagePackSerializedObject, self).__init__(metadata,
contained_object_ids)
self.nest_serialized_object = nest_serialized_object
self.msgpack_header = msgpack_header = msgpack.dumps(len(msgpack_data))
self.msgpack_data = msgpack_data
self._msgpack_header_bytes = len(msgpack_header)
self._msgpack_data_bytes = len(msgpack_data)
self._total_bytes = (kMessagePackOffset +
self._msgpack_data_bytes +
total_bytes)
self.msgpack_header_ptr = <const uint8_t*>msgpack_header
self.msgpack_data_ptr = <const uint8_t*>msgpack_data
assert self._msgpack_header_bytes <= kMessagePackOffset
@property
def total_bytes(self):
return self._total_bytes
@cython.boundscheck(False)
@cython.wraparound(False)
cdef void write_to(self, uint8_t[:] buffer) nogil:
cdef uint8_t *ptr = &buffer[0]
# Write msgpack data first.
memcpy(ptr, self.msgpack_header_ptr, self._msgpack_header_bytes)
memcpy(ptr + kMessagePackOffset,
self.msgpack_data_ptr, self._msgpack_data_bytes)
if self.nest_serialized_object is not None:
self.nest_serialized_object.write_to(
buffer[kMessagePackOffset + self._msgpack_data_bytes:])
cdef class RawSerializedObject(SerializedObject):
cdef:
object value
const uint8_t *value_ptr
int64_t _total_bytes
def __init__(self, value):
super(RawSerializedObject,
self).__init__(ray_constants.OBJECT_METADATA_TYPE_RAW)
self.value = value
self.value_ptr = <const uint8_t*> value
self._total_bytes = len(value)
@property
def total_bytes(self):
return self._total_bytes
@cython.boundscheck(False)
@cython.wraparound(False)
cdef void write_to(self, uint8_t[:] buffer) nogil:
if (MEMCOPY_THREADS > 1 and
self._total_bytes > kMemcopyDefaultThreshold):
parallel_memcopy(&buffer[0],
self.value_ptr,
self._total_bytes, kMemcopyDefaultBlocksize,
MEMCOPY_THREADS)
else:
memcpy(&buffer[0], self.value_ptr, self._total_bytes)