implement object reference serialization and debugging for object stores, some fixes

This commit is contained in:
Philipp Moritz
2016-03-15 13:06:51 -07:00
parent e46f500c91
commit bcc59e898d
6 changed files with 164 additions and 33 deletions
-1
View File
@@ -2,7 +2,6 @@ import sys
from setuptools import setup, Extension, find_packages
import setuptools
from Cython.Build import cythonize
# because of relative paths, this must be run from inside orch/lib/orchpy/
+5 -2
View File
@@ -132,10 +132,13 @@ message GetObjReply {
uint64 size = 3;
}
message ObjStoreDebugInfoRequest {}
message ObjStoreDebugInfoRequest {
repeated uint64 objref = 1; // get protocol buffer objects corresponding to objref
}
message ObjStoreDebugInfoReply {
repeated uint64 objref = 1;
repeated uint64 objref = 1; // list of object references in the store
repeated Obj obj = 2; // protocol buffer objects that were requested
}
service ObjStore {
+6 -9
View File
@@ -41,15 +41,12 @@ message Call {
repeated uint64 result = 3; // object references for result
}
enum DataType {
INT32 = 0;
INT64 = 1;
FLOAT32 = 2;
FLOAT64 = 3;
}
message Array {
repeated uint64 shape = 1;
DataType dtype = 3;
repeated double double_data = 2;
sint64 dtype = 2;
repeated double double_data = 3;
repeated float float_data = 4;
repeated sint64 int_data = 5;
repeated uint64 uint_data = 6;
repeated uint64 objref_data = 7;
}
+7
View File
@@ -79,6 +79,13 @@ Status ObjStoreService::ObjStoreDebugInfo(ServerContext* context, const ObjStore
for (const auto& entry : memory_) {
reply->add_objref(entry.first);
}
for (int i = 0; i < request->objref_size(); ++i) {
ObjRef objref = request->objref(i);
Obj* obj = new Obj();
std::string data(memory_[objref].ptr.data, memory_[objref].ptr.len); // copies, but for debugging should be ok
obj->ParseFromString(data);
reply->mutable_obj()->AddAllocated(obj);
}
return Status::OK;
}
+121 -18
View File
@@ -38,6 +38,18 @@ static int PyObjRef_init(PyObjRef *self, PyObject *args, PyObject *kwds) {
return 0;
};
static int PyObjRef_compare(PyObject* a, PyObject* b) {
PyObjRef* A = (PyObjRef*) a;
PyObjRef* B = (PyObjRef*) b;
if (A->val < B->val) {
return -1;
}
if (A->val > B->val) {
return 1;
}
return 0;
}
static PyMemberDef PyObjRef_members[] = {
{"val", T_INT, offsetof(PyObjRef, val), 0, "object reference"},
{NULL}
@@ -53,7 +65,7 @@ static PyTypeObject PyObjRefType = {
0, /* tp_print */
0, /* tp_getattr */
0, /* tp_setattr */
0, /* tp_compare */
PyObjRef_compare, /* tp_compare */
0, /* tp_repr */
0, /* tp_as_number */
0, /* tp_as_sequence */
@@ -132,7 +144,7 @@ int PyObjectToObjRef(PyObject* object, ObjRef *objref) {
*objref = ((PyObjRef*) object)->val;
return 1;
} else {
PyErr_SetString(PyExc_TypeError, "must be a 'worker' capsule");
PyErr_SetString(PyExc_TypeError, "must be an object reference");
return 0;
}
}
@@ -141,6 +153,7 @@ int PyObjectToObjRef(PyObject* object, ObjRef *objref) {
// serialize will serialize the python object val into the protocol buffer
// object obj, returns 0 if successful and something else if not
// FIXME(pcm): This currently only works for contiguous arrays
int serialize(PyObject* val, Obj* obj) {
if (PyInt_Check(val)) {
Int* data = obj->mutable_int_data();
@@ -164,19 +177,66 @@ int serialize(PyObject* val, Obj* obj) {
PyString_AsStringAndSize(val, &buffer, &length); // creates pointer to internal buffer
obj->mutable_string_data()->set_data(buffer, length);
} else if (PyArray_Check(val)) {
PyArrayObject* array = PyArray_GETCONTIGUOUS((PyArrayObject*)val);
PyArrayObject* array = PyArray_GETCONTIGUOUS((PyArrayObject*) val);
Array* data = obj->mutable_array_data();
npy_intp size = PyArray_SIZE(array);
for (int i = 0; i < PyArray_NDIM(array); ++i) {
data->add_shape(PyArray_DIM(array, i));
}
if (PyArray_ISFLOAT(array)) {
double* buffer = (double*) PyArray_DATA(array);
for (npy_intp i = 0; i < size; ++i) {
data->add_double_data(buffer[i]);
}
int typ = PyArray_TYPE(array);
data->set_dtype(typ);
switch (typ) {
case NPY_FLOAT: {
npy_float* buffer = (npy_float*) PyArray_DATA(array);
for (npy_intp i = 0; i < size; ++i) {
data->add_float_data(buffer[i]);
}
}
break;
case NPY_DOUBLE: {
npy_double* buffer = (npy_double*) PyArray_DATA(array);
for (npy_intp i = 0; i < size; ++i) {
data->add_double_data(buffer[i]);
}
}
break;
case NPY_INT8: {
npy_int8* buffer = (npy_int8*) PyArray_DATA(array);
for (npy_intp i = 0; i < size; ++i) {
data->add_int_data(buffer[i]);
}
}
break;
case NPY_UINT8: {
npy_uint8* buffer = (npy_uint8*) PyArray_DATA(array);
for (npy_intp i = 0; i < size; ++i) {
data->add_uint_data(buffer[i]);
}
}
break;
case NPY_OBJECT: { // FIXME(pcm): Support arbitrary python objects, not only objrefs
PyArrayIterObject* iter = (PyArrayIterObject*) PyArray_IterNew((PyObject*)array);
while (PyArray_ITER_NOTDONE(iter)) {
PyObject** item = (PyObject**) PyArray_ITER_DATA(iter);
ObjRef objref;
if (PyObject_IsInstance(*item, (PyObject*) &PyObjRefType)) {
objref = ((PyObjRef*) (*item))->val;
} else {
PyErr_SetString(PyExc_TypeError, "must be an object reference"); // TODO: improve error message
return -1;
}
data->add_objref_data(objref);
PyArray_ITER_NEXT(iter);
}
Py_XDECREF(iter);
}
break;
default:
PyErr_SetString(OrchPyError, "serialization: numpy datatype not know");
return -1;
}
} else {
PyErr_SetString(OrchPyError, "serialization: type not know");
return -1;
}
return 0;
@@ -201,21 +261,65 @@ PyObject* deserialize(const Obj& obj) {
return PyString_FromStringAndSize(buffer, length);
} else if (obj.has_array_data()) {
const Array& array = obj.array_data();
if (array.double_data_size() > 0) { // TODO: this is not quite right
std::vector<npy_intp> dims;
for (int i = 0; i < array.shape_size(); ++i) {
dims.push_back(array.shape(i));
}
PyArrayObject* pyarray = (PyArrayObject*) PyArray_SimpleNew(array.shape_size(), &dims[0], array.dtype());
if (array.double_data_size() > 0) { // TODO: handle empty array
npy_intp size = array.double_data_size();
std::vector<npy_intp> dims;
for (int i = 0; i < array.shape_size(); ++i) {
dims.push_back(array.shape(i));
}
PyArrayObject* pyarray = (PyArrayObject*)PyArray_SimpleNew(array.shape_size(), &dims[0], NPY_DOUBLE);
double* buffer = (double*) PyArray_DATA(pyarray);
npy_double* buffer = (npy_double*) PyArray_DATA(pyarray);
for (npy_intp i = 0; i < size; ++i) {
buffer[i] = array.double_data(i);
}
return (PyObject*)pyarray;
} else if (array.float_data_size() > 0) {
npy_intp size = array.float_data_size();
npy_float* buffer = (npy_float*) PyArray_DATA(pyarray);
for (npy_intp i = 0; i < size; ++i) {
buffer[i] = array.float_data(i);
}
} else if (array.int_data_size() > 0) {
npy_intp size = array.int_data_size();
switch (array.dtype()) {
case NPY_INT8: {
npy_int8* buffer = (npy_int8*) PyArray_DATA(pyarray);
for (npy_intp i = 0; i < size; ++i) {
buffer[i] = array.int_data(i);
}
}
break;
default:
PyErr_SetString(OrchPyError, "deserialization: internal error (array type not implemented)");
return NULL;
}
} else if (array.uint_data_size() > 0) {
npy_intp size = array.uint_data_size();
switch (array.dtype()) {
case NPY_UINT8: {
npy_uint8* buffer = (npy_uint8*) PyArray_DATA(pyarray);
for (npy_intp i = 0; i < size; ++i) {
buffer[i] = array.uint_data(i);
}
}
break;
default:
PyErr_SetString(OrchPyError, "deserialization: internal error (array type not implemented)");
return NULL;
}
} else if (array.objref_data_size() > 0) {
npy_intp size = array.objref_data_size();
PyObject** buffer = (PyObject**) PyArray_DATA(pyarray);
for (npy_intp i = 0; i < size; ++i) {
buffer[i] = make_pyobjref(array.objref_data(i));
}
} else {
PyErr_SetString(OrchPyError, "deserialization: internal error (array type not implemented)");
return NULL;
}
return (PyObject*) pyarray;
} else {
std::cout << "don't have object" << std::endl;
PyErr_SetString(OrchPyError, "deserialization: internal error (type not implemented)");
return NULL;
}
}
@@ -226,7 +330,6 @@ PyObject* serialize_object(PyObject* self, PyObject* args) {
return NULL;
}
if (serialize(pyval, obj) != 0) {
PyErr_SetString(OrchPyError, "serialization: type not know"); // TODO: put a more expressive error message here
return NULL;
}
return PyCapsule_New(static_cast<void*>(obj), "obj", NULL);
+25 -3
View File
@@ -54,15 +54,36 @@ class SerializationTest(unittest.TestCase):
result = orchpy.lib.deserialize_object(serialized)
self.assertEqual(data, result)
def numpyTypeTest(self, typ):
a = np.random.randint(0, 10, size=(100, 100)).astype(typ)
b = orchpy.lib.serialize_object(a)
c = orchpy.lib.deserialize_object(b)
self.assertTrue((a == c).all())
def testSerialize(self):
data = [1, "hello", 3.0]
self.roundTripTest(data)
self.roundTripTest([1, "hello", 3.0])
self.roundTripTest(42)
self.roundTripTest("hello world")
self.roundTripTest(42.0)
a = np.zeros((100, 100))
res = orchpy.lib.serialize_object(a)
b = orchpy.lib.deserialize_object(res)
self.assertTrue((a == b).all())
self.numpyTypeTest('int8')
self.numpyTypeTest('uint8')
# self.numpyTypeTest('int16') # TODO(pcm): implement this
# self.numpyTypeTest('int32') # TODO(pcm): implement this
self.numpyTypeTest('float32')
self.numpyTypeTest('float64')
a = np.array([[orchpy.lib.ObjRef(0), orchpy.lib.ObjRef(1)], [orchpy.lib.ObjRef(41), orchpy.lib.ObjRef(42)]])
capsule = orchpy.lib.serialize_object(a)
result = orchpy.lib.deserialize_object(capsule)
self.assertTrue((a == result).all())
"""
class OrchPyLibTest(unittest.TestCase):
def testOrchPyLib(self):
@@ -88,10 +109,11 @@ class OrchPyLibTest(unittest.TestCase):
self.assertEqual(result, 'hello world')
services.cleanup()
"""
class ObjStoreTest(unittest.TestCase):
"""Test setting up object stores, transfering data between them and retrieving data to a client"""
# Test setting up object stores, transfering data between them and retrieving data to a client
def testObjStore(self):
scheduler_port = new_scheduler_port()
objstore1_port = new_objstore_port()