mirror of
https://github.com/wassname/ray.git
synced 2026-07-02 09:38:00 +08:00
Refcount without modifying objects (#407)
* refcount without modifying objects * add documentation * Update tests and documentation. * Remove extraneous code. * Update numbuf version.
This commit is contained in:
committed by
Robert Nishihara
parent
81f40774a7
commit
68cec55a98
@@ -3,38 +3,6 @@ import numpy as np
|
||||
|
||||
import libraylib as raylib
|
||||
|
||||
# The following definitions are required because Python doesn't allow custom
|
||||
# attributes for primitive types. We need custom attributes for (a) implementing
|
||||
# destructors that close the shared memory segment that the object resides in
|
||||
# and (b) fixing https://github.com/amplab/ray/issues/72.
|
||||
|
||||
class Int(int):
|
||||
pass
|
||||
|
||||
class Long(long):
|
||||
pass
|
||||
|
||||
class Float(float):
|
||||
pass
|
||||
|
||||
class List(list):
|
||||
pass
|
||||
|
||||
class Dict(dict):
|
||||
pass
|
||||
|
||||
class Tuple(tuple):
|
||||
pass
|
||||
|
||||
class Str(str):
|
||||
pass
|
||||
|
||||
class Unicode(unicode):
|
||||
pass
|
||||
|
||||
class NDArray(np.ndarray):
|
||||
pass
|
||||
|
||||
def to_primitive(obj):
|
||||
if hasattr(obj, "serialize"):
|
||||
primitive_obj = ((type(obj).__module__, type(obj).__name__), obj.serialize())
|
||||
|
||||
+23
-60
@@ -145,36 +145,6 @@ class RayGetArgumentError(Exception):
|
||||
"""Format a RayGetArgumentError as a string."""
|
||||
return "Failed to get objectid {} as argument {} for remote function {}{}{}. It was created by remote function {}{}{} which failed with:\n{}".format(self.objectid, self.argument_index, colorama.Fore.RED, self.function_name, colorama.Fore.RESET, colorama.Fore.RED, self.task_error.function_name, colorama.Fore.RESET, self.task_error)
|
||||
|
||||
class RayDealloc(object):
|
||||
"""An object used internally to properly implement reference counting.
|
||||
|
||||
When we call get_object with a particular object ID, we create a RayDealloc
|
||||
object with the information necessary to properly handle closing the relevant
|
||||
memory segment when the object is no longer needed by the worker. The
|
||||
RayDealloc object is stored as a field in the object returned by get_object so
|
||||
that its destructor is only called when the worker no longer has any
|
||||
references to the object.
|
||||
|
||||
Attributes
|
||||
handle (worker capsule): A Python object wrapping a C++ Worker object.
|
||||
segmentid (int): The id of the segment that contains the object that holds
|
||||
this RayDealloc object.
|
||||
"""
|
||||
|
||||
def __init__(self, handle, segmentid):
|
||||
"""Initialize a RayDealloc object.
|
||||
|
||||
Args:
|
||||
handle (worker capsule): A Python object wrapping a C++ Worker object.
|
||||
segmentid (int): The id of the segment that contains the object that holds
|
||||
this RayDealloc object.
|
||||
"""
|
||||
self.handle = handle
|
||||
self.segmentid = segmentid
|
||||
|
||||
def __del__(self):
|
||||
"""Deallocate the relevant segment to avoid a memory leak."""
|
||||
raylib.unmap_object(self.handle, self.segmentid)
|
||||
|
||||
class Reusable(object):
|
||||
"""An Python object that can be shared between tasks.
|
||||
@@ -309,6 +279,28 @@ class RayReusables(object):
|
||||
"""
|
||||
raise Exception("Attempted deletion of attribute {}. Attributes of a RayReusable object may not be deleted.".format(name))
|
||||
|
||||
class ObjectFixture(object):
|
||||
"""This is used to handle unmapping objects backed by the object store.
|
||||
|
||||
The object referred to by objectid will get unmaped when the fixture is
|
||||
deallocated. In addition, the ObjectFixture holds the objectid as a field,
|
||||
which ensures that the corresponding object will not be deallocated from the
|
||||
object store while the ObjectFixture is alive. ObjectFixture is used as the
|
||||
base object for numpy arrays that are contained in the object referred to by
|
||||
objectid and prevents memory that is used by them from getting unmapped by the
|
||||
worker or deallocated by the object store.
|
||||
"""
|
||||
|
||||
def __init__(self, objectid, segmentid, handle):
|
||||
"""Initialize an ObjectFixture object."""
|
||||
self.objectid = objectid
|
||||
self.segmentid = segmentid
|
||||
self.handle = handle
|
||||
|
||||
def __del__(self):
|
||||
"""Unmap the segment when the object goes out of scope."""
|
||||
raylib.unmap_object(self.handle, self.segmentid)
|
||||
|
||||
class Worker(object):
|
||||
"""A class used to define the control flow of a worker process.
|
||||
|
||||
@@ -414,7 +406,7 @@ class Worker(object):
|
||||
metadata = np.frombuffer(buff, dtype="byte", offset=8, count=metadata_size)
|
||||
data = np.frombuffer(buff, dtype="byte")[8 + metadata_size:]
|
||||
serialized = libnumbuf.read_from_buffer(memoryview(data), bytearray(metadata), metadata_offset)
|
||||
deserialized = libnumbuf.deserialize_list(serialized)
|
||||
deserialized = libnumbuf.deserialize_list(serialized, ObjectFixture(objectid, segmentid, self.handle))
|
||||
# Unwrap the object from the list (it was wrapped put_object)
|
||||
assert len(deserialized) == 1
|
||||
result = deserialized[0]
|
||||
@@ -424,35 +416,6 @@ class Worker(object):
|
||||
object_capsule, segmentid = raylib.get_object(self.handle, objectid)
|
||||
result = serialization.deserialize(self.handle, object_capsule)
|
||||
|
||||
if isinstance(result, int):
|
||||
result = serialization.Int(result)
|
||||
elif isinstance(result, long):
|
||||
result = serialization.Long(result)
|
||||
elif isinstance(result, float):
|
||||
result = serialization.Float(result)
|
||||
elif isinstance(result, bool):
|
||||
raylib.unmap_object(self.handle, segmentid) # need to unmap here because result is passed back "by value" and we have no reference to unmap later
|
||||
return result # can't subclass bool, and don't need to because there is a global True/False
|
||||
elif isinstance(result, list):
|
||||
result = serialization.List(result)
|
||||
elif isinstance(result, dict):
|
||||
result = serialization.Dict(result)
|
||||
elif isinstance(result, tuple):
|
||||
result = serialization.Tuple(result)
|
||||
elif isinstance(result, str):
|
||||
result = serialization.Str(result)
|
||||
elif isinstance(result, unicode):
|
||||
result = serialization.Unicode(result)
|
||||
elif isinstance(result, np.ndarray):
|
||||
result = result.view(serialization.NDArray)
|
||||
elif isinstance(result, np.generic):
|
||||
return result
|
||||
# TODO(pcm): close the associated memory segment; if we don't, this leaks memory (but very little, so it is ok for now)
|
||||
elif result is None:
|
||||
raylib.unmap_object(self.handle, segmentid) # need to unmap here because result is passed back "by value" and we have no reference to unmap later
|
||||
return None # can't subclass None and don't need to because there is a global None
|
||||
result.ray_objectid = objectid # TODO(pcm): This could be done only for the "get" case in the future if we want to increase performance
|
||||
result.ray_deallocator = RayDealloc(self.handle, segmentid)
|
||||
return result
|
||||
|
||||
def alias_objectids(self, alias_objectid, target_objectid):
|
||||
|
||||
Reference in New Issue
Block a user