From 7f52d019ca6d0a417a7c1cdb62205da93f3bbb09 Mon Sep 17 00:00:00 2001 From: Eric Liang Date: Thu, 21 Nov 2019 10:13:53 -0800 Subject: [PATCH] Inline memory_store_provider into memory_store (#6217) --- src/ray/core_worker/core_worker.cc | 26 +++--- src/ray/core_worker/core_worker.h | 7 +- src/ray/core_worker/reference_count_test.cc | 4 +- .../memory_store/memory_store.cc | 12 ++- .../memory_store/memory_store.h | 53 +++++++++++- .../store_provider/memory_store_provider.cc | 82 ------------------- .../store_provider/memory_store_provider.h | 52 ------------ src/ray/core_worker/test/core_worker_test.cc | 9 +- .../test/direct_actor_transport_test.cc | 7 +- .../test/direct_task_transport_test.cc | 38 +++------ .../transport/dependency_resolver.h | 10 +-- .../transport/direct_actor_transport.cc | 7 +- .../transport/direct_actor_transport.h | 20 ++--- .../transport/direct_task_transport.h | 14 ++-- 14 files changed, 118 insertions(+), 223 deletions(-) delete mode 100644 src/ray/core_worker/store_provider/memory_store_provider.cc delete mode 100644 src/ray/core_worker/store_provider/memory_store_provider.h diff --git a/src/ray/core_worker/core_worker.cc b/src/ray/core_worker/core_worker.cc index bf610a7cc..8201afc45 100644 --- a/src/ray/core_worker/core_worker.cc +++ b/src/ray/core_worker/core_worker.cc @@ -166,7 +166,6 @@ CoreWorker::CoreWorker(const WorkerType worker_type, const Language language, RAY_CHECK_OK(plasma_store_provider_->Put(obj, obj_id)); }, ref_counting_enabled ? reference_counter_ : nullptr, raylet_client_)); - memory_store_provider_.reset(new CoreWorkerMemoryStoreProvider(memory_store_)); // Create an entry for the driver task in the task table. This task is // added immediately with status RUNNING. This allows us to push errors @@ -195,7 +194,7 @@ CoreWorker::CoreWorker(const WorkerType worker_type, const Language language, new rpc::CoreWorkerClient(addr.first, addr.second, *client_call_manager_)); }; direct_actor_submitter_ = std::unique_ptr( - new CoreWorkerDirectActorTaskSubmitter(client_factory, memory_store_provider_)); + new CoreWorkerDirectActorTaskSubmitter(client_factory, memory_store_)); direct_task_submitter_ = std::unique_ptr(new CoreWorkerDirectTaskSubmitter( @@ -206,7 +205,7 @@ CoreWorker::CoreWorker(const WorkerType worker_type, const Language language, return std::shared_ptr( new RayletClient(std::move(grpc_client))); }, - memory_store_provider_)); + memory_store_)); } CoreWorker::~CoreWorker() { @@ -345,9 +344,8 @@ Status CoreWorker::Get(const std::vector &ids, const int64_t timeout_m local_timeout_ms = std::max(static_cast(0), timeout_ms - (current_time_ms() - start_time)); } - RAY_RETURN_NOT_OK(memory_store_provider_->Get(memory_object_ids, local_timeout_ms, - worker_context_, &result_map, - &got_exception)); + RAY_RETURN_NOT_OK(memory_store_->Get(memory_object_ids, local_timeout_ms, + worker_context_, &result_map, &got_exception)); } // If any of the objects have been promoted to plasma, then we retry their @@ -400,7 +398,7 @@ Status CoreWorker::Contains(const ObjectID &object_id, bool *has_object) { if (object_id.IsDirectCallType()) { // Note that the memory store returns false if the object value is // ErrorType::OBJECT_IN_PLASMA. - RAY_RETURN_NOT_OK(memory_store_provider_->Contains(object_id, &found)); + found = memory_store_->Contains(object_id); } if (!found) { // We check plasma as a fallback in all cases, since a direct call object @@ -451,9 +449,9 @@ Status CoreWorker::Wait(const std::vector &ids, int num_objects, if (static_cast(ready.size()) < num_objects && memory_object_ids.size() > 0) { // TODO(ekl) for memory objects that are ErrorType::OBJECT_IN_PLASMA, we should // consider waiting on them in plasma as well to ensure they are local. - RAY_RETURN_NOT_OK(memory_store_provider_->Wait( - memory_object_ids, num_objects - static_cast(ready.size()), - /*timeout_ms=*/0, worker_context_, &ready)); + RAY_RETURN_NOT_OK(memory_store_->Wait(memory_object_ids, + num_objects - static_cast(ready.size()), + /*timeout_ms=*/0, worker_context_, &ready)); } RAY_CHECK(static_cast(ready.size()) <= num_objects); @@ -474,9 +472,9 @@ Status CoreWorker::Wait(const std::vector &ids, int num_objects, std::max(0, static_cast(timeout_ms - (current_time_ms() - start_time))); } if (static_cast(ready.size()) < num_objects && memory_object_ids.size() > 0) { - RAY_RETURN_NOT_OK(memory_store_provider_->Wait( - memory_object_ids, num_objects - static_cast(ready.size()), timeout_ms, - worker_context_, &ready)); + RAY_RETURN_NOT_OK(memory_store_->Wait(memory_object_ids, + num_objects - static_cast(ready.size()), + timeout_ms, worker_context_, &ready)); } RAY_CHECK(static_cast(ready.size()) <= num_objects); } @@ -498,7 +496,7 @@ Status CoreWorker::Delete(const std::vector &object_ids, bool local_on RAY_RETURN_NOT_OK(plasma_store_provider_->Delete(plasma_object_ids, local_only, delete_creating_tasks)); - RAY_RETURN_NOT_OK(memory_store_provider_->Delete(memory_object_ids)); + memory_store_->Delete(memory_object_ids); return Status::OK(); } diff --git a/src/ray/core_worker/core_worker.h b/src/ray/core_worker/core_worker.h index 67cae483f..c0299499d 100644 --- a/src/ray/core_worker/core_worker.h +++ b/src/ray/core_worker/core_worker.h @@ -9,7 +9,7 @@ #include "ray/core_worker/context.h" #include "ray/core_worker/profiling.h" #include "ray/core_worker/reference_count.h" -#include "ray/core_worker/store_provider/memory_store_provider.h" +#include "ray/core_worker/store_provider/memory_store/memory_store.h" #include "ray/core_worker/store_provider/plasma_store_provider.h" #include "ray/core_worker/transport/direct_actor_transport.h" #include "ray/core_worker/transport/direct_task_transport.h" @@ -495,15 +495,12 @@ class CoreWorker { /// Fields related to storing and retrieving objects. /// - /// In-memory store for return objects. This is used for `MEMORY` store provider. + /// In-memory store for return objects. std::shared_ptr memory_store_; /// Plasma store interface. std::shared_ptr plasma_store_provider_; - /// In-memory store interface. - std::shared_ptr memory_store_provider_; - /// /// Fields related to task submission. /// diff --git a/src/ray/core_worker/reference_count_test.cc b/src/ray/core_worker/reference_count_test.cc index c904f22db..004d74efb 100644 --- a/src/ray/core_worker/reference_count_test.cc +++ b/src/ray/core_worker/reference_count_test.cc @@ -153,12 +153,12 @@ TEST(MemoryStoreIntegrationTest, TestSimple) { CoreWorkerMemoryStore store(nullptr, rc); // Tests putting an object with no references is ignored. - RAY_CHECK_OK(store.Put(id2, buffer)); + RAY_CHECK_OK(store.Put(buffer, id2)); ASSERT_EQ(store.Size(), 0); // Tests ref counting overrides remove after get option. rc->AddReference(id1); - RAY_CHECK_OK(store.Put(id1, buffer)); + RAY_CHECK_OK(store.Put(buffer, id1)); ASSERT_EQ(store.Size(), 1); std::vector> results; WorkerContext ctx(WorkerType::WORKER, JobID::Nil()); diff --git a/src/ray/core_worker/store_provider/memory_store/memory_store.cc b/src/ray/core_worker/store_provider/memory_store/memory_store.cc index 1154645bc..e8525f788 100644 --- a/src/ray/core_worker/store_provider/memory_store/memory_store.cc +++ b/src/ray/core_worker/store_provider/memory_store/memory_store.cc @@ -2,7 +2,6 @@ #include "ray/common/ray_config.h" #include "ray/core_worker/context.h" #include "ray/core_worker/core_worker.h" -#include "ray/core_worker/store_provider/memory_store_provider.h" namespace ray { @@ -150,7 +149,7 @@ std::shared_ptr CoreWorkerMemoryStore::GetOrPromoteToPlasma( return nullptr; } -Status CoreWorkerMemoryStore::Put(const ObjectID &object_id, const RayObject &object) { +Status CoreWorkerMemoryStore::Put(const RayObject &object, const ObjectID &object_id) { RAY_CHECK(object_id.IsDirectCallType()); std::vector)>> async_callbacks; auto object_entry = @@ -161,7 +160,7 @@ Status CoreWorkerMemoryStore::Put(const ObjectID &object_id, const RayObject &ob auto iter = objects_.find(object_id); if (iter != objects_.end()) { - return Status::ObjectExists("object already exists in the memory store"); + return Status::OK(); // Object already exists in the store, which is fine. } auto async_callback_it = object_async_get_requests_.find(object_id); @@ -313,6 +312,13 @@ Status CoreWorkerMemoryStore::Get(const std::vector &object_ids, } } +void CoreWorkerMemoryStore::Delete(const absl::flat_hash_set &object_ids) { + absl::MutexLock lock(&mu_); + for (const auto &object_id : object_ids) { + objects_.erase(object_id); + } +} + void CoreWorkerMemoryStore::Delete(const std::vector &object_ids) { absl::MutexLock lock(&mu_); for (const auto &object_id : object_ids) { diff --git a/src/ray/core_worker/store_provider/memory_store/memory_store.h b/src/ray/core_worker/store_provider/memory_store/memory_store.h index ef94e3373..ab6fe592a 100644 --- a/src/ray/core_worker/store_provider/memory_store/memory_store.h +++ b/src/ray/core_worker/store_provider/memory_store/memory_store.h @@ -34,10 +34,10 @@ class CoreWorkerMemoryStore { /// Put an object with specified ID into object store. /// - /// \param[in] object_id Object ID specified by user. /// \param[in] object The ray object. + /// \param[in] object_id Object ID specified by user. /// \return Status. - Status Put(const ObjectID &object_id, const RayObject &object); + Status Put(const RayObject &object, const ObjectID &object_id); /// Get a list of objects from the object store. /// @@ -53,6 +53,49 @@ class CoreWorkerMemoryStore { const WorkerContext &ctx, bool remove_after_get, std::vector> *results); + /// Convenience wrapper around Get() that stores results in a given result map. + Status Get(const absl::flat_hash_set &object_ids, int64_t timeout_ms, + const WorkerContext &ctx, + absl::flat_hash_map> *results, + bool *got_exception) { + const std::vector id_vector(object_ids.begin(), object_ids.end()); + std::vector> result_objects; + RAY_RETURN_NOT_OK( + Get(id_vector, id_vector.size(), timeout_ms, ctx, true, &result_objects)); + + for (size_t i = 0; i < id_vector.size(); i++) { + if (result_objects[i] != nullptr) { + (*results)[id_vector[i]] = result_objects[i]; + if (result_objects[i]->IsException()) { + *got_exception = true; + } + } + } + return Status::OK(); + } + + /// Convenience wrapper around Get() that stores ready objects in a given result set. + Status Wait(const absl::flat_hash_set &object_ids, int num_objects, + int64_t timeout_ms, const WorkerContext &ctx, + absl::flat_hash_set *ready) { + std::vector id_vector(object_ids.begin(), object_ids.end()); + std::vector> result_objects; + RAY_CHECK(object_ids.size() == id_vector.size()); + auto status = Get(id_vector, num_objects, timeout_ms, ctx, false, &result_objects); + // Ignore TimedOut statuses since we return ready objects explicitly. + if (!status.IsTimedOut()) { + RAY_RETURN_NOT_OK(status); + } + + for (size_t i = 0; i < id_vector.size(); i++) { + if (result_objects[i] != nullptr) { + ready->insert(id_vector[i]); + } + } + + return Status::OK(); + } + /// Asynchronously get an object from the object store. The object will not be removed /// from storage after GetAsync (TODO(ekl): integrate this with object GC). /// @@ -70,6 +113,12 @@ class CoreWorkerMemoryStore { /// \return pointer to the local object, or nullptr if promoted to plasma. std::shared_ptr GetOrPromoteToPlasma(const ObjectID &object_id); + /// Delete a list of objects from the object store. + /// + /// \param[in] object_ids IDs of the objects to delete. + /// \return Void. + void Delete(const absl::flat_hash_set &object_ids); + /// Delete a list of objects from the object store. /// /// \param[in] object_ids IDs of the objects to delete. diff --git a/src/ray/core_worker/store_provider/memory_store_provider.cc b/src/ray/core_worker/store_provider/memory_store_provider.cc deleted file mode 100644 index 773c8c2dd..000000000 --- a/src/ray/core_worker/store_provider/memory_store_provider.cc +++ /dev/null @@ -1,82 +0,0 @@ -#include "ray/core_worker/store_provider/memory_store_provider.h" -#include -#include "ray/common/ray_config.h" -#include "ray/core_worker/context.h" -#include "ray/core_worker/core_worker.h" - -namespace ray { - -CoreWorkerMemoryStoreProvider::CoreWorkerMemoryStoreProvider( - std::shared_ptr store) - : store_(store) { - RAY_CHECK(store != nullptr); -} - -Status CoreWorkerMemoryStoreProvider::Put(const RayObject &object, - const ObjectID &object_id) { - RAY_CHECK(object_id.IsDirectCallType()); - Status status = store_->Put(object_id, object); - if (status.IsObjectExists()) { - // Object already exists in store, treat it as ok. - return Status::OK(); - } - return status; -} - -Status CoreWorkerMemoryStoreProvider::Get( - const absl::flat_hash_set &object_ids, int64_t timeout_ms, - const WorkerContext &ctx, - absl::flat_hash_map> *results, - bool *got_exception) { - const std::vector id_vector(object_ids.begin(), object_ids.end()); - std::vector> result_objects; - RAY_RETURN_NOT_OK( - store_->Get(id_vector, id_vector.size(), timeout_ms, ctx, true, &result_objects)); - - for (size_t i = 0; i < id_vector.size(); i++) { - if (result_objects[i] != nullptr) { - (*results)[id_vector[i]] = result_objects[i]; - if (result_objects[i]->IsException()) { - *got_exception = true; - } - } - } - return Status::OK(); -} - -Status CoreWorkerMemoryStoreProvider::Contains(const ObjectID &object_id, - bool *has_object) { - *has_object = store_->Contains(object_id); - return Status::OK(); -} - -Status CoreWorkerMemoryStoreProvider::Wait( - const absl::flat_hash_set &object_ids, int num_objects, int64_t timeout_ms, - const WorkerContext &ctx, absl::flat_hash_set *ready) { - std::vector id_vector(object_ids.begin(), object_ids.end()); - std::vector> result_objects; - RAY_CHECK(object_ids.size() == id_vector.size()); - auto status = - store_->Get(id_vector, num_objects, timeout_ms, ctx, false, &result_objects); - // Ignore TimedOut statuses since we return ready objects explicitly. - if (!status.IsTimedOut()) { - RAY_RETURN_NOT_OK(status); - } - - for (size_t i = 0; i < id_vector.size(); i++) { - if (result_objects[i] != nullptr) { - ready->insert(id_vector[i]); - } - } - - return Status::OK(); -} - -Status CoreWorkerMemoryStoreProvider::Delete( - const absl::flat_hash_set &object_ids) { - std::vector object_id_vector(object_ids.begin(), object_ids.end()); - store_->Delete(object_id_vector); - return Status::OK(); -} - -} // namespace ray diff --git a/src/ray/core_worker/store_provider/memory_store_provider.h b/src/ray/core_worker/store_provider/memory_store_provider.h deleted file mode 100644 index ce07f6633..000000000 --- a/src/ray/core_worker/store_provider/memory_store_provider.h +++ /dev/null @@ -1,52 +0,0 @@ -#ifndef RAY_CORE_WORKER_MEMORY_STORE_PROVIDER_H -#define RAY_CORE_WORKER_MEMORY_STORE_PROVIDER_H - -#include "absl/container/flat_hash_map.h" -#include "absl/container/flat_hash_set.h" -#include "ray/common/buffer.h" -#include "ray/common/id.h" -#include "ray/common/status.h" -#include "ray/core_worker/common.h" -#include "ray/core_worker/context.h" -#include "ray/core_worker/store_provider/memory_store/memory_store.h" - -namespace ray { - -/// The class provides implementations for accessing local process memory store. -/// An example usage for this is to retrieve the returned objects from direct -/// actor call (see direct_actor_transport.cc). -/// See `CoreWorkerStoreProvider` for the semantics of public methods. -class CoreWorkerMemoryStoreProvider { - public: - CoreWorkerMemoryStoreProvider(std::shared_ptr store); - - void GetAsync(const ObjectID &object_id, - std::function)> callback) { - store_->GetAsync(object_id, callback); - } - - Status Put(const RayObject &object, const ObjectID &object_id); - - Status Get(const absl::flat_hash_set &object_ids, int64_t timeout_ms, - const WorkerContext &ctx, - absl::flat_hash_map> *results, - bool *got_exception); - - Status Contains(const ObjectID &object_id, bool *has_object); - - /// Note that `num_objects` must equal to number of items in `object_ids`. - Status Wait(const absl::flat_hash_set &object_ids, int num_objects, - int64_t timeout_ms, const WorkerContext &ctx, - absl::flat_hash_set *ready); - - /// Note that `local_only` must be true, and `delete_creating_tasks` must be false here. - Status Delete(const absl::flat_hash_set &object_ids); - - private: - /// Implementation. - std::shared_ptr store_; -}; - -} // namespace ray - -#endif // RAY_CORE_WORKER_MEMORY_STORE_PROVIDER_H diff --git a/src/ray/core_worker/test/core_worker_test.cc b/src/ray/core_worker/test/core_worker_test.cc index fe16a624f..8fde21c06 100644 --- a/src/ray/core_worker/test/core_worker_test.cc +++ b/src/ray/core_worker/test/core_worker_test.cc @@ -10,7 +10,7 @@ #include "ray/core_worker/core_worker.h" #include "ray/core_worker/transport/direct_actor_transport.h" -#include "ray/core_worker/store_provider/memory_store_provider.h" +#include "ray/core_worker/store_provider/memory_store/memory_store.h" #include "ray/raylet/raylet_client.h" #include "src/ray/protobuf/core_worker.pb.h" @@ -619,11 +619,8 @@ TEST_F(ZeroNodeTest, TestActorHandle) { } TEST_F(SingleNodeTest, TestMemoryStoreProvider) { - std::shared_ptr memory_store = + std::shared_ptr provider_ptr = std::make_shared(); - std::unique_ptr provider_ptr = - std::unique_ptr( - new CoreWorkerMemoryStoreProvider(memory_store)); auto &provider = *provider_ptr; @@ -682,7 +679,7 @@ TEST_F(SingleNodeTest, TestMemoryStoreProvider) { // clear the reference held. results.clear(); - RAY_CHECK_OK(provider.Delete(ids_set)); + provider.Delete(ids_set); usleep(200 * 1000); ASSERT_TRUE(provider.Get(ids_set, 0, ctx, &results, &got_exception).IsTimedOut()); diff --git a/src/ray/core_worker/test/direct_actor_transport_test.cc b/src/ray/core_worker/test/direct_actor_transport_test.cc index 0a8707f07..1ef7a006d 100644 --- a/src/ray/core_worker/test/direct_actor_transport_test.cc +++ b/src/ray/core_worker/test/direct_actor_transport_test.cc @@ -2,7 +2,6 @@ #include "ray/common/task/task_spec.h" #include "ray/core_worker/store_provider/memory_store/memory_store.h" -#include "ray/core_worker/store_provider/memory_store_provider.h" #include "ray/core_worker/transport/direct_task_transport.h" #include "ray/raylet/raylet_client.h" #include "ray/rpc/worker/core_worker_client.h" @@ -38,14 +37,12 @@ class DirectActorTransportTest : public ::testing::Test { public: DirectActorTransportTest() : worker_client_(std::shared_ptr(new MockWorkerClient())), - ptr_(std::shared_ptr(new CoreWorkerMemoryStore())), - store_(std::make_shared(ptr_)), + store_(std::shared_ptr(new CoreWorkerMemoryStore())), submitter_([&](const rpc::WorkerAddress &addr) { return worker_client_; }, store_) {} std::shared_ptr worker_client_; - std::shared_ptr ptr_; - std::shared_ptr store_; + std::shared_ptr store_; CoreWorkerDirectActorTaskSubmitter submitter_; }; diff --git a/src/ray/core_worker/test/direct_task_transport_test.cc b/src/ray/core_worker/test/direct_task_transport_test.cc index 1bb35d0f0..0eaed321c 100644 --- a/src/ray/core_worker/test/direct_task_transport_test.cc +++ b/src/ray/core_worker/test/direct_task_transport_test.cc @@ -2,7 +2,6 @@ #include "ray/common/task/task_spec.h" #include "ray/core_worker/store_provider/memory_store/memory_store.h" -#include "ray/core_worker/store_provider/memory_store_provider.h" #include "ray/core_worker/transport/direct_task_transport.h" #include "ray/raylet/raylet_client.h" #include "ray/rpc/worker/core_worker_client.h" @@ -74,7 +73,7 @@ TEST(TestMemoryStore, TestPromoteToPlasma) { ObjectID obj1 = ObjectID::FromRandom().WithTransportType(TaskTransportType::DIRECT); ObjectID obj2 = ObjectID::FromRandom().WithTransportType(TaskTransportType::DIRECT); auto data = GenerateRandomObject(); - ASSERT_TRUE(mem->Put(obj1, *data).ok()); + ASSERT_TRUE(mem->Put(*data, obj1).ok()); // Test getting an already existing object. ASSERT_TRUE(mem->GetOrPromoteToPlasma(obj1) != nullptr); @@ -83,7 +82,7 @@ TEST(TestMemoryStore, TestPromoteToPlasma) { // Testing getting an object that doesn't exist yet causes promotion. ASSERT_TRUE(mem->GetOrPromoteToPlasma(obj2) == nullptr); ASSERT_TRUE(num_plasma_puts == 0); - ASSERT_TRUE(mem->Put(obj2, *data).ok()); + ASSERT_TRUE(mem->Put(*data, obj2).ok()); ASSERT_TRUE(num_plasma_puts == 1); // The next time you get it, it's already there so no need to promote. @@ -92,8 +91,7 @@ TEST(TestMemoryStore, TestPromoteToPlasma) { } TEST(LocalDependencyResolverTest, TestNoDependencies) { - auto ptr = std::shared_ptr(new CoreWorkerMemoryStore()); - auto store = std::make_shared(ptr); + auto store = std::shared_ptr(new CoreWorkerMemoryStore()); LocalDependencyResolver resolver(store); TaskSpecification task; bool ok = false; @@ -102,8 +100,7 @@ TEST(LocalDependencyResolverTest, TestNoDependencies) { } TEST(LocalDependencyResolverTest, TestIgnorePlasmaDependencies) { - auto ptr = std::shared_ptr(new CoreWorkerMemoryStore()); - auto store = std::make_shared(ptr); + auto store = std::shared_ptr(new CoreWorkerMemoryStore()); LocalDependencyResolver resolver(store); ObjectID obj1 = ObjectID::FromRandom().WithTransportType(TaskTransportType::RAYLET); TaskSpecification task; @@ -116,8 +113,7 @@ TEST(LocalDependencyResolverTest, TestIgnorePlasmaDependencies) { } TEST(LocalDependencyResolverTest, TestHandlePlasmaPromotion) { - auto ptr = std::shared_ptr(new CoreWorkerMemoryStore()); - auto store = std::make_shared(ptr); + auto store = std::shared_ptr(new CoreWorkerMemoryStore()); LocalDependencyResolver resolver(store); ObjectID obj1 = ObjectID::FromRandom().WithTransportType(TaskTransportType::DIRECT); std::string meta = std::to_string(static_cast(rpc::ErrorType::OBJECT_IN_PLASMA)); @@ -138,8 +134,7 @@ TEST(LocalDependencyResolverTest, TestHandlePlasmaPromotion) { } TEST(LocalDependencyResolverTest, TestInlineLocalDependencies) { - auto ptr = std::shared_ptr(new CoreWorkerMemoryStore()); - auto store = std::make_shared(ptr); + auto store = std::shared_ptr(new CoreWorkerMemoryStore()); LocalDependencyResolver resolver(store); ObjectID obj1 = ObjectID::FromRandom().WithTransportType(TaskTransportType::DIRECT); ObjectID obj2 = ObjectID::FromRandom().WithTransportType(TaskTransportType::DIRECT); @@ -162,8 +157,7 @@ TEST(LocalDependencyResolverTest, TestInlineLocalDependencies) { } TEST(LocalDependencyResolverTest, TestInlinePendingDependencies) { - auto ptr = std::shared_ptr(new CoreWorkerMemoryStore()); - auto store = std::make_shared(ptr); + auto store = std::shared_ptr(new CoreWorkerMemoryStore()); LocalDependencyResolver resolver(store); ObjectID obj1 = ObjectID::FromRandom().WithTransportType(TaskTransportType::DIRECT); ObjectID obj2 = ObjectID::FromRandom().WithTransportType(TaskTransportType::DIRECT); @@ -190,8 +184,7 @@ TEST(LocalDependencyResolverTest, TestInlinePendingDependencies) { TEST(DirectTaskTransportTest, TestSubmitOneTask) { auto raylet_client = std::make_shared(); auto worker_client = std::shared_ptr(new MockWorkerClient()); - auto ptr = std::shared_ptr(new CoreWorkerMemoryStore()); - auto store = std::make_shared(ptr); + auto store = std::shared_ptr(new CoreWorkerMemoryStore()); auto factory = [&](const rpc::WorkerAddress &addr) { return worker_client; }; CoreWorkerDirectTaskSubmitter submitter(raylet_client, factory, nullptr, store); TaskSpecification task; @@ -212,8 +205,7 @@ TEST(DirectTaskTransportTest, TestSubmitOneTask) { TEST(DirectTaskTransportTest, TestHandleTaskFailure) { auto raylet_client = std::make_shared(); auto worker_client = std::shared_ptr(new MockWorkerClient()); - auto ptr = std::shared_ptr(new CoreWorkerMemoryStore()); - auto store = std::make_shared(ptr); + auto store = std::shared_ptr(new CoreWorkerMemoryStore()); auto factory = [&](const rpc::WorkerAddress &addr) { return worker_client; }; CoreWorkerDirectTaskSubmitter submitter(raylet_client, factory, nullptr, store); TaskSpecification task; @@ -230,8 +222,7 @@ TEST(DirectTaskTransportTest, TestHandleTaskFailure) { TEST(DirectTaskTransportTest, TestConcurrentWorkerLeases) { auto raylet_client = std::make_shared(); auto worker_client = std::shared_ptr(new MockWorkerClient()); - auto ptr = std::shared_ptr(new CoreWorkerMemoryStore()); - auto store = std::make_shared(ptr); + auto store = std::shared_ptr(new CoreWorkerMemoryStore()); auto factory = [&](const rpc::WorkerAddress &addr) { return worker_client; }; CoreWorkerDirectTaskSubmitter submitter(raylet_client, factory, nullptr, store); TaskSpecification task1; @@ -271,8 +262,7 @@ TEST(DirectTaskTransportTest, TestConcurrentWorkerLeases) { TEST(DirectTaskTransportTest, TestReuseWorkerLease) { auto raylet_client = std::make_shared(); auto worker_client = std::shared_ptr(new MockWorkerClient()); - auto ptr = std::shared_ptr(new CoreWorkerMemoryStore()); - auto store = std::make_shared(ptr); + auto store = std::shared_ptr(new CoreWorkerMemoryStore()); auto factory = [&](const rpc::WorkerAddress &addr) { return worker_client; }; CoreWorkerDirectTaskSubmitter submitter(raylet_client, factory, nullptr, store); TaskSpecification task1; @@ -314,8 +304,7 @@ TEST(DirectTaskTransportTest, TestReuseWorkerLease) { TEST(DirectTaskTransportTest, TestWorkerNotReusedOnError) { auto raylet_client = std::make_shared(); auto worker_client = std::shared_ptr(new MockWorkerClient()); - auto ptr = std::shared_ptr(new CoreWorkerMemoryStore()); - auto store = std::make_shared(ptr); + auto store = std::shared_ptr(new CoreWorkerMemoryStore()); auto factory = [&](const rpc::WorkerAddress &addr) { return worker_client; }; CoreWorkerDirectTaskSubmitter submitter(raylet_client, factory, nullptr, store); TaskSpecification task1; @@ -347,8 +336,7 @@ TEST(DirectTaskTransportTest, TestWorkerNotReusedOnError) { TEST(DirectTaskTransportTest, TestSpillback) { auto raylet_client = std::make_shared(); auto worker_client = std::shared_ptr(new MockWorkerClient()); - auto ptr = std::shared_ptr(new CoreWorkerMemoryStore()); - auto store = std::make_shared(ptr); + auto store = std::shared_ptr(new CoreWorkerMemoryStore()); auto factory = [&](const rpc::WorkerAddress &addr) { return worker_client; }; std::unordered_map> remote_lease_clients; diff --git a/src/ray/core_worker/transport/dependency_resolver.h b/src/ray/core_worker/transport/dependency_resolver.h index f6ee49d9d..16644c2bc 100644 --- a/src/ray/core_worker/transport/dependency_resolver.h +++ b/src/ray/core_worker/transport/dependency_resolver.h @@ -5,15 +5,15 @@ #include "ray/common/id.h" #include "ray/common/task/task_spec.h" -#include "ray/core_worker/store_provider/memory_store_provider.h" +#include "ray/core_worker/store_provider/memory_store/memory_store.h" namespace ray { // This class is thread-safe. class LocalDependencyResolver { public: - LocalDependencyResolver(std::shared_ptr store_provider) - : in_memory_store_(store_provider), num_pending_(0) {} + LocalDependencyResolver(std::shared_ptr store) + : in_memory_store_(store), num_pending_(0) {} /// Resolve all local and remote dependencies for the task, calling the specified /// callback when done. Direct call ids in the task specification will be resolved @@ -30,8 +30,8 @@ class LocalDependencyResolver { int NumPendingTasks() const { return num_pending_; } private: - /// The store provider. - std::shared_ptr in_memory_store_; + /// The in-memory store. + std::shared_ptr in_memory_store_; /// Number of tasks pending dependency resolution. std::atomic num_pending_; diff --git a/src/ray/core_worker/transport/direct_actor_transport.cc b/src/ray/core_worker/transport/direct_actor_transport.cc index 539808b65..4e19b0559 100644 --- a/src/ray/core_worker/transport/direct_actor_transport.cc +++ b/src/ray/core_worker/transport/direct_actor_transport.cc @@ -11,7 +11,7 @@ int64_t GetRequestNumber(const std::unique_ptr &request) { void TreatTaskAsFailed(const TaskID &task_id, int num_returns, const rpc::ErrorType &error_type, - std::shared_ptr &in_memory_store) { + std::shared_ptr &in_memory_store) { RAY_LOG(DEBUG) << "Treat task as failed. task_id: " << task_id << ", error_type: " << ErrorType_Name(error_type); for (int i = 0; i < num_returns; i++) { @@ -25,9 +25,8 @@ void TreatTaskAsFailed(const TaskID &task_id, int num_returns, } } -void WriteObjectsToMemoryStore( - const rpc::PushTaskReply &reply, - std::shared_ptr &in_memory_store) { +void WriteObjectsToMemoryStore(const rpc::PushTaskReply &reply, + std::shared_ptr &in_memory_store) { for (int i = 0; i < reply.return_objects_size(); i++) { const auto &return_object = reply.return_objects(i); ObjectID object_id = ObjectID::FromBinary(return_object.object_id()); diff --git a/src/ray/core_worker/transport/direct_actor_transport.h b/src/ray/core_worker/transport/direct_actor_transport.h index 5b0b3acf2..a1f52da1d 100644 --- a/src/ray/core_worker/transport/direct_actor_transport.h +++ b/src/ray/core_worker/transport/direct_actor_transport.h @@ -14,7 +14,7 @@ #include "ray/common/id.h" #include "ray/common/ray_object.h" #include "ray/core_worker/context.h" -#include "ray/core_worker/store_provider/memory_store_provider.h" +#include "ray/core_worker/store_provider/memory_store/memory_store.h" #include "ray/core_worker/transport/dependency_resolver.h" #include "ray/gcs/redis_gcs_client.h" #include "ray/rpc/grpc_server.h" @@ -36,16 +36,15 @@ const int kMaxReorderWaitSeconds = 30; /// \return Void. void TreatTaskAsFailed(const TaskID &task_id, int num_returns, const rpc::ErrorType &error_type, - std::shared_ptr &in_memory_store); + std::shared_ptr &in_memory_store); /// Write return objects to the memory store. /// /// \param[in] reply Proto response to a direct actor or task call. /// \param[in] in_memory_store The memory store to write to. /// \return Void. -void WriteObjectsToMemoryStore( - const rpc::PushTaskReply &reply, - std::shared_ptr &in_memory_store); +void WriteObjectsToMemoryStore(const rpc::PushTaskReply &reply, + std::shared_ptr &in_memory_store); /// In direct actor call task submitter and receiver, a task is directly submitted /// to the actor that will execute it. @@ -65,11 +64,10 @@ struct ActorStateData { // This class is thread-safe. class CoreWorkerDirectActorTaskSubmitter { public: - CoreWorkerDirectActorTaskSubmitter( - rpc::ClientFactoryFn client_factory, - std::shared_ptr store_provider) + CoreWorkerDirectActorTaskSubmitter(rpc::ClientFactoryFn client_factory, + std::shared_ptr store) : client_factory_(client_factory), - in_memory_store_(store_provider), + in_memory_store_(store), resolver_(in_memory_store_) {} /// Submit a task to an actor for execution. @@ -142,8 +140,8 @@ class CoreWorkerDirectActorTaskSubmitter { /// Map from actor id to the tasks that are waiting for reply. std::unordered_map> waiting_reply_tasks_; - /// The store provider. - std::shared_ptr in_memory_store_; + /// The in-memory store. + std::shared_ptr in_memory_store_; /// Resolve direct call object dependencies; LocalDependencyResolver resolver_; diff --git a/src/ray/core_worker/transport/direct_task_transport.h b/src/ray/core_worker/transport/direct_task_transport.h index 079fa5958..e672ed8d5 100644 --- a/src/ray/core_worker/transport/direct_task_transport.h +++ b/src/ray/core_worker/transport/direct_task_transport.h @@ -6,7 +6,7 @@ #include "ray/common/id.h" #include "ray/common/ray_object.h" #include "ray/core_worker/context.h" -#include "ray/core_worker/store_provider/memory_store_provider.h" +#include "ray/core_worker/store_provider/memory_store/memory_store.h" #include "ray/core_worker/transport/dependency_resolver.h" #include "ray/core_worker/transport/direct_actor_transport.h" #include "ray/raylet/raylet_client.h" @@ -20,14 +20,14 @@ typedef std::function(const rpc::Address & // This class is thread-safe. class CoreWorkerDirectTaskSubmitter { public: - CoreWorkerDirectTaskSubmitter( - std::shared_ptr lease_client, - rpc::ClientFactoryFn client_factory, LeaseClientFactoryFn lease_client_factory, - std::shared_ptr store_provider) + CoreWorkerDirectTaskSubmitter(std::shared_ptr lease_client, + rpc::ClientFactoryFn client_factory, + LeaseClientFactoryFn lease_client_factory, + std::shared_ptr store) : local_lease_client_(lease_client), client_factory_(client_factory), lease_client_factory_(lease_client_factory), - in_memory_store_(store_provider), + in_memory_store_(store), resolver_(in_memory_store_) {} /// Schedule a task for direct submission to a worker. @@ -80,7 +80,7 @@ class CoreWorkerDirectTaskSubmitter { LeaseClientFactoryFn lease_client_factory_; /// The store provider. - std::shared_ptr in_memory_store_; + std::shared_ptr in_memory_store_; /// Resolve local and remote dependencies; LocalDependencyResolver resolver_;