diff --git a/lib/python/common_extension.c b/lib/python/common_extension.c index efc4b9e01..14f3292d9 100644 --- a/lib/python/common_extension.c +++ b/lib/python/common_extension.c @@ -51,9 +51,17 @@ static PyObject *PyObjectID_id(PyObject *self) { UNIQUE_ID_SIZE); } +static PyObject *PyObjectID___reduce__(PyObjectID *self) { + PyErr_SetString(CommonError, "ObjectID objects cannot be serialized."); + return NULL; +} + static PyMethodDef PyObjectID_methods[] = { {"id", (PyCFunction) PyObjectID_id, METH_NOARGS, "Return the hash associated with this ObjectID"}, + {"__reduce__", (PyCFunction) PyObjectID___reduce__, METH_NOARGS, + "Say how to pickle this ObjectID. This raises an exception to prevent" + "object IDs from being serialized."}, {NULL} /* Sentinel */ }; diff --git a/test/test.py b/test/test.py index a40d2045c..359e8c030 100644 --- a/test/test.py +++ b/test/test.py @@ -1,5 +1,6 @@ from __future__ import print_function +import pickle import unittest import common @@ -52,6 +53,23 @@ class TestObjectID(unittest.TestCase): def test_create_object_id(self): object_id = common.ObjectID(20 * "a") + def test_cannot_pickle_object_ids(self): + object_ids = [common.ObjectID(20 * chr(i)) for i in range(256)] + def f(): + return object_ids + def g(val=object_ids): + return 1 + def h(): + x = object_ids[0] + return 1 + # Make sure that object IDs cannot be pickled (including functions that + # close over object IDs). + self.assertRaises(Exception, lambda : pickling.dumps(object_ids[0])) + self.assertRaises(Exception, lambda : pickling.dumps(object_ids)) + self.assertRaises(Exception, lambda : pickling.dumps(f)) + self.assertRaises(Exception, lambda : pickling.dumps(g)) + self.assertRaises(Exception, lambda : pickling.dumps(h)) + class TestTask(unittest.TestCase): def test_create_task(self):