From 7481a02024dcbb4ee13f3b3a3cfdad986309807b Mon Sep 17 00:00:00 2001 From: Philipp Moritz Date: Thu, 25 Aug 2016 22:25:45 -0700 Subject: [PATCH] add custom callbacks for serialization --- python/src/pynumbuf/adapters/python.cc | 44 +++++++++++++++++++++++--- python/src/pynumbuf/numbuf.cc | 29 +++++++++++++++++ python/test/runtest.py | 31 ++++++++++++++++++ 3 files changed, 100 insertions(+), 4 deletions(-) diff --git a/python/src/pynumbuf/adapters/python.cc b/python/src/pynumbuf/adapters/python.cc index ba2c8b0e3..89177e425 100644 --- a/python/src/pynumbuf/adapters/python.cc +++ b/python/src/pynumbuf/adapters/python.cc @@ -6,6 +6,9 @@ using namespace arrow; +extern PyObject* numbuf_serialize_callback; +extern PyObject* numbuf_deserialize_callback; + namespace numbuf { PyObject* get_value(ArrayPtr arr, int32_t index, int32_t type) { @@ -49,6 +52,17 @@ PyObject* get_value(ArrayPtr arr, int32_t index, int32_t type) { return NULL; } +Status python_error_to_status() { + PyObject *type, *value, *traceback; + PyErr_Fetch(&type, &value, &traceback); + char *err_message = PyString_AsString(value); + std::stringstream ss; + if (err_message) { + ss << "Python error in callback: " << err_message; + } + return Status::NotImplemented(ss.str()); +} + Status append(PyObject* elem, SequenceBuilder& builder, std::vector& sublists, std::vector& subtuples, @@ -99,10 +113,22 @@ Status append(PyObject* elem, SequenceBuilder& builder, } else if (elem == Py_None) { RETURN_NOT_OK(builder.AppendNone()); } else { - std::stringstream ss; - ss << "data type of " << PyString_AS_STRING(PyObject_Repr(elem)) - << " not recognized"; - return Status::NotImplemented(ss.str()); + if (!numbuf_serialize_callback) { + std::stringstream ss; + ss << "data type of " << PyString_AS_STRING(PyObject_Repr(elem)) + << " not recognized and custom serialization handler not registered"; + return Status::NotImplemented(ss.str()); + } else { + PyObject* arglist = Py_BuildValue("(O)", elem); + PyObject* result = PyObject_CallObject(numbuf_serialize_callback, arglist); + if (!result) { + Py_XDECREF(arglist); + return python_error_to_status(); + } + builder.AppendDict(PyDict_Size(result)); + subdicts.push_back(result); + Py_XDECREF(arglist); + } } return Status::OK(); } @@ -213,6 +239,16 @@ Status DeserializeDict(std::shared_ptr array, int32_t start_idx, int32_t } Py_XDECREF(keys); // PyList_GetItem(keys, ...) incremented the reference count Py_XDECREF(vals); // PyList_GetItem(vals, ...) incremented the reference count + static PyObject* py_type = PyString_FromString("_pytype_"); + if (PyDict_Contains(result, py_type) && numbuf_deserialize_callback) { + PyObject* arglist = Py_BuildValue("(O)", result); + result = PyObject_CallObject(numbuf_deserialize_callback, arglist); + if (!result) { + Py_XDECREF(arglist); + return python_error_to_status(); + } + Py_XDECREF(arglist); + } *out = result; return Status::OK(); } diff --git a/python/src/pynumbuf/numbuf.cc b/python/src/pynumbuf/numbuf.cc index 33b4ac078..104747b4d 100644 --- a/python/src/pynumbuf/numbuf.cc +++ b/python/src/pynumbuf/numbuf.cc @@ -26,6 +26,9 @@ extern "C" { static PyObject *NumbufError; +PyObject *numbuf_serialize_callback = NULL; +PyObject *numbuf_deserialize_callback = NULL; + int PyObjectToArrow(PyObject* object, std::shared_ptr **result) { if (PyCapsule_IsValid(object, "arrow")) { *result = reinterpret_cast*>(PyCapsule_GetPointer(object, "arrow")); @@ -131,11 +134,37 @@ static PyObject* deserialize_list(PyObject* self, PyObject* args) { return result; } +static PyObject* register_callbacks(PyObject* self, PyObject* args) { + PyObject* result = NULL; + PyObject* serialize_callback; + PyObject* deserialize_callback; + if (PyArg_ParseTuple(args, "OO:register_callbacks", &serialize_callback, &deserialize_callback)) { + if (!PyCallable_Check(serialize_callback)) { + PyErr_SetString(PyExc_TypeError, "serialize_callback must be callable"); + return NULL; + } + if (!PyCallable_Check(deserialize_callback)) { + PyErr_SetString(PyExc_TypeError, "deserialize_callback must be callable"); + return NULL; + } + Py_XINCREF(serialize_callback); // Add a reference to new serialization callback + Py_XINCREF(deserialize_callback); // Add a reference to new deserialization callback + Py_XDECREF(numbuf_serialize_callback); // Dispose of old serialization callback + Py_XDECREF(numbuf_deserialize_callback); // Dispose of old deserialization callback + numbuf_serialize_callback = serialize_callback; + numbuf_deserialize_callback = deserialize_callback; + Py_INCREF(Py_None); + result = Py_None; + } + return result; +} + static PyMethodDef NumbufMethods[] = { { "serialize_list", serialize_list, METH_VARARGS, "serialize a Python list" }, { "deserialize_list", deserialize_list, METH_VARARGS, "deserialize a Python list" }, { "write_to_buffer", write_to_buffer, METH_VARARGS, "write serialized data to buffer"}, { "read_from_buffer", read_from_buffer, METH_VARARGS, "read serialized data from buffer"}, + { "register_callbacks", register_callbacks, METH_VARARGS, "set serialization and deserialization callbacks"}, { NULL, NULL, 0, NULL } }; diff --git a/python/test/runtest.py b/python/test/runtest.py index f8966b5ef..f4e161647 100644 --- a/python/test/runtest.py +++ b/python/test/runtest.py @@ -59,6 +59,37 @@ class SerializationTests(unittest.TestCase): for obj in TEST_OBJECTS: self.roundTripTest([obj]) + def testCallback(self): + + class Foo(object): + def __init__(self): + self.x = 1 + + class Bar(object): + def __init__(self): + self.foo = Foo() + + def serialize(obj): + return dict(obj.__dict__, **{"_pytype_": type(obj).__name__}) + + def deserialize(obj): + if obj["_pytype_"] == "Foo": + result = Foo() + elif obj["_pytype_"] == "Bar": + result = Bar() + + obj.pop("_pytype_", None) + result.__dict__ = obj + return result + + bar = Bar() + bar.foo.x = 42 + + libnumbuf.register_callbacks(serialize, deserialize) + + metadata, size, serialized = libnumbuf.serialize_list([bar]) + self.assertEqual(libnumbuf.deserialize_list(serialized)[0].foo.x, 42) + def testBuffer(self): for (i, obj) in enumerate(TEST_OBJECTS): schema, size, batch = libnumbuf.serialize_list([obj])