mirror of
https://github.com/wassname/ray.git
synced 2026-06-30 20:18:33 +08:00
Limit recursion depth for serializing objects to avoid infinite loops. (#17)
This commit is contained in:
committed by
Philipp Moritz
parent
0527736490
commit
7055c6f793
@@ -6,6 +6,8 @@
|
||||
|
||||
using namespace arrow;
|
||||
|
||||
int32_t MAX_RECURSION_DEPTH = 100;
|
||||
|
||||
extern PyObject* numbuf_serialize_callback;
|
||||
extern PyObject* numbuf_deserialize_callback;
|
||||
|
||||
@@ -127,8 +129,11 @@ Status append(PyObject* elem, SequenceBuilder& builder,
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
Status SerializeSequences(std::vector<PyObject*> sequences, std::shared_ptr<Array>* out) {
|
||||
Status SerializeSequences(std::vector<PyObject*> sequences, int32_t recursion_depth, std::shared_ptr<Array>* out) {
|
||||
DCHECK(out);
|
||||
if (recursion_depth >= MAX_RECURSION_DEPTH) {
|
||||
return Status::NotImplemented("This object exceeds the maximum recursion depth. It may contain itself recursively.");
|
||||
}
|
||||
SequenceBuilder builder(nullptr);
|
||||
std::vector<PyObject*> sublists, subtuples, subdicts;
|
||||
for (const auto& sequence : sequences) {
|
||||
@@ -147,15 +152,15 @@ Status SerializeSequences(std::vector<PyObject*> sequences, std::shared_ptr<Arra
|
||||
}
|
||||
std::shared_ptr<Array> list;
|
||||
if (sublists.size() > 0) {
|
||||
RETURN_NOT_OK(SerializeSequences(sublists, &list));
|
||||
RETURN_NOT_OK(SerializeSequences(sublists, recursion_depth + 1, &list));
|
||||
}
|
||||
std::shared_ptr<Array> tuple;
|
||||
if (subtuples.size() > 0) {
|
||||
RETURN_NOT_OK(SerializeSequences(subtuples, &tuple));
|
||||
RETURN_NOT_OK(SerializeSequences(subtuples, recursion_depth + 1, &tuple));
|
||||
}
|
||||
std::shared_ptr<Array> dict;
|
||||
if (subdicts.size() > 0) {
|
||||
RETURN_NOT_OK(SerializeDict(subdicts, &dict));
|
||||
RETURN_NOT_OK(SerializeDict(subdicts, recursion_depth + 1, &dict));
|
||||
}
|
||||
*out = builder.Finish(list, tuple, dict);
|
||||
return Status::OK();
|
||||
@@ -191,8 +196,11 @@ Status DeserializeTuple(std::shared_ptr<Array> array, int32_t start_idx, int32_t
|
||||
DESERIALIZE_SEQUENCE(PyTuple_New, PyTuple_SetItem)
|
||||
}
|
||||
|
||||
Status SerializeDict(std::vector<PyObject*> dicts, std::shared_ptr<Array>* out) {
|
||||
Status SerializeDict(std::vector<PyObject*> dicts, int32_t recursion_depth, std::shared_ptr<Array>* out) {
|
||||
DictBuilder result;
|
||||
if (recursion_depth >= MAX_RECURSION_DEPTH) {
|
||||
return Status::NotImplemented("This object exceeds the maximum recursion depth. It may contain itself recursively.");
|
||||
}
|
||||
std::vector<PyObject*> key_tuples, val_lists, val_tuples, val_dicts, dummy;
|
||||
for (const auto& dict : dicts) {
|
||||
PyObject *key, *value;
|
||||
@@ -205,19 +213,19 @@ Status SerializeDict(std::vector<PyObject*> dicts, std::shared_ptr<Array>* out)
|
||||
}
|
||||
std::shared_ptr<Array> key_tuples_arr;
|
||||
if (key_tuples.size() > 0) {
|
||||
RETURN_NOT_OK(SerializeSequences(key_tuples, &key_tuples_arr));
|
||||
RETURN_NOT_OK(SerializeSequences(key_tuples, recursion_depth + 1, &key_tuples_arr));
|
||||
}
|
||||
std::shared_ptr<Array> val_list_arr;
|
||||
if (val_lists.size() > 0) {
|
||||
RETURN_NOT_OK(SerializeSequences(val_lists, &val_list_arr));
|
||||
RETURN_NOT_OK(SerializeSequences(val_lists, recursion_depth + 1, &val_list_arr));
|
||||
}
|
||||
std::shared_ptr<Array> val_tuples_arr;
|
||||
if (val_tuples.size() > 0) {
|
||||
RETURN_NOT_OK(SerializeSequences(val_tuples, &val_tuples_arr));
|
||||
RETURN_NOT_OK(SerializeSequences(val_tuples, recursion_depth + 1, &val_tuples_arr));
|
||||
}
|
||||
std::shared_ptr<Array> val_dict_arr;
|
||||
if (val_dicts.size() > 0) {
|
||||
RETURN_NOT_OK(SerializeDict(val_dicts, &val_dict_arr));
|
||||
RETURN_NOT_OK(SerializeDict(val_dicts, recursion_depth + 1, &val_dict_arr));
|
||||
}
|
||||
*out = result.Finish(key_tuples_arr, val_list_arr, val_tuples_arr, val_dict_arr);
|
||||
|
||||
|
||||
@@ -11,8 +11,8 @@
|
||||
|
||||
namespace numbuf {
|
||||
|
||||
arrow::Status SerializeSequences(std::vector<PyObject*> sequences, std::shared_ptr<arrow::Array>* out);
|
||||
arrow::Status SerializeDict(std::vector<PyObject*> dicts, std::shared_ptr<arrow::Array>* out);
|
||||
arrow::Status SerializeSequences(std::vector<PyObject*> sequences, int32_t recursion_depth, std::shared_ptr<arrow::Array>* out);
|
||||
arrow::Status SerializeDict(std::vector<PyObject*> dicts, int32_t recursion_depth, std::shared_ptr<arrow::Array>* out);
|
||||
arrow::Status DeserializeList(std::shared_ptr<arrow::Array> array, int32_t start_idx, int32_t stop_idx, PyObject* base, PyObject** out);
|
||||
arrow::Status DeserializeTuple(std::shared_ptr<arrow::Array> array, int32_t start_idx, int32_t stop_idx, PyObject* base, PyObject** out);
|
||||
arrow::Status DeserializeDict(std::shared_ptr<arrow::Array> array, int32_t start_idx, int32_t stop_idx, PyObject* base, PyObject** out);
|
||||
|
||||
@@ -51,7 +51,8 @@ static PyObject* serialize_list(PyObject* self, PyObject* args) {
|
||||
}
|
||||
std::shared_ptr<Array> array;
|
||||
if (PyList_Check(value)) {
|
||||
Status s = SerializeSequences(std::vector<PyObject*>({value}), &array);
|
||||
int32_t recursion_depth = 0;
|
||||
Status s = SerializeSequences(std::vector<PyObject*>({value}), recursion_depth, &array);
|
||||
if (!s.ok()) {
|
||||
// If this condition is true, there was an error in the callback that
|
||||
// needs to be passed through
|
||||
|
||||
Reference in New Issue
Block a user