Fix the numpy ndarray subclass serialization bug (#7392)

This commit is contained in:
Siyuan (Ryans) Zhuang
2020-03-01 23:05:59 -08:00
committed by GitHub
parent 48cdca843f
commit 0792b5cb93
2 changed files with 52 additions and 2 deletions
+2 -2
View File
@@ -541,8 +541,8 @@ class CloudPickler(Pickler):
# This is a patch for python3.5
if isinstance(obj, numpy.ndarray):
if (self.proto < 5 or
(not obj.flags.c_contiguous and
not obj.flags.f_contiguous) or
(not obj.flags.c_contiguous and not obj.flags.f_contiguous) or
(issubclass(type(obj), numpy.ndarray) and type(obj) is not numpy.ndarray) or
obj.dtype == "O" or obj.itemsize == 0):
return NotImplemented
return _numpy_ndarray_reduce(obj)
+50
View File
@@ -378,6 +378,56 @@ def test_complex_serialization_with_pickle(shutdown_only):
complex_serialization(use_pickle=True)
def test_numpy_serialization(ray_start_regular):
array = np.zeros(314)
from ray.cloudpickle import dumps
buffers = []
inband = dumps(array, protocol=5, buffer_callback=buffers.append)
assert len(inband) < array.nbytes
assert len(buffers) == 1
def test_numpy_subclass_serialization(ray_start_regular):
class MyNumpyConstant(np.ndarray):
def __init__(self, value):
super().__init__()
self.constant = value
def __str__(self):
print(self.constant)
constant = MyNumpyConstant(123)
def explode(x):
raise RuntimeError("Expected error.")
ray.register_custom_serializer(
type(constant), serializer=explode, deserializer=explode)
try:
ray.put(constant)
assert False, "Should never get here!"
except (RuntimeError, IndexError):
print("Correct behavior, proof that customer serializer was used.")
def test_numpy_subclass_serialization_pickle(ray_start_regular):
class MyNumpyConstant(np.ndarray):
def __init__(self, value):
super().__init__()
self.constant = value
def __str__(self):
print(self.constant)
constant = MyNumpyConstant(123)
ray.register_custom_serializer(type(constant), use_pickle=True)
repr_orig = repr(constant)
repr_ser = repr(ray.get(ray.put(constant)))
assert repr_orig == repr_ser
def test_function_descriptor():
python_descriptor = ray._raylet.PythonFunctionDescriptor(
"module_name", "function_name", "class_name", "function_hash")