diff --git a/include/ray/ray.h b/include/ray/ray.h index 72d82ee66..242541537 100644 --- a/include/ray/ray.h +++ b/include/ray/ray.h @@ -10,6 +10,7 @@ typedef size_t ObjRef; typedef size_t WorkerId; typedef size_t ObjStoreId; typedef size_t OperationId; +typedef size_t SegmentId; // index into a memory segment table class FnInfo { size_t num_return_vals_; @@ -45,6 +46,7 @@ public: struct slice { uint8_t* data; size_t len; + SegmentId segmentid; }; #endif diff --git a/lib/python/ray/serialization.py b/lib/python/ray/serialization.py index f15a0f9d9..e22bf6604 100644 --- a/lib/python/ray/serialization.py +++ b/lib/python/ray/serialization.py @@ -1,7 +1,34 @@ import importlib +import numpy as np import ray +# The following definitions are required because Python doesn't allow custom +# attributes for primitive types. We need custom attributes for (a) implementing +# destructors that close the shared memory segment that the object resides in +# and (b) fixing https://github.com/amplab/ray/issues/72. + +class Int(int): + pass + +class Float(float): + pass + +class List(list): + pass + +class Dict(dict): + pass + +class Tuple(tuple): + pass + +class Str(str): + pass + +class NDArray(np.ndarray): + pass + def to_primitive(obj): if hasattr(obj, "serialize"): primitive_obj = ((type(obj).__module__, type(obj).__name__), obj.serialize()) diff --git a/lib/python/ray/worker.py b/lib/python/ray/worker.py index 2801146ed..f804b5274 100644 --- a/lib/python/ray/worker.py +++ b/lib/python/ray/worker.py @@ -12,6 +12,15 @@ import ray from ray.config import LOG_DIRECTORY, LOG_TIMESTAMP import serialization +class RayDealloc(object): + def __init__(self, handle, segmentid): + self.handle = handle + self.segmentid = segmentid + + def __del__(self): + # TODO(pcm): This will be used to free the segment + pass + class Worker(object): """The methods in this class are considered unexposed to the user. The functions outside of this class are considered exposed.""" @@ -39,10 +48,33 @@ class Worker(object): WARNING: get_object can only be called on a canonical objref. """ if ray.lib.is_arrow(self.handle, objref): - return ray.lib.get_arrow(self.handle, objref) + result, segmentid = ray.lib.get_arrow(self.handle, objref) else: - object_capsule = ray.lib.get_object(self.handle, objref) - return serialization.deserialize(self.handle, object_capsule) + object_capsule, segmentid = ray.lib.get_object(self.handle, objref) + result = serialization.deserialize(self.handle, object_capsule) + if isinstance(result, int): + result = serialization.Int(result) + elif isinstance(result, float): + result = serialization.Float(result) + elif isinstance(result, bool): + return result # can't subclass bool, and don't need to because there is a global True/False + # TODO(pcm): close the associated memory segment; if we don't, this leaks memory (but very little, so it is ok for now) + elif isinstance(result, list): + result = serialization.List(result) + elif isinstance(result, dict): + result = serialization.Dict(result) + elif isinstance(result, tuple): + result = serialization.Tuple(result) + elif isinstance(result, str): + result = serialization.Str(result) + elif isinstance(result, np.ndarray): + result = result.view(serialization.NDArray) + elif result == None: + return None # can't subclass None and don't need to because there is a global None + # TODO(pcm): close the associated memory segment; if we don't, this leaks memory (but very little, so it is ok for now) + # TODO(pcm): Here, we can add the object reference to fix https://github.com/amplab/ray/issues/72 + result.ray_deallocator = RayDealloc(self.handle, segmentid) + return result def alias_objrefs(self, alias_objref, target_objref): """Make `alias_objref` refer to the same object that `target_objref` refers to.""" diff --git a/src/raylib.cc b/src/raylib.cc index 732ad34eb..4c77b5703 100644 --- a/src/raylib.cc +++ b/src/raylib.cc @@ -512,7 +512,12 @@ PyObject* get_arrow(PyObject* self, PyObject* args) { if (!PyArg_ParseTuple(args, "O&O&", &PyObjectToWorker, &worker, &PyObjectToObjRef, &objref)) { return NULL; } - return (PyObject*) worker->get_arrow(objref); + SegmentId segmentid; + PyObject* value = worker->get_arrow(objref, segmentid); + PyObject* val_and_segmentid = PyList_New(2); + PyList_SetItem(val_and_segmentid, 0, value); + PyList_SetItem(val_and_segmentid, 1, PyInt_FromLong(segmentid)); + return val_and_segmentid; } PyObject* is_arrow(PyObject* self, PyObject* args) { @@ -748,7 +753,10 @@ PyObject* get_object(PyObject* self, PyObject* args) { slice s = worker->get_object(objref); Obj* obj = new Obj(); // TODO: Make sure this will get deleted obj->ParseFromString(std::string(reinterpret_cast(s.data), s.len)); - return PyCapsule_New(static_cast(obj), "obj", &ObjCapsule_Destructor); + PyObject* result = PyList_New(2); + PyList_SetItem(result, 0, PyCapsule_New(static_cast(obj), "obj", &ObjCapsule_Destructor)); + PyList_SetItem(result, 1, PyInt_FromLong(s.segmentid)); + return result; } PyObject* request_object(PyObject* self, PyObject* args) { diff --git a/src/worker.cc b/src/worker.cc index cbfde6b77..861c0b1d7 100644 --- a/src/worker.cc +++ b/src/worker.cc @@ -93,6 +93,7 @@ slice Worker::get_object(ObjRef objref) { slice slice; slice.data = segmentpool_->get_address(result); slice.len = result.size(); + slice.segmentid = result.segmentid(); return slice; } @@ -165,7 +166,9 @@ PyObject* Worker::put_arrow(ObjRef objref, PyObject* value) { Py_RETURN_NONE; } -PyObject* Worker::get_arrow(ObjRef objref) { +// returns python list containing the value represented by objref and the +// segmentid in which the object is stored +PyObject* Worker::get_arrow(ObjRef objref, SegmentId& segmentid) { RAY_CHECK(connected_, "Attempted to perform get_arrow but failed."); ObjRequest request; request.workerid = workerid_; @@ -176,6 +179,7 @@ PyObject* Worker::get_arrow(ObjRef objref) { receive_obj_queue_.receive(&result); uint8_t* address = segmentpool_->get_address(result); auto source = std::make_shared(address, result.size()); + segmentid = result.segmentid(); PyObject* value; CHECK_ARROW_STATUS(pynumbuf::ReadPythonObjectFrom(source.get(), result.metadata_offset(), &value), "error during ReadPythonObjectFrom: "); return value; diff --git a/src/worker.h b/src/worker.h index cd260786e..0ba11ca70 100644 --- a/src/worker.h +++ b/src/worker.h @@ -57,7 +57,7 @@ class Worker { // stores an arrow object to the local object store PyObject* put_arrow(ObjRef objref, PyObject* array); // gets an arrow object from the local object store - PyObject* get_arrow(ObjRef objref); + PyObject* get_arrow(ObjRef objref, SegmentId& segmentid); // determine if the object stored in objref is an arrow object // TODO(pcm): more general mechanism for this? bool is_arrow(ObjRef objref); // make `alias_objref` refer to the same object that `target_objref` refers to