mirror of
https://github.com/wassname/ray.git
synced 2026-07-02 14:32:01 +08:00
new arrow serialization code (serialize python objects recursively) (#284)
This commit is contained in:
committed by
Robert Nishihara
parent
5ff00e0e81
commit
4a0f35b042
@@ -15,6 +15,7 @@ import serialization
|
||||
import ray.internal.graph_pb2
|
||||
import ray.graph
|
||||
import services
|
||||
import libnumbuf
|
||||
|
||||
class RayFailedObject(object):
|
||||
"""An object used internally to represent a task that threw an exception.
|
||||
@@ -286,9 +287,28 @@ class Worker(object):
|
||||
objref (ray.ObjRef): The object reference of the value to be put.
|
||||
value (serializable object): The value to put in the object store.
|
||||
"""
|
||||
if serialization.is_arrow_serializable(value):
|
||||
ray.lib.put_arrow(self.handle, objref, value)
|
||||
else:
|
||||
try:
|
||||
# We put the value into a list here because in arrow the concept of
|
||||
# "serializing a single object" does not exits.
|
||||
schema, size, serialized = libnumbuf.serialize_list([value])
|
||||
# TODO(pcm): Right now, metadata is serialized twice, change that in the future
|
||||
# in the following line, the "8" is for storing the metadata size,
|
||||
# the len(schema) is for storing the metadata and the 4096 is for storing
|
||||
# the metadata in the batch (see INITIAL_METADATA_SIZE in arrow)
|
||||
size = size + 8 + len(schema) + 4096
|
||||
buff, segmentid = ray.lib.allocate_buffer(self.handle, objref, size)
|
||||
# write the metadata length
|
||||
np.frombuffer(buff, dtype="int64", count=1)[0] = len(schema)
|
||||
# metadata buffer
|
||||
metadata = np.frombuffer(buff, dtype="byte", offset=8, count=len(schema))
|
||||
# write the metadata
|
||||
metadata[:] = schema
|
||||
data = np.frombuffer(buff, dtype="byte")[8 + len(schema):]
|
||||
metadata_offset = libnumbuf.write_to_buffer(serialized, memoryview(data))
|
||||
ray.lib.finish_buffer(self.handle, objref, segmentid, metadata_offset)
|
||||
except:
|
||||
# At the moment, custom object and objects that contain object references take this path
|
||||
# TODO(pcm): Make sure that these are the only objects getting serialized to protobuf
|
||||
object_capsule, contained_objrefs = serialization.serialize(self.handle, value) # contained_objrefs is a list of the objrefs contained in object_capsule
|
||||
ray.lib.put_object(self.handle, objref, object_capsule, contained_objrefs)
|
||||
|
||||
@@ -302,10 +322,22 @@ class Worker(object):
|
||||
objref (ray.ObjRef): The object reference of the value to retrieve.
|
||||
"""
|
||||
if ray.lib.is_arrow(self.handle, objref):
|
||||
result, segmentid = ray.lib.get_arrow(self.handle, objref)
|
||||
## this is the new codepath
|
||||
buff, segmentid, metadata_offset = ray.libraylib.get_buffer(self.handle, objref)
|
||||
metadata_size = np.frombuffer(buff, dtype="int64", count=1)[0]
|
||||
metadata = np.frombuffer(buff, dtype="byte", offset=8, count=metadata_size)
|
||||
data = np.frombuffer(buff, dtype="byte")[8 + metadata_size:]
|
||||
serialized = libnumbuf.read_from_buffer(memoryview(data), bytearray(metadata), metadata_offset)
|
||||
deserialized = libnumbuf.deserialize_list(serialized)
|
||||
# Unwrap the object from the list (it was wrapped put_object)
|
||||
assert len(deserialized) == 1
|
||||
result = deserialized[0]
|
||||
## this is the old codepath
|
||||
# result, segmentid = ray.lib.get_arrow(self.handle, objref)
|
||||
else:
|
||||
object_capsule, segmentid = ray.lib.get_object(self.handle, objref)
|
||||
result = serialization.deserialize(self.handle, object_capsule)
|
||||
|
||||
if isinstance(result, int):
|
||||
result = serialization.Int(result)
|
||||
elif isinstance(result, long):
|
||||
|
||||
Reference in New Issue
Block a user