mirror of
https://github.com/wassname/ray.git
synced 2026-06-27 19:32:11 +08:00
implement object reference serialization and debugging for object stores, some fixes
This commit is contained in:
@@ -2,7 +2,6 @@ import sys
|
||||
|
||||
from setuptools import setup, Extension, find_packages
|
||||
import setuptools
|
||||
from Cython.Build import cythonize
|
||||
|
||||
# because of relative paths, this must be run from inside orch/lib/orchpy/
|
||||
|
||||
|
||||
@@ -132,10 +132,13 @@ message GetObjReply {
|
||||
uint64 size = 3;
|
||||
}
|
||||
|
||||
message ObjStoreDebugInfoRequest {}
|
||||
message ObjStoreDebugInfoRequest {
|
||||
repeated uint64 objref = 1; // get protocol buffer objects corresponding to objref
|
||||
}
|
||||
|
||||
message ObjStoreDebugInfoReply {
|
||||
repeated uint64 objref = 1;
|
||||
repeated uint64 objref = 1; // list of object references in the store
|
||||
repeated Obj obj = 2; // protocol buffer objects that were requested
|
||||
}
|
||||
|
||||
service ObjStore {
|
||||
|
||||
+6
-9
@@ -41,15 +41,12 @@ message Call {
|
||||
repeated uint64 result = 3; // object references for result
|
||||
}
|
||||
|
||||
enum DataType {
|
||||
INT32 = 0;
|
||||
INT64 = 1;
|
||||
FLOAT32 = 2;
|
||||
FLOAT64 = 3;
|
||||
}
|
||||
|
||||
message Array {
|
||||
repeated uint64 shape = 1;
|
||||
DataType dtype = 3;
|
||||
repeated double double_data = 2;
|
||||
sint64 dtype = 2;
|
||||
repeated double double_data = 3;
|
||||
repeated float float_data = 4;
|
||||
repeated sint64 int_data = 5;
|
||||
repeated uint64 uint_data = 6;
|
||||
repeated uint64 objref_data = 7;
|
||||
}
|
||||
|
||||
@@ -79,6 +79,13 @@ Status ObjStoreService::ObjStoreDebugInfo(ServerContext* context, const ObjStore
|
||||
for (const auto& entry : memory_) {
|
||||
reply->add_objref(entry.first);
|
||||
}
|
||||
for (int i = 0; i < request->objref_size(); ++i) {
|
||||
ObjRef objref = request->objref(i);
|
||||
Obj* obj = new Obj();
|
||||
std::string data(memory_[objref].ptr.data, memory_[objref].ptr.len); // copies, but for debugging should be ok
|
||||
obj->ParseFromString(data);
|
||||
reply->mutable_obj()->AddAllocated(obj);
|
||||
}
|
||||
return Status::OK;
|
||||
}
|
||||
|
||||
|
||||
+121
-18
@@ -38,6 +38,18 @@ static int PyObjRef_init(PyObjRef *self, PyObject *args, PyObject *kwds) {
|
||||
return 0;
|
||||
};
|
||||
|
||||
static int PyObjRef_compare(PyObject* a, PyObject* b) {
|
||||
PyObjRef* A = (PyObjRef*) a;
|
||||
PyObjRef* B = (PyObjRef*) b;
|
||||
if (A->val < B->val) {
|
||||
return -1;
|
||||
}
|
||||
if (A->val > B->val) {
|
||||
return 1;
|
||||
}
|
||||
return 0;
|
||||
}
|
||||
|
||||
static PyMemberDef PyObjRef_members[] = {
|
||||
{"val", T_INT, offsetof(PyObjRef, val), 0, "object reference"},
|
||||
{NULL}
|
||||
@@ -53,7 +65,7 @@ static PyTypeObject PyObjRefType = {
|
||||
0, /* tp_print */
|
||||
0, /* tp_getattr */
|
||||
0, /* tp_setattr */
|
||||
0, /* tp_compare */
|
||||
PyObjRef_compare, /* tp_compare */
|
||||
0, /* tp_repr */
|
||||
0, /* tp_as_number */
|
||||
0, /* tp_as_sequence */
|
||||
@@ -132,7 +144,7 @@ int PyObjectToObjRef(PyObject* object, ObjRef *objref) {
|
||||
*objref = ((PyObjRef*) object)->val;
|
||||
return 1;
|
||||
} else {
|
||||
PyErr_SetString(PyExc_TypeError, "must be a 'worker' capsule");
|
||||
PyErr_SetString(PyExc_TypeError, "must be an object reference");
|
||||
return 0;
|
||||
}
|
||||
}
|
||||
@@ -141,6 +153,7 @@ int PyObjectToObjRef(PyObject* object, ObjRef *objref) {
|
||||
|
||||
// serialize will serialize the python object val into the protocol buffer
|
||||
// object obj, returns 0 if successful and something else if not
|
||||
// FIXME(pcm): This currently only works for contiguous arrays
|
||||
int serialize(PyObject* val, Obj* obj) {
|
||||
if (PyInt_Check(val)) {
|
||||
Int* data = obj->mutable_int_data();
|
||||
@@ -164,19 +177,66 @@ int serialize(PyObject* val, Obj* obj) {
|
||||
PyString_AsStringAndSize(val, &buffer, &length); // creates pointer to internal buffer
|
||||
obj->mutable_string_data()->set_data(buffer, length);
|
||||
} else if (PyArray_Check(val)) {
|
||||
PyArrayObject* array = PyArray_GETCONTIGUOUS((PyArrayObject*)val);
|
||||
PyArrayObject* array = PyArray_GETCONTIGUOUS((PyArrayObject*) val);
|
||||
Array* data = obj->mutable_array_data();
|
||||
npy_intp size = PyArray_SIZE(array);
|
||||
for (int i = 0; i < PyArray_NDIM(array); ++i) {
|
||||
data->add_shape(PyArray_DIM(array, i));
|
||||
}
|
||||
if (PyArray_ISFLOAT(array)) {
|
||||
double* buffer = (double*) PyArray_DATA(array);
|
||||
for (npy_intp i = 0; i < size; ++i) {
|
||||
data->add_double_data(buffer[i]);
|
||||
}
|
||||
int typ = PyArray_TYPE(array);
|
||||
data->set_dtype(typ);
|
||||
switch (typ) {
|
||||
case NPY_FLOAT: {
|
||||
npy_float* buffer = (npy_float*) PyArray_DATA(array);
|
||||
for (npy_intp i = 0; i < size; ++i) {
|
||||
data->add_float_data(buffer[i]);
|
||||
}
|
||||
}
|
||||
break;
|
||||
case NPY_DOUBLE: {
|
||||
npy_double* buffer = (npy_double*) PyArray_DATA(array);
|
||||
for (npy_intp i = 0; i < size; ++i) {
|
||||
data->add_double_data(buffer[i]);
|
||||
}
|
||||
}
|
||||
break;
|
||||
case NPY_INT8: {
|
||||
npy_int8* buffer = (npy_int8*) PyArray_DATA(array);
|
||||
for (npy_intp i = 0; i < size; ++i) {
|
||||
data->add_int_data(buffer[i]);
|
||||
}
|
||||
}
|
||||
break;
|
||||
case NPY_UINT8: {
|
||||
npy_uint8* buffer = (npy_uint8*) PyArray_DATA(array);
|
||||
for (npy_intp i = 0; i < size; ++i) {
|
||||
data->add_uint_data(buffer[i]);
|
||||
}
|
||||
}
|
||||
break;
|
||||
case NPY_OBJECT: { // FIXME(pcm): Support arbitrary python objects, not only objrefs
|
||||
PyArrayIterObject* iter = (PyArrayIterObject*) PyArray_IterNew((PyObject*)array);
|
||||
while (PyArray_ITER_NOTDONE(iter)) {
|
||||
PyObject** item = (PyObject**) PyArray_ITER_DATA(iter);
|
||||
ObjRef objref;
|
||||
if (PyObject_IsInstance(*item, (PyObject*) &PyObjRefType)) {
|
||||
objref = ((PyObjRef*) (*item))->val;
|
||||
} else {
|
||||
PyErr_SetString(PyExc_TypeError, "must be an object reference"); // TODO: improve error message
|
||||
return -1;
|
||||
}
|
||||
data->add_objref_data(objref);
|
||||
PyArray_ITER_NEXT(iter);
|
||||
}
|
||||
Py_XDECREF(iter);
|
||||
}
|
||||
break;
|
||||
default:
|
||||
PyErr_SetString(OrchPyError, "serialization: numpy datatype not know");
|
||||
return -1;
|
||||
}
|
||||
} else {
|
||||
PyErr_SetString(OrchPyError, "serialization: type not know");
|
||||
return -1;
|
||||
}
|
||||
return 0;
|
||||
@@ -201,21 +261,65 @@ PyObject* deserialize(const Obj& obj) {
|
||||
return PyString_FromStringAndSize(buffer, length);
|
||||
} else if (obj.has_array_data()) {
|
||||
const Array& array = obj.array_data();
|
||||
if (array.double_data_size() > 0) { // TODO: this is not quite right
|
||||
std::vector<npy_intp> dims;
|
||||
for (int i = 0; i < array.shape_size(); ++i) {
|
||||
dims.push_back(array.shape(i));
|
||||
}
|
||||
PyArrayObject* pyarray = (PyArrayObject*) PyArray_SimpleNew(array.shape_size(), &dims[0], array.dtype());
|
||||
if (array.double_data_size() > 0) { // TODO: handle empty array
|
||||
npy_intp size = array.double_data_size();
|
||||
std::vector<npy_intp> dims;
|
||||
for (int i = 0; i < array.shape_size(); ++i) {
|
||||
dims.push_back(array.shape(i));
|
||||
}
|
||||
PyArrayObject* pyarray = (PyArrayObject*)PyArray_SimpleNew(array.shape_size(), &dims[0], NPY_DOUBLE);
|
||||
double* buffer = (double*) PyArray_DATA(pyarray);
|
||||
npy_double* buffer = (npy_double*) PyArray_DATA(pyarray);
|
||||
for (npy_intp i = 0; i < size; ++i) {
|
||||
buffer[i] = array.double_data(i);
|
||||
}
|
||||
return (PyObject*)pyarray;
|
||||
} else if (array.float_data_size() > 0) {
|
||||
npy_intp size = array.float_data_size();
|
||||
npy_float* buffer = (npy_float*) PyArray_DATA(pyarray);
|
||||
for (npy_intp i = 0; i < size; ++i) {
|
||||
buffer[i] = array.float_data(i);
|
||||
}
|
||||
} else if (array.int_data_size() > 0) {
|
||||
npy_intp size = array.int_data_size();
|
||||
switch (array.dtype()) {
|
||||
case NPY_INT8: {
|
||||
npy_int8* buffer = (npy_int8*) PyArray_DATA(pyarray);
|
||||
for (npy_intp i = 0; i < size; ++i) {
|
||||
buffer[i] = array.int_data(i);
|
||||
}
|
||||
}
|
||||
break;
|
||||
default:
|
||||
PyErr_SetString(OrchPyError, "deserialization: internal error (array type not implemented)");
|
||||
return NULL;
|
||||
}
|
||||
} else if (array.uint_data_size() > 0) {
|
||||
npy_intp size = array.uint_data_size();
|
||||
switch (array.dtype()) {
|
||||
case NPY_UINT8: {
|
||||
npy_uint8* buffer = (npy_uint8*) PyArray_DATA(pyarray);
|
||||
for (npy_intp i = 0; i < size; ++i) {
|
||||
buffer[i] = array.uint_data(i);
|
||||
}
|
||||
}
|
||||
break;
|
||||
default:
|
||||
PyErr_SetString(OrchPyError, "deserialization: internal error (array type not implemented)");
|
||||
return NULL;
|
||||
}
|
||||
} else if (array.objref_data_size() > 0) {
|
||||
npy_intp size = array.objref_data_size();
|
||||
PyObject** buffer = (PyObject**) PyArray_DATA(pyarray);
|
||||
for (npy_intp i = 0; i < size; ++i) {
|
||||
buffer[i] = make_pyobjref(array.objref_data(i));
|
||||
}
|
||||
} else {
|
||||
PyErr_SetString(OrchPyError, "deserialization: internal error (array type not implemented)");
|
||||
return NULL;
|
||||
}
|
||||
return (PyObject*) pyarray;
|
||||
} else {
|
||||
std::cout << "don't have object" << std::endl;
|
||||
PyErr_SetString(OrchPyError, "deserialization: internal error (type not implemented)");
|
||||
return NULL;
|
||||
}
|
||||
}
|
||||
|
||||
@@ -226,7 +330,6 @@ PyObject* serialize_object(PyObject* self, PyObject* args) {
|
||||
return NULL;
|
||||
}
|
||||
if (serialize(pyval, obj) != 0) {
|
||||
PyErr_SetString(OrchPyError, "serialization: type not know"); // TODO: put a more expressive error message here
|
||||
return NULL;
|
||||
}
|
||||
return PyCapsule_New(static_cast<void*>(obj), "obj", NULL);
|
||||
|
||||
+25
-3
@@ -54,15 +54,36 @@ class SerializationTest(unittest.TestCase):
|
||||
result = orchpy.lib.deserialize_object(serialized)
|
||||
self.assertEqual(data, result)
|
||||
|
||||
def numpyTypeTest(self, typ):
|
||||
a = np.random.randint(0, 10, size=(100, 100)).astype(typ)
|
||||
b = orchpy.lib.serialize_object(a)
|
||||
c = orchpy.lib.deserialize_object(b)
|
||||
self.assertTrue((a == c).all())
|
||||
|
||||
def testSerialize(self):
|
||||
data = [1, "hello", 3.0]
|
||||
self.roundTripTest(data)
|
||||
self.roundTripTest([1, "hello", 3.0])
|
||||
self.roundTripTest(42)
|
||||
self.roundTripTest("hello world")
|
||||
self.roundTripTest(42.0)
|
||||
|
||||
a = np.zeros((100, 100))
|
||||
res = orchpy.lib.serialize_object(a)
|
||||
b = orchpy.lib.deserialize_object(res)
|
||||
self.assertTrue((a == b).all())
|
||||
|
||||
self.numpyTypeTest('int8')
|
||||
self.numpyTypeTest('uint8')
|
||||
# self.numpyTypeTest('int16') # TODO(pcm): implement this
|
||||
# self.numpyTypeTest('int32') # TODO(pcm): implement this
|
||||
self.numpyTypeTest('float32')
|
||||
self.numpyTypeTest('float64')
|
||||
|
||||
a = np.array([[orchpy.lib.ObjRef(0), orchpy.lib.ObjRef(1)], [orchpy.lib.ObjRef(41), orchpy.lib.ObjRef(42)]])
|
||||
capsule = orchpy.lib.serialize_object(a)
|
||||
result = orchpy.lib.deserialize_object(capsule)
|
||||
self.assertTrue((a == result).all())
|
||||
|
||||
"""
|
||||
class OrchPyLibTest(unittest.TestCase):
|
||||
|
||||
def testOrchPyLib(self):
|
||||
@@ -88,10 +109,11 @@ class OrchPyLibTest(unittest.TestCase):
|
||||
self.assertEqual(result, 'hello world')
|
||||
|
||||
services.cleanup()
|
||||
"""
|
||||
|
||||
class ObjStoreTest(unittest.TestCase):
|
||||
|
||||
"""Test setting up object stores, transfering data between them and retrieving data to a client"""
|
||||
# Test setting up object stores, transfering data between them and retrieving data to a client
|
||||
def testObjStore(self):
|
||||
scheduler_port = new_scheduler_port()
|
||||
objstore1_port = new_objstore_port()
|
||||
|
||||
Reference in New Issue
Block a user