mirror of
https://github.com/wassname/ray.git
synced 2026-06-28 03:34:48 +08:00
Fix the numpy ndarray subclass serialization bug (#7392)
This commit is contained in:
committed by
GitHub
parent
48cdca843f
commit
0792b5cb93
@@ -541,8 +541,8 @@ class CloudPickler(Pickler):
|
||||
# This is a patch for python3.5
|
||||
if isinstance(obj, numpy.ndarray):
|
||||
if (self.proto < 5 or
|
||||
(not obj.flags.c_contiguous and
|
||||
not obj.flags.f_contiguous) or
|
||||
(not obj.flags.c_contiguous and not obj.flags.f_contiguous) or
|
||||
(issubclass(type(obj), numpy.ndarray) and type(obj) is not numpy.ndarray) or
|
||||
obj.dtype == "O" or obj.itemsize == 0):
|
||||
return NotImplemented
|
||||
return _numpy_ndarray_reduce(obj)
|
||||
|
||||
@@ -378,6 +378,56 @@ def test_complex_serialization_with_pickle(shutdown_only):
|
||||
complex_serialization(use_pickle=True)
|
||||
|
||||
|
||||
def test_numpy_serialization(ray_start_regular):
|
||||
array = np.zeros(314)
|
||||
from ray.cloudpickle import dumps
|
||||
buffers = []
|
||||
inband = dumps(array, protocol=5, buffer_callback=buffers.append)
|
||||
assert len(inband) < array.nbytes
|
||||
assert len(buffers) == 1
|
||||
|
||||
|
||||
def test_numpy_subclass_serialization(ray_start_regular):
|
||||
class MyNumpyConstant(np.ndarray):
|
||||
def __init__(self, value):
|
||||
super().__init__()
|
||||
self.constant = value
|
||||
|
||||
def __str__(self):
|
||||
print(self.constant)
|
||||
|
||||
constant = MyNumpyConstant(123)
|
||||
|
||||
def explode(x):
|
||||
raise RuntimeError("Expected error.")
|
||||
|
||||
ray.register_custom_serializer(
|
||||
type(constant), serializer=explode, deserializer=explode)
|
||||
|
||||
try:
|
||||
ray.put(constant)
|
||||
assert False, "Should never get here!"
|
||||
except (RuntimeError, IndexError):
|
||||
print("Correct behavior, proof that customer serializer was used.")
|
||||
|
||||
|
||||
def test_numpy_subclass_serialization_pickle(ray_start_regular):
|
||||
class MyNumpyConstant(np.ndarray):
|
||||
def __init__(self, value):
|
||||
super().__init__()
|
||||
self.constant = value
|
||||
|
||||
def __str__(self):
|
||||
print(self.constant)
|
||||
|
||||
constant = MyNumpyConstant(123)
|
||||
ray.register_custom_serializer(type(constant), use_pickle=True)
|
||||
|
||||
repr_orig = repr(constant)
|
||||
repr_ser = repr(ray.get(ray.put(constant)))
|
||||
assert repr_orig == repr_ser
|
||||
|
||||
|
||||
def test_function_descriptor():
|
||||
python_descriptor = ray._raylet.PythonFunctionDescriptor(
|
||||
"module_name", "function_name", "class_name", "function_hash")
|
||||
|
||||
Reference in New Issue
Block a user