Limit recursion depth for serializing objects to avoid infinite loops. (#17)

This commit is contained in:
Robert Nishihara
2016-09-19 14:54:56 -07:00
committed by Philipp Moritz
parent 0527736490
commit 7055c6f793
3 changed files with 21 additions and 12 deletions
+17 -9
View File
@@ -6,6 +6,8 @@
using namespace arrow;
int32_t MAX_RECURSION_DEPTH = 100;
extern PyObject* numbuf_serialize_callback;
extern PyObject* numbuf_deserialize_callback;
@@ -127,8 +129,11 @@ Status append(PyObject* elem, SequenceBuilder& builder,
return Status::OK();
}
Status SerializeSequences(std::vector<PyObject*> sequences, std::shared_ptr<Array>* out) {
Status SerializeSequences(std::vector<PyObject*> sequences, int32_t recursion_depth, std::shared_ptr<Array>* out) {
DCHECK(out);
if (recursion_depth >= MAX_RECURSION_DEPTH) {
return Status::NotImplemented("This object exceeds the maximum recursion depth. It may contain itself recursively.");
}
SequenceBuilder builder(nullptr);
std::vector<PyObject*> sublists, subtuples, subdicts;
for (const auto& sequence : sequences) {
@@ -147,15 +152,15 @@ Status SerializeSequences(std::vector<PyObject*> sequences, std::shared_ptr<Arra
}
std::shared_ptr<Array> list;
if (sublists.size() > 0) {
RETURN_NOT_OK(SerializeSequences(sublists, &list));
RETURN_NOT_OK(SerializeSequences(sublists, recursion_depth + 1, &list));
}
std::shared_ptr<Array> tuple;
if (subtuples.size() > 0) {
RETURN_NOT_OK(SerializeSequences(subtuples, &tuple));
RETURN_NOT_OK(SerializeSequences(subtuples, recursion_depth + 1, &tuple));
}
std::shared_ptr<Array> dict;
if (subdicts.size() > 0) {
RETURN_NOT_OK(SerializeDict(subdicts, &dict));
RETURN_NOT_OK(SerializeDict(subdicts, recursion_depth + 1, &dict));
}
*out = builder.Finish(list, tuple, dict);
return Status::OK();
@@ -191,8 +196,11 @@ Status DeserializeTuple(std::shared_ptr<Array> array, int32_t start_idx, int32_t
DESERIALIZE_SEQUENCE(PyTuple_New, PyTuple_SetItem)
}
Status SerializeDict(std::vector<PyObject*> dicts, std::shared_ptr<Array>* out) {
Status SerializeDict(std::vector<PyObject*> dicts, int32_t recursion_depth, std::shared_ptr<Array>* out) {
DictBuilder result;
if (recursion_depth >= MAX_RECURSION_DEPTH) {
return Status::NotImplemented("This object exceeds the maximum recursion depth. It may contain itself recursively.");
}
std::vector<PyObject*> key_tuples, val_lists, val_tuples, val_dicts, dummy;
for (const auto& dict : dicts) {
PyObject *key, *value;
@@ -205,19 +213,19 @@ Status SerializeDict(std::vector<PyObject*> dicts, std::shared_ptr<Array>* out)
}
std::shared_ptr<Array> key_tuples_arr;
if (key_tuples.size() > 0) {
RETURN_NOT_OK(SerializeSequences(key_tuples, &key_tuples_arr));
RETURN_NOT_OK(SerializeSequences(key_tuples, recursion_depth + 1, &key_tuples_arr));
}
std::shared_ptr<Array> val_list_arr;
if (val_lists.size() > 0) {
RETURN_NOT_OK(SerializeSequences(val_lists, &val_list_arr));
RETURN_NOT_OK(SerializeSequences(val_lists, recursion_depth + 1, &val_list_arr));
}
std::shared_ptr<Array> val_tuples_arr;
if (val_tuples.size() > 0) {
RETURN_NOT_OK(SerializeSequences(val_tuples, &val_tuples_arr));
RETURN_NOT_OK(SerializeSequences(val_tuples, recursion_depth + 1, &val_tuples_arr));
}
std::shared_ptr<Array> val_dict_arr;
if (val_dicts.size() > 0) {
RETURN_NOT_OK(SerializeDict(val_dicts, &val_dict_arr));
RETURN_NOT_OK(SerializeDict(val_dicts, recursion_depth + 1, &val_dict_arr));
}
*out = result.Finish(key_tuples_arr, val_list_arr, val_tuples_arr, val_dict_arr);
+2 -2
View File
@@ -11,8 +11,8 @@
namespace numbuf {
arrow::Status SerializeSequences(std::vector<PyObject*> sequences, std::shared_ptr<arrow::Array>* out);
arrow::Status SerializeDict(std::vector<PyObject*> dicts, std::shared_ptr<arrow::Array>* out);
arrow::Status SerializeSequences(std::vector<PyObject*> sequences, int32_t recursion_depth, std::shared_ptr<arrow::Array>* out);
arrow::Status SerializeDict(std::vector<PyObject*> dicts, int32_t recursion_depth, std::shared_ptr<arrow::Array>* out);
arrow::Status DeserializeList(std::shared_ptr<arrow::Array> array, int32_t start_idx, int32_t stop_idx, PyObject* base, PyObject** out);
arrow::Status DeserializeTuple(std::shared_ptr<arrow::Array> array, int32_t start_idx, int32_t stop_idx, PyObject* base, PyObject** out);
arrow::Status DeserializeDict(std::shared_ptr<arrow::Array> array, int32_t start_idx, int32_t stop_idx, PyObject* base, PyObject** out);
+2 -1
View File
@@ -51,7 +51,8 @@ static PyObject* serialize_list(PyObject* self, PyObject* args) {
}
std::shared_ptr<Array> array;
if (PyList_Check(value)) {
Status s = SerializeSequences(std::vector<PyObject*>({value}), &array);
int32_t recursion_depth = 0;
Status s = SerializeSequences(std::vector<PyObject*>({value}), recursion_depth, &array);
if (!s.ok()) {
// If this condition is true, there was an error in the callback that
// needs to be passed through