From 7d33e9949b942acde92db6698abdb6b409c0648c Mon Sep 17 00:00:00 2001 From: Eric Liang Date: Fri, 15 Nov 2019 10:52:19 -0800 Subject: [PATCH] Integrate ref count module into local memory store (#6122) --- python/ray/_raylet.pyx | 2 +- python/ray/includes/libcoreworker.pxd | 3 +- python/ray/tests/test_basic.py | 25 ++++++- src/ray/core_worker/core_worker.cc | 50 ++++++++------ src/ray/core_worker/core_worker.h | 27 +++++--- src/ray/core_worker/reference_count.cc | 16 +++-- src/ray/core_worker/reference_count.h | 14 +++- src/ray/core_worker/reference_count_test.cc | 68 +++++++++++++++---- .../memory_store/memory_store.cc | 23 +++++-- .../memory_store/memory_store.h | 23 ++++++- 10 files changed, 190 insertions(+), 61 deletions(-) diff --git a/python/ray/_raylet.pyx b/python/ray/_raylet.pyx index 979921cd9..db6def77e 100644 --- a/python/ray/_raylet.pyx +++ b/python/ray/_raylet.pyx @@ -720,7 +720,7 @@ cdef class CoreWorker: raylet_socket.encode("ascii"), job_id.native(), gcs_options.native()[0], log_dir.encode("utf-8"), node_ip_address.encode("utf-8"), node_manager_port, - task_execution_handler, check_signals, exit_handler)) + task_execution_handler, check_signals, exit_handler, True)) def disconnect(self): with nogil: diff --git a/python/ray/includes/libcoreworker.pxd b/python/ray/includes/libcoreworker.pxd index 3d55fd98a..9edcdcad9 100644 --- a/python/ray/includes/libcoreworker.pxd +++ b/python/ray/includes/libcoreworker.pxd @@ -65,7 +65,8 @@ cdef extern from "ray/core_worker/core_worker.h" nogil: const c_vector[CObjectID] &return_ids, c_vector[shared_ptr[CRayObject]] *returns) nogil, CRayStatus() nogil, - void () nogil) + void () nogil, + c_bool ref_counting_enabled) void Disconnect() CWorkerType &GetWorkerType() CLanguage &GetLanguage() diff --git a/python/ray/tests/test_basic.py b/python/ray/tests/test_basic.py index 2766aaf66..6e8cadd5f 100644 --- a/python/ray/tests/test_basic.py +++ b/python/ray/tests/test_basic.py @@ -1211,13 +1211,32 @@ def test_direct_call_simple(ray_start_regular): return x + 1 f_direct = f.options(is_direct_call=True) - print("a") assert ray.get(f_direct.remote(2)) == 3 - print("b") assert ray.get([f_direct.remote(i) for i in range(100)]) == list( range(1, 101)) +def test_direct_call_refcount(ray_start_regular): + @ray.remote + def f(x): + return x + 1 + + @ray.remote + def sleep(): + time.sleep(.1) + return 1 + + # Multiple gets should not hang with ref counting enabled. + f_direct = f.options(is_direct_call=True) + x = f_direct.remote(2) + ray.get(x) + ray.get(x) + + # Temporary objects should be retained for chained callers. + y = f_direct.remote(sleep.options(is_direct_call=True).remote()) + assert ray.get(y) == 2 + + def test_direct_call_matrix(shutdown_only): ray.init(object_store_memory=1000 * 1024 * 1024) @@ -1407,7 +1426,7 @@ def test_direct_actor_recursive(ray_start_regular): return x * 2 a = Actor._remote(is_direct_call=True) - b = Actor._remote(args=[a], is_direct_call=False) + b = Actor._remote(args=[a], is_direct_call=True) c = Actor._remote(args=[b], is_direct_call=True) result = ray.get([c.f.remote(i) for i in range(100)]) diff --git a/src/ray/core_worker/core_worker.cc b/src/ray/core_worker/core_worker.cc index 07f18b461..dc16e1f77 100644 --- a/src/ray/core_worker/core_worker.cc +++ b/src/ray/core_worker/core_worker.cc @@ -100,16 +100,19 @@ CoreWorker::CoreWorker(const WorkerType worker_type, const Language language, int node_manager_port, const TaskExecutionCallback &task_execution_callback, std::function check_signals, - const std::function exit_handler) + const std::function exit_handler, + bool ref_counting_enabled) : worker_type_(worker_type), language_(language), log_dir_(log_dir), + ref_counting_enabled_(ref_counting_enabled), check_signals_(check_signals), worker_context_(worker_type, job_id), io_work_(io_service_), client_call_manager_(new rpc::ClientCallManager(io_service_)), heartbeat_timer_(io_service_), core_worker_server_(WorkerTypeString(worker_type), 0 /* let grpc choose a port */), + reference_counter_(std::make_shared()), task_execution_service_work_(task_execution_service_), task_execution_callback_(task_execution_callback), grpc_service_(io_service_, *this) { @@ -189,10 +192,11 @@ CoreWorker::CoreWorker(const WorkerType worker_type, const Language language, plasma_store_provider_.reset( new CoreWorkerPlasmaStoreProvider(store_socket, raylet_client_, check_signals_)); - memory_store_.reset( - new CoreWorkerMemoryStore([this](const RayObject &obj, const ObjectID &obj_id) { + memory_store_.reset(new CoreWorkerMemoryStore( + [this](const RayObject &obj, const ObjectID &obj_id) { RAY_CHECK_OK(plasma_store_provider_->Put(obj, obj_id)); - })); + }, + ref_counting_enabled ? reference_counter_ : nullptr)); memory_store_provider_.reset(new CoreWorkerMemoryStoreProvider(memory_store_)); // Create an entry for the driver task in the task table. This task is @@ -282,7 +286,7 @@ void CoreWorker::SetCurrentTaskId(const TaskID &task_id) { void CoreWorker::ReportActiveObjectIDs() { std::unordered_set active_object_ids = - reference_counter_.GetAllInScopeObjectIDs(); + reference_counter_->GetAllInScopeObjectIDs(); RAY_LOG(DEBUG) << "Sending " << active_object_ids.size() << " object IDs to raylet."; if (active_object_ids.size() > RayConfig::instance().raylet_max_active_object_ids()) { RAY_LOG(WARNING) << active_object_ids.size() << "object IDs are currently in scope. " @@ -542,9 +546,8 @@ TaskID CoreWorker::GetCallerId() const { return caller_id; } -Status CoreWorker::SubmitTaskToRaylet(const TaskSpecification &task_spec) { - RAY_RETURN_NOT_OK(raylet_client_->SubmitTask(task_spec)); - +void CoreWorker::PinObjectReferences(const TaskSpecification &task_spec, + const TaskTransportType transport_type) { size_t num_returns = task_spec.NumReturns(); if (task_spec.IsActorCreationTask() || task_spec.IsActorTask()) { num_returns--; @@ -560,13 +563,10 @@ Status CoreWorker::SubmitTaskToRaylet(const TaskSpecification &task_spec) { } } - if (task_deps->size() > 0) { - for (size_t i = 0; i < num_returns; i++) { - reference_counter_.SetDependencies(task_spec.ReturnIdForPlasma(i), task_deps); - } + // Note that we call this even if task_deps.size() == 0, in order to pin the return id. + for (size_t i = 0; i < num_returns; i++) { + reference_counter_->SetDependencies(task_spec.ReturnId(i, transport_type), task_deps); } - - return Status::OK(); } Status CoreWorker::SubmitTask(const RayFunction &function, @@ -586,10 +586,13 @@ Status CoreWorker::SubmitTask(const RayFunction &function, function, args, task_options.num_returns, task_options.resources, {}, task_options.is_direct_call ? TaskTransportType::DIRECT : TaskTransportType::RAYLET, return_ids); + TaskSpecification task_spec = builder.Build(); if (task_options.is_direct_call) { - return direct_task_submitter_->SubmitTask(builder.Build()); + PinObjectReferences(task_spec, TaskTransportType::DIRECT); + return direct_task_submitter_->SubmitTask(task_spec); } else { - return raylet_client_->SubmitTask(builder.Build()); + PinObjectReferences(task_spec, TaskTransportType::RAYLET); + return raylet_client_->SubmitTask(task_spec); } } @@ -623,7 +626,9 @@ Status CoreWorker::CreateActor(const RayFunction &function, << "Actor " << actor_id << " already exists"; *return_actor_id = actor_id; - return SubmitTaskToRaylet(builder.Build()); + TaskSpecification task_spec = builder.Build(); + PinObjectReferences(task_spec, TaskTransportType::RAYLET); + return raylet_client_->SubmitTask(task_spec); } Status CoreWorker::SubmitActorTask(const ActorID &actor_id, const RayFunction &function, @@ -659,10 +664,13 @@ Status CoreWorker::SubmitActorTask(const ActorID &actor_id, const RayFunction &f // Submit task. Status status; + TaskSpecification task_spec = builder.Build(); if (is_direct_call) { - status = direct_actor_submitter_->SubmitTask(builder.Build()); + PinObjectReferences(task_spec, TaskTransportType::DIRECT); + status = direct_actor_submitter_->SubmitTask(task_spec); } else { - status = SubmitTaskToRaylet(builder.Build()); + PinObjectReferences(task_spec, TaskTransportType::RAYLET); + RAY_CHECK_OK(raylet_client_->SubmitTask(task_spec)); } return status; } @@ -830,9 +838,9 @@ Status CoreWorker::ExecuteTask(const TaskSpecification &task_spec, } } - if (task_spec.IsNormalTask() && reference_counter_.NumObjectIDsInScope() != 0) { + if (task_spec.IsNormalTask() && reference_counter_->NumObjectIDsInScope() != 0) { RAY_LOG(DEBUG) - << "There were " << reference_counter_.NumObjectIDsInScope() + << "There were " << reference_counter_->NumObjectIDsInScope() << " ObjectIDs left in scope after executing task " << task_spec.TaskId() << ". This is either caused by keeping references to ObjectIDs in Python between " "tasks (e.g., in global variables) or indicates a problem with Ray's " diff --git a/src/ray/core_worker/core_worker.h b/src/ray/core_worker/core_worker.h index 8dc8270d9..bd9800b6b 100644 --- a/src/ray/core_worker/core_worker.h +++ b/src/ray/core_worker/core_worker.h @@ -61,11 +61,12 @@ class CoreWorker { /// \param[in] node_ip_address IP address of the node. /// \param[in] node_manager_port Port of the local raylet. /// \param[in] task_execution_callback Language worker callback to execute tasks. - /// \parma[in] check_signals Language worker function to check for signals and handle + /// \param[in] check_signals Language worker function to check for signals and handle /// them. If the function returns anything but StatusOK, any long-running /// operations in the core worker will short circuit and return that status. - /// \parma[in] exit_handler Language worker function to orderly shutdown the worker. + /// \param[in] exit_handler Language worker function to orderly shutdown the worker. /// We guarantee this will be run on the main thread of the worker. + /// \param[in] ref_counting_enabled Whether to enable object ref counting. /// /// NOTE(zhijunfu): the constructor would throw if a failure happens. CoreWorker(const WorkerType worker_type, const Language language, @@ -74,7 +75,8 @@ class CoreWorker { const std::string &log_dir, const std::string &node_ip_address, int node_manager_port, const TaskExecutionCallback &task_execution_callback, std::function check_signals = nullptr, - std::function exit_handler = nullptr); + std::function exit_handler = nullptr, + bool ref_counting_enabled = false); ~CoreWorker(); @@ -103,14 +105,18 @@ class CoreWorker { /// /// \param[in] object_id The object ID to increase the reference count for. void AddObjectIDReference(const ObjectID &object_id) { - reference_counter_.AddReference(object_id); + reference_counter_->AddReference(object_id); } /// Decrease the reference count for this object ID. /// /// \param[in] object_id The object ID to decrease the reference count for. void RemoveObjectIDReference(const ObjectID &object_id) { - reference_counter_.RemoveReference(object_id); + std::vector deleted; + reference_counter_->RemoveReference(object_id, &deleted); + if (ref_counting_enabled_) { + memory_store_->Delete(deleted); + } } /// Promote an object to plasma. If it already exists locally, it will be @@ -369,8 +375,10 @@ class CoreWorker { /// Private methods related to task submission. /// - /// Submit the task to the raylet and add its dependencies to the reference counter. - Status SubmitTaskToRaylet(const TaskSpecification &task_spec); + /// Add task dependencies to the reference counter. This prevents the argument + /// objects from early eviction, and also adds the return object. + void PinObjectReferences(const TaskSpecification &task_spec, + const TaskTransportType transport_type); /// Give this worker a handle to an actor. /// @@ -434,6 +442,9 @@ class CoreWorker { /// Directory where log files are written. const std::string log_dir_; + /// Whether local reference counting is enabled. + const bool ref_counting_enabled_; + /// Application-language callback to check for signals that have been received /// since calling into C++. This will be called periodically (at least every /// 1s) during long-running operations. @@ -481,7 +492,7 @@ class CoreWorker { std::thread io_thread_; // Keeps track of object ID reference counts. - ReferenceCounter reference_counter_; + std::shared_ptr reference_counter_; /// /// Fields related to storing and retrieving objects. diff --git a/src/ray/core_worker/reference_count.cc b/src/ray/core_worker/reference_count.cc index 9e1bd178f..6e8b4fcd0 100644 --- a/src/ray/core_worker/reference_count.cc +++ b/src/ray/core_worker/reference_count.cc @@ -36,12 +36,14 @@ void ReferenceCounter::SetDependencies( } } -void ReferenceCounter::RemoveReference(const ObjectID &object_id) { +void ReferenceCounter::RemoveReference(const ObjectID &object_id, + std::vector *deleted) { absl::MutexLock lock(&mutex_); - RemoveReferenceRecursive(object_id); + RemoveReferenceRecursive(object_id, deleted); } -void ReferenceCounter::RemoveReferenceRecursive(const ObjectID &object_id) { +void ReferenceCounter::RemoveReferenceRecursive(const ObjectID &object_id, + std::vector *deleted) { auto entry = object_id_refs_.find(object_id); if (entry == object_id_refs_.end()) { RAY_LOG(WARNING) << "Tried to decrease ref count for nonexistent object ID: " @@ -52,13 +54,19 @@ void ReferenceCounter::RemoveReferenceRecursive(const ObjectID &object_id) { // If the reference count reached 0, decrease the reference count for each dependency. if (entry->second.second) { for (const ObjectID &pending_task_object_id : *entry->second.second) { - RemoveReferenceRecursive(pending_task_object_id); + RemoveReferenceRecursive(pending_task_object_id, deleted); } } object_id_refs_.erase(object_id); + deleted->push_back(object_id); } } +bool ReferenceCounter::HasReference(const ObjectID &object_id) { + absl::MutexLock lock(&mutex_); + return object_id_refs_.find(object_id) != object_id_refs_.end(); +} + size_t ReferenceCounter::NumObjectIDsInScope() const { absl::MutexLock lock(&mutex_); return object_id_refs_.size(); diff --git a/src/ray/core_worker/reference_count.h b/src/ray/core_worker/reference_count.h index bde2f2ecd..96e5ff759 100644 --- a/src/ray/core_worker/reference_count.h +++ b/src/ray/core_worker/reference_count.h @@ -25,7 +25,11 @@ class ReferenceCounter { /// Decrease the reference count for the ObjectID by one. If the reference count reaches /// zero, it will be erased from the map and the reference count for all of its /// dependencies will be decreased be one. - void RemoveReference(const ObjectID &object_id) LOCKS_EXCLUDED(mutex_); + /// + /// \param[in] object_id The object to to decrement the count for. + /// \param[in] deleted List to store objects that hit zero ref count. + void RemoveReference(const ObjectID &object_id, std::vector *deleted) + LOCKS_EXCLUDED(mutex_); /// Set the dependencies for the ObjectID. Dependencies for each ObjectID must be /// set at most once. The direct reference count for the ObjectID is set to zero and the @@ -37,6 +41,9 @@ class ReferenceCounter { /// Returns the total number of ObjectIDs currently in scope. size_t NumObjectIDsInScope() const LOCKS_EXCLUDED(mutex_); + /// Returns whether this object has an active reference. + bool HasReference(const ObjectID &object_id) LOCKS_EXCLUDED(mutex_); + /// Returns a set of all ObjectIDs currently in scope (i.e., nonzero reference count). std::unordered_set GetAllInScopeObjectIDs() const LOCKS_EXCLUDED(mutex_); @@ -51,7 +58,10 @@ class ReferenceCounter { /// Recursive helper function for decreasing reference counts. Will recursively call /// itself on any dependencies whose reference count reaches zero as a result of /// removing the reference. - void RemoveReferenceRecursive(const ObjectID &object_id) + /// + /// \param[in] object_id The object to to decrement the count for. + /// \param[in] deleted List to store objects that hit zero ref count. + void RemoveReferenceRecursive(const ObjectID &object_id, std::vector *deleted) EXCLUSIVE_LOCKS_REQUIRED(mutex_); /// Protects access to the reference counting state. diff --git a/src/ray/core_worker/reference_count_test.cc b/src/ray/core_worker/reference_count_test.cc index 7a0f640f2..60d6c6191 100644 --- a/src/ray/core_worker/reference_count_test.cc +++ b/src/ray/core_worker/reference_count_test.cc @@ -1,7 +1,9 @@ #include #include "gtest/gtest.h" +#include "ray/common/ray_object.h" #include "ray/core_worker/reference_count.h" +#include "ray/core_worker/store_provider/memory_store/memory_store.h" namespace ray { @@ -16,6 +18,7 @@ class ReferenceCountTest : public ::testing::Test { // Tests basic incrementing/decrementing of direct reference counts. An entry should only // be removed once its reference count reaches zero. TEST_F(ReferenceCountTest, TestBasic) { + std::vector out; ObjectID id = ObjectID::FromRandom(); rc->AddReference(id); ASSERT_EQ(rc->NumObjectIDsInScope(), 1); @@ -23,18 +26,22 @@ TEST_F(ReferenceCountTest, TestBasic) { ASSERT_EQ(rc->NumObjectIDsInScope(), 1); rc->AddReference(id); ASSERT_EQ(rc->NumObjectIDsInScope(), 1); - rc->RemoveReference(id); + rc->RemoveReference(id, &out); ASSERT_EQ(rc->NumObjectIDsInScope(), 1); - rc->RemoveReference(id); + ASSERT_EQ(out.size(), 0); + rc->RemoveReference(id, &out); ASSERT_EQ(rc->NumObjectIDsInScope(), 1); - rc->RemoveReference(id); + ASSERT_EQ(out.size(), 0); + rc->RemoveReference(id, &out); ASSERT_EQ(rc->NumObjectIDsInScope(), 0); + ASSERT_EQ(out.size(), 1); } // Tests the basic logic for dependencies - when an ObjectID with dependencies // goes out of scope (i.e., reference count reaches zero), all of its dependencies // should have their reference count decremented and be removed if it reaches zero. TEST_F(ReferenceCountTest, TestDependencies) { + std::vector out; ObjectID id1 = ObjectID::FromRandom(); ObjectID id2 = ObjectID::FromRandom(); ObjectID id3 = ObjectID::FromRandom(); @@ -49,13 +56,16 @@ TEST_F(ReferenceCountTest, TestDependencies) { rc->AddReference(id3); ASSERT_EQ(rc->NumObjectIDsInScope(), 3); - rc->RemoveReference(id1); + rc->RemoveReference(id1, &out); ASSERT_EQ(rc->NumObjectIDsInScope(), 3); - rc->RemoveReference(id1); + ASSERT_EQ(out.size(), 0); + rc->RemoveReference(id1, &out); ASSERT_EQ(rc->NumObjectIDsInScope(), 1); + ASSERT_EQ(out.size(), 2); - rc->RemoveReference(id3); + rc->RemoveReference(id3, &out); ASSERT_EQ(rc->NumObjectIDsInScope(), 0); + ASSERT_EQ(out.size(), 3); } // Tests the case where two entries share the same set of dependencies. When one @@ -63,6 +73,7 @@ TEST_F(ReferenceCountTest, TestDependencies) { // but they should still be nonzero until the second entry goes out of scope and all // direct dependencies to the dependencies are removed. TEST_F(ReferenceCountTest, TestSharedDependencies) { + std::vector out; ObjectID id1 = ObjectID::FromRandom(); ObjectID id2 = ObjectID::FromRandom(); ObjectID id3 = ObjectID::FromRandom(); @@ -79,13 +90,16 @@ TEST_F(ReferenceCountTest, TestSharedDependencies) { rc->AddReference(id4); ASSERT_EQ(rc->NumObjectIDsInScope(), 4); - rc->RemoveReference(id1); + rc->RemoveReference(id1, &out); ASSERT_EQ(rc->NumObjectIDsInScope(), 3); - rc->RemoveReference(id2); + ASSERT_EQ(out.size(), 1); + rc->RemoveReference(id2, &out); ASSERT_EQ(rc->NumObjectIDsInScope(), 1); + ASSERT_EQ(out.size(), 3); - rc->RemoveReference(id4); + rc->RemoveReference(id4, &out); ASSERT_EQ(rc->NumObjectIDsInScope(), 0); + ASSERT_EQ(out.size(), 4); } // Tests the case when an entry has a dependency that itself has a @@ -93,6 +107,7 @@ TEST_F(ReferenceCountTest, TestSharedDependencies) { // it should decrease the reference count for its dependency, causing // that entry to go out of scope and decrease its dependencies' reference counts. TEST_F(ReferenceCountTest, TestRecursiveDependencies) { + std::vector out; ObjectID id1 = ObjectID::FromRandom(); ObjectID id2 = ObjectID::FromRandom(); ObjectID id3 = ObjectID::FromRandom(); @@ -114,13 +129,42 @@ TEST_F(ReferenceCountTest, TestRecursiveDependencies) { rc->AddReference(id4); ASSERT_EQ(rc->NumObjectIDsInScope(), 4); - rc->RemoveReference(id2); + rc->RemoveReference(id2, &out); ASSERT_EQ(rc->NumObjectIDsInScope(), 4); - rc->RemoveReference(id1); + ASSERT_EQ(out.size(), 0); + rc->RemoveReference(id1, &out); ASSERT_EQ(rc->NumObjectIDsInScope(), 1); + ASSERT_EQ(out.size(), 3); - rc->RemoveReference(id4); + rc->RemoveReference(id4, &out); ASSERT_EQ(rc->NumObjectIDsInScope(), 0); + ASSERT_EQ(out.size(), 4); +} + +// Tests that the ref counts are properly integrated into the local +// object memory store. +TEST(MemoryStoreIntegrationTest, TestSimple) { + ObjectID id1 = ObjectID::FromRandom().WithDirectTransportType(); + ObjectID id2 = ObjectID::FromRandom().WithDirectTransportType(); + uint8_t data[] = {1, 2, 3, 4, 5, 6, 7, 8}; + RayObject buffer(std::make_shared(data, sizeof(data)), nullptr); + + auto rc = std::shared_ptr(new ReferenceCounter()); + CoreWorkerMemoryStore store(nullptr, rc); + + // Tests putting an object with no references is ignored. + RAY_CHECK_OK(store.Put(id2, buffer)); + ASSERT_EQ(store.Size(), 0); + + // Tests ref counting overrides remove after get option. + rc->AddReference(id1); + RAY_CHECK_OK(store.Put(id1, buffer)); + ASSERT_EQ(store.Size(), 1); + std::vector> results; + RAY_CHECK_OK(store.Get({id1}, /*num_objects*/ 1, /*timeout_ms*/ -1, + /*remove_after_get*/ true, &results)); + ASSERT_EQ(results.size(), 1); + ASSERT_EQ(store.Size(), 1); } } // namespace ray diff --git a/src/ray/core_worker/store_provider/memory_store/memory_store.cc b/src/ray/core_worker/store_provider/memory_store/memory_store.cc index 368b8dee2..1fa82e495 100644 --- a/src/ray/core_worker/store_provider/memory_store/memory_store.cc +++ b/src/ray/core_worker/store_provider/memory_store/memory_store.cc @@ -108,8 +108,9 @@ std::shared_ptr GetRequest::Get(const ObjectID &object_id) const { } CoreWorkerMemoryStore::CoreWorkerMemoryStore( - std::function store_in_plasma) - : store_in_plasma_(store_in_plasma) {} + std::function store_in_plasma, + std::shared_ptr counter) + : store_in_plasma_(store_in_plasma), ref_counter_(counter) {} void CoreWorkerMemoryStore::GetAsync( const ObjectID &object_id, std::function)> callback) { @@ -154,6 +155,7 @@ Status CoreWorkerMemoryStore::Put(const ObjectID &object_id, const RayObject &ob { absl::MutexLock lock(&mu_); + auto iter = objects_.find(object_id); if (iter != objects_.end()) { return Status::ObjectExists("object already exists in the memory store"); @@ -179,11 +181,16 @@ Status CoreWorkerMemoryStore::Put(const ObjectID &object_id, const RayObject &ob auto &get_requests = object_request_iter->second; for (auto &get_request : get_requests) { get_request->Set(object_id, object_entry); - if (get_request->ShouldRemoveObjects()) { + // If ref counting is enabled, override the removal behaviour. + if (get_request->ShouldRemoveObjects() && ref_counter_ == nullptr) { should_add_entry = false; } } } + // Don't put it in the store, since we won't get a callback for deletion. + if (ref_counter_ != nullptr && !ref_counter_->HasReference(object_id)) { + should_add_entry = false; + } if (should_add_entry) { // If there is no existing get request, then add the `RayObject` to map. @@ -231,8 +238,11 @@ Status CoreWorkerMemoryStore::Get(const std::vector &object_ids, } RAY_CHECK(count <= num_objects); - for (const auto &object_id : ids_to_remove) { - objects_.erase(object_id); + // Clean up the objects if ref counting is off. + if (ref_counter_ == nullptr) { + for (const auto &object_id : ids_to_remove) { + objects_.erase(object_id); + } } // Return if all the objects are obtained. @@ -298,8 +308,7 @@ void CoreWorkerMemoryStore::Delete(const std::vector &object_ids) { bool CoreWorkerMemoryStore::Contains(const ObjectID &object_id) { absl::MutexLock lock(&mu_); auto it = objects_.find(object_id); - // If obj is in plasma, we defer to the plasma store for the Contains() call. - return it != objects_.end() && !it->second->IsInPlasmaError(); + return it != objects_.end(); } } // namespace ray diff --git a/src/ray/core_worker/store_provider/memory_store/memory_store.h b/src/ray/core_worker/store_provider/memory_store/memory_store.h index 431e1e9ca..504404a35 100644 --- a/src/ray/core_worker/store_provider/memory_store/memory_store.h +++ b/src/ray/core_worker/store_provider/memory_store/memory_store.h @@ -7,6 +7,7 @@ #include "ray/common/id.h" #include "ray/common/status.h" #include "ray/core_worker/common.h" +#include "ray/core_worker/reference_count.h" namespace ray { @@ -18,8 +19,14 @@ class CoreWorkerMemoryStore; /// actor call (see direct_actor_transport.cc). class CoreWorkerMemoryStore { public: + /// Create a memory store. + /// + /// \param[in] store_in_plasma If not null, this is used to spill to plasma. + /// \param[in] counter If not null, this enables ref counting for local objects, + /// and the `remove_after_get` flag for Get() will be ignored. CoreWorkerMemoryStore( - std::function store_in_plasma = nullptr); + std::function store_in_plasma = nullptr, + std::shared_ptr counter = nullptr); ~CoreWorkerMemoryStore(){}; /// Put an object with specified ID into object store. @@ -35,7 +42,7 @@ class CoreWorkerMemoryStore { /// \param[in] num_objects Number of objects that should appear. /// \param[in] timeout_ms Timeout in milliseconds, wait infinitely if it's negative. /// \param[in] remove_after_get When to remove the objects from store after `Get` - /// finishes. + /// finishes. This has no effect if ref counting is enabled. /// \param[out] results Result list of objects data. /// \return Status. Status Get(const std::vector &object_ids, int num_objects, int64_t timeout_ms, @@ -70,10 +77,22 @@ class CoreWorkerMemoryStore { /// \return Whether the store has the object. bool Contains(const ObjectID &object_id); + /// Returns the number of objects in this store. + /// + /// \return Count of objects in the store. + int Size() { + absl::MutexLock lock(&mu_); + return objects_.size(); + } + private: /// Optional callback for putting objects into the plasma store. std::function store_in_plasma_; + /// If enabled, holds a reference to local worker ref counter. TODO(ekl) make this + /// mandatory once Java is supported. + std::shared_ptr ref_counter_ = nullptr; + /// Protects the data structures below. absl::Mutex mu_;