From 97c6369b495e1e84b4a8511e92bfd7709745eec9 Mon Sep 17 00:00:00 2001 From: Robert Nishihara Date: Sat, 21 Oct 2017 21:24:10 -0700 Subject: [PATCH] Update arrow to include custom serializer for pytorch and register default serialization handlers. (#1152) * Update arrow to include custom serializer for pytorch. * Call pyarrow function for registering default custom serialization handlers. * Change class ID used in serialization context for object IDs. --- python/ray/worker.py | 79 +-------------------------- src/thirdparty/download_thirdparty.sh | 2 +- 2 files changed, 4 insertions(+), 77 deletions(-) diff --git a/python/ray/worker.py b/python/ray/worker.py index cf2ad083d..a2f7f4878 100644 --- a/python/ray/worker.py +++ b/python/ray/worker.py @@ -1009,6 +1009,8 @@ def _initialize_serialization(worker=global_worker): serialize several exception classes that we define for error handling. """ worker.serialization_context = pyarrow.SerializationContext() + pyarrow.register_default_serialization_handlers( + worker.serialization_context) # Define a custom serializer and deserializer for handling Object IDs. def objectid_custom_serializer(obj): @@ -1018,85 +1020,10 @@ def _initialize_serialization(worker=global_worker): return ray.local_scheduler.ObjectID(serialized_obj) worker.serialization_context.register_type( - ray.local_scheduler.ObjectID, 20 * b"\x00", pickle=False, + ray.local_scheduler.ObjectID, "ray.ObjectID", pickle=False, custom_serializer=objectid_custom_serializer, custom_deserializer=objectid_custom_deserializer) - # Define a custom serializer and deserializer for handling numpy arrays - # that contain objects. - def array_custom_serializer(obj): - return obj.tolist(), obj.dtype.str - - def array_custom_deserializer(serialized_obj): - return np.array(serialized_obj[0], dtype=np.dtype(serialized_obj[1])) - - worker.serialization_context.register_type( - np.ndarray, 20 * b"\x01", pickle=False, - custom_serializer=array_custom_serializer, - custom_deserializer=array_custom_deserializer) - - def ordered_dict_custom_serializer(obj): - return list(obj.keys()), list(obj.values()) - - def ordered_dict_custom_deserializer(obj): - return collections.OrderedDict(zip(obj[0], obj[1])) - - worker.serialization_context.register_type( - collections.OrderedDict, 20 * b"\x02", pickle=False, - custom_serializer=ordered_dict_custom_serializer, - custom_deserializer=ordered_dict_custom_deserializer) - - def default_dict_custom_serializer(obj): - return list(obj.keys()), list(obj.values()), obj.default_factory - - def default_dict_custom_deserializer(obj): - return collections.defaultdict(obj[2], zip(obj[0], obj[1])) - - worker.serialization_context.register_type( - collections.defaultdict, 20 * b"\x03", pickle=False, - custom_serializer=default_dict_custom_serializer, - custom_deserializer=default_dict_custom_deserializer) - - def _serialize_pandas_series(s): - import pandas as pd - # TODO: serializing Series without extra copy - serialized = pyarrow.serialize_pandas(pd.DataFrame({s.name: s})) - return { - 'type': 'Series', - 'data': serialized.to_pybytes() - } - - def _serialize_pandas_dataframe(df): - return { - 'type': 'DataFrame', - 'data': pyarrow.serialize_pandas(df).to_pybytes() - } - - def _deserialize_callback_pandas(data): - deserialized = pyarrow.deserialize_pandas(data['data']) - type_ = data['type'] - if type_ == 'Series': - return deserialized[deserialized.columns[0]] - elif type_ == 'DataFrame': - return deserialized - else: - raise ValueError(type_) - - try: - import pandas as pd - worker.serialization_context.register_type( - pd.Series, 'pandas.Series', - custom_serializer=_serialize_pandas_series, - custom_deserializer=_deserialize_callback_pandas) - - worker.serialization_context.register_type( - pd.DataFrame, 'pandas.DataFrame', - custom_serializer=_serialize_pandas_dataframe, - custom_deserializer=_deserialize_callback_pandas) - except ImportError: - # no pandas - pass - if worker.mode in [SCRIPT_MODE, SILENT_MODE]: # These should only be called on the driver because _register_class # will export the class to all of the workers. diff --git a/src/thirdparty/download_thirdparty.sh b/src/thirdparty/download_thirdparty.sh index 8c2086b9d..8113d255b 100755 --- a/src/thirdparty/download_thirdparty.sh +++ b/src/thirdparty/download_thirdparty.sh @@ -13,4 +13,4 @@ fi cd $TP_DIR/arrow git fetch origin master -git checkout a8f518588fda471b2e3cc8e0f0064e7c4bb99899 +git checkout 05788d035f4aa918d80c9db7a1bf74fe38309c60