diff --git a/lib/python/ray/worker.py b/lib/python/ray/worker.py index 976121455..3f92dc015 100644 --- a/lib/python/ray/worker.py +++ b/lib/python/ray/worker.py @@ -69,6 +69,9 @@ class Worker(object): result = serialization.Str(result) elif isinstance(result, np.ndarray): result = result.view(serialization.NDArray) + elif isinstance(result, np.generic): + return result + # TODO(pcm): close the associated memory segment; if we don't, this leaks memory (but very little, so it is ok for now) elif result == None: return None # can't subclass None and don't need to because there is a global None # TODO(pcm): close the associated memory segment; if we don't, this leaks memory (but very little, so it is ok for now) diff --git a/protos/types.proto b/protos/types.proto index 2491a28e3..8de582055 100644 --- a/protos/types.proto +++ b/protos/types.proto @@ -95,6 +95,7 @@ message TaskStatus { message Array { repeated uint64 shape = 1; sint64 dtype = 2; + bool is_scalar = 8; repeated double double_data = 3; repeated float float_data = 4; repeated sint64 int_data = 5; diff --git a/src/raylib.cc b/src/raylib.cc index 4c77b5703..3ac9b971a 100644 --- a/src/raylib.cc +++ b/src/raylib.cc @@ -199,6 +199,15 @@ void set_dict_item_and_transfer_ownership(PyObject* dict, PyObject* key, PyObjec // Serialization +#define RAYLIB_SERIALIZE_NPY(TYPE, npy_type, proto_type) \ + case NPY_##TYPE: { \ + npy_type* buffer = (npy_type*) PyArray_DATA(array); \ + for (npy_intp i = 0; i < size; ++i) { \ + data->add_##proto_type##_data(buffer[i]); \ + } \ + } \ + break; + // serialize will serialize the python object val into the protocol buffer // object obj, returns 0 if successful and something else if not // FIXME(pcm): This currently only works for contiguous arrays @@ -263,9 +272,17 @@ int serialize(PyObject* worker_capsule, PyObject* val, Obj* obj, std::vectormutable_objref_data(); data->set_data(objref); objrefs.push_back(objref); - } else if (PyArray_Check(val)) { - PyArrayObject* array = PyArray_GETCONTIGUOUS((PyArrayObject*) val); + } else if (PyArray_Check(val) || PyArray_CheckScalar(val)) { // Python int and float already handled Array* data = obj->mutable_array_data(); + PyArrayObject* array; // will be deallocated at the end + if (PyArray_IsScalar(val, Generic)) { + data->set_is_scalar(true); + PyArray_Descr* descr = PyArray_DescrFromScalar(val); // new reference + array = (PyArrayObject*) PyArray_FromScalar(val, descr); // steals the new reference + } else { // val is a numpy array + array = PyArray_GETCONTIGUOUS((PyArrayObject*) val); + } + npy_intp size = PyArray_SIZE(array); for (int i = 0; i < PyArray_NDIM(array); ++i) { data->add_shape(PyArray_DIM(array, i)); @@ -273,48 +290,16 @@ int serialize(PyObject* worker_capsule, PyObject* val, Obj* obj, std::vectorset_dtype(typ); switch (typ) { - case NPY_FLOAT: { - npy_float* buffer = (npy_float*) PyArray_DATA(array); - for (npy_intp i = 0; i < size; ++i) { - data->add_float_data(buffer[i]); - } - } - break; - case NPY_DOUBLE: { - npy_double* buffer = (npy_double*) PyArray_DATA(array); - for (npy_intp i = 0; i < size; ++i) { - data->add_double_data(buffer[i]); - } - } - break; - case NPY_INT8: { - npy_int8* buffer = (npy_int8*) PyArray_DATA(array); - for (npy_intp i = 0; i < size; ++i) { - data->add_int_data(buffer[i]); - } - } - break; - case NPY_INT64: { - npy_int64* buffer = (npy_int64*) PyArray_DATA(array); - for (npy_intp i = 0; i < size; ++i) { - data->add_int_data(buffer[i]); - } - } - break; - case NPY_UINT8: { - npy_uint8* buffer = (npy_uint8*) PyArray_DATA(array); - for (npy_intp i = 0; i < size; ++i) { - data->add_uint_data(buffer[i]); - } - } - break; - case NPY_UINT64: { - npy_uint64* buffer = (npy_uint64*) PyArray_DATA(array); - for (npy_intp i = 0; i < size; ++i) { - data->add_uint_data(buffer[i]); - } - } - break; + RAYLIB_SERIALIZE_NPY(FLOAT, npy_float, float) + RAYLIB_SERIALIZE_NPY(DOUBLE, npy_double, double) + RAYLIB_SERIALIZE_NPY(INT8, npy_int8, int) + RAYLIB_SERIALIZE_NPY(INT16, npy_int16, int) + RAYLIB_SERIALIZE_NPY(INT32, npy_int32, int) + RAYLIB_SERIALIZE_NPY(INT64, npy_int64, int) + RAYLIB_SERIALIZE_NPY(UINT8, npy_uint8, uint) + RAYLIB_SERIALIZE_NPY(UINT16, npy_uint16, uint) + RAYLIB_SERIALIZE_NPY(UINT32, npy_uint32, uint) + RAYLIB_SERIALIZE_NPY(UINT64, npy_uint64, uint) case NPY_OBJECT: { // FIXME(pcm): Support arbitrary python objects, not only objrefs PyArrayIterObject* iter = (PyArrayIterObject*) PyArray_IterNew((PyObject*)array); while (PyArray_ITER_NOTDONE(iter)) { @@ -345,6 +330,16 @@ int serialize(PyObject* worker_capsule, PyObject* val, Obj* obj, std::vector &objrefs) { if (obj.has_int_data()) { @@ -399,72 +394,35 @@ PyObject* deserialize(PyObject* worker_capsule, const Obj& obj, std::vector 0) { // TODO: handle empty array - npy_intp size = array.double_data_size(); - npy_double* buffer = (npy_double*) PyArray_DATA(pyarray); - for (npy_intp i = 0; i < size; ++i) { - buffer[i] = array.double_data(i); - } - } else if (array.float_data_size() > 0) { - npy_intp size = array.float_data_size(); - npy_float* buffer = (npy_float*) PyArray_DATA(pyarray); - for (npy_intp i = 0; i < size; ++i) { - buffer[i] = array.float_data(i); - } - } else if (array.int_data_size() > 0) { - npy_intp size = array.int_data_size(); - switch (array.dtype()) { - case NPY_INT8: { - npy_int8* buffer = (npy_int8*) PyArray_DATA(pyarray); - for (npy_intp i = 0; i < size; ++i) { - buffer[i] = array.int_data(i); - } + switch (array.dtype()) { + RAYLIB_DESERIALIZE_NPY(FLOAT, npy_float, float) + RAYLIB_DESERIALIZE_NPY(DOUBLE, npy_double, double) + RAYLIB_DESERIALIZE_NPY(INT8, npy_int8, int) + RAYLIB_DESERIALIZE_NPY(INT16, npy_int16, int) + RAYLIB_DESERIALIZE_NPY(INT32, npy_int32, int) + RAYLIB_DESERIALIZE_NPY(INT64, npy_int64, int) + RAYLIB_DESERIALIZE_NPY(UINT8, npy_uint8, uint) + RAYLIB_DESERIALIZE_NPY(UINT16, npy_uint16, uint) + RAYLIB_DESERIALIZE_NPY(UINT32, npy_uint32, uint) + RAYLIB_DESERIALIZE_NPY(UINT64, npy_uint64, uint) + case NPY_OBJECT: { + npy_intp size = array.objref_data_size(); + PyObject** buffer = (PyObject**) PyArray_DATA(pyarray); + for (npy_intp i = 0; i < size; ++i) { + buffer[i] = make_pyobjref(worker_capsule, array.objref_data(i)); + objrefs.push_back(array.objref_data(i)); } - break; - case NPY_INT64: { - npy_int64* buffer = (npy_int64*) PyArray_DATA(pyarray); - for (npy_intp i = 0; i < size; ++i) { - buffer[i] = array.int_data(i); - } - } - break; - default: - PyErr_SetString(RayError, "deserialization: internal error (array type not implemented)"); - return NULL; - } - } else if (array.uint_data_size() > 0) { - npy_intp size = array.uint_data_size(); - switch (array.dtype()) { - case NPY_UINT8: { - npy_uint8* buffer = (npy_uint8*) PyArray_DATA(pyarray); - for (npy_intp i = 0; i < size; ++i) { - buffer[i] = array.uint_data(i); - } - } - break; - case NPY_UINT64: { - npy_uint64* buffer = (npy_uint64*) PyArray_DATA(pyarray); - for (npy_intp i = 0; i < size; ++i) { - buffer[i] = array.uint_data(i); - } - } - break; - default: - PyErr_SetString(RayError, "deserialization: internal error (array type not implemented)"); - return NULL; - } - } else if (array.objref_data_size() > 0) { - npy_intp size = array.objref_data_size(); - PyObject** buffer = (PyObject**) PyArray_DATA(pyarray); - for (npy_intp i = 0; i < size; ++i) { - buffer[i] = make_pyobjref(worker_capsule, array.objref_data(i)); - objrefs.push_back(array.objref_data(i)); - } - } else { - PyErr_SetString(RayError, "deserialization: internal error (array type not implemented)"); - return NULL; + } + break; + default: + PyErr_SetString(RayError, "deserialization: internal error (array type not implemented)"); + return NULL; + } + if (array.is_scalar()) { + return PyArray_ScalarFromObject((PyObject*) pyarray); + } else { + return (PyObject*) pyarray; } - return (PyObject*) pyarray; } else { PyErr_SetString(RayError, "deserialization: internal error (type not implemented)"); return NULL; diff --git a/test/runtest.py b/test/runtest.py index e9c7542a7..5a3f82c0d 100644 --- a/test/runtest.py +++ b/test/runtest.py @@ -16,7 +16,10 @@ RAY_TEST_OBJECTS = [[1, "hello", 3.0], 42, "hello world", 42.0, (1.0, "hi"), None, (None, None), ("hello", None), True, False, (True, False), {True: "hello", False: "world"}, - {"hello" : "world", 1: 42, 1.0: 45}, {}] + {"hello" : "world", 1: 42, 1.0: 45}, {}, + np.int8(3), np.int32(4), np.int64(5), + np.uint8(3), np.uint32(4), np.uint64(5), + np.float32(1.0), np.float64(1.0)] class UserDefinedType(object): def __init__(self): @@ -41,6 +44,16 @@ class SerializationTest(unittest.TestCase): c = serialization.deserialize(worker.handle, b) self.assertTrue((a == c).all()) + a = np.array(0).astype(typ) + b, _ = serialization.serialize(worker.handle, a) + c = serialization.deserialize(worker.handle, b) + self.assertTrue((a == c).all()) + + a = np.empty((0,)).astype(typ) + b, _ = serialization.serialize(worker.handle, a) + c = serialization.deserialize(worker.handle, b) + self.assertTrue(a.dtype == c.dtype) + def testSerialize(self): [w] = services.start_singlenode_cluster(return_drivers=True) @@ -54,8 +67,10 @@ class SerializationTest(unittest.TestCase): self.numpyTypeTest(w, 'int8') self.numpyTypeTest(w, 'uint8') - # self.numpyTypeTest('int16') # TODO(pcm): implement this - # self.numpyTypeTest('int32') # TODO(pcm): implement this + self.numpyTypeTest(w, 'int16') + self.numpyTypeTest(w, 'uint16') + self.numpyTypeTest(w, 'int32') + self.numpyTypeTest(w, 'uint32') self.numpyTypeTest(w, 'float32') self.numpyTypeTest(w, 'float64') @@ -311,7 +326,7 @@ class ReferenceCountingTest(unittest.TestCase): objref_val = check_get_deallocated(val) self.assertEqual(ray.scheduler_info()["reference_counts"][objref_val], -1) - if not isinstance(val, bool) and val is not None: + if not isinstance(val, bool) and not isinstance(val, np.generic) and val is not None: x, objref_val = check_get_not_deallocated(val) self.assertEqual(ray.scheduler_info()["reference_counts"][objref_val], 1)