From e2e9e4ce6fa803502e47b2969b42f0c51a239efd Mon Sep 17 00:00:00 2001 From: Eric Liang Date: Mon, 15 May 2017 03:09:13 -0500 Subject: [PATCH] Fix segmentation fault when calling ray.put on a dictionary with object keys (#548) * fix segfault when serializing dict key * fix style * fix test * Fix linting. --- .gitignore | 2 +- src/numbuf/cpp/src/numbuf/dict.cc | 7 ++++--- src/numbuf/cpp/src/numbuf/dict.h | 1 + src/numbuf/python/src/pynumbuf/adapters/python.cc | 13 ++++++++++--- test/runtest.py | 13 ++++++++++--- 5 files changed, 26 insertions(+), 10 deletions(-) diff --git a/.gitignore b/.gitignore index bfc38ada3..136e3d3e0 100644 --- a/.gitignore +++ b/.gitignore @@ -1,7 +1,7 @@ # The build output should clearly not be checked in /python/ray/core /src/common/thirdparty/redis -/numbuf/thirdparty/arrow +/src/numbuf/thirdparty/arrow # Files generated by flatc should be ignored /src/common/format/*.py diff --git a/src/numbuf/cpp/src/numbuf/dict.cc b/src/numbuf/cpp/src/numbuf/dict.cc index 14daefa47..832e4bc86 100644 --- a/src/numbuf/cpp/src/numbuf/dict.cc +++ b/src/numbuf/cpp/src/numbuf/dict.cc @@ -5,12 +5,13 @@ using namespace arrow; namespace numbuf { Status DictBuilder::Finish(std::shared_ptr key_tuple_data, - std::shared_ptr val_list_data, std::shared_ptr val_tuple_data, - std::shared_ptr val_dict_data, std::shared_ptr* out) { + std::shared_ptr key_dict_data, std::shared_ptr val_list_data, + std::shared_ptr val_tuple_data, std::shared_ptr val_dict_data, + std::shared_ptr* out) { // lists and dicts can't be keys of dicts in Python, that is why for // the keys we do not need to collect sublists std::shared_ptr keys, vals; - RETURN_NOT_OK(keys_.Finish(nullptr, key_tuple_data, nullptr, &keys)); + RETURN_NOT_OK(keys_.Finish(nullptr, key_tuple_data, key_dict_data, &keys)); RETURN_NOT_OK(vals_.Finish(val_list_data, val_tuple_data, val_dict_data, &vals)); auto keys_field = std::make_shared("keys", keys->type()); auto vals_field = std::make_shared("vals", vals->type()); diff --git a/src/numbuf/cpp/src/numbuf/dict.h b/src/numbuf/cpp/src/numbuf/dict.h index c8f5925a7..708d36747 100644 --- a/src/numbuf/cpp/src/numbuf/dict.h +++ b/src/numbuf/cpp/src/numbuf/dict.h @@ -33,6 +33,7 @@ class DictBuilder { value list of the dictionary */ arrow::Status Finish(std::shared_ptr key_tuple_data, + std::shared_ptr key_dict_data, std::shared_ptr val_list_data, std::shared_ptr val_tuple_data, std::shared_ptr val_dict_data, std::shared_ptr* out); diff --git a/src/numbuf/python/src/pynumbuf/adapters/python.cc b/src/numbuf/python/src/pynumbuf/adapters/python.cc index 79289c284..6d2ae17e3 100644 --- a/src/numbuf/python/src/pynumbuf/adapters/python.cc +++ b/src/numbuf/python/src/pynumbuf/adapters/python.cc @@ -229,12 +229,13 @@ Status SerializeDict(std::vector dicts, int32_t recursion_depth, "This object exceeds the maximum recursion depth. It may contain itself " "recursively."); } - std::vector key_tuples, val_lists, val_tuples, val_dicts, dummy; + std::vector key_tuples, key_dicts, val_lists, val_tuples, val_dicts, dummy; for (const auto& dict : dicts) { PyObject *key, *value; Py_ssize_t pos = 0; while (PyDict_Next(dict, &pos, &key, &value)) { - RETURN_NOT_OK(append(key, result.keys(), dummy, key_tuples, dummy, tensors_out)); + RETURN_NOT_OK( + append(key, result.keys(), dummy, key_tuples, key_dicts, tensors_out)); DCHECK(dummy.size() == 0); RETURN_NOT_OK( append(value, result.vals(), val_lists, val_tuples, val_dicts, tensors_out)); @@ -245,6 +246,11 @@ Status SerializeDict(std::vector dicts, int32_t recursion_depth, RETURN_NOT_OK(SerializeSequences( key_tuples, recursion_depth + 1, &key_tuples_arr, tensors_out)); } + std::shared_ptr key_dicts_arr; + if (key_dicts.size() > 0) { + RETURN_NOT_OK( + SerializeDict(key_dicts, recursion_depth + 1, &key_dicts_arr, tensors_out)); + } std::shared_ptr val_list_arr; if (val_lists.size() > 0) { RETURN_NOT_OK( @@ -260,7 +266,8 @@ Status SerializeDict(std::vector dicts, int32_t recursion_depth, RETURN_NOT_OK( SerializeDict(val_dicts, recursion_depth + 1, &val_dict_arr, tensors_out)); } - result.Finish(key_tuples_arr, val_list_arr, val_tuples_arr, val_dict_arr, out); + result.Finish( + key_tuples_arr, key_dicts_arr, val_list_arr, val_tuples_arr, val_dict_arr, out); // This block is used to decrement the reference counts of the results // returned by the serialization callback, which is called in SerializeArray diff --git a/test/runtest.py b/test/runtest.py index fd760c57a..538633fec 100644 --- a/test/runtest.py +++ b/test/runtest.py @@ -89,8 +89,14 @@ COMPLEX_OBJECTS = [ class Foo(object): - def __init__(self): - pass + def __init__(self, value=0): + self.value = value + + def __hash__(self): + return hash(self.value) + + def __eq__(self, other): + return other.value == self.value class Bar(object): @@ -139,7 +145,8 @@ TUPLE_OBJECTS = [(obj,) for obj in BASE_OBJECTS] DICT_OBJECTS = ([{obj: obj} for obj in PRIMITIVE_OBJECTS if (obj.__hash__ is not None and type(obj).__module__ != "numpy")] + - [{0: obj} for obj in BASE_OBJECTS]) + [{0: obj} for obj in BASE_OBJECTS] + + [{Foo(123): Foo(456)}]) RAY_TEST_OBJECTS = BASE_OBJECTS + LIST_OBJECTS + TUPLE_OBJECTS + DICT_OBJECTS