From bcc59e898d6236565086a8e81ed8428fbeda7232 Mon Sep 17 00:00:00 2001 From: Philipp Moritz Date: Tue, 15 Mar 2016 13:06:51 -0700 Subject: [PATCH] implement object reference serialization and debugging for object stores, some fixes --- lib/orchpy/setup.py | 1 - protos/orchestra.proto | 7 ++- protos/types.proto | 15 ++--- src/objstore.cc | 7 +++ src/orchpylib.cc | 139 +++++++++++++++++++++++++++++++++++------ test/runtest.py | 28 ++++++++- 6 files changed, 164 insertions(+), 33 deletions(-) diff --git a/lib/orchpy/setup.py b/lib/orchpy/setup.py index ba49125cf..55efc4312 100644 --- a/lib/orchpy/setup.py +++ b/lib/orchpy/setup.py @@ -2,7 +2,6 @@ import sys from setuptools import setup, Extension, find_packages import setuptools -from Cython.Build import cythonize # because of relative paths, this must be run from inside orch/lib/orchpy/ diff --git a/protos/orchestra.proto b/protos/orchestra.proto index 8570acd72..f31673275 100644 --- a/protos/orchestra.proto +++ b/protos/orchestra.proto @@ -132,10 +132,13 @@ message GetObjReply { uint64 size = 3; } -message ObjStoreDebugInfoRequest {} +message ObjStoreDebugInfoRequest { + repeated uint64 objref = 1; // get protocol buffer objects corresponding to objref +} message ObjStoreDebugInfoReply { - repeated uint64 objref = 1; + repeated uint64 objref = 1; // list of object references in the store + repeated Obj obj = 2; // protocol buffer objects that were requested } service ObjStore { diff --git a/protos/types.proto b/protos/types.proto index 98340a0a4..a72674458 100644 --- a/protos/types.proto +++ b/protos/types.proto @@ -41,15 +41,12 @@ message Call { repeated uint64 result = 3; // object references for result } -enum DataType { - INT32 = 0; - INT64 = 1; - FLOAT32 = 2; - FLOAT64 = 3; -} - message Array { repeated uint64 shape = 1; - DataType dtype = 3; - repeated double double_data = 2; + sint64 dtype = 2; + repeated double double_data = 3; + repeated float float_data = 4; + repeated sint64 int_data = 5; + repeated uint64 uint_data = 6; + repeated uint64 objref_data = 7; } diff --git a/src/objstore.cc b/src/objstore.cc index 0991bf43a..17aa3004d 100644 --- a/src/objstore.cc +++ b/src/objstore.cc @@ -79,6 +79,13 @@ Status ObjStoreService::ObjStoreDebugInfo(ServerContext* context, const ObjStore for (const auto& entry : memory_) { reply->add_objref(entry.first); } + for (int i = 0; i < request->objref_size(); ++i) { + ObjRef objref = request->objref(i); + Obj* obj = new Obj(); + std::string data(memory_[objref].ptr.data, memory_[objref].ptr.len); // copies, but for debugging should be ok + obj->ParseFromString(data); + reply->mutable_obj()->AddAllocated(obj); + } return Status::OK; } diff --git a/src/orchpylib.cc b/src/orchpylib.cc index eb8312222..e2f59879b 100644 --- a/src/orchpylib.cc +++ b/src/orchpylib.cc @@ -38,6 +38,18 @@ static int PyObjRef_init(PyObjRef *self, PyObject *args, PyObject *kwds) { return 0; }; +static int PyObjRef_compare(PyObject* a, PyObject* b) { + PyObjRef* A = (PyObjRef*) a; + PyObjRef* B = (PyObjRef*) b; + if (A->val < B->val) { + return -1; + } + if (A->val > B->val) { + return 1; + } + return 0; +} + static PyMemberDef PyObjRef_members[] = { {"val", T_INT, offsetof(PyObjRef, val), 0, "object reference"}, {NULL} @@ -53,7 +65,7 @@ static PyTypeObject PyObjRefType = { 0, /* tp_print */ 0, /* tp_getattr */ 0, /* tp_setattr */ - 0, /* tp_compare */ + PyObjRef_compare, /* tp_compare */ 0, /* tp_repr */ 0, /* tp_as_number */ 0, /* tp_as_sequence */ @@ -132,7 +144,7 @@ int PyObjectToObjRef(PyObject* object, ObjRef *objref) { *objref = ((PyObjRef*) object)->val; return 1; } else { - PyErr_SetString(PyExc_TypeError, "must be a 'worker' capsule"); + PyErr_SetString(PyExc_TypeError, "must be an object reference"); return 0; } } @@ -141,6 +153,7 @@ int PyObjectToObjRef(PyObject* object, ObjRef *objref) { // serialize will serialize the python object val into the protocol buffer // object obj, returns 0 if successful and something else if not +// FIXME(pcm): This currently only works for contiguous arrays int serialize(PyObject* val, Obj* obj) { if (PyInt_Check(val)) { Int* data = obj->mutable_int_data(); @@ -164,19 +177,66 @@ int serialize(PyObject* val, Obj* obj) { PyString_AsStringAndSize(val, &buffer, &length); // creates pointer to internal buffer obj->mutable_string_data()->set_data(buffer, length); } else if (PyArray_Check(val)) { - PyArrayObject* array = PyArray_GETCONTIGUOUS((PyArrayObject*)val); + PyArrayObject* array = PyArray_GETCONTIGUOUS((PyArrayObject*) val); Array* data = obj->mutable_array_data(); npy_intp size = PyArray_SIZE(array); for (int i = 0; i < PyArray_NDIM(array); ++i) { data->add_shape(PyArray_DIM(array, i)); } - if (PyArray_ISFLOAT(array)) { - double* buffer = (double*) PyArray_DATA(array); - for (npy_intp i = 0; i < size; ++i) { - data->add_double_data(buffer[i]); - } + int typ = PyArray_TYPE(array); + data->set_dtype(typ); + switch (typ) { + case NPY_FLOAT: { + npy_float* buffer = (npy_float*) PyArray_DATA(array); + for (npy_intp i = 0; i < size; ++i) { + data->add_float_data(buffer[i]); + } + } + break; + case NPY_DOUBLE: { + npy_double* buffer = (npy_double*) PyArray_DATA(array); + for (npy_intp i = 0; i < size; ++i) { + data->add_double_data(buffer[i]); + } + } + break; + case NPY_INT8: { + npy_int8* buffer = (npy_int8*) PyArray_DATA(array); + for (npy_intp i = 0; i < size; ++i) { + data->add_int_data(buffer[i]); + } + } + break; + case NPY_UINT8: { + npy_uint8* buffer = (npy_uint8*) PyArray_DATA(array); + for (npy_intp i = 0; i < size; ++i) { + data->add_uint_data(buffer[i]); + } + } + break; + case NPY_OBJECT: { // FIXME(pcm): Support arbitrary python objects, not only objrefs + PyArrayIterObject* iter = (PyArrayIterObject*) PyArray_IterNew((PyObject*)array); + while (PyArray_ITER_NOTDONE(iter)) { + PyObject** item = (PyObject**) PyArray_ITER_DATA(iter); + ObjRef objref; + if (PyObject_IsInstance(*item, (PyObject*) &PyObjRefType)) { + objref = ((PyObjRef*) (*item))->val; + } else { + PyErr_SetString(PyExc_TypeError, "must be an object reference"); // TODO: improve error message + return -1; + } + data->add_objref_data(objref); + PyArray_ITER_NEXT(iter); + } + Py_XDECREF(iter); + } + break; + default: + PyErr_SetString(OrchPyError, "serialization: numpy datatype not know"); + return -1; } } else { + PyErr_SetString(OrchPyError, "serialization: type not know"); return -1; } return 0; @@ -201,21 +261,65 @@ PyObject* deserialize(const Obj& obj) { return PyString_FromStringAndSize(buffer, length); } else if (obj.has_array_data()) { const Array& array = obj.array_data(); - if (array.double_data_size() > 0) { // TODO: this is not quite right + std::vector dims; + for (int i = 0; i < array.shape_size(); ++i) { + dims.push_back(array.shape(i)); + } + PyArrayObject* pyarray = (PyArrayObject*) PyArray_SimpleNew(array.shape_size(), &dims[0], array.dtype()); + if (array.double_data_size() > 0) { // TODO: handle empty array npy_intp size = array.double_data_size(); - std::vector dims; - for (int i = 0; i < array.shape_size(); ++i) { - dims.push_back(array.shape(i)); - } - PyArrayObject* pyarray = (PyArrayObject*)PyArray_SimpleNew(array.shape_size(), &dims[0], NPY_DOUBLE); - double* buffer = (double*) PyArray_DATA(pyarray); + npy_double* buffer = (npy_double*) PyArray_DATA(pyarray); for (npy_intp i = 0; i < size; ++i) { buffer[i] = array.double_data(i); } - return (PyObject*)pyarray; + } else if (array.float_data_size() > 0) { + npy_intp size = array.float_data_size(); + npy_float* buffer = (npy_float*) PyArray_DATA(pyarray); + for (npy_intp i = 0; i < size; ++i) { + buffer[i] = array.float_data(i); + } + } else if (array.int_data_size() > 0) { + npy_intp size = array.int_data_size(); + switch (array.dtype()) { + case NPY_INT8: { + npy_int8* buffer = (npy_int8*) PyArray_DATA(pyarray); + for (npy_intp i = 0; i < size; ++i) { + buffer[i] = array.int_data(i); + } + } + break; + default: + PyErr_SetString(OrchPyError, "deserialization: internal error (array type not implemented)"); + return NULL; + } + } else if (array.uint_data_size() > 0) { + npy_intp size = array.uint_data_size(); + switch (array.dtype()) { + case NPY_UINT8: { + npy_uint8* buffer = (npy_uint8*) PyArray_DATA(pyarray); + for (npy_intp i = 0; i < size; ++i) { + buffer[i] = array.uint_data(i); + } + } + break; + default: + PyErr_SetString(OrchPyError, "deserialization: internal error (array type not implemented)"); + return NULL; + } + } else if (array.objref_data_size() > 0) { + npy_intp size = array.objref_data_size(); + PyObject** buffer = (PyObject**) PyArray_DATA(pyarray); + for (npy_intp i = 0; i < size; ++i) { + buffer[i] = make_pyobjref(array.objref_data(i)); + } + } else { + PyErr_SetString(OrchPyError, "deserialization: internal error (array type not implemented)"); + return NULL; } + return (PyObject*) pyarray; } else { - std::cout << "don't have object" << std::endl; + PyErr_SetString(OrchPyError, "deserialization: internal error (type not implemented)"); + return NULL; } } @@ -226,7 +330,6 @@ PyObject* serialize_object(PyObject* self, PyObject* args) { return NULL; } if (serialize(pyval, obj) != 0) { - PyErr_SetString(OrchPyError, "serialization: type not know"); // TODO: put a more expressive error message here return NULL; } return PyCapsule_New(static_cast(obj), "obj", NULL); diff --git a/test/runtest.py b/test/runtest.py index e0a1b5c39..bce975409 100644 --- a/test/runtest.py +++ b/test/runtest.py @@ -54,15 +54,36 @@ class SerializationTest(unittest.TestCase): result = orchpy.lib.deserialize_object(serialized) self.assertEqual(data, result) + def numpyTypeTest(self, typ): + a = np.random.randint(0, 10, size=(100, 100)).astype(typ) + b = orchpy.lib.serialize_object(a) + c = orchpy.lib.deserialize_object(b) + self.assertTrue((a == c).all()) + def testSerialize(self): - data = [1, "hello", 3.0] - self.roundTripTest(data) + self.roundTripTest([1, "hello", 3.0]) + self.roundTripTest(42) + self.roundTripTest("hello world") + self.roundTripTest(42.0) a = np.zeros((100, 100)) res = orchpy.lib.serialize_object(a) b = orchpy.lib.deserialize_object(res) self.assertTrue((a == b).all()) + self.numpyTypeTest('int8') + self.numpyTypeTest('uint8') + # self.numpyTypeTest('int16') # TODO(pcm): implement this + # self.numpyTypeTest('int32') # TODO(pcm): implement this + self.numpyTypeTest('float32') + self.numpyTypeTest('float64') + + a = np.array([[orchpy.lib.ObjRef(0), orchpy.lib.ObjRef(1)], [orchpy.lib.ObjRef(41), orchpy.lib.ObjRef(42)]]) + capsule = orchpy.lib.serialize_object(a) + result = orchpy.lib.deserialize_object(capsule) + self.assertTrue((a == result).all()) + +""" class OrchPyLibTest(unittest.TestCase): def testOrchPyLib(self): @@ -88,10 +109,11 @@ class OrchPyLibTest(unittest.TestCase): self.assertEqual(result, 'hello world') services.cleanup() +""" class ObjStoreTest(unittest.TestCase): - """Test setting up object stores, transfering data between them and retrieving data to a client""" + # Test setting up object stores, transfering data between them and retrieving data to a client def testObjStore(self): scheduler_port = new_scheduler_port() objstore1_port = new_objstore_port()