diff --git a/lib/python/ray/serialization.py b/lib/python/ray/serialization.py index 153fb89f4..92f3053ac 100644 --- a/lib/python/ray/serialization.py +++ b/lib/python/ray/serialization.py @@ -3,38 +3,6 @@ import numpy as np import libraylib as raylib -# The following definitions are required because Python doesn't allow custom -# attributes for primitive types. We need custom attributes for (a) implementing -# destructors that close the shared memory segment that the object resides in -# and (b) fixing https://github.com/amplab/ray/issues/72. - -class Int(int): - pass - -class Long(long): - pass - -class Float(float): - pass - -class List(list): - pass - -class Dict(dict): - pass - -class Tuple(tuple): - pass - -class Str(str): - pass - -class Unicode(unicode): - pass - -class NDArray(np.ndarray): - pass - def to_primitive(obj): if hasattr(obj, "serialize"): primitive_obj = ((type(obj).__module__, type(obj).__name__), obj.serialize()) diff --git a/lib/python/ray/worker.py b/lib/python/ray/worker.py index ead96319b..95ba06fd3 100644 --- a/lib/python/ray/worker.py +++ b/lib/python/ray/worker.py @@ -145,36 +145,6 @@ class RayGetArgumentError(Exception): """Format a RayGetArgumentError as a string.""" return "Failed to get objectid {} as argument {} for remote function {}{}{}. It was created by remote function {}{}{} which failed with:\n{}".format(self.objectid, self.argument_index, colorama.Fore.RED, self.function_name, colorama.Fore.RESET, colorama.Fore.RED, self.task_error.function_name, colorama.Fore.RESET, self.task_error) -class RayDealloc(object): - """An object used internally to properly implement reference counting. - - When we call get_object with a particular object ID, we create a RayDealloc - object with the information necessary to properly handle closing the relevant - memory segment when the object is no longer needed by the worker. The - RayDealloc object is stored as a field in the object returned by get_object so - that its destructor is only called when the worker no longer has any - references to the object. - - Attributes - handle (worker capsule): A Python object wrapping a C++ Worker object. - segmentid (int): The id of the segment that contains the object that holds - this RayDealloc object. - """ - - def __init__(self, handle, segmentid): - """Initialize a RayDealloc object. - - Args: - handle (worker capsule): A Python object wrapping a C++ Worker object. - segmentid (int): The id of the segment that contains the object that holds - this RayDealloc object. - """ - self.handle = handle - self.segmentid = segmentid - - def __del__(self): - """Deallocate the relevant segment to avoid a memory leak.""" - raylib.unmap_object(self.handle, self.segmentid) class Reusable(object): """An Python object that can be shared between tasks. @@ -309,6 +279,28 @@ class RayReusables(object): """ raise Exception("Attempted deletion of attribute {}. Attributes of a RayReusable object may not be deleted.".format(name)) +class ObjectFixture(object): + """This is used to handle unmapping objects backed by the object store. + + The object referred to by objectid will get unmaped when the fixture is + deallocated. In addition, the ObjectFixture holds the objectid as a field, + which ensures that the corresponding object will not be deallocated from the + object store while the ObjectFixture is alive. ObjectFixture is used as the + base object for numpy arrays that are contained in the object referred to by + objectid and prevents memory that is used by them from getting unmapped by the + worker or deallocated by the object store. + """ + + def __init__(self, objectid, segmentid, handle): + """Initialize an ObjectFixture object.""" + self.objectid = objectid + self.segmentid = segmentid + self.handle = handle + + def __del__(self): + """Unmap the segment when the object goes out of scope.""" + raylib.unmap_object(self.handle, self.segmentid) + class Worker(object): """A class used to define the control flow of a worker process. @@ -414,7 +406,7 @@ class Worker(object): metadata = np.frombuffer(buff, dtype="byte", offset=8, count=metadata_size) data = np.frombuffer(buff, dtype="byte")[8 + metadata_size:] serialized = libnumbuf.read_from_buffer(memoryview(data), bytearray(metadata), metadata_offset) - deserialized = libnumbuf.deserialize_list(serialized) + deserialized = libnumbuf.deserialize_list(serialized, ObjectFixture(objectid, segmentid, self.handle)) # Unwrap the object from the list (it was wrapped put_object) assert len(deserialized) == 1 result = deserialized[0] @@ -424,35 +416,6 @@ class Worker(object): object_capsule, segmentid = raylib.get_object(self.handle, objectid) result = serialization.deserialize(self.handle, object_capsule) - if isinstance(result, int): - result = serialization.Int(result) - elif isinstance(result, long): - result = serialization.Long(result) - elif isinstance(result, float): - result = serialization.Float(result) - elif isinstance(result, bool): - raylib.unmap_object(self.handle, segmentid) # need to unmap here because result is passed back "by value" and we have no reference to unmap later - return result # can't subclass bool, and don't need to because there is a global True/False - elif isinstance(result, list): - result = serialization.List(result) - elif isinstance(result, dict): - result = serialization.Dict(result) - elif isinstance(result, tuple): - result = serialization.Tuple(result) - elif isinstance(result, str): - result = serialization.Str(result) - elif isinstance(result, unicode): - result = serialization.Unicode(result) - elif isinstance(result, np.ndarray): - result = result.view(serialization.NDArray) - elif isinstance(result, np.generic): - return result - # TODO(pcm): close the associated memory segment; if we don't, this leaks memory (but very little, so it is ok for now) - elif result is None: - raylib.unmap_object(self.handle, segmentid) # need to unmap here because result is passed back "by value" and we have no reference to unmap later - return None # can't subclass None and don't need to because there is a global None - result.ray_objectid = objectid # TODO(pcm): This could be done only for the "get" case in the future if we want to increase performance - result.ray_deallocator = RayDealloc(self.handle, segmentid) return result def alias_objectids(self, alias_objectid, target_objectid): diff --git a/test/runtest.py b/test/runtest.py index 7819e6d77..0d8447295 100644 --- a/test/runtest.py +++ b/test/runtest.py @@ -372,16 +372,6 @@ class APITest(unittest.TestCase): ray.worker.cleanup() -def check_get_deallocated(data): - x = ray.put(data) - ray.get(x) - return x.id - -def check_get_not_deallocated(data): - x = ray.put(data) - y = ray.get(x) - return y, x.id - class ReferenceCountingTest(unittest.TestCase): def testDeallocation(self): @@ -421,13 +411,34 @@ class ReferenceCountingTest(unittest.TestCase): def testGet(self): ray.init(start_ray_local=True, num_workers=3) + # Remote objects should be deallocated when the corresponding ObjectID goes + # out of scope, and all results of ray.get called on the ID go out of scope. for val in RAY_TEST_OBJECTS + [np.zeros((2, 2)), UserDefinedType()]: - objectid_val = check_get_deallocated(val) - self.assertEqual(ray.scheduler_info()["reference_counts"][objectid_val], -1) + x = ray.put(val) + objectid = x.id + xval = ray.get(x) + del x, xval + self.assertEqual(ray.scheduler_info()["reference_counts"][objectid], -1) - if not isinstance(val, bool) and not isinstance(val, np.generic) and val is not None: - x, objectid_val = check_get_not_deallocated(val) - self.assertEqual(ray.scheduler_info()["reference_counts"][objectid_val], 1) + # Remote objects that do not contain numpy arrays should be deallocated when + # the corresponding ObjectID goes out of scope, even if ray.get has been + # called on the ObjectID. + for val in [True, False, None, 1, 1.0, 1L, "hi", u"hi", [1, 2, 3], (1, 2, 3), [(), {(): ()}]]: + x = ray.put(val) + objectid = x.id + xval = ray.get(x) + del x + self.assertEqual(ray.scheduler_info()["reference_counts"][objectid], -1) + + # Remote objects that contain numpy arrays should not be deallocated when + # the corresponding ObjectID goes out of scope, if ray.get has been called + # on the ObjectID and the result of that call is still in scope. + for val in [np.zeros(10), [np.zeros(10)], (((np.zeros(10)),),), {(): np.zeros(10)}, [1, 2, 3, np.zeros(1)]]: + x = ray.put(val) + objectid = x.id + xval = ray.get(x) + del x + self.assertEqual(ray.scheduler_info()["reference_counts"][objectid], 1) # The following currently segfaults: The second "result = " closes the # memory segment as soon as the assignment is done (and the first result @@ -440,22 +451,6 @@ class ReferenceCountingTest(unittest.TestCase): ray.worker.cleanup() - # @unittest.expectedFailure - # def testGetFailing(self): - # ray.init(start_ray_local=True, num_workers=3) - - # # This is failing, because for bool and None, we cannot track python - # # refcounts and therefore cannot keep the refcount up - # # (see 5281bd414f6b404f61e1fe25ec5f6651defee206). - # # The resulting behavior is still correct however because True, False and - # # None are returned by get "by value" and therefore can be reclaimed from - # # the object store safely. - # for val in [True, False, None]: - # x, objectid_val = check_get_not_deallocated(val) - # self.assertEqual(ray.scheduler_info()["reference_counts"][objectid_val], 1) - - # ray.worker.cleanup() - class PythonModeTest(unittest.TestCase): def testPythonMode(self): diff --git a/thirdparty/numbuf b/thirdparty/numbuf index 74406e92e..c4c33bd08 160000 --- a/thirdparty/numbuf +++ b/thirdparty/numbuf @@ -1 +1 @@ -Subproject commit 74406e92e4e10c228f2850e72ea835980c779e3b +Subproject commit c4c33bd087f9b4aed47f518d4e99fdd401abccbe