mirror of
https://github.com/wassname/ray.git
synced 2026-06-28 00:29:38 +08:00
Cross language serialization for primitive types (#7711)
* Cross language serialization for Java and Python * Use strict types when Python serializing * Handle recursive objects in Python; Pin msgpack >= 0.6.0, < 1.0.0 * Disable gc for optimizing msgpack loads * Fix merge bug * Java call Python use returnType; Fix ClassLoaderTest * Fix RayMethodsTest * Fix checkstyle * Fix lint * prepare_args raises exception if try to transfer a non-deserializable object to another language * Fix CrossLanguageInvocationTest.java, Python msgpack treat float as double * Minor fixes * Fix compile error on linux * Fix lint in java/BUILD.bazel * Fix test_failure * Fix lint * Class<?> to Class<T>; Refine metadata bytes. * Rename FST to Fst; sort java dependencies * Change Class<?>[] to Optional<Class<?>>; sort requirements in setup.py * Improve CrossLanguageInvocationTest * Refactor MessagePackSerializer.java * Refactor MessagePackSerializer.java; Refine CrossLanguageInvocationTest.java * Remove unnecessary dependencies for Java; Add getReturnType() for RayFunction in Java * Fix bug * Remove custom cross language type support * Replace Serializer.Meta with MutableBoolean * Remove @SuppressWarnings support from checkstyle.xml; Add null test in CrossLanguageInvocationTest.java * Refine MessagePackSerializer.pack * Ray.get support RayObject as input * Improve comments and error info * Remove classLoader argument from serializer * Separate msgpack from pickle5 in Python * Pair<byte[], MutableBoolean> to Pair<byte[], Boolean> * Remove public static <T> T get(RayObject<T> object), use RayObject.get() instead * Refine test * small fixes Co-authored-by: 刘宝 <po.lb@antfin.com> Co-authored-by: Hao Chen <chenh1024@gmail.com>
This commit is contained in:
+25
-35
@@ -92,6 +92,8 @@ from ray.exceptions import (
|
||||
RayTimeoutError,
|
||||
)
|
||||
from ray.utils import decode
|
||||
import gc
|
||||
import msgpack
|
||||
|
||||
cimport cpython
|
||||
|
||||
@@ -106,8 +108,6 @@ include "includes/libcoreworker.pxi"
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
MEMCOPY_THREADS = 6
|
||||
|
||||
|
||||
def set_internal_config(dict options):
|
||||
cdef:
|
||||
@@ -257,8 +257,9 @@ cdef int prepare_resources(
|
||||
return 0
|
||||
|
||||
|
||||
cdef void prepare_args(
|
||||
CoreWorker core_worker, args, c_vector[CTaskArg] *args_vector):
|
||||
cdef prepare_args(
|
||||
CoreWorker core_worker,
|
||||
Language language, args, c_vector[CTaskArg] *args_vector):
|
||||
cdef:
|
||||
size_t size
|
||||
int64_t put_threshold
|
||||
@@ -274,6 +275,13 @@ cdef void prepare_args(
|
||||
|
||||
else:
|
||||
serialized_arg = worker.get_serialization_context().serialize(arg)
|
||||
metadata = serialized_arg.metadata
|
||||
if language != Language.PYTHON:
|
||||
if metadata not in [
|
||||
ray_constants.OBJECT_METADATA_TYPE_CROSS_LANGUAGE,
|
||||
ray_constants.OBJECT_METADATA_TYPE_RAW]:
|
||||
raise Exception("Can't transfer {} data to {}".format(
|
||||
metadata, language))
|
||||
size = serialized_arg.total_bytes
|
||||
|
||||
# TODO(edoakes): any objects containing ObjectIDs are spilled to
|
||||
@@ -283,12 +291,14 @@ cdef void prepare_args(
|
||||
if <int64_t>size <= put_threshold:
|
||||
arg_data = dynamic_pointer_cast[CBuffer, LocalMemoryBuffer](
|
||||
make_shared[LocalMemoryBuffer](size))
|
||||
write_serialized_object(serialized_arg, arg_data)
|
||||
if size > 0:
|
||||
(<SerializedObject>serialized_arg).write_to(
|
||||
Buffer.make(arg_data))
|
||||
for object_id in serialized_arg.contained_object_ids:
|
||||
inlined_ids.push_back((<ObjectID>object_id).native())
|
||||
args_vector.push_back(
|
||||
CTaskArg.PassByValue(make_shared[CRayObject](
|
||||
arg_data, string_to_buffer(serialized_arg.metadata),
|
||||
arg_data, string_to_buffer(metadata),
|
||||
inlined_ids)))
|
||||
inlined_ids.clear()
|
||||
else:
|
||||
@@ -616,29 +626,6 @@ cdef shared_ptr[CBuffer] string_to_buffer(c_string& c_str):
|
||||
<uint8_t*>(c_str.data()), c_str.size(), True))
|
||||
|
||||
|
||||
cdef write_serialized_object(
|
||||
serialized_object, const shared_ptr[CBuffer]& buf):
|
||||
from ray.serialization import Pickle5SerializedObject, RawSerializedObject
|
||||
|
||||
if isinstance(serialized_object, RawSerializedObject):
|
||||
if buf.get() != NULL and buf.get().Size() > 0:
|
||||
size = serialized_object.total_bytes
|
||||
if MEMCOPY_THREADS > 1 and size > kMemcopyDefaultThreshold:
|
||||
parallel_memcopy(buf.get().Data(),
|
||||
<const uint8_t*> serialized_object.value,
|
||||
size, kMemcopyDefaultBlocksize,
|
||||
MEMCOPY_THREADS)
|
||||
else:
|
||||
memcpy(buf.get().Data(),
|
||||
<const uint8_t*>serialized_object.value, size)
|
||||
|
||||
elif isinstance(serialized_object, Pickle5SerializedObject):
|
||||
(<Pickle5Writer>serialized_object.writer).write_to(
|
||||
serialized_object.inband, buf, MEMCOPY_THREADS)
|
||||
else:
|
||||
raise TypeError("Unsupported serialization type.")
|
||||
|
||||
|
||||
cdef class CoreWorker:
|
||||
|
||||
def __cinit__(self, is_driver, store_socket, raylet_socket,
|
||||
@@ -780,7 +767,9 @@ cdef class CoreWorker:
|
||||
&c_object_id, &data)
|
||||
|
||||
if not object_already_exists:
|
||||
write_serialized_object(serialized_object, data)
|
||||
if total_bytes > 0:
|
||||
(<SerializedObject>serialized_object).write_to(
|
||||
Buffer.make(data))
|
||||
if self.is_local_mode:
|
||||
c_object_id_vector.push_back(c_object_id)
|
||||
check_status(CCoreWorkerProcess.GetCoreWorker().Put(
|
||||
@@ -875,7 +864,7 @@ cdef class CoreWorker:
|
||||
num_return_vals, c_resources)
|
||||
ray_function = CRayFunction(
|
||||
language.lang, function_descriptor.descriptor)
|
||||
prepare_args(self, args, &args_vector)
|
||||
prepare_args(self, language, args, &args_vector)
|
||||
|
||||
with nogil:
|
||||
check_status(CCoreWorkerProcess.GetCoreWorker().SubmitTask(
|
||||
@@ -908,7 +897,7 @@ cdef class CoreWorker:
|
||||
prepare_resources(placement_resources, &c_placement_resources)
|
||||
ray_function = CRayFunction(
|
||||
language.lang, function_descriptor.descriptor)
|
||||
prepare_args(self, args, &args_vector)
|
||||
prepare_args(self, language, args, &args_vector)
|
||||
|
||||
with nogil:
|
||||
check_status(CCoreWorkerProcess.GetCoreWorker().CreateActor(
|
||||
@@ -944,7 +933,7 @@ cdef class CoreWorker:
|
||||
task_options = CTaskOptions(num_return_vals, c_resources)
|
||||
ray_function = CRayFunction(
|
||||
language.lang, function_descriptor.descriptor)
|
||||
prepare_args(self, args, &args_vector)
|
||||
prepare_args(self, language, args, &args_vector)
|
||||
|
||||
with nogil:
|
||||
check_status(
|
||||
@@ -1133,8 +1122,9 @@ cdef class CoreWorker:
|
||||
for i, serialized_object in enumerate(serialized_objects):
|
||||
# A nullptr is returned if the object already exists.
|
||||
if returns[0][i].get() != NULL:
|
||||
write_serialized_object(
|
||||
serialized_object, returns[0][i].get().GetData())
|
||||
if returns[0][i].get().HasData():
|
||||
(<SerializedObject>serialized_object).write_to(
|
||||
Buffer.make(returns[0][i].get().GetData()))
|
||||
if self.is_local_mode:
|
||||
return_ids_vector.push_back(return_ids[i])
|
||||
check_status(
|
||||
|
||||
@@ -44,7 +44,7 @@ cdef class Buffer:
|
||||
def __getbuffer__(self, Py_buffer* buffer, int flags):
|
||||
buffer.readonly = 0
|
||||
buffer.buf = <char *>self.buffer.get().Data()
|
||||
buffer.format = 'b'
|
||||
buffer.format = 'B'
|
||||
buffer.internal = NULL
|
||||
buffer.itemsize = 1
|
||||
buffer.len = self.size
|
||||
|
||||
@@ -1,5 +1,9 @@
|
||||
from libc.string cimport memcpy
|
||||
from libc.stdint cimport uintptr_t, uint64_t, INT32_MAX
|
||||
from libcpp cimport nullptr
|
||||
import cython
|
||||
|
||||
DEF MEMCOPY_THREADS = 6
|
||||
|
||||
# This is the default alignment value for len(buffer) < 2048.
|
||||
DEF kMinorBufferAlign = 8
|
||||
@@ -9,6 +13,8 @@ DEF kMajorBufferAlign = 64
|
||||
DEF kMajorBufferSize = 2048
|
||||
DEF kMemcopyDefaultBlocksize = 64
|
||||
DEF kMemcopyDefaultThreshold = 1024 * 1024
|
||||
DEF kLanguageSpecificTypeExtensionId = 101
|
||||
DEF kMessagePackOffset = 9
|
||||
|
||||
cdef extern from "ray/util/memory.h" namespace "ray" nogil:
|
||||
void parallel_memcopy(uint8_t* dst, const uint8_t* src, int64_t nbytes,
|
||||
@@ -82,7 +88,7 @@ cdef class SubBuffer:
|
||||
void *internal
|
||||
object buffer
|
||||
|
||||
def __cinit__(self, Buffer buffer):
|
||||
def __cinit__(self, object buffer):
|
||||
# Increase ref count.
|
||||
self.buffer = buffer
|
||||
self.suboffsets = NULL
|
||||
@@ -142,15 +148,68 @@ cdef class SubBuffer:
|
||||
return self.size
|
||||
|
||||
|
||||
# See 'serialization.proto' for the memory layout in the Plasma buffer.
|
||||
def unpack_pickle5_buffers(Buffer buf):
|
||||
cdef class MessagePackSerializer(object):
|
||||
@staticmethod
|
||||
def dumps(o, python_serializer=None):
|
||||
def _default(obj):
|
||||
if python_serializer is not None:
|
||||
return msgpack.ExtType(kLanguageSpecificTypeExtensionId,
|
||||
msgpack.dumps(python_serializer(obj)))
|
||||
return obj
|
||||
try:
|
||||
# If we let strict_types is False, then whether list or tuple will
|
||||
# be packed to a message pack array. So, they can't be
|
||||
# distinguished when unpacking.
|
||||
return msgpack.dumps(o, default=_default,
|
||||
use_bin_type=True, strict_types=True)
|
||||
except ValueError as ex:
|
||||
# msgpack can't handle recursive objects, so we serialize them by
|
||||
# python serializer, e.g. pickle.
|
||||
return msgpack.dumps(_default(o), default=_default,
|
||||
use_bin_type=True, strict_types=True)
|
||||
|
||||
@classmethod
|
||||
def loads(cls, s, python_deserializer=None):
|
||||
def _ext_hook(code, data):
|
||||
if code == kLanguageSpecificTypeExtensionId:
|
||||
if python_deserializer is not None:
|
||||
return python_deserializer(msgpack.loads(data))
|
||||
raise Exception('Unrecognized ext type id: {}'.format(code))
|
||||
try:
|
||||
gc.disable() # Performance optimization for msgpack.
|
||||
return msgpack.loads(s, ext_hook=_ext_hook, raw=False)
|
||||
finally:
|
||||
gc.enable()
|
||||
|
||||
|
||||
@cython.boundscheck(False)
|
||||
@cython.wraparound(False)
|
||||
def split_buffer(Buffer buf):
|
||||
cdef:
|
||||
shared_ptr[CBuffer] _buffer = buf.buffer
|
||||
const uint8_t *data = buf.buffer.get().Data()
|
||||
size_t size = _buffer.get().Size()
|
||||
size_t size = buf.buffer.get().Size()
|
||||
uint8_t[:] bufferview = buf
|
||||
int64_t msgpack_bytes_length
|
||||
|
||||
assert kMessagePackOffset <= size
|
||||
header_unpacker = msgpack.Unpacker()
|
||||
header_unpacker.feed(bufferview[:kMessagePackOffset])
|
||||
msgpack_bytes_length = header_unpacker.unpack()
|
||||
assert kMessagePackOffset + msgpack_bytes_length <= <int64_t>size
|
||||
return (bufferview[kMessagePackOffset:
|
||||
kMessagePackOffset + msgpack_bytes_length],
|
||||
bufferview[kMessagePackOffset + msgpack_bytes_length:])
|
||||
|
||||
|
||||
# See 'serialization.proto' for the memory layout in the Plasma buffer.
|
||||
@cython.boundscheck(False)
|
||||
@cython.wraparound(False)
|
||||
def unpack_pickle5_buffers(uint8_t[:] bufferview):
|
||||
cdef:
|
||||
const uint8_t *data = &bufferview[0]
|
||||
size_t size = len(bufferview)
|
||||
CPythonObject python_object
|
||||
CPythonBuffer *buffer_meta
|
||||
c_string inband_data
|
||||
int64_t protobuf_offset
|
||||
int64_t protobuf_size
|
||||
int32_t i
|
||||
@@ -167,14 +226,16 @@ def unpack_pickle5_buffers(Buffer buf):
|
||||
if not python_object.ParseFromArray(
|
||||
data + protobuf_offset, <int32_t>protobuf_size):
|
||||
raise ValueError("Protobuf object is corrupted.")
|
||||
inband_data.append(<char*>(data + python_object.inband_data_offset()),
|
||||
<size_t>python_object.inband_data_size())
|
||||
inband_data_offset = python_object.inband_data_offset()
|
||||
inband_data = bufferview[
|
||||
inband_data_offset:
|
||||
inband_data_offset + python_object.inband_data_size()]
|
||||
buffers_segment = data + python_object.raw_buffers_offset()
|
||||
pickled_buffers = []
|
||||
# Now read buffer meta
|
||||
for i in range(python_object.buffer_size()):
|
||||
buffer_meta = <CPythonBuffer *>&python_object.buffer(i)
|
||||
buffer = SubBuffer(buf)
|
||||
buffer = SubBuffer(bufferview)
|
||||
buffer.buf = <void*>(buffers_segment + buffer_meta.address())
|
||||
buffer.len = buffer_meta.length()
|
||||
buffer.itemsize = buffer_meta.itemsize()
|
||||
@@ -207,6 +268,11 @@ cdef class Pickle5Writer:
|
||||
self._curr_buffer_addr = 0
|
||||
self._total_bytes = -1
|
||||
|
||||
def __dealloc__(self):
|
||||
# We must release the buffer, or we could experience memory leaks.
|
||||
for i in range(self.buffers.size()):
|
||||
cpython.PyBuffer_Release(&self.buffers[i])
|
||||
|
||||
def buffer_callback(self, pickle_buffer):
|
||||
cdef:
|
||||
Py_buffer view
|
||||
@@ -240,14 +306,14 @@ cdef class Pickle5Writer:
|
||||
self._curr_buffer_addr += view.len
|
||||
self.buffers.push_back(view)
|
||||
|
||||
def get_total_bytes(self, const c_string &inband):
|
||||
def get_total_bytes(self, const uint8_t[:] inband):
|
||||
cdef:
|
||||
size_t protobuf_bytes = 0
|
||||
uint64_t inband_data_offset = sizeof(int64_t) * 2
|
||||
uint64_t raw_buffers_offset = padded_length_u64(
|
||||
inband_data_offset + inband.length(), kMajorBufferAlign)
|
||||
inband_data_offset + len(inband), kMajorBufferAlign)
|
||||
self.python_object.set_inband_data_offset(inband_data_offset)
|
||||
self.python_object.set_inband_data_size(inband.length())
|
||||
self.python_object.set_inband_data_size(len(inband))
|
||||
self.python_object.set_raw_buffers_offset(raw_buffers_offset)
|
||||
self.python_object.set_raw_buffers_size(self._curr_buffer_addr)
|
||||
# Since calculating the output size is expensive, we will
|
||||
@@ -265,9 +331,11 @@ cdef class Pickle5Writer:
|
||||
self._total_bytes = self._protobuf_offset + protobuf_bytes
|
||||
return self._total_bytes
|
||||
|
||||
cdef void write_to(self, const c_string &inband, shared_ptr[CBuffer] data,
|
||||
int memcopy_threads):
|
||||
cdef uint8_t *ptr = data.get().Data()
|
||||
@cython.boundscheck(False)
|
||||
@cython.wraparound(False)
|
||||
cdef void write_to(self, const uint8_t[:] inband, uint8_t[:] data,
|
||||
int memcopy_threads) nogil:
|
||||
cdef uint8_t *ptr = &data[0]
|
||||
cdef int32_t protobuf_size
|
||||
cdef uint64_t buffer_addr
|
||||
cdef uint64_t buffer_len
|
||||
@@ -284,7 +352,7 @@ cdef class Pickle5Writer:
|
||||
ptr + self._protobuf_offset)
|
||||
# Write inband data.
|
||||
memcpy(ptr + self.python_object.inband_data_offset(),
|
||||
inband.data(), inband.length())
|
||||
&inband[0], len(inband))
|
||||
# Write buffer data.
|
||||
ptr += self.python_object.raw_buffers_offset()
|
||||
for i in range(self.python_object.buffer_size()):
|
||||
@@ -298,5 +366,141 @@ cdef class Pickle5Writer:
|
||||
kMemcopyDefaultBlocksize, memcopy_threads)
|
||||
else:
|
||||
memcpy(ptr + buffer_addr, self.buffers[i].buf, buffer_len)
|
||||
# We must release the buffer, or we could experience memory leaks.
|
||||
cpython.PyBuffer_Release(&self.buffers[i])
|
||||
|
||||
|
||||
cdef class SerializedObject(object):
|
||||
cdef:
|
||||
object _metadata
|
||||
object _contained_object_ids
|
||||
|
||||
def __init__(self, metadata, contained_object_ids=None):
|
||||
self._metadata = metadata
|
||||
self._contained_object_ids = contained_object_ids or []
|
||||
|
||||
@property
|
||||
def total_bytes(self):
|
||||
raise NotImplementedError("{}.total_bytes not implemented.".format(
|
||||
type(self).__name__))
|
||||
|
||||
@property
|
||||
def metadata(self):
|
||||
return self._metadata
|
||||
|
||||
@property
|
||||
def contained_object_ids(self):
|
||||
return self._contained_object_ids
|
||||
|
||||
@cython.boundscheck(False)
|
||||
@cython.wraparound(False)
|
||||
cdef void write_to(self, uint8_t[:] buffer) nogil:
|
||||
raise NotImplementedError("{}.write_to not implemented.".format(
|
||||
type(self).__name__))
|
||||
|
||||
|
||||
cdef class Pickle5SerializedObject(SerializedObject):
|
||||
cdef:
|
||||
const uint8_t[:] inband
|
||||
Pickle5Writer writer
|
||||
object _total_bytes
|
||||
|
||||
def __init__(self, metadata, inband, Pickle5Writer writer,
|
||||
contained_object_ids):
|
||||
super(Pickle5SerializedObject, self).__init__(metadata,
|
||||
contained_object_ids)
|
||||
self.inband = inband
|
||||
self.writer = writer
|
||||
# cached total bytes
|
||||
self._total_bytes = None
|
||||
|
||||
@property
|
||||
def total_bytes(self):
|
||||
if self._total_bytes is None:
|
||||
self._total_bytes = self.writer.get_total_bytes(self.inband)
|
||||
return self._total_bytes
|
||||
|
||||
@cython.boundscheck(False)
|
||||
@cython.wraparound(False)
|
||||
cdef void write_to(self, uint8_t[:] buffer) nogil:
|
||||
self.writer.write_to(self.inband, buffer, MEMCOPY_THREADS)
|
||||
|
||||
|
||||
cdef class MessagePackSerializedObject(SerializedObject):
|
||||
cdef:
|
||||
SerializedObject nest_serialized_object
|
||||
object msgpack_header
|
||||
object msgpack_data
|
||||
int64_t _msgpack_header_bytes
|
||||
int64_t _msgpack_data_bytes
|
||||
int64_t _total_bytes
|
||||
const uint8_t *msgpack_header_ptr
|
||||
const uint8_t *msgpack_data_ptr
|
||||
|
||||
def __init__(self, metadata, msgpack_data,
|
||||
SerializedObject nest_serialized_object=None):
|
||||
if nest_serialized_object:
|
||||
contained_object_ids = nest_serialized_object.contained_object_ids
|
||||
total_bytes = nest_serialized_object.total_bytes
|
||||
else:
|
||||
contained_object_ids = []
|
||||
total_bytes = 0
|
||||
super(MessagePackSerializedObject, self).__init__(metadata,
|
||||
contained_object_ids)
|
||||
self.nest_serialized_object = nest_serialized_object
|
||||
self.msgpack_header = msgpack_header = msgpack.dumps(len(msgpack_data))
|
||||
self.msgpack_data = msgpack_data
|
||||
self._msgpack_header_bytes = len(msgpack_header)
|
||||
self._msgpack_data_bytes = len(msgpack_data)
|
||||
self._total_bytes = (kMessagePackOffset +
|
||||
self._msgpack_data_bytes +
|
||||
total_bytes)
|
||||
self.msgpack_header_ptr = <const uint8_t*>msgpack_header
|
||||
self.msgpack_data_ptr = <const uint8_t*>msgpack_data
|
||||
assert self._msgpack_header_bytes <= kMessagePackOffset
|
||||
|
||||
@property
|
||||
def total_bytes(self):
|
||||
return self._total_bytes
|
||||
|
||||
@cython.boundscheck(False)
|
||||
@cython.wraparound(False)
|
||||
cdef void write_to(self, uint8_t[:] buffer) nogil:
|
||||
cdef uint8_t *ptr = &buffer[0]
|
||||
|
||||
# Write msgpack data first.
|
||||
memcpy(ptr, self.msgpack_header_ptr, self._msgpack_header_bytes)
|
||||
memcpy(ptr + kMessagePackOffset,
|
||||
self.msgpack_data_ptr, self._msgpack_data_bytes)
|
||||
|
||||
if self.nest_serialized_object is not None:
|
||||
self.nest_serialized_object.write_to(
|
||||
buffer[kMessagePackOffset + self._msgpack_data_bytes:])
|
||||
|
||||
|
||||
cdef class RawSerializedObject(SerializedObject):
|
||||
cdef:
|
||||
object value
|
||||
const uint8_t *value_ptr
|
||||
int64_t _total_bytes
|
||||
|
||||
def __init__(self, value):
|
||||
super(RawSerializedObject,
|
||||
self).__init__(ray_constants.OBJECT_METADATA_TYPE_RAW)
|
||||
self.value = value
|
||||
self.value_ptr = <const uint8_t*> value
|
||||
self._total_bytes = len(value)
|
||||
|
||||
@property
|
||||
def total_bytes(self):
|
||||
return self._total_bytes
|
||||
|
||||
@cython.boundscheck(False)
|
||||
@cython.wraparound(False)
|
||||
cdef void write_to(self, uint8_t[:] buffer) nogil:
|
||||
if (MEMCOPY_THREADS > 1 and
|
||||
self._total_bytes > kMemcopyDefaultThreshold):
|
||||
parallel_memcopy(&buffer[0],
|
||||
self.value_ptr,
|
||||
self._total_bytes, kMemcopyDefaultBlocksize,
|
||||
MEMCOPY_THREADS)
|
||||
else:
|
||||
memcpy(&buffer[0], self.value_ptr, self._total_bytes)
|
||||
|
||||
@@ -180,13 +180,12 @@ PROCESS_TYPE_GCS_SERVER = "gcs_server"
|
||||
|
||||
LOG_MONITOR_MAX_OPEN_FILES = 200
|
||||
|
||||
# A constant used as object metadata to indicate the object is raw binary.
|
||||
RAW_BUFFER_METADATA = b"RAW"
|
||||
# A constant used as object metadata to indicate the object is pickled. This
|
||||
# format is only ever used for Python inline task argument values.
|
||||
PICKLE_BUFFER_METADATA = b"PICKLE"
|
||||
# A constant used as object metadata to indicate the object is pickle5 format.
|
||||
PICKLE5_BUFFER_METADATA = b"PICKLE5"
|
||||
# A constant used as object metadata to indicate the object is cross language.
|
||||
OBJECT_METADATA_TYPE_CROSS_LANGUAGE = b"XLANG"
|
||||
# A constant used as object metadata to indicate the object is python specific.
|
||||
OBJECT_METADATA_TYPE_PYTHON = b"PYTHON"
|
||||
# A constant used as object metadata to indicate the object is raw bytes.
|
||||
OBJECT_METADATA_TYPE_RAW = b"RAW"
|
||||
|
||||
AUTOSCALER_RESOURCE_REQUEST_CHANNEL = b"autoscaler_resource_request"
|
||||
|
||||
|
||||
+81
-67
@@ -15,7 +15,15 @@ from ray.exceptions import (
|
||||
RayWorkerError,
|
||||
UnreconstructableError,
|
||||
)
|
||||
from ray._raylet import Pickle5Writer, unpack_pickle5_buffers
|
||||
from ray._raylet import (
|
||||
split_buffer,
|
||||
unpack_pickle5_buffers,
|
||||
Pickle5Writer,
|
||||
Pickle5SerializedObject,
|
||||
MessagePackSerializer,
|
||||
MessagePackSerializedObject,
|
||||
RawSerializedObject,
|
||||
)
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
@@ -34,51 +42,6 @@ class DeserializationError(Exception):
|
||||
pass
|
||||
|
||||
|
||||
class SerializedObject:
|
||||
def __init__(self, metadata, contained_object_ids=None):
|
||||
self._metadata = metadata
|
||||
self._contained_object_ids = contained_object_ids or []
|
||||
|
||||
@property
|
||||
def total_bytes(self):
|
||||
raise NotImplementedError
|
||||
|
||||
@property
|
||||
def metadata(self):
|
||||
return self._metadata
|
||||
|
||||
@property
|
||||
def contained_object_ids(self):
|
||||
return self._contained_object_ids
|
||||
|
||||
|
||||
class Pickle5SerializedObject(SerializedObject):
|
||||
def __init__(self, metadata, inband, writer, contained_object_ids):
|
||||
super(Pickle5SerializedObject, self).__init__(metadata,
|
||||
contained_object_ids)
|
||||
self.inband = inband
|
||||
self.writer = writer
|
||||
# cached total bytes
|
||||
self._total_bytes = None
|
||||
|
||||
@property
|
||||
def total_bytes(self):
|
||||
if self._total_bytes is None:
|
||||
self._total_bytes = self.writer.get_total_bytes(self.inband)
|
||||
return self._total_bytes
|
||||
|
||||
|
||||
class RawSerializedObject(SerializedObject):
|
||||
def __init__(self, value):
|
||||
super(RawSerializedObject,
|
||||
self).__init__(ray_constants.RAW_BUFFER_METADATA)
|
||||
self.value = value
|
||||
|
||||
@property
|
||||
def total_bytes(self):
|
||||
return len(self.value)
|
||||
|
||||
|
||||
def _try_to_compute_deterministic_class_id(cls, depth=5):
|
||||
"""Attempt to produce a deterministic class ID for a given class.
|
||||
|
||||
@@ -265,23 +228,51 @@ class SerializationContext:
|
||||
raise DeserializationError()
|
||||
return obj
|
||||
|
||||
def _deserialize_msgpack_data(self, data, metadata):
|
||||
msgpack_data, pickle5_data = split_buffer(data)
|
||||
|
||||
if metadata == ray_constants.OBJECT_METADATA_TYPE_CROSS_LANGUAGE:
|
||||
python_objects = []
|
||||
else:
|
||||
python_objects = self._deserialize_pickle5_data(pickle5_data)
|
||||
|
||||
try:
|
||||
|
||||
def _python_deserializer(index):
|
||||
return python_objects[index]
|
||||
|
||||
obj = MessagePackSerializer.loads(msgpack_data,
|
||||
_python_deserializer)
|
||||
except Exception:
|
||||
raise DeserializationError()
|
||||
return obj
|
||||
|
||||
def _deserialize_object(self, data, metadata, object_id):
|
||||
if metadata:
|
||||
if metadata == ray_constants.PICKLE5_BUFFER_METADATA:
|
||||
return self._deserialize_pickle5_data(data)
|
||||
if metadata in [
|
||||
ray_constants.OBJECT_METADATA_TYPE_CROSS_LANGUAGE,
|
||||
ray_constants.OBJECT_METADATA_TYPE_PYTHON
|
||||
]:
|
||||
return self._deserialize_msgpack_data(data, metadata)
|
||||
# Check if the object should be returned as raw bytes.
|
||||
if metadata == ray_constants.RAW_BUFFER_METADATA:
|
||||
if metadata == ray_constants.OBJECT_METADATA_TYPE_RAW:
|
||||
if data is None:
|
||||
return b""
|
||||
return data.to_pybytes()
|
||||
# Otherwise, return an exception object based on
|
||||
# the error type.
|
||||
error_type = int(metadata)
|
||||
try:
|
||||
error_type = int(metadata)
|
||||
except Exception:
|
||||
raise Exception(
|
||||
"Can't deserialize object: {}, metadata: {}".format(
|
||||
object_id, metadata))
|
||||
|
||||
# RayTaskError is serialized with pickle5 in the data field.
|
||||
# TODO (kfstorm): exception serialization should be language
|
||||
# independent.
|
||||
if error_type == ErrorType.Value("TASK_EXECUTION_EXCEPTION"):
|
||||
obj = self._deserialize_pickle5_data(data)
|
||||
obj = self._deserialize_msgpack_data(data, metadata)
|
||||
assert isinstance(obj, RayTaskError)
|
||||
return obj
|
||||
elif error_type == ErrorType.Value("WORKER_DIED"):
|
||||
@@ -347,6 +338,43 @@ class SerializationContext:
|
||||
|
||||
return results
|
||||
|
||||
def _serialize_to_pickle5(self, metadata, value):
|
||||
writer = Pickle5Writer()
|
||||
# TODO(swang): Check that contained_object_ids is empty.
|
||||
try:
|
||||
self.set_in_band_serialization()
|
||||
inband = pickle.dumps(
|
||||
value, protocol=5, buffer_callback=writer.buffer_callback)
|
||||
except Exception as e:
|
||||
self.get_and_clear_contained_object_ids()
|
||||
raise e
|
||||
finally:
|
||||
self.set_out_of_band_serialization()
|
||||
|
||||
return Pickle5SerializedObject(
|
||||
metadata, inband, writer,
|
||||
self.get_and_clear_contained_object_ids())
|
||||
|
||||
def _serialize_to_msgpack(self, metadata, value):
|
||||
python_objects = []
|
||||
|
||||
def _python_serializer(o):
|
||||
index = len(python_objects)
|
||||
python_objects.append(o)
|
||||
return index
|
||||
|
||||
msgpack_data = MessagePackSerializer.dumps(value, _python_serializer)
|
||||
|
||||
if python_objects:
|
||||
pickle5_serialized_object = \
|
||||
self._serialize_to_pickle5(metadata, python_objects)
|
||||
else:
|
||||
metadata = ray_constants.OBJECT_METADATA_TYPE_CROSS_LANGUAGE
|
||||
pickle5_serialized_object = None
|
||||
|
||||
return MessagePackSerializedObject(metadata, msgpack_data,
|
||||
pickle5_serialized_object)
|
||||
|
||||
def serialize(self, value):
|
||||
"""Serialize an object.
|
||||
|
||||
@@ -365,23 +393,9 @@ class SerializationContext:
|
||||
metadata = str(ErrorType.Value(
|
||||
"TASK_EXECUTION_EXCEPTION")).encode("ascii")
|
||||
else:
|
||||
metadata = ray_constants.PICKLE5_BUFFER_METADATA
|
||||
metadata = ray_constants.OBJECT_METADATA_TYPE_PYTHON
|
||||
|
||||
writer = Pickle5Writer()
|
||||
# TODO(swang): Check that contained_object_ids is empty.
|
||||
try:
|
||||
self.set_in_band_serialization()
|
||||
inband = pickle.dumps(
|
||||
value, protocol=5, buffer_callback=writer.buffer_callback)
|
||||
except Exception as e:
|
||||
self.get_and_clear_contained_object_ids()
|
||||
raise e
|
||||
finally:
|
||||
self.set_out_of_band_serialization()
|
||||
|
||||
return Pickle5SerializedObject(
|
||||
metadata, inband, writer,
|
||||
self.get_and_clear_contained_object_ids())
|
||||
return self._serialize_to_msgpack(metadata, value)
|
||||
|
||||
def register_custom_serializer(self,
|
||||
cls,
|
||||
|
||||
@@ -13,3 +13,13 @@ def test_cross_language_raise_kwargs(shutdown_only):
|
||||
|
||||
with pytest.raises(Exception, match="kwargs"):
|
||||
ray.java_actor_class("a").remote(x="arg1")
|
||||
|
||||
|
||||
def test_cross_language_raise_exception(shutdown_only):
|
||||
ray.init(load_code_from_local=True, include_java=True)
|
||||
|
||||
class PythonObject(object):
|
||||
pass
|
||||
|
||||
with pytest.raises(Exception, match="transfer"):
|
||||
ray.java_function("a", "b").remote(PythonObject())
|
||||
|
||||
+13
-3
@@ -172,9 +172,19 @@ def find_version(*filepath):
|
||||
|
||||
|
||||
requires = [
|
||||
"numpy >= 1.16", "filelock", "jsonschema", "click", "colorama", "pyyaml",
|
||||
"redis >= 3.3.2", "protobuf >= 3.8.0", "py-spy >= 0.2.0", "aiohttp",
|
||||
"google", "grpcio"
|
||||
"aiohttp",
|
||||
"click",
|
||||
"colorama",
|
||||
"filelock",
|
||||
"google",
|
||||
"grpcio",
|
||||
"jsonschema",
|
||||
"msgpack >= 0.6.0, < 1.0.0",
|
||||
"numpy >= 1.16",
|
||||
"protobuf >= 3.8.0",
|
||||
"py-spy >= 0.2.0",
|
||||
"pyyaml",
|
||||
"redis >= 3.3.2",
|
||||
]
|
||||
|
||||
setup(
|
||||
|
||||
Reference in New Issue
Block a user