diff --git a/BUILD.bazel b/BUILD.bazel index 0ea928129..7e63f4416 100644 --- a/BUILD.bazel +++ b/BUILD.bazel @@ -404,6 +404,16 @@ cc_binary( ], ) +cc_test( + name = "reference_count_test", + srcs = ["src/ray/core_worker/reference_count_test.cc"], + copts = COPTS, + deps = [ + ":core_worker_lib", + "@com_google_googletest//:gtest_main", + ], +) + cc_test( name = "scheduling_queue_test", srcs = ["src/ray/core_worker/test/scheduling_queue_test.cc"], diff --git a/python/ray/_raylet.pyx b/python/ray/_raylet.pyx index 2b31761e4..e72115ecd 100644 --- a/python/ray/_raylet.pyx +++ b/python/ray/_raylet.pyx @@ -1000,17 +1000,17 @@ cdef class CoreWorker: c_actor_id, &output)) return output - def add_active_object_id(self, ObjectID object_id): + def add_object_id_reference(self, ObjectID object_id): cdef: CObjectID c_object_id = object_id.native() # Note: faster to not release GIL for short-running op. - self.core_worker.get().AddActiveObjectID(c_object_id) + self.core_worker.get().AddObjectIDReference(c_object_id) - def remove_active_object_id(self, ObjectID object_id): + def remove_object_id_reference(self, ObjectID object_id): cdef: CObjectID c_object_id = object_id.native() # Note: faster to not release GIL for short-running op. - self.core_worker.get().RemoveActiveObjectID(c_object_id) + self.core_worker.get().RemoveObjectIDReference(c_object_id) # TODO: handle noreturn better cdef store_task_outputs( diff --git a/python/ray/includes/libcoreworker.pxd b/python/ray/includes/libcoreworker.pxd index 19c8dda53..073356385 100644 --- a/python/ray/includes/libcoreworker.pxd +++ b/python/ray/includes/libcoreworker.pxd @@ -101,8 +101,8 @@ cdef extern from "ray/core_worker/core_worker.h" nogil: CActorID DeserializeAndRegisterActorHandle(const c_string &bytes) CRayStatus SerializeActorHandle(const CActorID &actor_id, c_string *bytes) - void AddActiveObjectID(const CObjectID &object_id) - void RemoveActiveObjectID(const CObjectID &object_id) + void AddObjectIDReference(const CObjectID &object_id) + void RemoveObjectIDReference(const CObjectID &object_id) CRayStatus SetClientOptions(c_string client_name, int64_t limit) CRayStatus Put(const CRayObject &object, CObjectID *object_id) diff --git a/python/ray/includes/unique_ids.pxi b/python/ray/includes/unique_ids.pxi index 571ecedc3..65ce4755f 100644 --- a/python/ray/includes/unique_ids.pxi +++ b/python/ray/includes/unique_ids.pxi @@ -146,14 +146,14 @@ cdef class ObjectID(BaseID): # TODO(edoakes): there are dummy object IDs being created in # includes/task.pxi before the core worker is initialized. if hasattr(worker, "core_worker"): - worker.core_worker.add_active_object_id(self) + worker.core_worker.add_object_id_reference(self) self.in_core_worker = True def __dealloc__(self): if self.in_core_worker: try: worker = ray.worker.global_worker - worker.core_worker.remove_active_object_id(self) + worker.core_worker.remove_object_id_reference(self) except Exception as e: # There is a strange error in rllib that causes the above to # fail. Somehow the global 'ray' variable corresponding to the diff --git a/src/ray/core_worker/core_worker.cc b/src/ray/core_worker/core_worker.cc index d93ef6242..47971ef4c 100644 --- a/src/ray/core_worker/core_worker.cc +++ b/src/ray/core_worker/core_worker.cc @@ -1,7 +1,9 @@ -#include "ray/core_worker/core_worker.h" +#include + #include "ray/common/ray_config.h" #include "ray/common/task/task_util.h" #include "ray/core_worker/context.h" +#include "ray/core_worker/core_worker.h" #include "ray/core_worker/transport/raylet_transport.h" namespace { @@ -141,8 +143,14 @@ CoreWorker::CoreWorker(const WorkerType worker_type, const Language language, // Set timer to periodically send heartbeats containing active object IDs to the raylet. // If the heartbeat timeout is < 0, the heartbeats are disabled. if (RayConfig::instance().worker_heartbeat_timeout_milliseconds() >= 0) { - heartbeat_timer_.expires_from_now(boost::asio::chrono::milliseconds( - RayConfig::instance().worker_heartbeat_timeout_milliseconds())); + // Seed using current time. + std::srand(std::time(nullptr)); + // Randomly choose a time from [0, timeout]) to send the first heartbeat to avoid all + // workers sending heartbeats at the same time. + int64_t heartbeat_timeout = + std::rand() % RayConfig::instance().worker_heartbeat_timeout_milliseconds(); + heartbeat_timer_.expires_from_now( + boost::asio::chrono::milliseconds(heartbeat_timeout)); heartbeat_timer_.async_wait(boost::bind(&CoreWorker::ReportActiveObjectIDs, this)); } @@ -230,40 +238,19 @@ void CoreWorker::SetCurrentTaskId(const TaskID &task_id) { } } -void CoreWorker::AddActiveObjectID(const ObjectID &object_id) { - absl::MutexLock lock(&object_ref_mu_); - active_object_ids_.insert(object_id); - active_object_ids_updated_ = true; -} - -void CoreWorker::RemoveActiveObjectID(const ObjectID &object_id) { - absl::MutexLock lock(&object_ref_mu_); - if (active_object_ids_.erase(object_id)) { - active_object_ids_updated_ = true; - } else { - RAY_LOG(WARNING) << "Tried to erase non-existent object ID" << object_id; - } -} - void CoreWorker::ReportActiveObjectIDs() { - absl::MutexLock lock(&object_ref_mu_); - // Only send a heartbeat when the set of active object IDs has changed because the - // raylet only modifies the set of IDs when it receives a heartbeat. - // TODO(edoakes): this is currently commented out because this heartbeat causes the - // workers to die when the raylet crashes unexpectedly. Without this, they could - // hang idle forever because they wait for the raylet to push tasks via gRPC. - // if (active_object_ids_updated_) { - RAY_LOG(DEBUG) << "Sending " << active_object_ids_.size() << " object IDs to raylet."; - if (active_object_ids_.size() > RayConfig::instance().raylet_max_active_object_ids()) { - RAY_LOG(WARNING) << active_object_ids_.size() << "object IDs are currently in scope. " + std::unordered_set active_object_ids = + reference_counter_.GetAllInScopeObjectIDs(); + RAY_LOG(DEBUG) << "Sending " << active_object_ids.size() << " object IDs to raylet."; + if (active_object_ids.size() > RayConfig::instance().raylet_max_active_object_ids()) { + RAY_LOG(WARNING) << active_object_ids.size() << "object IDs are currently in scope. " << "This may lead to required objects being garbage collected."; } - std::unordered_set copy(active_object_ids_.begin(), active_object_ids_.end()); - if (!raylet_client_->ReportActiveObjectIDs(copy).ok()) { + + if (!raylet_client_->ReportActiveObjectIDs(active_object_ids).ok()) { RAY_LOG(ERROR) << "Raylet connection failed. Shutting down."; Shutdown(); } - // } // Reset the timer from the previous expiration time to avoid drift. heartbeat_timer_.expires_at( @@ -271,7 +258,6 @@ void CoreWorker::ReportActiveObjectIDs() { boost::asio::chrono::milliseconds( RayConfig::instance().worker_heartbeat_timeout_milliseconds())); heartbeat_timer_.async_wait(boost::bind(&CoreWorker::ReportActiveObjectIDs, this)); - active_object_ids_updated_ = false; } Status CoreWorker::SetClientOptions(std::string name, int64_t limit_bytes) { @@ -450,6 +436,33 @@ TaskID CoreWorker::GetCallerId() const { return caller_id; } +Status CoreWorker::SubmitTaskToRaylet(const TaskSpecification &task_spec) { + RAY_RETURN_NOT_OK(raylet_client_->SubmitTask(task_spec)); + + size_t num_returns = task_spec.NumReturns(); + if (task_spec.IsActorCreationTask() || task_spec.IsActorTask()) { + num_returns--; + } + + std::shared_ptr> task_deps = + std::make_shared>(); + for (size_t i = 0; i < task_spec.NumArgs(); i++) { + if (task_spec.ArgByRef(i)) { + for (size_t j = 0; j < task_spec.ArgIdCount(i); j++) { + task_deps->push_back(task_spec.ArgId(i, j)); + } + } + } + + if (task_deps->size() > 0) { + for (size_t i = 0; i < num_returns; i++) { + reference_counter_.SetDependencies(task_spec.ReturnId(i), task_deps); + } + } + + return Status::OK(); +} + Status CoreWorker::SubmitTask(const RayFunction &function, const std::vector &args, const TaskOptions &task_options, @@ -463,7 +476,7 @@ Status CoreWorker::SubmitTask(const RayFunction &function, worker_context_.GetCurrentTaskID(), next_task_index, GetCallerId(), function, args, task_options.num_returns, task_options.resources, {}, TaskTransportType::RAYLET, return_ids); - return raylet_client_->SubmitTask(builder.Build()); + return SubmitTaskToRaylet(builder.Build()); } Status CoreWorker::CreateActor(const RayFunction &function, @@ -494,9 +507,8 @@ Status CoreWorker::CreateActor(const RayFunction &function, RAY_CHECK(AddActorHandle(std::move(actor_handle))) << "Actor " << actor_id << " already exists"; - RAY_RETURN_NOT_OK(raylet_client_->SubmitTask(builder.Build())); *return_actor_id = actor_id; - return Status::OK(); + return SubmitTaskToRaylet(builder.Build()); } Status CoreWorker::SubmitActorTask(const ActorID &actor_id, const RayFunction &function, @@ -534,7 +546,7 @@ Status CoreWorker::SubmitActorTask(const ActorID &actor_id, const RayFunction &f if (is_direct_call) { status = direct_actor_submitter_->SubmitTask(builder.Build()); } else { - status = raylet_client_->SubmitTask(builder.Build()); + status = SubmitTaskToRaylet(builder.Build()); } return status; } @@ -699,6 +711,15 @@ Status CoreWorker::ExecuteTask(const TaskSpecification &task_spec, } } + if (task_spec.IsNormalTask() && reference_counter_.NumObjectIDsInScope() != 0) { + RAY_LOG(ERROR) + << "There were " << reference_counter_.NumObjectIDsInScope() + << " ObjectIDs left in scope after executing task " << task_spec.TaskId() + << ". This is either caused by keeping references to ObjectIDs in Python between " + "tasks (e.g., in global variables) or indicates a problem with Ray's " + "reference counting, and may cause problems in the object store."; + } + SetCurrentTaskId(TaskID::Nil()); worker_context_.ResetCurrentTask(task_spec); return status; diff --git a/src/ray/core_worker/core_worker.h b/src/ray/core_worker/core_worker.h index 4cb78a030..15789b051 100644 --- a/src/ray/core_worker/core_worker.h +++ b/src/ray/core_worker/core_worker.h @@ -3,7 +3,6 @@ #include "absl/base/thread_annotations.h" #include "absl/container/flat_hash_map.h" -#include "absl/container/flat_hash_set.h" #include "absl/synchronization/mutex.h" #include "ray/common/buffer.h" @@ -11,6 +10,7 @@ #include "ray/core_worker/common.h" #include "ray/core_worker/context.h" #include "ray/core_worker/profiling.h" +#include "ray/core_worker/reference_count.h" #include "ray/core_worker/store_provider/memory_store_provider.h" #include "ray/core_worker/store_provider/plasma_store_provider.h" #include "ray/core_worker/transport/direct_actor_transport.h" @@ -97,13 +97,19 @@ class CoreWorker { actor_id_ = actor_id; } - // Add this object ID to the set of active object IDs that is sent to the raylet - // in the heartbeat messsage. - void AddActiveObjectID(const ObjectID &object_id) LOCKS_EXCLUDED(object_ref_mu_); + /// Increase the reference count for this object ID. + /// + /// \param[in] object_id The object ID to increase the reference count for. + void AddObjectIDReference(const ObjectID &object_id) { + reference_counter_.AddReference(object_id); + } - // Remove this object ID from the set of active object IDs that is sent to the raylet - // in the heartbeat messsage. - void RemoveActiveObjectID(const ObjectID &object_id) LOCKS_EXCLUDED(object_ref_mu_); + /// Decrease the reference count for this object ID. + /// + /// \param[in] object_id The object ID to decrease the reference count for. + void RemoveObjectIDReference(const ObjectID &object_id) { + reference_counter_.RemoveReference(object_id); + } /// /// Public methods related to storing and retrieving objects. @@ -343,12 +349,15 @@ class CoreWorker { void Shutdown(); /// Send the list of active object IDs to the raylet. - void ReportActiveObjectIDs() LOCKS_EXCLUDED(object_ref_mu_); + void ReportActiveObjectIDs(); /// /// Private methods related to task submission. /// + /// Submit the task to the raylet and add its dependencies to the reference counter. + Status SubmitTaskToRaylet(const TaskSpecification &task_spec); + /// Give this worker a handle to an actor. /// /// This handle will remain as long as the current actor or task is @@ -451,21 +460,8 @@ class CoreWorker { // Thread that runs a boost::asio service to process IO events. std::thread io_thread_; - /// - /// Fields related to ref counting objects. - /// - - /// Protects access to the set of active object ids. Since this set is updated - /// very frequently, it is faster to lock around accesses rather than serialize - /// accesses via the event loop. - absl::Mutex object_ref_mu_; - - /// Set of object IDs that are in scope in the language worker. - absl::flat_hash_set active_object_ids_ GUARDED_BY(object_ref_mu_); - - /// Indicates whether or not the active_object_ids map has changed since the - /// last time it was sent to the raylet. - bool active_object_ids_updated_ GUARDED_BY(object_ref_mu_) = false; + // Keeps track of object ID reference counts. + ReferenceCounter reference_counter_; /// /// Fields related to storing and retrieving objects. diff --git a/src/ray/core_worker/reference_count.cc b/src/ray/core_worker/reference_count.cc new file mode 100644 index 000000000..9e1bd178f --- /dev/null +++ b/src/ray/core_worker/reference_count.cc @@ -0,0 +1,100 @@ +#include "ray/core_worker/reference_count.h" + +namespace ray { + +void ReferenceCounter::AddReference(const ObjectID &object_id) { + absl::MutexLock lock(&mutex_); + AddReferenceInternal(object_id); +} + +void ReferenceCounter::AddReferenceInternal(const ObjectID &object_id) { + auto entry = object_id_refs_.find(object_id); + if (entry == object_id_refs_.end()) { + object_id_refs_[object_id] = std::make_pair(1, nullptr); + } else { + entry->second.first++; + } +} + +void ReferenceCounter::SetDependencies( + const ObjectID &object_id, std::shared_ptr> dependencies) { + absl::MutexLock lock(&mutex_); + + auto entry = object_id_refs_.find(object_id); + if (entry == object_id_refs_.end()) { + // If the entry doesn't exist, we initialize the direct reference count to zero + // because this corresponds to a submitted task whose return ObjectID will be created + // in the frontend language, incrementing the reference count. + object_id_refs_[object_id] = std::make_pair(0, dependencies); + } else { + RAY_CHECK(!entry->second.second); + entry->second.second = dependencies; + } + + for (const ObjectID &dependency_id : *dependencies) { + AddReferenceInternal(dependency_id); + } +} + +void ReferenceCounter::RemoveReference(const ObjectID &object_id) { + absl::MutexLock lock(&mutex_); + RemoveReferenceRecursive(object_id); +} + +void ReferenceCounter::RemoveReferenceRecursive(const ObjectID &object_id) { + auto entry = object_id_refs_.find(object_id); + if (entry == object_id_refs_.end()) { + RAY_LOG(WARNING) << "Tried to decrease ref count for nonexistent object ID: " + << object_id; + return; + } + if (--entry->second.first == 0) { + // If the reference count reached 0, decrease the reference count for each dependency. + if (entry->second.second) { + for (const ObjectID &pending_task_object_id : *entry->second.second) { + RemoveReferenceRecursive(pending_task_object_id); + } + } + object_id_refs_.erase(object_id); + } +} + +size_t ReferenceCounter::NumObjectIDsInScope() const { + absl::MutexLock lock(&mutex_); + return object_id_refs_.size(); +} + +std::unordered_set ReferenceCounter::GetAllInScopeObjectIDs() const { + absl::MutexLock lock(&mutex_); + std::unordered_set in_scope_object_ids; + in_scope_object_ids.reserve(object_id_refs_.size()); + for (auto it : object_id_refs_) { + in_scope_object_ids.insert(it.first); + } + return in_scope_object_ids; +} + +void ReferenceCounter::LogDebugString() const { + absl::MutexLock lock(&mutex_); + + RAY_LOG(DEBUG) << "ReferenceCounter state:"; + if (object_id_refs_.empty()) { + RAY_LOG(DEBUG) << "\tEMPTY"; + return; + } + + for (const auto &entry : object_id_refs_) { + RAY_LOG(DEBUG) << "\t" << entry.first.Hex(); + RAY_LOG(DEBUG) << "\t\treference count: " << entry.second.first; + RAY_LOG(DEBUG) << "\t\tdependencies: "; + if (!entry.second.second) { + RAY_LOG(DEBUG) << "\t\t\tNULL"; + } else { + for (const ObjectID &pending_task_object_id : *entry.second.second) { + RAY_LOG(DEBUG) << "\t\t\t" << pending_task_object_id.Hex(); + } + } + } +} + +} // namespace ray diff --git a/src/ray/core_worker/reference_count.h b/src/ray/core_worker/reference_count.h new file mode 100644 index 000000000..bde2f2ecd --- /dev/null +++ b/src/ray/core_worker/reference_count.h @@ -0,0 +1,70 @@ +#ifndef RAY_CORE_WORKER_REF_COUNT_H +#define RAY_CORE_WORKER_REF_COUNT_H + +#include "absl/base/thread_annotations.h" +#include "absl/container/flat_hash_map.h" +#include "absl/synchronization/mutex.h" + +#include "ray/common/id.h" +#include "ray/util/logging.h" + +namespace ray { + +/// Class used by the core worker to keep track of ObjectID reference counts for garbage +/// collection. This class is thread safe. +class ReferenceCounter { + public: + ReferenceCounter() {} + + ~ReferenceCounter() {} + + /// Increase the reference count for the ObjectID by one. If there is no + /// entry for the ObjectID, one will be created with no dependencies. + void AddReference(const ObjectID &object_id) LOCKS_EXCLUDED(mutex_); + + /// Decrease the reference count for the ObjectID by one. If the reference count reaches + /// zero, it will be erased from the map and the reference count for all of its + /// dependencies will be decreased be one. + void RemoveReference(const ObjectID &object_id) LOCKS_EXCLUDED(mutex_); + + /// Set the dependencies for the ObjectID. Dependencies for each ObjectID must be + /// set at most once. The direct reference count for the ObjectID is set to zero and the + /// reference count for each dependency is incremented. + void SetDependencies(const ObjectID &object_id, + std::shared_ptr> dependencies) + LOCKS_EXCLUDED(mutex_); + + /// Returns the total number of ObjectIDs currently in scope. + size_t NumObjectIDsInScope() const LOCKS_EXCLUDED(mutex_); + + /// Returns a set of all ObjectIDs currently in scope (i.e., nonzero reference count). + std::unordered_set GetAllInScopeObjectIDs() const LOCKS_EXCLUDED(mutex_); + + /// Dumps information about all currently tracked references to RAY_LOG(DEBUG). + void LogDebugString() const LOCKS_EXCLUDED(mutex_); + + private: + /// Helper function with the same semantics as AddReference to allow adding a reference + /// while already holding mutex_. + void AddReferenceInternal(const ObjectID &object_id) EXCLUSIVE_LOCKS_REQUIRED(mutex_); + + /// Recursive helper function for decreasing reference counts. Will recursively call + /// itself on any dependencies whose reference count reaches zero as a result of + /// removing the reference. + void RemoveReferenceRecursive(const ObjectID &object_id) + EXCLUSIVE_LOCKS_REQUIRED(mutex_); + + /// Protects access to the reference counting state. + mutable absl::Mutex mutex_; + + /// Holds all direct reference counts and dependency information for tracked ObjectIDs. + /// Dependencies are stored as shared_ptrs because the same set of dependencies can be + /// shared among multiple entries. For example, when a task has multiple return values, + /// the entry for each return ObjectID depends on all task dependencies. + absl::flat_hash_map>>> + object_id_refs_ GUARDED_BY(mutex_); +}; + +} // namespace ray + +#endif // RAY_CORE_WORKER_REF_COUNT_H diff --git a/src/ray/core_worker/reference_count_test.cc b/src/ray/core_worker/reference_count_test.cc new file mode 100644 index 000000000..7a0f640f2 --- /dev/null +++ b/src/ray/core_worker/reference_count_test.cc @@ -0,0 +1,131 @@ +#include + +#include "gtest/gtest.h" +#include "ray/core_worker/reference_count.h" + +namespace ray { + +class ReferenceCountTest : public ::testing::Test { + protected: + std::unique_ptr rc; + virtual void SetUp() { rc = std::unique_ptr(new ReferenceCounter); } + + virtual void TearDown() {} +}; + +// Tests basic incrementing/decrementing of direct reference counts. An entry should only +// be removed once its reference count reaches zero. +TEST_F(ReferenceCountTest, TestBasic) { + ObjectID id = ObjectID::FromRandom(); + rc->AddReference(id); + ASSERT_EQ(rc->NumObjectIDsInScope(), 1); + rc->AddReference(id); + ASSERT_EQ(rc->NumObjectIDsInScope(), 1); + rc->AddReference(id); + ASSERT_EQ(rc->NumObjectIDsInScope(), 1); + rc->RemoveReference(id); + ASSERT_EQ(rc->NumObjectIDsInScope(), 1); + rc->RemoveReference(id); + ASSERT_EQ(rc->NumObjectIDsInScope(), 1); + rc->RemoveReference(id); + ASSERT_EQ(rc->NumObjectIDsInScope(), 0); +} + +// Tests the basic logic for dependencies - when an ObjectID with dependencies +// goes out of scope (i.e., reference count reaches zero), all of its dependencies +// should have their reference count decremented and be removed if it reaches zero. +TEST_F(ReferenceCountTest, TestDependencies) { + ObjectID id1 = ObjectID::FromRandom(); + ObjectID id2 = ObjectID::FromRandom(); + ObjectID id3 = ObjectID::FromRandom(); + + std::shared_ptr> deps = std::make_shared>(); + deps->push_back(id2); + deps->push_back(id3); + rc->SetDependencies(id1, deps); + + rc->AddReference(id1); + rc->AddReference(id1); + rc->AddReference(id3); + ASSERT_EQ(rc->NumObjectIDsInScope(), 3); + + rc->RemoveReference(id1); + ASSERT_EQ(rc->NumObjectIDsInScope(), 3); + rc->RemoveReference(id1); + ASSERT_EQ(rc->NumObjectIDsInScope(), 1); + + rc->RemoveReference(id3); + ASSERT_EQ(rc->NumObjectIDsInScope(), 0); +} + +// Tests the case where two entries share the same set of dependencies. When one +// entry goes out of scope, it should decrease the reference count for the dependencies +// but they should still be nonzero until the second entry goes out of scope and all +// direct dependencies to the dependencies are removed. +TEST_F(ReferenceCountTest, TestSharedDependencies) { + ObjectID id1 = ObjectID::FromRandom(); + ObjectID id2 = ObjectID::FromRandom(); + ObjectID id3 = ObjectID::FromRandom(); + ObjectID id4 = ObjectID::FromRandom(); + + std::shared_ptr> deps = std::make_shared>(); + deps->push_back(id3); + deps->push_back(id4); + rc->SetDependencies(id1, deps); + rc->SetDependencies(id2, deps); + + rc->AddReference(id1); + rc->AddReference(id2); + rc->AddReference(id4); + ASSERT_EQ(rc->NumObjectIDsInScope(), 4); + + rc->RemoveReference(id1); + ASSERT_EQ(rc->NumObjectIDsInScope(), 3); + rc->RemoveReference(id2); + ASSERT_EQ(rc->NumObjectIDsInScope(), 1); + + rc->RemoveReference(id4); + ASSERT_EQ(rc->NumObjectIDsInScope(), 0); +} + +// Tests the case when an entry has a dependency that itself has a +// dependency. In this case, when the first entry goes out of scope +// it should decrease the reference count for its dependency, causing +// that entry to go out of scope and decrease its dependencies' reference counts. +TEST_F(ReferenceCountTest, TestRecursiveDependencies) { + ObjectID id1 = ObjectID::FromRandom(); + ObjectID id2 = ObjectID::FromRandom(); + ObjectID id3 = ObjectID::FromRandom(); + ObjectID id4 = ObjectID::FromRandom(); + + std::shared_ptr> deps1 = + std::make_shared>(); + deps1->push_back(id2); + rc->SetDependencies(id1, deps1); + + std::shared_ptr> deps2 = + std::make_shared>(); + deps2->push_back(id3); + deps2->push_back(id4); + rc->SetDependencies(id2, deps2); + + rc->AddReference(id1); + rc->AddReference(id2); + rc->AddReference(id4); + ASSERT_EQ(rc->NumObjectIDsInScope(), 4); + + rc->RemoveReference(id2); + ASSERT_EQ(rc->NumObjectIDsInScope(), 4); + rc->RemoveReference(id1); + ASSERT_EQ(rc->NumObjectIDsInScope(), 1); + + rc->RemoveReference(id4); + ASSERT_EQ(rc->NumObjectIDsInScope(), 0); +} + +} // namespace ray + +int main(int argc, char **argv) { + ::testing::InitGoogleTest(&argc, argv); + return RUN_ALL_TESTS(); +}