mirror of
https://github.com/wassname/ray.git
synced 2026-06-28 03:18:59 +08:00
introduce base object
This commit is contained in:
@@ -25,11 +25,15 @@ namespace numbuf {
|
||||
type* data = const_cast<type*>(values->raw_data()) \
|
||||
+ content->offset(offset); \
|
||||
*out = PyArray_SimpleNewFromData(num_dims, dim.data(), NPY_##TYPE, \
|
||||
reinterpret_cast<void*>(data)); \
|
||||
reinterpret_cast<void*>(data)); \
|
||||
if (base != Py_None) { \
|
||||
PyArray_SetBaseObject((PyArrayObject*) *out, base); \
|
||||
} \
|
||||
Py_XINCREF(base); \
|
||||
} \
|
||||
return Status::OK();
|
||||
|
||||
Status DeserializeArray(std::shared_ptr<Array> array, int32_t offset, PyObject** out) {
|
||||
Status DeserializeArray(std::shared_ptr<Array> array, int32_t offset, PyObject* base, PyObject** out) {
|
||||
DCHECK(array);
|
||||
auto tensor = std::dynamic_pointer_cast<StructArray>(array);
|
||||
DCHECK(tensor);
|
||||
|
||||
@@ -15,7 +15,7 @@
|
||||
namespace numbuf {
|
||||
|
||||
arrow::Status SerializeArray(PyArrayObject* array, SequenceBuilder& builder, std::vector<PyObject*>& subdicts);
|
||||
arrow::Status DeserializeArray(std::shared_ptr<arrow::Array> array, int32_t offset, PyObject** out);
|
||||
arrow::Status DeserializeArray(std::shared_ptr<arrow::Array> array, int32_t offset, PyObject* base, PyObject** out);
|
||||
|
||||
}
|
||||
|
||||
|
||||
@@ -11,7 +11,7 @@ extern PyObject* numbuf_deserialize_callback;
|
||||
|
||||
namespace numbuf {
|
||||
|
||||
PyObject* get_value(ArrayPtr arr, int32_t index, int32_t type) {
|
||||
PyObject* get_value(ArrayPtr arr, int32_t index, int32_t type, PyObject* base) {
|
||||
PyObject* result;
|
||||
switch (arr->type()->type) {
|
||||
case Type::BOOL:
|
||||
@@ -36,13 +36,13 @@ PyObject* get_value(ArrayPtr arr, int32_t index, int32_t type) {
|
||||
auto s = std::static_pointer_cast<StructArray>(arr);
|
||||
auto l = std::static_pointer_cast<ListArray>(s->field(0));
|
||||
if (s->type()->child(0)->name == "list") {
|
||||
ARROW_CHECK_OK(DeserializeList(l->values(), l->value_offset(index), l->value_offset(index+1), &result));
|
||||
ARROW_CHECK_OK(DeserializeList(l->values(), l->value_offset(index), l->value_offset(index+1), base, &result));
|
||||
} else if (s->type()->child(0)->name == "tuple") {
|
||||
ARROW_CHECK_OK(DeserializeTuple(l->values(), l->value_offset(index), l->value_offset(index+1), &result));
|
||||
ARROW_CHECK_OK(DeserializeTuple(l->values(), l->value_offset(index), l->value_offset(index+1), base, &result));
|
||||
} else if (s->type()->child(0)->name == "dict") {
|
||||
ARROW_CHECK_OK(DeserializeDict(l->values(), l->value_offset(index), l->value_offset(index+1), &result));
|
||||
ARROW_CHECK_OK(DeserializeDict(l->values(), l->value_offset(index), l->value_offset(index+1), base, &result));
|
||||
} else {
|
||||
ARROW_CHECK_OK(DeserializeArray(arr, index, &result));
|
||||
ARROW_CHECK_OK(DeserializeArray(arr, index, base, &result));
|
||||
}
|
||||
return result;
|
||||
}
|
||||
@@ -181,17 +181,17 @@ Status SerializeSequences(std::vector<PyObject*> sequences, std::shared_ptr<Arra
|
||||
int32_t offset = offsets->Value(i); \
|
||||
int8_t type = types->Value(i); \
|
||||
ArrayPtr arr = data->child(type); \
|
||||
SET_ITEM(result, i-start_idx, get_value(arr, offset, type)); \
|
||||
SET_ITEM(result, i-start_idx, get_value(arr, offset, type, base)); \
|
||||
} \
|
||||
} \
|
||||
*out = result; \
|
||||
return Status::OK();
|
||||
|
||||
Status DeserializeList(std::shared_ptr<Array> array, int32_t start_idx, int32_t stop_idx, PyObject** out) {
|
||||
Status DeserializeList(std::shared_ptr<Array> array, int32_t start_idx, int32_t stop_idx, PyObject* base, PyObject** out) {
|
||||
DESERIALIZE_SEQUENCE(PyList_New, PyList_SetItem)
|
||||
}
|
||||
|
||||
Status DeserializeTuple(std::shared_ptr<Array> array, int32_t start_idx, int32_t stop_idx, PyObject** out) {
|
||||
Status DeserializeTuple(std::shared_ptr<Array> array, int32_t start_idx, int32_t stop_idx, PyObject* base, PyObject** out) {
|
||||
DESERIALIZE_SEQUENCE(PyTuple_New, PyTuple_SetItem)
|
||||
}
|
||||
|
||||
@@ -227,13 +227,13 @@ Status SerializeDict(std::vector<PyObject*> dicts, std::shared_ptr<Array>* out)
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
Status DeserializeDict(std::shared_ptr<Array> array, int32_t start_idx, int32_t stop_idx, PyObject** out) {
|
||||
Status DeserializeDict(std::shared_ptr<Array> array, int32_t start_idx, int32_t stop_idx, PyObject* base, PyObject** out) {
|
||||
auto data = std::dynamic_pointer_cast<StructArray>(array);
|
||||
// TODO(pcm): error handling, get rid of the temporary copy of the list
|
||||
PyObject *keys, *vals;
|
||||
PyObject* result = PyDict_New();
|
||||
ARROW_RETURN_NOT_OK(DeserializeList(data->field(0), start_idx, stop_idx, &keys));
|
||||
ARROW_RETURN_NOT_OK(DeserializeList(data->field(1), start_idx, stop_idx, &vals));
|
||||
ARROW_RETURN_NOT_OK(DeserializeList(data->field(0), start_idx, stop_idx, base, &keys));
|
||||
ARROW_RETURN_NOT_OK(DeserializeList(data->field(1), start_idx, stop_idx, base, &vals));
|
||||
for (size_t i = start_idx; i < stop_idx; ++i) {
|
||||
PyDict_SetItem(result, PyList_GetItem(keys, i - start_idx), PyList_GetItem(vals, i - start_idx));
|
||||
}
|
||||
|
||||
@@ -13,9 +13,9 @@ namespace numbuf {
|
||||
|
||||
arrow::Status SerializeSequences(std::vector<PyObject*> sequences, std::shared_ptr<arrow::Array>* out);
|
||||
arrow::Status SerializeDict(std::vector<PyObject*> dicts, std::shared_ptr<arrow::Array>* out);
|
||||
arrow::Status DeserializeList(std::shared_ptr<arrow::Array> array, int32_t start_idx, int32_t stop_idx, PyObject** out);
|
||||
arrow::Status DeserializeTuple(std::shared_ptr<arrow::Array> array, int32_t start_idx, int32_t stop_idx, PyObject** out);
|
||||
arrow::Status DeserializeDict(std::shared_ptr<arrow::Array> array, int32_t start_idx, int32_t stop_idx, PyObject** out);
|
||||
arrow::Status DeserializeList(std::shared_ptr<arrow::Array> array, int32_t start_idx, int32_t stop_idx, PyObject* base, PyObject** out);
|
||||
arrow::Status DeserializeTuple(std::shared_ptr<arrow::Array> array, int32_t start_idx, int32_t stop_idx, PyObject* base, PyObject** out);
|
||||
arrow::Status DeserializeDict(std::shared_ptr<arrow::Array> array, int32_t start_idx, int32_t stop_idx, PyObject* base, PyObject** out);
|
||||
|
||||
arrow::Status python_error_to_status();
|
||||
|
||||
|
||||
@@ -126,11 +126,12 @@ static PyObject* read_from_buffer(PyObject* self, PyObject* args) {
|
||||
/* Documented in doc/numbuf.rst in ray-core */
|
||||
static PyObject* deserialize_list(PyObject* self, PyObject* args) {
|
||||
std::shared_ptr<RowBatch>* data;
|
||||
if (!PyArg_ParseTuple(args, "O&", &PyObjectToArrow, &data)) {
|
||||
PyObject* base = Py_None;
|
||||
if (!PyArg_ParseTuple(args, "O&|O", &PyObjectToArrow, &data, &base)) {
|
||||
return NULL;
|
||||
}
|
||||
PyObject* result;
|
||||
ARROW_CHECK_OK(DeserializeList((*data)->column(0), 0, (*data)->num_rows(), &result));
|
||||
ARROW_CHECK_OK(DeserializeList((*data)->column(0), 0, (*data)->num_rows(), base, &result));
|
||||
return result;
|
||||
}
|
||||
|
||||
|
||||
Reference in New Issue
Block a user