Refcount without modifying objects (#407)

* refcount without modifying objects

* add documentation

* Update tests and documentation.

* Remove extraneous code.

* Update numbuf version.
This commit is contained in:
Philipp Moritz
2016-09-04 12:07:52 -07:00
committed by Robert Nishihara
parent 81f40774a7
commit 68cec55a98
4 changed files with 50 additions and 124 deletions
-32
View File
@@ -3,38 +3,6 @@ import numpy as np
import libraylib as raylib
# The following definitions are required because Python doesn't allow custom
# attributes for primitive types. We need custom attributes for (a) implementing
# destructors that close the shared memory segment that the object resides in
# and (b) fixing https://github.com/amplab/ray/issues/72.
class Int(int):
pass
class Long(long):
pass
class Float(float):
pass
class List(list):
pass
class Dict(dict):
pass
class Tuple(tuple):
pass
class Str(str):
pass
class Unicode(unicode):
pass
class NDArray(np.ndarray):
pass
def to_primitive(obj):
if hasattr(obj, "serialize"):
primitive_obj = ((type(obj).__module__, type(obj).__name__), obj.serialize())
+23 -60
View File
@@ -145,36 +145,6 @@ class RayGetArgumentError(Exception):
"""Format a RayGetArgumentError as a string."""
return "Failed to get objectid {} as argument {} for remote function {}{}{}. It was created by remote function {}{}{} which failed with:\n{}".format(self.objectid, self.argument_index, colorama.Fore.RED, self.function_name, colorama.Fore.RESET, colorama.Fore.RED, self.task_error.function_name, colorama.Fore.RESET, self.task_error)
class RayDealloc(object):
"""An object used internally to properly implement reference counting.
When we call get_object with a particular object ID, we create a RayDealloc
object with the information necessary to properly handle closing the relevant
memory segment when the object is no longer needed by the worker. The
RayDealloc object is stored as a field in the object returned by get_object so
that its destructor is only called when the worker no longer has any
references to the object.
Attributes
handle (worker capsule): A Python object wrapping a C++ Worker object.
segmentid (int): The id of the segment that contains the object that holds
this RayDealloc object.
"""
def __init__(self, handle, segmentid):
"""Initialize a RayDealloc object.
Args:
handle (worker capsule): A Python object wrapping a C++ Worker object.
segmentid (int): The id of the segment that contains the object that holds
this RayDealloc object.
"""
self.handle = handle
self.segmentid = segmentid
def __del__(self):
"""Deallocate the relevant segment to avoid a memory leak."""
raylib.unmap_object(self.handle, self.segmentid)
class Reusable(object):
"""An Python object that can be shared between tasks.
@@ -309,6 +279,28 @@ class RayReusables(object):
"""
raise Exception("Attempted deletion of attribute {}. Attributes of a RayReusable object may not be deleted.".format(name))
class ObjectFixture(object):
"""This is used to handle unmapping objects backed by the object store.
The object referred to by objectid will get unmaped when the fixture is
deallocated. In addition, the ObjectFixture holds the objectid as a field,
which ensures that the corresponding object will not be deallocated from the
object store while the ObjectFixture is alive. ObjectFixture is used as the
base object for numpy arrays that are contained in the object referred to by
objectid and prevents memory that is used by them from getting unmapped by the
worker or deallocated by the object store.
"""
def __init__(self, objectid, segmentid, handle):
"""Initialize an ObjectFixture object."""
self.objectid = objectid
self.segmentid = segmentid
self.handle = handle
def __del__(self):
"""Unmap the segment when the object goes out of scope."""
raylib.unmap_object(self.handle, self.segmentid)
class Worker(object):
"""A class used to define the control flow of a worker process.
@@ -414,7 +406,7 @@ class Worker(object):
metadata = np.frombuffer(buff, dtype="byte", offset=8, count=metadata_size)
data = np.frombuffer(buff, dtype="byte")[8 + metadata_size:]
serialized = libnumbuf.read_from_buffer(memoryview(data), bytearray(metadata), metadata_offset)
deserialized = libnumbuf.deserialize_list(serialized)
deserialized = libnumbuf.deserialize_list(serialized, ObjectFixture(objectid, segmentid, self.handle))
# Unwrap the object from the list (it was wrapped put_object)
assert len(deserialized) == 1
result = deserialized[0]
@@ -424,35 +416,6 @@ class Worker(object):
object_capsule, segmentid = raylib.get_object(self.handle, objectid)
result = serialization.deserialize(self.handle, object_capsule)
if isinstance(result, int):
result = serialization.Int(result)
elif isinstance(result, long):
result = serialization.Long(result)
elif isinstance(result, float):
result = serialization.Float(result)
elif isinstance(result, bool):
raylib.unmap_object(self.handle, segmentid) # need to unmap here because result is passed back "by value" and we have no reference to unmap later
return result # can't subclass bool, and don't need to because there is a global True/False
elif isinstance(result, list):
result = serialization.List(result)
elif isinstance(result, dict):
result = serialization.Dict(result)
elif isinstance(result, tuple):
result = serialization.Tuple(result)
elif isinstance(result, str):
result = serialization.Str(result)
elif isinstance(result, unicode):
result = serialization.Unicode(result)
elif isinstance(result, np.ndarray):
result = result.view(serialization.NDArray)
elif isinstance(result, np.generic):
return result
# TODO(pcm): close the associated memory segment; if we don't, this leaks memory (but very little, so it is ok for now)
elif result is None:
raylib.unmap_object(self.handle, segmentid) # need to unmap here because result is passed back "by value" and we have no reference to unmap later
return None # can't subclass None and don't need to because there is a global None
result.ray_objectid = objectid # TODO(pcm): This could be done only for the "get" case in the future if we want to increase performance
result.ray_deallocator = RayDealloc(self.handle, segmentid)
return result
def alias_objectids(self, alias_objectid, target_objectid):