mirror of
https://github.com/wassname/ray.git
synced 2026-07-03 07:44:45 +08:00
Serialize numpy arrays with custom objects
This commit is contained in:
@@ -1,4 +1,5 @@
|
||||
#include "numpy.h"
|
||||
#include "python.h"
|
||||
|
||||
#include <sstream>
|
||||
|
||||
@@ -6,6 +7,11 @@
|
||||
|
||||
using namespace arrow;
|
||||
|
||||
extern "C" {
|
||||
extern PyObject *numbuf_serialize_callback;
|
||||
extern PyObject *numbuf_deserialize_callback;
|
||||
}
|
||||
|
||||
namespace numbuf {
|
||||
|
||||
#define ARROW_TYPE_TO_NUMPY_CASE(TYPE) \
|
||||
@@ -52,7 +58,8 @@ Status DeserializeArray(std::shared_ptr<Array> array, int32_t offset, PyObject**
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
Status SerializeArray(PyArrayObject* array, SequenceBuilder& builder) {
|
||||
Status SerializeArray(PyArrayObject* array, SequenceBuilder& builder,
|
||||
std::vector<PyObject*>& subdicts) {
|
||||
size_t ndim = PyArray_NDIM(array);
|
||||
int dtype = PyArray_TYPE(array);
|
||||
std::vector<int64_t> dims(ndim);
|
||||
@@ -96,6 +103,23 @@ Status SerializeArray(PyArrayObject* array, SequenceBuilder& builder) {
|
||||
case NPY_DOUBLE:
|
||||
RETURN_NOT_OK(builder.AppendTensor(dims, reinterpret_cast<double*>(data)));
|
||||
break;
|
||||
case NPY_OBJECT:
|
||||
if (!numbuf_serialize_callback) {
|
||||
std::stringstream stream;
|
||||
stream << "numpy data type not recognized: " << dtype;
|
||||
return Status::NotImplemented(stream.str());
|
||||
} else {
|
||||
PyObject* arglist = Py_BuildValue("(O)", array);
|
||||
PyObject* result = PyObject_CallObject(numbuf_serialize_callback, arglist);
|
||||
if (!result) {
|
||||
Py_XDECREF(arglist);
|
||||
return python_error_to_status();
|
||||
}
|
||||
builder.AppendDict(PyDict_Size(result));
|
||||
subdicts.push_back(result);
|
||||
Py_XDECREF(arglist);
|
||||
}
|
||||
break;
|
||||
default:
|
||||
std::stringstream stream;
|
||||
stream << "numpy data type not recognized: " << dtype;
|
||||
|
||||
@@ -14,7 +14,7 @@
|
||||
|
||||
namespace numbuf {
|
||||
|
||||
arrow::Status SerializeArray(PyArrayObject* array, SequenceBuilder& builder);
|
||||
arrow::Status SerializeArray(PyArrayObject* array, SequenceBuilder& builder, std::vector<PyObject*>& subdicts);
|
||||
arrow::Status DeserializeArray(std::shared_ptr<arrow::Array> array, int32_t offset, PyObject** out);
|
||||
|
||||
}
|
||||
|
||||
@@ -109,7 +109,7 @@ Status append(PyObject* elem, SequenceBuilder& builder,
|
||||
} else if (PyArray_IsScalar(elem, Generic)) {
|
||||
RETURN_NOT_OK(AppendScalar(elem, builder));
|
||||
} else if (PyArray_Check(elem)) {
|
||||
RETURN_NOT_OK(SerializeArray((PyArrayObject*) elem, builder));
|
||||
RETURN_NOT_OK(SerializeArray((PyArrayObject*) elem, builder, subdicts));
|
||||
} else if (elem == Py_None) {
|
||||
RETURN_NOT_OK(builder.AppendNone());
|
||||
} else {
|
||||
|
||||
@@ -17,6 +17,8 @@ arrow::Status DeserializeList(std::shared_ptr<arrow::Array> array, int32_t start
|
||||
arrow::Status DeserializeTuple(std::shared_ptr<arrow::Array> array, int32_t start_idx, int32_t stop_idx, PyObject** out);
|
||||
arrow::Status DeserializeDict(std::shared_ptr<arrow::Array> array, int32_t start_idx, int32_t stop_idx, PyObject** out);
|
||||
|
||||
arrow::Status python_error_to_status();
|
||||
|
||||
}
|
||||
|
||||
#endif
|
||||
|
||||
Reference in New Issue
Block a user