diff --git a/lib/python/ray/worker.py b/lib/python/ray/worker.py index db384e48e..70e199467 100644 --- a/lib/python/ray/worker.py +++ b/lib/python/ray/worker.py @@ -19,8 +19,7 @@ class RayDealloc(object): self.segmentid = segmentid def __del__(self): - # TODO(pcm): This will be used to free the segment - pass + ray.lib.unmap_object(self.handle, self.segmentid) class Worker(object): """The methods in this class are considered unexposed to the user. The functions outside of this class are considered exposed.""" @@ -58,8 +57,8 @@ class Worker(object): elif isinstance(result, float): result = serialization.Float(result) elif isinstance(result, bool): + ray.lib.unmap_object(self.handle, segmentid) # need to unmap here because result is passed back "by value" and we have no reference to unmap later return result # can't subclass bool, and don't need to because there is a global True/False - # TODO(pcm): close the associated memory segment; if we don't, this leaks memory (but very little, so it is ok for now) elif isinstance(result, list): result = serialization.List(result) elif isinstance(result, dict): @@ -74,8 +73,8 @@ class Worker(object): return result # TODO(pcm): close the associated memory segment; if we don't, this leaks memory (but very little, so it is ok for now) elif result == None: + ray.lib.unmap_object(self.handle, segmentid) # need to unmap here because result is passed back "by value" and we have no reference to unmap later return None # can't subclass None and don't need to because there is a global None - # TODO(pcm): close the associated memory segment; if we don't, this leaks memory (but very little, so it is ok for now) result.ray_objref = objref # TODO(pcm): This could be done only for the "get" case in the future if we want to increase performance result.ray_deallocator = RayDealloc(self.handle, segmentid) return result diff --git a/src/ipc.cc b/src/ipc.cc index 2739441c4..0e3eb1923 100644 --- a/src/ipc.cc +++ b/src/ipc.cc @@ -56,6 +56,11 @@ void MemorySegmentPool::open_segment(SegmentId segmentid, size_t size) { } } +void MemorySegmentPool::unmap_segment(SegmentId segmentid) { + segments_[segmentid].first.reset(); + segments_[segmentid].second = SegmentStatusType::UNOPENED; +} + void MemorySegmentPool::close_segment(SegmentId segmentid) { RAY_LOG(RAY_DEBUG, "closing segmentid " << segmentid); std::string segment_name = get_segment_name(segmentid); diff --git a/src/ipc.h b/src/ipc.h index 6bcbe80d3..fd2e7e7a4 100644 --- a/src/ipc.h +++ b/src/ipc.h @@ -152,6 +152,7 @@ public: void deallocate(ObjHandle pointer); // deallocate object, potentially deallocating a new segment (only run on object store) uint8_t* get_address(ObjHandle pointer); // get address of shared object std::string get_segment_name(SegmentId segmentid); // get the name of a segment + void unmap_segment(SegmentId segmentid); // unmap a memory segment from a client (only to be called by clients) private: void open_segment(SegmentId segmentid, size_t size = 0); // create a segment or map an existing one into memory void close_segment(SegmentId segmentid); // close a segment diff --git a/src/raylib.cc b/src/raylib.cc index 3ac9b971a..9c9f760a6 100644 --- a/src/raylib.cc +++ b/src/raylib.cc @@ -460,7 +460,11 @@ PyObject* put_arrow(PyObject* self, PyObject* args) { if (!PyArg_ParseTuple(args, "O&O&O", &PyObjectToWorker, &worker, &PyObjectToObjRef, &objref, &value)) { return NULL; } - worker->put_arrow(objref, value); + // The following is reqired, because numbuf expects contiguous arrays at the moment. + // This is to make sure that we do not have to do reference counting inside numbuf, it is expected to change. + PyArrayObject* array = PyArray_GETCONTIGUOUS((PyArrayObject*) value); // TODO(pcm): put that into numbuf + worker->put_arrow(objref, (PyObject*) array); + Py_XDECREF(array); // GETCONTIGUOUS from above returned a new reference Py_RETURN_NONE; } @@ -490,6 +494,16 @@ PyObject* is_arrow(PyObject* self, PyObject* args) { Py_RETURN_FALSE; } +PyObject* unmap_object(PyObject* self, PyObject* args) { + Worker* worker; + int segmentid; + if (!PyArg_ParseTuple(args, "O&i", &PyObjectToWorker, &worker, &segmentid)) { + return NULL; + } + worker->unmap_object(segmentid); + Py_RETURN_NONE; +} + PyObject* deserialize_object(PyObject* self, PyObject* args) { PyObject* worker_capsule; Obj* obj; @@ -827,6 +841,7 @@ static PyMethodDef RayLibMethods[] = { { "put_arrow", put_arrow, METH_VARARGS, "put an arrow array on the local object store"}, { "get_arrow", get_arrow, METH_VARARGS, "get an arrow array from the local object store"}, { "is_arrow", is_arrow, METH_VARARGS, "is the object in the local object store an arrow object?"}, + { "unmap_object", unmap_object, METH_VARARGS, "unmap the object from the client's shared memory pool"}, { "serialize_task", serialize_task, METH_VARARGS, "serialize a task to protocol buffers" }, { "deserialize_task", deserialize_task, METH_VARARGS, "deserialize a task from protocol buffers" }, { "create_worker", create_worker, METH_VARARGS, "connect to the scheduler and the object store" }, diff --git a/src/worker.cc b/src/worker.cc index 58eb31506..2d6b3964f 100644 --- a/src/worker.cc +++ b/src/worker.cc @@ -117,6 +117,9 @@ void Worker::put_object(ObjRef objref, const Obj* obj, std::vector &cont receive_obj_queue_.receive(&result); uint8_t* target = segmentpool_->get_address(result); std::memcpy(target, &data[0], data.size()); + // We immediately unmap here; if the object is going to be accessed again, it will be mapped again; + // This is reqired because we do not have a mechanism to unmap the object later. + segmentpool_->unmap_segment(result.segmentid()); request.type = ObjRequestType::WORKER_DONE; request.metadata_offset = 0; request_obj_queue_.send(&request); @@ -160,6 +163,9 @@ PyObject* Worker::put_arrow(ObjRef objref, PyObject* value) { uint8_t* address = segmentpool_->get_address(result); auto source = std::make_shared(address, size); CHECK_ARROW_STATUS(writer.Write(source.get(), &metadata_offset), "error during Write: "); + // We immediately unmap here; if the object is going to be accessed again, it will be mapped again; + // This is reqired because we do not have a mechanism to unmap the object later. + segmentpool_->unmap_segment(result.segmentid()); request.type = ObjRequestType::WORKER_DONE; request.metadata_offset = metadata_offset; request_obj_queue_.send(&request); @@ -197,6 +203,14 @@ bool Worker::is_arrow(ObjRef objref) { return result.metadata_offset() != 0; } +void Worker::unmap_object(ObjRef objref) { + if (!connected_) { + RAY_LOG(RAY_DEBUG, "Attempted to perform unmap_object but failed."); + return; + } + segmentpool_->unmap_segment(objref); +} + void Worker::alias_objrefs(ObjRef alias_objref, ObjRef target_objref) { RAY_CHECK(connected_, "Attempted to perform alias_objrefs but failed."); ClientContext context; diff --git a/src/worker.h b/src/worker.h index 0ba11ca70..6c303691e 100644 --- a/src/worker.h +++ b/src/worker.h @@ -60,6 +60,8 @@ class Worker { PyObject* get_arrow(ObjRef objref, SegmentId& segmentid); // determine if the object stored in objref is an arrow object // TODO(pcm): more general mechanism for this? bool is_arrow(ObjRef objref); + // unmap the segment containing an object from the local address space + void unmap_object(ObjRef objref); // make `alias_objref` refer to the same object that `target_objref` refers to void alias_objrefs(ObjRef alias_objref, ObjRef target_objref); // increment the reference count for objref diff --git a/test/runtest.py b/test/runtest.py index c0c1a0c5a..da176e70d 100644 --- a/test/runtest.py +++ b/test/runtest.py @@ -330,6 +330,17 @@ class ReferenceCountingTest(unittest.TestCase): x, objref_val = check_get_not_deallocated(val) self.assertEqual(ray.scheduler_info()["reference_counts"][objref_val], 1) + # The following currently segfaults: The second "result = " closes the + # memory segment as soon as the assignment is done (and the first result + # goes out of scope). + """ + data = np.zeros([10, 20]) + objref = ray.put(data) + result = worker.get(objref) + result = worker.get(objref) + self.assertTrue(np.alltrue(result == data)) + """ + services.cleanup() @unittest.expectedFailure diff --git a/thirdparty/numbuf b/thirdparty/numbuf index 9f71ee375..508698698 160000 --- a/thirdparty/numbuf +++ b/thirdparty/numbuf @@ -1 +1 @@ -Subproject commit 9f71ee37557a740beb5547cf51b5aad9a7c92fe3 +Subproject commit 508698698217b956286ff696fbdde2604c00c101