mirror of
https://github.com/wassname/ray.git
synced 2026-07-06 05:16:30 +08:00
Make numpy arrays immutable (#183)
* Make numpy arrays immutable in numbuf * Move break statement outside of brackets * Simplify test case * Simplify test case
This commit is contained in:
committed by
Philipp Moritz
parent
651aa6007a
commit
cac473b557
@@ -27,8 +27,7 @@ namespace numbuf {
|
||||
num_dims, dim.data(), NPY_##TYPE, reinterpret_cast<void*>(data)); \
|
||||
if (base != Py_None) { PyArray_SetBaseObject((PyArrayObject*)*out, base); } \
|
||||
Py_XINCREF(base); \
|
||||
} \
|
||||
return Status::OK();
|
||||
} break;
|
||||
|
||||
Status DeserializeArray(
|
||||
std::shared_ptr<Array> array, int32_t offset, PyObject* base, PyObject** out) {
|
||||
@@ -57,6 +56,11 @@ Status DeserializeArray(
|
||||
default:
|
||||
DCHECK(false) << "arrow type not recognized: " << content->value_type()->type;
|
||||
}
|
||||
/* Mark the array as immutable. */
|
||||
PyObject* flags = PyObject_GetAttrString(*out, "flags");
|
||||
DCHECK(flags != NULL) << "Could not mark Numpy array immutable";
|
||||
int flag_set = PyObject_SetAttrString(flags, "writeable", Py_False);
|
||||
DCHECK(flag_set == 0) << "Could not mark Numpy array immutable";
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
|
||||
@@ -118,5 +118,13 @@ class SerializationTests(unittest.TestCase):
|
||||
result = numbuf.deserialize_list(array)
|
||||
assert_equal(result[0], obj)
|
||||
|
||||
def testObjectArrayImmutable(self):
|
||||
obj = np.zeros([10])
|
||||
schema, size, serialized = numbuf.serialize_list([obj])
|
||||
result = numbuf.deserialize_list(serialized)
|
||||
assert_equal(result[0], obj)
|
||||
with self.assertRaises(ValueError):
|
||||
result[0][0] = 1
|
||||
|
||||
if __name__ == "__main__":
|
||||
unittest.main(verbosity=2)
|
||||
|
||||
Reference in New Issue
Block a user