From 82a7ee575257c44ccc610ffa866702d1a28573ed Mon Sep 17 00:00:00 2001 From: Philipp Moritz Date: Tue, 21 Jun 2016 17:39:48 -0700 Subject: [PATCH] Fix https://github.com/amplab/ray/issues/72 --- lib/python/ray/worker.py | 2 +- test/runtest.py | 74 ++++++++++++++++++++++++++++++++-------- 2 files changed, 60 insertions(+), 16 deletions(-) diff --git a/lib/python/ray/worker.py b/lib/python/ray/worker.py index f804b5274..3ecd68986 100644 --- a/lib/python/ray/worker.py +++ b/lib/python/ray/worker.py @@ -72,7 +72,7 @@ class Worker(object): elif result == None: return None # can't subclass None and don't need to because there is a global None # TODO(pcm): close the associated memory segment; if we don't, this leaks memory (but very little, so it is ok for now) - # TODO(pcm): Here, we can add the object reference to fix https://github.com/amplab/ray/issues/72 + result.ray_objref = objref # TODO(pcm): This could be done only for the "pull" case in the future if we want to increase performance result.ray_deallocator = RayDealloc(self.handle, segmentid) return result diff --git a/test/runtest.py b/test/runtest.py index 1f9a189a1..1caa7c0e6 100644 --- a/test/runtest.py +++ b/test/runtest.py @@ -12,6 +12,22 @@ import test_functions import ray.arrays.remote as ra import ray.arrays.distributed as da +RAY_TEST_OBJECTS = [[1, "hello", 3.0], 42, "hello world", 42.0, + (1.0, "hi"), None, (None, None), ("hello", None), + True, False, (True, False), + {True: "hello", False: "world"}, + {"hello" : "world", 1: 42, 1.0: 45}, {}] + +class UserDefinedType(object): + def __init__(self): + pass + + def deserialize(self, primitives): + return "user defined type" + + def serialize(self): + return "user defined type" + class SerializationTest(unittest.TestCase): def roundTripTest(self, worker, data): @@ -28,21 +44,8 @@ class SerializationTest(unittest.TestCase): def testSerialize(self): [w] = services.start_singlenode_cluster(return_drivers=True) - self.roundTripTest(w, [1, "hello", 3.0]) - self.roundTripTest(w, 42) - self.roundTripTest(w, "hello world") - self.roundTripTest(w, 42.0) - self.roundTripTest(w, (1.0, "hi")) - self.roundTripTest(w, None) - self.roundTripTest(w, (None, None)) - self.roundTripTest(w, ("hello", None)) - self.roundTripTest(w, True) - self.roundTripTest(w, False) - self.roundTripTest(w, (True, False)) - self.roundTripTest(w, {True: "hello", False: "world"}) - - self.roundTripTest(w, {"hello" : "world", 1: 42, 1.0: 45}) - self.roundTripTest(w, {}) + for val in RAY_TEST_OBJECTS: + self.roundTripTest(w, val) a = np.zeros((100, 100)) res, _ = serialization.serialize(w.handle, a) @@ -238,6 +241,16 @@ class TaskStatusTest(unittest.TestCase): self.assertTrue(task['operationid'] not in task_ids) task_ids.add(task['operationid']) +def check_pull_deallocated(data): + x = ray.push(data) + ray.pull(x) + return x.val + +def check_pull_not_deallocated(data): + x = ray.push(data) + y = ray.pull(x) + return y, x.val + class ReferenceCountingTest(unittest.TestCase): def testDeallocation(self): @@ -290,5 +303,36 @@ class ReferenceCountingTest(unittest.TestCase): services.cleanup() + def testPull(self): + worker_path = os.path.join(os.path.dirname(os.path.abspath(__file__)), "test_worker.py") + services.start_singlenode_cluster(return_drivers=False, num_workers_per_objstore=3, worker_path=worker_path) + + for val in RAY_TEST_OBJECTS + [np.zeros((2, 2)), UserDefinedType()]: + objref_val = check_pull_deallocated(val) + self.assertEqual(ray.scheduler_info()["reference_counts"][objref_val], -1) + + if not isinstance(val, bool) and val is not None: + x, objref_val = check_pull_not_deallocated(val) + self.assertEqual(ray.scheduler_info()["reference_counts"][objref_val], 1) + + services.cleanup() + + @unittest.expectedFailure + def testPullFailing(self): + worker_path = os.path.join(os.path.dirname(os.path.abspath(__file__)), "test_worker.py") + services.start_singlenode_cluster(return_drivers=False, num_workers_per_objstore=3, worker_path=worker_path) + + # This is failing, because for bool and None, we cannot track python + # refcounts and therefore cannot keep the refcount up + # (see 5281bd414f6b404f61e1fe25ec5f6651defee206). + # The resulting behavior is still correct however because True, False and + # None are returned by pull "by value" and therefore can be reclaimed from + # the object store safely. + for val in [True, False, None]: + x, objref_val = check_pull_not_deallocated(val) + self.assertEqual(ray.scheduler_info()["reference_counts"][objref_val], 1) + + services.cleanup() + if __name__ == '__main__': unittest.main()