mirror of
https://github.com/wassname/ray.git
synced 2026-07-04 02:45:10 +08:00
Fix segmentation fault when calling ray.put on a dictionary with object keys (#548)
* fix segfault when serializing dict key * fix style * fix test * Fix linting.
This commit is contained in:
committed by
Philipp Moritz
parent
3c5375345f
commit
e2e9e4ce6f
+1
-1
@@ -1,7 +1,7 @@
|
||||
# The build output should clearly not be checked in
|
||||
/python/ray/core
|
||||
/src/common/thirdparty/redis
|
||||
/numbuf/thirdparty/arrow
|
||||
/src/numbuf/thirdparty/arrow
|
||||
|
||||
# Files generated by flatc should be ignored
|
||||
/src/common/format/*.py
|
||||
|
||||
@@ -5,12 +5,13 @@ using namespace arrow;
|
||||
namespace numbuf {
|
||||
|
||||
Status DictBuilder::Finish(std::shared_ptr<Array> key_tuple_data,
|
||||
std::shared_ptr<Array> val_list_data, std::shared_ptr<Array> val_tuple_data,
|
||||
std::shared_ptr<Array> val_dict_data, std::shared_ptr<arrow::Array>* out) {
|
||||
std::shared_ptr<Array> key_dict_data, std::shared_ptr<Array> val_list_data,
|
||||
std::shared_ptr<Array> val_tuple_data, std::shared_ptr<Array> val_dict_data,
|
||||
std::shared_ptr<arrow::Array>* out) {
|
||||
// lists and dicts can't be keys of dicts in Python, that is why for
|
||||
// the keys we do not need to collect sublists
|
||||
std::shared_ptr<Array> keys, vals;
|
||||
RETURN_NOT_OK(keys_.Finish(nullptr, key_tuple_data, nullptr, &keys));
|
||||
RETURN_NOT_OK(keys_.Finish(nullptr, key_tuple_data, key_dict_data, &keys));
|
||||
RETURN_NOT_OK(vals_.Finish(val_list_data, val_tuple_data, val_dict_data, &vals));
|
||||
auto keys_field = std::make_shared<Field>("keys", keys->type());
|
||||
auto vals_field = std::make_shared<Field>("vals", vals->type());
|
||||
|
||||
@@ -33,6 +33,7 @@ class DictBuilder {
|
||||
value list of the dictionary
|
||||
*/
|
||||
arrow::Status Finish(std::shared_ptr<arrow::Array> key_tuple_data,
|
||||
std::shared_ptr<arrow::Array> key_dict_data,
|
||||
std::shared_ptr<arrow::Array> val_list_data,
|
||||
std::shared_ptr<arrow::Array> val_tuple_data,
|
||||
std::shared_ptr<arrow::Array> val_dict_data, std::shared_ptr<arrow::Array>* out);
|
||||
|
||||
@@ -229,12 +229,13 @@ Status SerializeDict(std::vector<PyObject*> dicts, int32_t recursion_depth,
|
||||
"This object exceeds the maximum recursion depth. It may contain itself "
|
||||
"recursively.");
|
||||
}
|
||||
std::vector<PyObject *> key_tuples, val_lists, val_tuples, val_dicts, dummy;
|
||||
std::vector<PyObject *> key_tuples, key_dicts, val_lists, val_tuples, val_dicts, dummy;
|
||||
for (const auto& dict : dicts) {
|
||||
PyObject *key, *value;
|
||||
Py_ssize_t pos = 0;
|
||||
while (PyDict_Next(dict, &pos, &key, &value)) {
|
||||
RETURN_NOT_OK(append(key, result.keys(), dummy, key_tuples, dummy, tensors_out));
|
||||
RETURN_NOT_OK(
|
||||
append(key, result.keys(), dummy, key_tuples, key_dicts, tensors_out));
|
||||
DCHECK(dummy.size() == 0);
|
||||
RETURN_NOT_OK(
|
||||
append(value, result.vals(), val_lists, val_tuples, val_dicts, tensors_out));
|
||||
@@ -245,6 +246,11 @@ Status SerializeDict(std::vector<PyObject*> dicts, int32_t recursion_depth,
|
||||
RETURN_NOT_OK(SerializeSequences(
|
||||
key_tuples, recursion_depth + 1, &key_tuples_arr, tensors_out));
|
||||
}
|
||||
std::shared_ptr<Array> key_dicts_arr;
|
||||
if (key_dicts.size() > 0) {
|
||||
RETURN_NOT_OK(
|
||||
SerializeDict(key_dicts, recursion_depth + 1, &key_dicts_arr, tensors_out));
|
||||
}
|
||||
std::shared_ptr<Array> val_list_arr;
|
||||
if (val_lists.size() > 0) {
|
||||
RETURN_NOT_OK(
|
||||
@@ -260,7 +266,8 @@ Status SerializeDict(std::vector<PyObject*> dicts, int32_t recursion_depth,
|
||||
RETURN_NOT_OK(
|
||||
SerializeDict(val_dicts, recursion_depth + 1, &val_dict_arr, tensors_out));
|
||||
}
|
||||
result.Finish(key_tuples_arr, val_list_arr, val_tuples_arr, val_dict_arr, out);
|
||||
result.Finish(
|
||||
key_tuples_arr, key_dicts_arr, val_list_arr, val_tuples_arr, val_dict_arr, out);
|
||||
|
||||
// This block is used to decrement the reference counts of the results
|
||||
// returned by the serialization callback, which is called in SerializeArray
|
||||
|
||||
+10
-3
@@ -89,8 +89,14 @@ COMPLEX_OBJECTS = [
|
||||
|
||||
|
||||
class Foo(object):
|
||||
def __init__(self):
|
||||
pass
|
||||
def __init__(self, value=0):
|
||||
self.value = value
|
||||
|
||||
def __hash__(self):
|
||||
return hash(self.value)
|
||||
|
||||
def __eq__(self, other):
|
||||
return other.value == self.value
|
||||
|
||||
|
||||
class Bar(object):
|
||||
@@ -139,7 +145,8 @@ TUPLE_OBJECTS = [(obj,) for obj in BASE_OBJECTS]
|
||||
DICT_OBJECTS = ([{obj: obj} for obj in PRIMITIVE_OBJECTS
|
||||
if (obj.__hash__ is not None and
|
||||
type(obj).__module__ != "numpy")] +
|
||||
[{0: obj} for obj in BASE_OBJECTS])
|
||||
[{0: obj} for obj in BASE_OBJECTS] +
|
||||
[{Foo(123): Foo(456)}])
|
||||
|
||||
RAY_TEST_OBJECTS = BASE_OBJECTS + LIST_OBJECTS + TUPLE_OBJECTS + DICT_OBJECTS
|
||||
|
||||
|
||||
Reference in New Issue
Block a user