diff --git a/python/ray/worker.py b/python/ray/worker.py index ee7d6b472..7aa1843c6 100644 --- a/python/ray/worker.py +++ b/python/ray/worker.py @@ -1053,6 +1053,28 @@ def _initialize_serialization(worker=global_worker): custom_serializer=array_custom_serializer, custom_deserializer=array_custom_deserializer) + def ordered_dict_custom_serializer(obj): + return list(obj.keys()), list(obj.values()) + + def ordered_dict_custom_deserializer(obj): + return collections.OrderedDict(zip(obj[0], obj[1])) + + worker.serialization_context.register_type( + collections.OrderedDict, 20 * b"\x02", pickle=False, + custom_serializer=ordered_dict_custom_serializer, + custom_deserializer=ordered_dict_custom_deserializer) + + def default_dict_custom_serializer(obj): + return list(obj.keys()), list(obj.values()), obj.default_factory + + def default_dict_custom_deserializer(obj): + return collections.defaultdict(obj[2], zip(obj[0], obj[1])) + + worker.serialization_context.register_type( + collections.defaultdict, 20 * b"\x03", pickle=False, + custom_serializer=default_dict_custom_serializer, + custom_deserializer=default_dict_custom_deserializer) + if worker.mode in [SCRIPT_MODE, SILENT_MODE]: # These should only be called on the driver because _register_class # will export the class to all of the workers. diff --git a/src/thirdparty/download_thirdparty.sh b/src/thirdparty/download_thirdparty.sh index 2fc086de0..385343bb2 100755 --- a/src/thirdparty/download_thirdparty.sh +++ b/src/thirdparty/download_thirdparty.sh @@ -13,4 +13,4 @@ fi cd $TP_DIR/arrow git fetch origin master -git checkout 84e5e02fbf412c979387b0a53b0ad0c6d5c5e790 +git checkout 49e02d27227332b06528816bbf73e434a4e1ebcb diff --git a/test/runtest.py b/test/runtest.py index 8502548f2..a340f80bc 100644 --- a/test/runtest.py +++ b/test/runtest.py @@ -7,7 +7,7 @@ import string import sys import time import unittest -from collections import defaultdict, namedtuple +from collections import defaultdict, namedtuple, OrderedDict import numpy as np @@ -354,10 +354,18 @@ class APITest(unittest.TestCase): ray.get(ray.put(TempClass())) - # Note that the below actually returns a dictionary and not a - # defaultdict. This is a bug - # (https://github.com/ray-project/ray/issues/512). - ray.get(ray.put(defaultdict(lambda: 0))) + # Test subtypes of dictionaries. + value_before = OrderedDict([("hello", 1), ("world", 2)]) + object_id = ray.put(value_before) + self.assertEqual(value_before, ray.get(object_id)) + + value_before = defaultdict(lambda: 0, [("hello", 1), ("world", 2)]) + object_id = ray.put(value_before) + self.assertEqual(value_before, ray.get(object_id)) + + value_before = defaultdict(lambda: [], [("hello", 1), ("world", 2)]) + object_id = ray.put(value_before) + self.assertEqual(value_before, ray.get(object_id)) # Test passing custom classes into remote functions from the driver. @ray.remote