mirror of
https://github.com/wassname/ray.git
synced 2026-06-30 03:30:12 +08:00
add custom callbacks for serialization
This commit is contained in:
@@ -6,6 +6,9 @@
|
||||
|
||||
using namespace arrow;
|
||||
|
||||
extern PyObject* numbuf_serialize_callback;
|
||||
extern PyObject* numbuf_deserialize_callback;
|
||||
|
||||
namespace numbuf {
|
||||
|
||||
PyObject* get_value(ArrayPtr arr, int32_t index, int32_t type) {
|
||||
@@ -49,6 +52,17 @@ PyObject* get_value(ArrayPtr arr, int32_t index, int32_t type) {
|
||||
return NULL;
|
||||
}
|
||||
|
||||
Status python_error_to_status() {
|
||||
PyObject *type, *value, *traceback;
|
||||
PyErr_Fetch(&type, &value, &traceback);
|
||||
char *err_message = PyString_AsString(value);
|
||||
std::stringstream ss;
|
||||
if (err_message) {
|
||||
ss << "Python error in callback: " << err_message;
|
||||
}
|
||||
return Status::NotImplemented(ss.str());
|
||||
}
|
||||
|
||||
Status append(PyObject* elem, SequenceBuilder& builder,
|
||||
std::vector<PyObject*>& sublists,
|
||||
std::vector<PyObject*>& subtuples,
|
||||
@@ -99,10 +113,22 @@ Status append(PyObject* elem, SequenceBuilder& builder,
|
||||
} else if (elem == Py_None) {
|
||||
RETURN_NOT_OK(builder.AppendNone());
|
||||
} else {
|
||||
std::stringstream ss;
|
||||
ss << "data type of " << PyString_AS_STRING(PyObject_Repr(elem))
|
||||
<< " not recognized";
|
||||
return Status::NotImplemented(ss.str());
|
||||
if (!numbuf_serialize_callback) {
|
||||
std::stringstream ss;
|
||||
ss << "data type of " << PyString_AS_STRING(PyObject_Repr(elem))
|
||||
<< " not recognized and custom serialization handler not registered";
|
||||
return Status::NotImplemented(ss.str());
|
||||
} else {
|
||||
PyObject* arglist = Py_BuildValue("(O)", elem);
|
||||
PyObject* result = PyObject_CallObject(numbuf_serialize_callback, arglist);
|
||||
if (!result) {
|
||||
Py_XDECREF(arglist);
|
||||
return python_error_to_status();
|
||||
}
|
||||
builder.AppendDict(PyDict_Size(result));
|
||||
subdicts.push_back(result);
|
||||
Py_XDECREF(arglist);
|
||||
}
|
||||
}
|
||||
return Status::OK();
|
||||
}
|
||||
@@ -213,6 +239,16 @@ Status DeserializeDict(std::shared_ptr<Array> array, int32_t start_idx, int32_t
|
||||
}
|
||||
Py_XDECREF(keys); // PyList_GetItem(keys, ...) incremented the reference count
|
||||
Py_XDECREF(vals); // PyList_GetItem(vals, ...) incremented the reference count
|
||||
static PyObject* py_type = PyString_FromString("_pytype_");
|
||||
if (PyDict_Contains(result, py_type) && numbuf_deserialize_callback) {
|
||||
PyObject* arglist = Py_BuildValue("(O)", result);
|
||||
result = PyObject_CallObject(numbuf_deserialize_callback, arglist);
|
||||
if (!result) {
|
||||
Py_XDECREF(arglist);
|
||||
return python_error_to_status();
|
||||
}
|
||||
Py_XDECREF(arglist);
|
||||
}
|
||||
*out = result;
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
Reference in New Issue
Block a user