From 327d7ff689c0930da3075c7870d126aa0ea0bd55 Mon Sep 17 00:00:00 2001 From: Robert Nishihara Date: Sun, 4 Sep 2016 13:32:55 -0700 Subject: [PATCH] Fix bug to enable calling ray.get multiple times on same ObjectID. (#409) --- lib/python/ray/worker.py | 21 ++++++++++++++++++- test/runtest.py | 45 +++++++++++++++++++++++++++++++++------- 2 files changed, 57 insertions(+), 9 deletions(-) diff --git a/lib/python/ray/worker.py b/lib/python/ray/worker.py index 95ba06fd3..5ed698c33 100644 --- a/lib/python/ray/worker.py +++ b/lib/python/ray/worker.py @@ -10,6 +10,7 @@ import colorama import atexit import threading import string +import weakref # Ray modules import config @@ -406,7 +407,16 @@ class Worker(object): metadata = np.frombuffer(buff, dtype="byte", offset=8, count=metadata_size) data = np.frombuffer(buff, dtype="byte")[8 + metadata_size:] serialized = libnumbuf.read_from_buffer(memoryview(data), bytearray(metadata), metadata_offset) - deserialized = libnumbuf.deserialize_list(serialized, ObjectFixture(objectid, segmentid, self.handle)) + # If there is currently no ObjectFixture for this ObjectID, then create a + # new one. The object_fixtures object is a WeakValueDictionary, so entries + # will be discarded when there are no strong references to their values. + # We create object_fixture outside of the assignment because if we created + # it inside the assignement it would immediately go out of scope. + object_fixture = None + if objectid.id not in object_fixtures: + object_fixture = ObjectFixture(objectid, segmentid, self.handle) + object_fixtures[objectid.id] = object_fixture + deserialized = libnumbuf.deserialize_list(serialized, object_fixtures[objectid.id]) # Unwrap the object from the list (it was wrapped put_object) assert len(deserialized) == 1 result = deserialized[0] @@ -476,6 +486,15 @@ made by one task do not affect other tasks. logger = logging.getLogger("ray") """Logger: The logging object for the Python worker code.""" +object_fixtures = weakref.WeakValueDictionary() +"""WeakValueDictionary: The mapping from ObjectID to ObjectFixture object. + +This is to ensure that we have only one ObjectFixture per ObjectID. That way, if +we call get on an object twice, we do not unmap the segment before both of the +results go out of scope. It is a WeakValueDictionary instead of a regular +dictionary so that it does not keep the ObjectFixtures in scope forever. +""" + class RayConnectionError(Exception): pass diff --git a/test/runtest.py b/test/runtest.py index 0d8447295..441648442 100644 --- a/test/runtest.py +++ b/test/runtest.py @@ -440,14 +440,43 @@ class ReferenceCountingTest(unittest.TestCase): del x self.assertEqual(ray.scheduler_info()["reference_counts"][objectid], 1) - # The following currently segfaults: The second "result = " closes the - # memory segment as soon as the assignment is done (and the first result - # goes out of scope). - # data = np.zeros([10, 20]) - # objectid = ray.put(data) - # result = worker.get(objectid) - # result = worker.get(objectid) - # assert_equal(result, data) + # Getting an object multiple times should not be a problem. And the remote + # object should not be deallocated until both of the results are out of scope. + for val in [np.zeros(10), [np.zeros(10)], (((np.zeros(10)),),), {(): np.zeros(10)}, [1, 2, 3, np.zeros(1)]]: + x = ray.put(val) + objectid = x.id + xval1 = ray.get(x) + xval2 = ray.get(x) + del xval1 + # Make sure we can still access xval2. + xval2 + del xval2 + self.assertEqual(ray.scheduler_info()["reference_counts"][objectid], 1) + xval3 = ray.get(x) + xval4 = ray.get(x) + xval5 = ray.get(x) + del x + del xval4, xval5 + # Make sure we can still access xval3. + xval3 + self.assertEqual(ray.scheduler_info()["reference_counts"][objectid], 1) + del xval3 + self.assertEqual(ray.scheduler_info()["reference_counts"][objectid], -1) + + # Getting an object multiple times and assigning it to the same name should + # work. This was a problem in https://github.com/amplab/ray/issues/159. + for val in [np.zeros(10), [np.zeros(10)], (((np.zeros(10)),),), {(): np.zeros(10)}, [1, 2, 3, np.zeros(1)]]: + x = ray.put(val) + objectid = x.id + xval = ray.get(x) + xval = ray.get(x) + xval = ray.get(x) + xval = ray.get(x) + self.assertEqual(ray.scheduler_info()["reference_counts"][objectid], 1) + del x + self.assertEqual(ray.scheduler_info()["reference_counts"][objectid], 1) + del xval + self.assertEqual(ray.scheduler_info()["reference_counts"][objectid], -1) ray.worker.cleanup()