Update arrow to include custom serializer for pytorch and register default serialization handlers. (#1152)

* Update arrow to include custom serializer for pytorch.

* Call pyarrow function for registering default custom serialization handlers.

* Change class ID used in serialization context for object IDs.
This commit is contained in:
Robert Nishihara
2017-10-21 21:24:10 -07:00
committed by Philipp Moritz
parent 684e62e784
commit 97c6369b49
2 changed files with 4 additions and 77 deletions
+3 -76
View File
@@ -1009,6 +1009,8 @@ def _initialize_serialization(worker=global_worker):
serialize several exception classes that we define for error handling.
"""
worker.serialization_context = pyarrow.SerializationContext()
pyarrow.register_default_serialization_handlers(
worker.serialization_context)
# Define a custom serializer and deserializer for handling Object IDs.
def objectid_custom_serializer(obj):
@@ -1018,85 +1020,10 @@ def _initialize_serialization(worker=global_worker):
return ray.local_scheduler.ObjectID(serialized_obj)
worker.serialization_context.register_type(
ray.local_scheduler.ObjectID, 20 * b"\x00", pickle=False,
ray.local_scheduler.ObjectID, "ray.ObjectID", pickle=False,
custom_serializer=objectid_custom_serializer,
custom_deserializer=objectid_custom_deserializer)
# Define a custom serializer and deserializer for handling numpy arrays
# that contain objects.
def array_custom_serializer(obj):
return obj.tolist(), obj.dtype.str
def array_custom_deserializer(serialized_obj):
return np.array(serialized_obj[0], dtype=np.dtype(serialized_obj[1]))
worker.serialization_context.register_type(
np.ndarray, 20 * b"\x01", pickle=False,
custom_serializer=array_custom_serializer,
custom_deserializer=array_custom_deserializer)
def ordered_dict_custom_serializer(obj):
return list(obj.keys()), list(obj.values())
def ordered_dict_custom_deserializer(obj):
return collections.OrderedDict(zip(obj[0], obj[1]))
worker.serialization_context.register_type(
collections.OrderedDict, 20 * b"\x02", pickle=False,
custom_serializer=ordered_dict_custom_serializer,
custom_deserializer=ordered_dict_custom_deserializer)
def default_dict_custom_serializer(obj):
return list(obj.keys()), list(obj.values()), obj.default_factory
def default_dict_custom_deserializer(obj):
return collections.defaultdict(obj[2], zip(obj[0], obj[1]))
worker.serialization_context.register_type(
collections.defaultdict, 20 * b"\x03", pickle=False,
custom_serializer=default_dict_custom_serializer,
custom_deserializer=default_dict_custom_deserializer)
def _serialize_pandas_series(s):
import pandas as pd
# TODO: serializing Series without extra copy
serialized = pyarrow.serialize_pandas(pd.DataFrame({s.name: s}))
return {
'type': 'Series',
'data': serialized.to_pybytes()
}
def _serialize_pandas_dataframe(df):
return {
'type': 'DataFrame',
'data': pyarrow.serialize_pandas(df).to_pybytes()
}
def _deserialize_callback_pandas(data):
deserialized = pyarrow.deserialize_pandas(data['data'])
type_ = data['type']
if type_ == 'Series':
return deserialized[deserialized.columns[0]]
elif type_ == 'DataFrame':
return deserialized
else:
raise ValueError(type_)
try:
import pandas as pd
worker.serialization_context.register_type(
pd.Series, 'pandas.Series',
custom_serializer=_serialize_pandas_series,
custom_deserializer=_deserialize_callback_pandas)
worker.serialization_context.register_type(
pd.DataFrame, 'pandas.DataFrame',
custom_serializer=_serialize_pandas_dataframe,
custom_deserializer=_deserialize_callback_pandas)
except ImportError:
# no pandas
pass
if worker.mode in [SCRIPT_MODE, SILENT_MODE]:
# These should only be called on the driver because _register_class
# will export the class to all of the workers.