diff --git a/python/src/pynumbuf/adapters/python.cc b/python/src/pynumbuf/adapters/python.cc index deafd9d44..87cfef8d7 100644 --- a/python/src/pynumbuf/adapters/python.cc +++ b/python/src/pynumbuf/adapters/python.cc @@ -6,6 +6,8 @@ using namespace arrow; +int32_t MAX_RECURSION_DEPTH = 100; + extern PyObject* numbuf_serialize_callback; extern PyObject* numbuf_deserialize_callback; @@ -127,8 +129,11 @@ Status append(PyObject* elem, SequenceBuilder& builder, return Status::OK(); } -Status SerializeSequences(std::vector sequences, std::shared_ptr* out) { +Status SerializeSequences(std::vector sequences, int32_t recursion_depth, std::shared_ptr* out) { DCHECK(out); + if (recursion_depth >= MAX_RECURSION_DEPTH) { + return Status::NotImplemented("This object exceeds the maximum recursion depth. It may contain itself recursively."); + } SequenceBuilder builder(nullptr); std::vector sublists, subtuples, subdicts; for (const auto& sequence : sequences) { @@ -147,15 +152,15 @@ Status SerializeSequences(std::vector sequences, std::shared_ptr list; if (sublists.size() > 0) { - RETURN_NOT_OK(SerializeSequences(sublists, &list)); + RETURN_NOT_OK(SerializeSequences(sublists, recursion_depth + 1, &list)); } std::shared_ptr tuple; if (subtuples.size() > 0) { - RETURN_NOT_OK(SerializeSequences(subtuples, &tuple)); + RETURN_NOT_OK(SerializeSequences(subtuples, recursion_depth + 1, &tuple)); } std::shared_ptr dict; if (subdicts.size() > 0) { - RETURN_NOT_OK(SerializeDict(subdicts, &dict)); + RETURN_NOT_OK(SerializeDict(subdicts, recursion_depth + 1, &dict)); } *out = builder.Finish(list, tuple, dict); return Status::OK(); @@ -191,8 +196,11 @@ Status DeserializeTuple(std::shared_ptr array, int32_t start_idx, int32_t DESERIALIZE_SEQUENCE(PyTuple_New, PyTuple_SetItem) } -Status SerializeDict(std::vector dicts, std::shared_ptr* out) { +Status SerializeDict(std::vector dicts, int32_t recursion_depth, std::shared_ptr* out) { DictBuilder result; + if (recursion_depth >= MAX_RECURSION_DEPTH) { + return Status::NotImplemented("This object exceeds the maximum recursion depth. It may contain itself recursively."); + } std::vector key_tuples, val_lists, val_tuples, val_dicts, dummy; for (const auto& dict : dicts) { PyObject *key, *value; @@ -205,19 +213,19 @@ Status SerializeDict(std::vector dicts, std::shared_ptr* out) } std::shared_ptr key_tuples_arr; if (key_tuples.size() > 0) { - RETURN_NOT_OK(SerializeSequences(key_tuples, &key_tuples_arr)); + RETURN_NOT_OK(SerializeSequences(key_tuples, recursion_depth + 1, &key_tuples_arr)); } std::shared_ptr val_list_arr; if (val_lists.size() > 0) { - RETURN_NOT_OK(SerializeSequences(val_lists, &val_list_arr)); + RETURN_NOT_OK(SerializeSequences(val_lists, recursion_depth + 1, &val_list_arr)); } std::shared_ptr val_tuples_arr; if (val_tuples.size() > 0) { - RETURN_NOT_OK(SerializeSequences(val_tuples, &val_tuples_arr)); + RETURN_NOT_OK(SerializeSequences(val_tuples, recursion_depth + 1, &val_tuples_arr)); } std::shared_ptr val_dict_arr; if (val_dicts.size() > 0) { - RETURN_NOT_OK(SerializeDict(val_dicts, &val_dict_arr)); + RETURN_NOT_OK(SerializeDict(val_dicts, recursion_depth + 1, &val_dict_arr)); } *out = result.Finish(key_tuples_arr, val_list_arr, val_tuples_arr, val_dict_arr); diff --git a/python/src/pynumbuf/adapters/python.h b/python/src/pynumbuf/adapters/python.h index 1f05e923f..024ebebbf 100644 --- a/python/src/pynumbuf/adapters/python.h +++ b/python/src/pynumbuf/adapters/python.h @@ -11,8 +11,8 @@ namespace numbuf { -arrow::Status SerializeSequences(std::vector sequences, std::shared_ptr* out); -arrow::Status SerializeDict(std::vector dicts, std::shared_ptr* out); +arrow::Status SerializeSequences(std::vector sequences, int32_t recursion_depth, std::shared_ptr* out); +arrow::Status SerializeDict(std::vector dicts, int32_t recursion_depth, std::shared_ptr* out); arrow::Status DeserializeList(std::shared_ptr array, int32_t start_idx, int32_t stop_idx, PyObject* base, PyObject** out); arrow::Status DeserializeTuple(std::shared_ptr array, int32_t start_idx, int32_t stop_idx, PyObject* base, PyObject** out); arrow::Status DeserializeDict(std::shared_ptr array, int32_t start_idx, int32_t stop_idx, PyObject* base, PyObject** out); diff --git a/python/src/pynumbuf/numbuf.cc b/python/src/pynumbuf/numbuf.cc index de8722a48..8951b07be 100644 --- a/python/src/pynumbuf/numbuf.cc +++ b/python/src/pynumbuf/numbuf.cc @@ -51,7 +51,8 @@ static PyObject* serialize_list(PyObject* self, PyObject* args) { } std::shared_ptr array; if (PyList_Check(value)) { - Status s = SerializeSequences(std::vector({value}), &array); + int32_t recursion_depth = 0; + Status s = SerializeSequences(std::vector({value}), recursion_depth, &array); if (!s.ok()) { // If this condition is true, there was an error in the callback that // needs to be passed through