mirror of
https://github.com/wassname/ray.git
synced 2026-07-01 18:04:09 +08:00
Merge pull request #152 from amplab/fixnumpyscalar
Fix serialization of numpy scalars, implement more numpy types, empty arrays
This commit is contained in:
+66
-108
@@ -199,6 +199,15 @@ void set_dict_item_and_transfer_ownership(PyObject* dict, PyObject* key, PyObjec
|
||||
|
||||
// Serialization
|
||||
|
||||
#define RAYLIB_SERIALIZE_NPY(TYPE, npy_type, proto_type) \
|
||||
case NPY_##TYPE: { \
|
||||
npy_type* buffer = (npy_type*) PyArray_DATA(array); \
|
||||
for (npy_intp i = 0; i < size; ++i) { \
|
||||
data->add_##proto_type##_data(buffer[i]); \
|
||||
} \
|
||||
} \
|
||||
break;
|
||||
|
||||
// serialize will serialize the python object val into the protocol buffer
|
||||
// object obj, returns 0 if successful and something else if not
|
||||
// FIXME(pcm): This currently only works for contiguous arrays
|
||||
@@ -263,9 +272,17 @@ int serialize(PyObject* worker_capsule, PyObject* val, Obj* obj, std::vector<Obj
|
||||
Ref* data = obj->mutable_objref_data();
|
||||
data->set_data(objref);
|
||||
objrefs.push_back(objref);
|
||||
} else if (PyArray_Check(val)) {
|
||||
PyArrayObject* array = PyArray_GETCONTIGUOUS((PyArrayObject*) val);
|
||||
} else if (PyArray_Check(val) || PyArray_CheckScalar(val)) { // Python int and float already handled
|
||||
Array* data = obj->mutable_array_data();
|
||||
PyArrayObject* array; // will be deallocated at the end
|
||||
if (PyArray_IsScalar(val, Generic)) {
|
||||
data->set_is_scalar(true);
|
||||
PyArray_Descr* descr = PyArray_DescrFromScalar(val); // new reference
|
||||
array = (PyArrayObject*) PyArray_FromScalar(val, descr); // steals the new reference
|
||||
} else { // val is a numpy array
|
||||
array = PyArray_GETCONTIGUOUS((PyArrayObject*) val);
|
||||
}
|
||||
|
||||
npy_intp size = PyArray_SIZE(array);
|
||||
for (int i = 0; i < PyArray_NDIM(array); ++i) {
|
||||
data->add_shape(PyArray_DIM(array, i));
|
||||
@@ -273,48 +290,16 @@ int serialize(PyObject* worker_capsule, PyObject* val, Obj* obj, std::vector<Obj
|
||||
int typ = PyArray_TYPE(array);
|
||||
data->set_dtype(typ);
|
||||
switch (typ) {
|
||||
case NPY_FLOAT: {
|
||||
npy_float* buffer = (npy_float*) PyArray_DATA(array);
|
||||
for (npy_intp i = 0; i < size; ++i) {
|
||||
data->add_float_data(buffer[i]);
|
||||
}
|
||||
}
|
||||
break;
|
||||
case NPY_DOUBLE: {
|
||||
npy_double* buffer = (npy_double*) PyArray_DATA(array);
|
||||
for (npy_intp i = 0; i < size; ++i) {
|
||||
data->add_double_data(buffer[i]);
|
||||
}
|
||||
}
|
||||
break;
|
||||
case NPY_INT8: {
|
||||
npy_int8* buffer = (npy_int8*) PyArray_DATA(array);
|
||||
for (npy_intp i = 0; i < size; ++i) {
|
||||
data->add_int_data(buffer[i]);
|
||||
}
|
||||
}
|
||||
break;
|
||||
case NPY_INT64: {
|
||||
npy_int64* buffer = (npy_int64*) PyArray_DATA(array);
|
||||
for (npy_intp i = 0; i < size; ++i) {
|
||||
data->add_int_data(buffer[i]);
|
||||
}
|
||||
}
|
||||
break;
|
||||
case NPY_UINT8: {
|
||||
npy_uint8* buffer = (npy_uint8*) PyArray_DATA(array);
|
||||
for (npy_intp i = 0; i < size; ++i) {
|
||||
data->add_uint_data(buffer[i]);
|
||||
}
|
||||
}
|
||||
break;
|
||||
case NPY_UINT64: {
|
||||
npy_uint64* buffer = (npy_uint64*) PyArray_DATA(array);
|
||||
for (npy_intp i = 0; i < size; ++i) {
|
||||
data->add_uint_data(buffer[i]);
|
||||
}
|
||||
}
|
||||
break;
|
||||
RAYLIB_SERIALIZE_NPY(FLOAT, npy_float, float)
|
||||
RAYLIB_SERIALIZE_NPY(DOUBLE, npy_double, double)
|
||||
RAYLIB_SERIALIZE_NPY(INT8, npy_int8, int)
|
||||
RAYLIB_SERIALIZE_NPY(INT16, npy_int16, int)
|
||||
RAYLIB_SERIALIZE_NPY(INT32, npy_int32, int)
|
||||
RAYLIB_SERIALIZE_NPY(INT64, npy_int64, int)
|
||||
RAYLIB_SERIALIZE_NPY(UINT8, npy_uint8, uint)
|
||||
RAYLIB_SERIALIZE_NPY(UINT16, npy_uint16, uint)
|
||||
RAYLIB_SERIALIZE_NPY(UINT32, npy_uint32, uint)
|
||||
RAYLIB_SERIALIZE_NPY(UINT64, npy_uint64, uint)
|
||||
case NPY_OBJECT: { // FIXME(pcm): Support arbitrary python objects, not only objrefs
|
||||
PyArrayIterObject* iter = (PyArrayIterObject*) PyArray_IterNew((PyObject*)array);
|
||||
while (PyArray_ITER_NOTDONE(iter)) {
|
||||
@@ -345,6 +330,16 @@ int serialize(PyObject* worker_capsule, PyObject* val, Obj* obj, std::vector<Obj
|
||||
return 0;
|
||||
}
|
||||
|
||||
#define RAYLIB_DESERIALIZE_NPY(TYPE, npy_type, proto_type) \
|
||||
case NPY_##TYPE: { \
|
||||
npy_intp size = array.proto_type##_data_size(); \
|
||||
npy_type* buffer = (npy_type*) PyArray_DATA(pyarray); \
|
||||
for (npy_intp i = 0; i < size; ++i) { \
|
||||
buffer[i] = array.proto_type##_data(i); \
|
||||
} \
|
||||
} \
|
||||
break;
|
||||
|
||||
// This method will push all of the object references contained in `obj` to the `objrefs` vector.
|
||||
PyObject* deserialize(PyObject* worker_capsule, const Obj& obj, std::vector<ObjRef> &objrefs) {
|
||||
if (obj.has_int_data()) {
|
||||
@@ -399,72 +394,35 @@ PyObject* deserialize(PyObject* worker_capsule, const Obj& obj, std::vector<ObjR
|
||||
dims.push_back(array.shape(i));
|
||||
}
|
||||
PyArrayObject* pyarray = (PyArrayObject*) PyArray_SimpleNew(array.shape_size(), &dims[0], array.dtype());
|
||||
if (array.double_data_size() > 0) { // TODO: handle empty array
|
||||
npy_intp size = array.double_data_size();
|
||||
npy_double* buffer = (npy_double*) PyArray_DATA(pyarray);
|
||||
for (npy_intp i = 0; i < size; ++i) {
|
||||
buffer[i] = array.double_data(i);
|
||||
}
|
||||
} else if (array.float_data_size() > 0) {
|
||||
npy_intp size = array.float_data_size();
|
||||
npy_float* buffer = (npy_float*) PyArray_DATA(pyarray);
|
||||
for (npy_intp i = 0; i < size; ++i) {
|
||||
buffer[i] = array.float_data(i);
|
||||
}
|
||||
} else if (array.int_data_size() > 0) {
|
||||
npy_intp size = array.int_data_size();
|
||||
switch (array.dtype()) {
|
||||
case NPY_INT8: {
|
||||
npy_int8* buffer = (npy_int8*) PyArray_DATA(pyarray);
|
||||
for (npy_intp i = 0; i < size; ++i) {
|
||||
buffer[i] = array.int_data(i);
|
||||
}
|
||||
switch (array.dtype()) {
|
||||
RAYLIB_DESERIALIZE_NPY(FLOAT, npy_float, float)
|
||||
RAYLIB_DESERIALIZE_NPY(DOUBLE, npy_double, double)
|
||||
RAYLIB_DESERIALIZE_NPY(INT8, npy_int8, int)
|
||||
RAYLIB_DESERIALIZE_NPY(INT16, npy_int16, int)
|
||||
RAYLIB_DESERIALIZE_NPY(INT32, npy_int32, int)
|
||||
RAYLIB_DESERIALIZE_NPY(INT64, npy_int64, int)
|
||||
RAYLIB_DESERIALIZE_NPY(UINT8, npy_uint8, uint)
|
||||
RAYLIB_DESERIALIZE_NPY(UINT16, npy_uint16, uint)
|
||||
RAYLIB_DESERIALIZE_NPY(UINT32, npy_uint32, uint)
|
||||
RAYLIB_DESERIALIZE_NPY(UINT64, npy_uint64, uint)
|
||||
case NPY_OBJECT: {
|
||||
npy_intp size = array.objref_data_size();
|
||||
PyObject** buffer = (PyObject**) PyArray_DATA(pyarray);
|
||||
for (npy_intp i = 0; i < size; ++i) {
|
||||
buffer[i] = make_pyobjref(worker_capsule, array.objref_data(i));
|
||||
objrefs.push_back(array.objref_data(i));
|
||||
}
|
||||
break;
|
||||
case NPY_INT64: {
|
||||
npy_int64* buffer = (npy_int64*) PyArray_DATA(pyarray);
|
||||
for (npy_intp i = 0; i < size; ++i) {
|
||||
buffer[i] = array.int_data(i);
|
||||
}
|
||||
}
|
||||
break;
|
||||
default:
|
||||
PyErr_SetString(RayError, "deserialization: internal error (array type not implemented)");
|
||||
return NULL;
|
||||
}
|
||||
} else if (array.uint_data_size() > 0) {
|
||||
npy_intp size = array.uint_data_size();
|
||||
switch (array.dtype()) {
|
||||
case NPY_UINT8: {
|
||||
npy_uint8* buffer = (npy_uint8*) PyArray_DATA(pyarray);
|
||||
for (npy_intp i = 0; i < size; ++i) {
|
||||
buffer[i] = array.uint_data(i);
|
||||
}
|
||||
}
|
||||
break;
|
||||
case NPY_UINT64: {
|
||||
npy_uint64* buffer = (npy_uint64*) PyArray_DATA(pyarray);
|
||||
for (npy_intp i = 0; i < size; ++i) {
|
||||
buffer[i] = array.uint_data(i);
|
||||
}
|
||||
}
|
||||
break;
|
||||
default:
|
||||
PyErr_SetString(RayError, "deserialization: internal error (array type not implemented)");
|
||||
return NULL;
|
||||
}
|
||||
} else if (array.objref_data_size() > 0) {
|
||||
npy_intp size = array.objref_data_size();
|
||||
PyObject** buffer = (PyObject**) PyArray_DATA(pyarray);
|
||||
for (npy_intp i = 0; i < size; ++i) {
|
||||
buffer[i] = make_pyobjref(worker_capsule, array.objref_data(i));
|
||||
objrefs.push_back(array.objref_data(i));
|
||||
}
|
||||
} else {
|
||||
PyErr_SetString(RayError, "deserialization: internal error (array type not implemented)");
|
||||
return NULL;
|
||||
}
|
||||
break;
|
||||
default:
|
||||
PyErr_SetString(RayError, "deserialization: internal error (array type not implemented)");
|
||||
return NULL;
|
||||
}
|
||||
if (array.is_scalar()) {
|
||||
return PyArray_ScalarFromObject((PyObject*) pyarray);
|
||||
} else {
|
||||
return (PyObject*) pyarray;
|
||||
}
|
||||
return (PyObject*) pyarray;
|
||||
} else {
|
||||
PyErr_SetString(RayError, "deserialization: internal error (type not implemented)");
|
||||
return NULL;
|
||||
|
||||
Reference in New Issue
Block a user