mirror of
https://github.com/wassname/ray.git
synced 2026-07-01 19:14:35 +08:00
throw proper error if numpy array that contains object is serialized
This commit is contained in:
@@ -1,5 +1,7 @@
|
||||
#include "numpy.h"
|
||||
|
||||
#include <sstream>
|
||||
|
||||
#include <numbuf/tensor.h>
|
||||
|
||||
using namespace arrow;
|
||||
@@ -95,7 +97,9 @@ Status SerializeArray(PyArrayObject* array, SequenceBuilder& builder) {
|
||||
RETURN_NOT_OK(builder.AppendTensor(dims, reinterpret_cast<double*>(data)));
|
||||
break;
|
||||
default:
|
||||
DCHECK(false) << "numpy data type not recognized: " << dtype;
|
||||
std::stringstream stream;
|
||||
stream << "numpy data type not recognized: " << dtype;
|
||||
return Status::NotImplemented(stream.str());
|
||||
}
|
||||
Py_XDECREF(contiguous);
|
||||
return Status::OK();
|
||||
|
||||
@@ -47,6 +47,13 @@ class SerializationTests(unittest.TestCase):
|
||||
for t in ["int8", "uint8", "int16", "uint16", "int32", "uint32", "float32", "float64"]:
|
||||
self.numpyTest(t)
|
||||
|
||||
def testNumpyObject(self):
|
||||
a = np.array([np.zeros((2,2))], dtype=object)
|
||||
try:
|
||||
x = self.roundTripTest([a])
|
||||
except:
|
||||
pass
|
||||
|
||||
def testRay(self):
|
||||
for obj in TEST_OBJECTS:
|
||||
self.roundTripTest([obj])
|
||||
|
||||
Reference in New Issue
Block a user