From 58e8bbcb341ae0214ac91228417e3735be21a0d2 Mon Sep 17 00:00:00 2001 From: Philipp Moritz Date: Wed, 30 Nov 2016 23:21:53 -0800 Subject: [PATCH] Fix bug in serializing arguments of tasks that are more complex objects (#72) * Give more informative error message when we do not know how to serialize a class. * Check that passing arguments to remote functions and getting them does not change their values. * fix serialization bug * fix tests for common module * Formatting. * Bug fix in init_pickle_module signature. * Use pickle with HIGHEST_PROTOCOL. --- lib/python/ray/serialization.py | 2 +- src/common/lib/python/common_extension.c | 35 ++++++++++++++++++++---- src/common/lib/python/common_extension.h | 7 +++++ src/common/lib/python/common_module.c | 2 ++ src/photon/photon_extension.c | 2 ++ test/runtest.py | 28 +++++++++++++++++-- 6 files changed, 66 insertions(+), 10 deletions(-) diff --git a/lib/python/ray/serialization.py b/lib/python/ray/serialization.py index dd6106d95..64d708c98 100644 --- a/lib/python/ray/serialization.py +++ b/lib/python/ray/serialization.py @@ -91,7 +91,7 @@ def serialize(obj): """ class_id = class_identifier(type(obj)) if class_id not in whitelisted_classes: - raise Exception("Ray does not know how to serialize the object {}. To fix this, call 'ray.register_class' on the class of the object.".format(obj)) + raise Exception("Ray does not know how to serialize objects of type {}. To fix this, call 'ray.register_class' with this class.".format(type(obj))) if class_id in classes_to_pickle: serialized_obj = {"data": pickling.dumps(obj)} elif class_id in custom_serializers.keys(): diff --git a/src/common/lib/python/common_extension.c b/src/common/lib/python/common_extension.c index f1058fdbc..5aebc6283 100644 --- a/src/common/lib/python/common_extension.c +++ b/src/common/lib/python/common_extension.c @@ -1,6 +1,7 @@ #include #include "node.h" +#include "common.h" #include "common_extension.h" #include "task.h" #include "utarray.h" @@ -8,7 +9,21 @@ PyObject *CommonError; -#define MARSHAL_VERSION 2 +/* Initialize pickle module. */ + +PyObject *pickle_module = NULL; +PyObject *pickle_loads = NULL; +PyObject *pickle_dumps = NULL; +PyObject *pickle_protocol = NULL; + +void init_pickle_module(void) { + /* For Python 3 this needs to be "_pickle" instead of "cPickle". */ + pickle_module = PyImport_ImportModuleNoBlock("cPickle"); + pickle_loads = PyString_FromString("loads"); + pickle_dumps = PyString_FromString("dumps"); + pickle_protocol = PyObject_GetAttrString(pickle_module, "HIGHEST_PROTOCOL"); + CHECK(pickle_module != NULL); +} /* Define the PyObjectID class. */ @@ -194,7 +209,10 @@ static int PyTask_init(PyTask *self, PyObject *args, PyObject *kwds) { for (size_t i = 0; i < size; ++i) { PyObject *arg = PyList_GetItem(arguments, i); if (!PyObject_IsInstance(arg, (PyObject *) &PyObjectIDType)) { - PyObject *data = PyMarshal_WriteObjectToString(arg, MARSHAL_VERSION); + CHECK(pickle_module != NULL); + CHECK(pickle_dumps != NULL); + PyObject *data = PyObject_CallMethodObjArgs(pickle_module, pickle_dumps, + arg, pickle_protocol, NULL); value_data_bytes += PyString_Size(data); utarray_push_back(val_repr_ptrs, &data); } @@ -248,10 +266,15 @@ static PyObject *PyTask_arguments(PyObject *self) { object_id object_id = task_arg_id(task, i); PyList_SetItem(arg_list, i, PyObjectID_make(object_id)); } else { - PyObject *s = - PyMarshal_ReadObjectFromString((char *) task_arg_val(task, i), - (Py_ssize_t) task_arg_length(task, i)); - PyList_SetItem(arg_list, i, s); + CHECK(pickle_module != NULL); + CHECK(pickle_loads != NULL); + PyObject *str = + PyString_FromStringAndSize((char *) task_arg_val(task, i), + (Py_ssize_t) task_arg_length(task, i)); + PyObject *val = + PyObject_CallMethodObjArgs(pickle_module, pickle_loads, str, NULL); + Py_XDECREF(str); + PyList_SetItem(arg_list, i, val); } } return arg_list; diff --git a/src/common/lib/python/common_extension.h b/src/common/lib/python/common_extension.h index 1fce38e42..a34fca5ff 100644 --- a/src/common/lib/python/common_extension.h +++ b/src/common/lib/python/common_extension.h @@ -26,6 +26,13 @@ extern PyTypeObject PyObjectIDType; extern PyTypeObject PyTaskType; +/* Python module for pickling. */ +extern PyObject *pickle_module; +extern PyObject *pickle_dumps; +extern PyObject *pickle_loads; + +void init_pickle_module(void); + int PyObjectToUniqueID(PyObject *object, object_id *objectid); PyObject *PyObjectID_make(object_id object_id); diff --git a/src/common/lib/python/common_module.c b/src/common/lib/python/common_module.c index a32f508a6..6ae8b2eba 100644 --- a/src/common/lib/python/common_module.c +++ b/src/common/lib/python/common_module.c @@ -24,6 +24,8 @@ PyMODINIT_FUNC initcommon(void) { m = Py_InitModule3("common", common_methods, "A module for common types. This is used for testing."); + init_pickle_module(); + Py_INCREF(&PyTaskType); PyModule_AddObject(m, "Task", (PyObject *) &PyTaskType); diff --git a/src/photon/photon_extension.c b/src/photon/photon_extension.c index e7ddded15..07858ceed 100644 --- a/src/photon/photon_extension.c +++ b/src/photon/photon_extension.c @@ -125,6 +125,8 @@ PyMODINIT_FUNC initlibphoton(void) { m = Py_InitModule3("libphoton", photon_methods, "A module for the local scheduler."); + init_pickle_module(); + Py_INCREF(&PyTaskType); PyModule_AddObject(m, "Task", (PyObject *) &PyTaskType); diff --git a/test/runtest.py b/test/runtest.py index be97449c1..ad5ce7f51 100644 --- a/test/runtest.py +++ b/test/runtest.py @@ -50,11 +50,11 @@ PRIMITIVE_OBJECTS = [0, 0.0, 0.9, 0L, 1L << 62, "a", string.printable, "\u262F", np.array(["hi", 3], dtype=object), np.array([["hi", u"hi"], [1.3, 1L]])] -COMPLEX_OBJECTS = [#[[[[[[[[[[[[]]]]]]]]]]]], +COMPLEX_OBJECTS = [[[[[[[[[[[[[]]]]]]]]]]]], {"obj{}".format(i): np.random.normal(size=[100, 100]) for i in range(10)}, #{(): {(): {(): {(): {(): {(): {(): {(): {(): {(): {(): {(): {}}}}}}}}}}}}}, - #((((((((((),),),),),),),),),), - #{"a": {"b": {"c": {"d": {}}}}} + ((((((((((),),),),),),),),),), + {"a": {"b": {"c": {"d": {}}}}} ] class Foo(object): @@ -144,6 +144,28 @@ class SerializationTest(unittest.TestCase): ray.worker.cleanup() + def testPassingArgumentsByValue(self): + ray.init(start_ray_local=True, num_workers=1) + + @ray.remote + def f(x): + return x + + ray.register_class(Exception) + ray.register_class(CustomError) + ray.register_class(Point) + ray.register_class(Foo) + ray.register_class(Bar) + ray.register_class(Baz) + ray.register_class(NamedTupleExample) + + # Check that we can pass arguments by value to remote functions and that + # they are uncorrupted. + for obj in RAY_TEST_OBJECTS: + assert_equal(obj, ray.get(f.remote(obj))) + + ray.worker.cleanup() + class WorkerTest(unittest.TestCase): def testPutGet(self):