diff --git a/python/ray/cloudpickle/cloudpickle_fast.py b/python/ray/cloudpickle/cloudpickle_fast.py index 75fd89e6b..ccfee6c7a 100644 --- a/python/ray/cloudpickle/cloudpickle_fast.py +++ b/python/ray/cloudpickle/cloudpickle_fast.py @@ -541,8 +541,8 @@ class CloudPickler(Pickler): # This is a patch for python3.5 if isinstance(obj, numpy.ndarray): if (self.proto < 5 or - (not obj.flags.c_contiguous and - not obj.flags.f_contiguous) or + (not obj.flags.c_contiguous and not obj.flags.f_contiguous) or + (issubclass(type(obj), numpy.ndarray) and type(obj) is not numpy.ndarray) or obj.dtype == "O" or obj.itemsize == 0): return NotImplemented return _numpy_ndarray_reduce(obj) diff --git a/python/ray/tests/test_basic.py b/python/ray/tests/test_basic.py index 45d6ae9e6..c6ecd4bbf 100644 --- a/python/ray/tests/test_basic.py +++ b/python/ray/tests/test_basic.py @@ -378,6 +378,56 @@ def test_complex_serialization_with_pickle(shutdown_only): complex_serialization(use_pickle=True) +def test_numpy_serialization(ray_start_regular): + array = np.zeros(314) + from ray.cloudpickle import dumps + buffers = [] + inband = dumps(array, protocol=5, buffer_callback=buffers.append) + assert len(inband) < array.nbytes + assert len(buffers) == 1 + + +def test_numpy_subclass_serialization(ray_start_regular): + class MyNumpyConstant(np.ndarray): + def __init__(self, value): + super().__init__() + self.constant = value + + def __str__(self): + print(self.constant) + + constant = MyNumpyConstant(123) + + def explode(x): + raise RuntimeError("Expected error.") + + ray.register_custom_serializer( + type(constant), serializer=explode, deserializer=explode) + + try: + ray.put(constant) + assert False, "Should never get here!" + except (RuntimeError, IndexError): + print("Correct behavior, proof that customer serializer was used.") + + +def test_numpy_subclass_serialization_pickle(ray_start_regular): + class MyNumpyConstant(np.ndarray): + def __init__(self, value): + super().__init__() + self.constant = value + + def __str__(self): + print(self.constant) + + constant = MyNumpyConstant(123) + ray.register_custom_serializer(type(constant), use_pickle=True) + + repr_orig = repr(constant) + repr_ser = repr(ray.get(ray.put(constant))) + assert repr_orig == repr_ser + + def test_function_descriptor(): python_descriptor = ray._raylet.PythonFunctionDescriptor( "module_name", "function_name", "class_name", "function_hash")