From da7bdacea544e666e2f93c6007beb9c705ff4342 Mon Sep 17 00:00:00 2001 From: micafan <550435771@qq.com> Date: Tue, 20 Aug 2019 20:32:53 +0800 Subject: [PATCH] support for subscription to an actor (#5269) --- BUILD.bazel | 12 ++ src/ray/gcs/accessor_test_base.h | 76 +++++++ src/ray/gcs/actor_state_accessor.cc | 28 ++- src/ray/gcs/actor_state_accessor.h | 31 ++- src/ray/gcs/actor_state_accessor_test.cc | 90 ++------- src/ray/gcs/gcs_client_interface.h | 2 +- src/ray/gcs/redis_gcs_client.h | 1 + src/ray/gcs/redis_gcs_client_test.cc | 26 +-- src/ray/gcs/subscription_executor.cc | 139 +++++++++++++ src/ray/gcs/subscription_executor.h | 85 ++++++++ src/ray/gcs/subscription_executor_test.cc | 201 +++++++++++++++++++ src/ray/gcs/tables.cc | 36 +++- src/ray/gcs/tables.h | 15 +- src/ray/object_manager/object_directory.cc | 6 +- src/ray/raylet/lineage_cache.cc | 6 +- src/ray/raylet/lineage_cache_test.cc | 9 +- src/ray/raylet/reconstruction_policy.cc | 8 +- src/ray/raylet/reconstruction_policy_test.cc | 7 +- src/ray/test/run_gcs_tests.sh | 3 +- src/ray/util/test_util.h | 4 +- 20 files changed, 655 insertions(+), 130 deletions(-) create mode 100644 src/ray/gcs/accessor_test_base.h create mode 100644 src/ray/gcs/subscription_executor.cc create mode 100644 src/ray/gcs/subscription_executor.h create mode 100644 src/ray/gcs/subscription_executor_test.cc diff --git a/BUILD.bazel b/BUILD.bazel index 32e57bf85..1dd14b209 100644 --- a/BUILD.bazel +++ b/BUILD.bazel @@ -640,6 +640,7 @@ cc_library( ], ) +# TODO(micafan) Replace cc_binary with cc_test for GCS test. cc_binary( name = "redis_gcs_client_test", testonly = 1, @@ -662,6 +663,17 @@ cc_binary( ], ) +cc_binary( + name = "subscription_executor_test", + testonly = 1, + srcs = ["src/ray/gcs/subscription_executor_test.cc"], + copts = COPTS, + deps = [ + ":gcs", + "@com_google_googletest//:gtest_main", + ], +) + cc_binary( name = "asio_test", testonly = 1, diff --git a/src/ray/gcs/accessor_test_base.h b/src/ray/gcs/accessor_test_base.h new file mode 100644 index 000000000..c817d8019 --- /dev/null +++ b/src/ray/gcs/accessor_test_base.h @@ -0,0 +1,76 @@ +#include +#include +#include +#include +#include +#include "gtest/gtest.h" +#include "ray/gcs/redis_gcs_client.h" +#include "ray/util/test_util.h" + +namespace ray { + +namespace gcs { + +template +class AccessorTestBase : public ::testing::Test { + public: + AccessorTestBase() : options_("127.0.0.1", 6379, "", true) {} + + virtual ~AccessorTestBase() {} + + virtual void SetUp() { + GenTestData(); + + gcs_client_.reset(new RedisGcsClient(options_)); + RAY_CHECK_OK(gcs_client_->Connect(io_service_)); + + work_thread.reset(new std::thread([this] { + std::unique_ptr work( + new boost::asio::io_service::work(io_service_)); + io_service_.run(); + })); + } + + virtual void TearDown() { + gcs_client_->Disconnect(); + + io_service_.stop(); + work_thread->join(); + work_thread.reset(); + + gcs_client_.reset(); + + ClearTestData(); + } + + protected: + virtual void GenTestData() = 0; + + void ClearTestData() { id_to_data_.clear(); } + + void WaitPendingDone(std::chrono::milliseconds timeout) { + WaitPendingDone(pending_count_, timeout); + } + + void WaitPendingDone(std::atomic &pending_count, + std::chrono::milliseconds timeout) { + auto condition = [&pending_count]() { return pending_count == 0; }; + EXPECT_TRUE(WaitForCondition(condition, timeout.count())); + } + + protected: + GcsClientOptions options_; + std::unique_ptr gcs_client_; + + boost::asio::io_service io_service_; + std::unique_ptr work_thread; + + std::unordered_map> id_to_data_; + + std::atomic pending_count_{0}; + std::chrono::milliseconds wait_pending_timeout_{10000}; +}; + +} // namespace gcs + +} // namespace ray diff --git a/src/ray/gcs/actor_state_accessor.cc b/src/ray/gcs/actor_state_accessor.cc index 6d868c301..f76af65cb 100644 --- a/src/ray/gcs/actor_state_accessor.cc +++ b/src/ray/gcs/actor_state_accessor.cc @@ -8,7 +8,7 @@ namespace ray { namespace gcs { ActorStateAccessor::ActorStateAccessor(RedisGcsClient &client_impl) - : client_impl_(client_impl) {} + : client_impl_(client_impl), actor_sub_executor_(client_impl_.actor_table()) {} Status ActorStateAccessor::AsyncGet(const ActorID &actor_id, const MultiItemCallback &callback) { @@ -90,23 +90,19 @@ Status ActorStateAccessor::AsyncSubscribe( const SubscribeCallback &subscribe, const StatusCallback &done) { RAY_CHECK(subscribe != nullptr); - auto on_subscribe = [subscribe](RedisGcsClient *client, const ActorID &actor_id, - const std::vector &data) { - if (!data.empty()) { - // We only need the last entry, because it represents the latest state of - // this actor. - subscribe(actor_id, data.back()); - } - }; + return actor_sub_executor_.AsyncSubscribe(ClientID::Nil(), subscribe, done); +} - auto on_done = [done](RedisGcsClient *client) { - if (done != nullptr) { - done(Status::OK()); - } - }; +Status ActorStateAccessor::AsyncSubscribe( + const ActorID &actor_id, const SubscribeCallback &subscribe, + const StatusCallback &done) { + RAY_CHECK(subscribe != nullptr); + return actor_sub_executor_.AsyncSubscribe(node_id_, actor_id, subscribe, done); +} - ActorTable &actor_table = client_impl_.actor_table(); - return actor_table.Subscribe(JobID::Nil(), ClientID::Nil(), on_subscribe, on_done); +Status ActorStateAccessor::AsyncUnsubscribe(const ActorID &actor_id, + const StatusCallback &done) { + return actor_sub_executor_.AsyncUnsubscribe(node_id_, actor_id, done); } } // namespace gcs diff --git a/src/ray/gcs/actor_state_accessor.h b/src/ray/gcs/actor_state_accessor.h index 273c7aa88..c6a812f4d 100644 --- a/src/ray/gcs/actor_state_accessor.h +++ b/src/ray/gcs/actor_state_accessor.h @@ -3,6 +3,7 @@ #include "ray/common/id.h" #include "ray/gcs/callback.h" +#include "ray/gcs/subscription_executor.h" #include "ray/gcs/tables.h" namespace ray { @@ -50,7 +51,7 @@ class ActorStateAccessor { const std::shared_ptr &data_ptr, const StatusCallback &callback); - /// Subscribe to any register operations of actors. + /// Subscribe to any register or update operations of actors. /// /// \param subscribe Callback that will be called each time when an actor is registered /// or updated. @@ -60,8 +61,36 @@ class ActorStateAccessor { Status AsyncSubscribe(const SubscribeCallback &subscribe, const StatusCallback &done); + /// Subscribe to any update operations of an actor. + /// + /// \param actor_id The ID of actor to be subscribed to. + /// \param subscribe Callback that will be called each time when the actor is updated. + /// \param done Callback that will be called when subscription is complete. + /// \return Status + Status AsyncSubscribe(const ActorID &actor_id, + const SubscribeCallback &subscribe, + const StatusCallback &done); + + /// Cancel subscription to an actor. + /// + /// \param actor_id The ID of the actor to be unsubscribed to. + /// \param done Callback that will be called when unsubscribe is complete. + /// \return Status + Status AsyncUnsubscribe(const ActorID &actor_id, const StatusCallback &done); + private: RedisGcsClient &client_impl_; + // Use a random ClientID for actor subscription. Because: + // If we use ClientID::Nil, GCS will still send all actors' updates to this GCS Client. + // Even we can filter out irrelevant updates, but there will be extra overhead. + // And because the new GCS Client will no longer hold the local ClientID, so we use + // random ClientID instead. + // TODO(micafan): Remove this random id, once GCS becomes a service. + ClientID node_id_{ClientID::FromRandom()}; + + typedef SubscriptionExecutor + ActorSubscriptionExecutor; + ActorSubscriptionExecutor actor_sub_executor_; }; } // namespace gcs diff --git a/src/ray/gcs/actor_state_accessor_test.cc b/src/ray/gcs/actor_state_accessor_test.cc index c6b3d45c0..5a5ab1475 100644 --- a/src/ray/gcs/actor_state_accessor_test.cc +++ b/src/ray/gcs/actor_state_accessor_test.cc @@ -4,6 +4,7 @@ #include #include #include "gtest/gtest.h" +#include "ray/gcs/accessor_test_base.h" #include "ray/gcs/redis_gcs_client.h" #include "ray/util/test_util.h" @@ -11,39 +12,9 @@ namespace ray { namespace gcs { -class ActorStateAccessorTest : public ::testing::Test { - public: - ActorStateAccessorTest() : options_("127.0.0.1", 6379, "", true) {} - - virtual void SetUp() { - GenTestData(); - - gcs_client_.reset(new RedisGcsClient(options_)); - RAY_CHECK_OK(gcs_client_->Connect(io_service_)); - - work_thread.reset(new std::thread([this] { - std::unique_ptr work( - new boost::asio::io_service::work(io_service_)); - io_service_.run(); - })); - } - - virtual void TearDown() { - gcs_client_->Disconnect(); - - io_service_.stop(); - work_thread->join(); - work_thread.reset(); - - gcs_client_.reset(); - - ClearTestData(); - } - +class ActorStateAccessorTest : public AccessorTestBase { protected: - void GenTestData() { GenActorData(); } - - void GenActorData() { + virtual void GenTestData() { for (size_t i = 0; i < 100; ++i) { std::shared_ptr actor = std::make_shared(); actor->set_max_reconstructions(1); @@ -53,42 +24,15 @@ class ActorStateAccessorTest : public ::testing::Test { actor->set_state(ActorTableData::ALIVE); ActorID actor_id = ActorID::Of(job_id, RandomTaskId(), /*parent_task_counter=*/i); actor->set_actor_id(actor_id.Binary()); - actor_datas_[actor_id] = actor; + id_to_data_[actor_id] = actor; } } - - void ClearTestData() { actor_datas_.clear(); } - - void WaitPendingDone(std::chrono::milliseconds timeout) { - WaitPendingDone(pending_count_, timeout); - } - - void WaitPendingDone(std::atomic &pending_count, - std::chrono::milliseconds timeout) { - while (pending_count != 0 && timeout.count() > 0) { - std::chrono::milliseconds interval(10); - std::this_thread::sleep_for(interval); - timeout -= interval; - } - EXPECT_EQ(pending_count, 0); - } - - protected: - GcsClientOptions options_; - std::unique_ptr gcs_client_; - - boost::asio::io_service io_service_; - std::unique_ptr work_thread; - - std::unordered_map> actor_datas_; - - std::atomic pending_count_{0}; }; TEST_F(ActorStateAccessorTest, RegisterAndGet) { ActorStateAccessor &actor_accessor = gcs_client_->Actors(); // register - for (const auto &elem : actor_datas_) { + for (const auto &elem : id_to_data_) { const auto &actor = elem.second; ++pending_count_; RAY_CHECK_OK(actor_accessor.AsyncRegister(actor, [this](Status status) { @@ -97,35 +41,33 @@ TEST_F(ActorStateAccessorTest, RegisterAndGet) { })); } - std::chrono::milliseconds timeout(10000); - WaitPendingDone(timeout); + WaitPendingDone(wait_pending_timeout_); // get - for (const auto &elem : actor_datas_) { + for (const auto &elem : id_to_data_) { ++pending_count_; RAY_CHECK_OK(actor_accessor.AsyncGet( elem.first, [this](Status status, std::vector datas) { ASSERT_EQ(datas.size(), 1U); ActorID actor_id = ActorID::FromBinary(datas[0].actor_id()); - auto it = actor_datas_.find(actor_id); - ASSERT_TRUE(it != actor_datas_.end()); + auto it = id_to_data_.find(actor_id); + ASSERT_TRUE(it != id_to_data_.end()); --pending_count_; })); } - WaitPendingDone(timeout); + WaitPendingDone(wait_pending_timeout_); } TEST_F(ActorStateAccessorTest, Subscribe) { ActorStateAccessor &actor_accessor = gcs_client_->Actors(); - std::chrono::milliseconds timeout(10000); // subscribe std::atomic sub_pending_count(0); std::atomic do_sub_pending_count(0); auto subscribe = [this, &sub_pending_count](const ActorID &actor_id, const ActorTableData &data) { - const auto it = actor_datas_.find(actor_id); - ASSERT_TRUE(it != actor_datas_.end()); + const auto it = id_to_data_.find(actor_id); + ASSERT_TRUE(it != id_to_data_.end()); --sub_pending_count; }; auto done = [&do_sub_pending_count](Status status) { @@ -136,11 +78,11 @@ TEST_F(ActorStateAccessorTest, Subscribe) { ++do_sub_pending_count; RAY_CHECK_OK(actor_accessor.AsyncSubscribe(subscribe, done)); // Wait until subscribe finishes. - WaitPendingDone(do_sub_pending_count, timeout); + WaitPendingDone(do_sub_pending_count, wait_pending_timeout_); // register std::atomic register_pending_count(0); - for (const auto &elem : actor_datas_) { + for (const auto &elem : id_to_data_) { const auto &actor = elem.second; ++sub_pending_count; ++register_pending_count; @@ -151,10 +93,10 @@ TEST_F(ActorStateAccessorTest, Subscribe) { })); } // Wait until register finishes. - WaitPendingDone(register_pending_count, timeout); + WaitPendingDone(register_pending_count, wait_pending_timeout_); // Wait for all subscribe notifications. - WaitPendingDone(sub_pending_count, timeout); + WaitPendingDone(sub_pending_count, wait_pending_timeout_); } } // namespace gcs diff --git a/src/ray/gcs/gcs_client_interface.h b/src/ray/gcs/gcs_client_interface.h index f62ac9f39..dbf2b1818 100644 --- a/src/ray/gcs/gcs_client_interface.h +++ b/src/ray/gcs/gcs_client_interface.h @@ -95,7 +95,7 @@ class GcsClientInterface : public std::enable_shared_from_this actor_accessor_; diff --git a/src/ray/gcs/redis_gcs_client.h b/src/ray/gcs/redis_gcs_client.h index 52bf9f628..42192b294 100644 --- a/src/ray/gcs/redis_gcs_client.h +++ b/src/ray/gcs/redis_gcs_client.h @@ -19,6 +19,7 @@ class RedisContext; class RAY_EXPORT RedisGcsClient : public GcsClientInterface { friend class ActorStateAccessor; + friend class SubscriptionExecutorTest; public: /// Constructor of RedisGcsClient. diff --git a/src/ray/gcs/redis_gcs_client_test.cc b/src/ray/gcs/redis_gcs_client_test.cc index 61ca14ee6..2ede4616d 100644 --- a/src/ray/gcs/redis_gcs_client_test.cc +++ b/src/ray/gcs/redis_gcs_client_test.cc @@ -741,7 +741,7 @@ void TestTableSubscribeId(const JobID &job_id, num_modifications](gcs::RedisGcsClient *client) { // Request notifications for one of the keys. RAY_CHECK_OK(client->raylet_task_table().RequestNotifications( - job_id, task_id2, client->client_table().GetLocalClientId())); + job_id, task_id2, client->client_table().GetLocalClientId(), nullptr)); // Write both keys. We should only receive notifications for the key that // we requested them for. for (uint64_t i = 0; i < num_modifications; i++) { @@ -814,7 +814,7 @@ void TestLogSubscribeId(const JobID &job_id, job_ids2](gcs::RedisGcsClient *client) { // Request notifications for one of the keys. RAY_CHECK_OK(client->job_table().RequestNotifications( - job_id, job_id2, client->client_table().GetLocalClientId())); + job_id, job_id2, client->client_table().GetLocalClientId(), nullptr)); // Write both keys. We should only receive notifications for the key that // we requested them for. auto remaining = std::vector(++job_ids1.begin(), job_ids1.end()); @@ -890,7 +890,7 @@ void TestSetSubscribeId(const JobID &job_id, managers2](gcs::RedisGcsClient *client) { // Request notifications for one of the keys. RAY_CHECK_OK(client->object_table().RequestNotifications( - job_id, object_id2, client->client_table().GetLocalClientId())); + job_id, object_id2, client->client_table().GetLocalClientId(), nullptr)); // Write both keys. We should only receive notifications for the key that // we requested them for. auto remaining = std::vector(++managers1.begin(), managers1.end()); @@ -964,9 +964,9 @@ void TestTableSubscribeCancel(const JobID &job_id, // Request notifications, then cancel immediately. We should receive a // notification for the current value at the key. RAY_CHECK_OK(client->raylet_task_table().RequestNotifications( - job_id, task_id, client->client_table().GetLocalClientId())); + job_id, task_id, client->client_table().GetLocalClientId(), nullptr)); RAY_CHECK_OK(client->raylet_task_table().CancelNotifications( - job_id, task_id, client->client_table().GetLocalClientId())); + job_id, task_id, client->client_table().GetLocalClientId(), nullptr)); // Write to the key. Since we canceled notifications, we should not receive // a notification for these writes. for (uint64_t i = 1; i < num_modifications; i++) { @@ -976,7 +976,7 @@ void TestTableSubscribeCancel(const JobID &job_id, // Request notifications again. We should receive a notification for the // current value at the key. RAY_CHECK_OK(client->raylet_task_table().RequestNotifications( - job_id, task_id, client->client_table().GetLocalClientId())); + job_id, task_id, client->client_table().GetLocalClientId(), nullptr)); }; // Subscribe to notifications for this client. This allows us to request and @@ -1033,9 +1033,9 @@ void TestLogSubscribeCancel(const JobID &job_id, // Request notifications, then cancel immediately. We should receive a // notification for the current value at the key. RAY_CHECK_OK(client->job_table().RequestNotifications( - job_id, random_job_id, client->client_table().GetLocalClientId())); + job_id, random_job_id, client->client_table().GetLocalClientId(), nullptr)); RAY_CHECK_OK(client->job_table().CancelNotifications( - job_id, random_job_id, client->client_table().GetLocalClientId())); + job_id, random_job_id, client->client_table().GetLocalClientId(), nullptr)); // Append to the key. Since we canceled notifications, we should not // receive a notification for these writes. auto remaining = std::vector(++job_ids.begin(), job_ids.end()); @@ -1047,7 +1047,7 @@ void TestLogSubscribeCancel(const JobID &job_id, // Request notifications again. We should receive a notification for the // current values at the key. RAY_CHECK_OK(client->job_table().RequestNotifications( - job_id, random_job_id, client->client_table().GetLocalClientId())); + job_id, random_job_id, client->client_table().GetLocalClientId(), nullptr)); }; // Subscribe to notifications for this client. This allows us to request and @@ -1115,9 +1115,9 @@ void TestSetSubscribeCancel(const JobID &job_id, // Request notifications, then cancel immediately. We should receive a // notification for the current value at the key. RAY_CHECK_OK(client->object_table().RequestNotifications( - job_id, object_id, client->client_table().GetLocalClientId())); + job_id, object_id, client->client_table().GetLocalClientId(), nullptr)); RAY_CHECK_OK(client->object_table().CancelNotifications( - job_id, object_id, client->client_table().GetLocalClientId())); + job_id, object_id, client->client_table().GetLocalClientId(), nullptr)); // Add to the key. Since we canceled notifications, we should not // receive a notification for these writes. auto remaining = std::vector(++managers.begin(), managers.end()); @@ -1129,7 +1129,7 @@ void TestSetSubscribeCancel(const JobID &job_id, // Request notifications again. We should receive a notification for the // current values at the key. RAY_CHECK_OK(client->object_table().RequestNotifications( - job_id, object_id, client->client_table().GetLocalClientId())); + job_id, object_id, client->client_table().GetLocalClientId(), nullptr)); }; // Subscribe to notifications for this client. This allows us to request and @@ -1342,7 +1342,7 @@ void TestHashTable(const JobID &job_id, std::shared_ptr cli RAY_CHECK_OK(client->resource_table().Subscribe( job_id, ClientID::Nil(), notification_callback, subscribe_callback)); RAY_CHECK_OK(client->resource_table().RequestNotifications( - job_id, client_id, client->client_table().GetLocalClientId())); + job_id, client_id, client->client_table().GetLocalClientId(), nullptr)); // Step 1: Add elements to the hash table. auto update_callback1 = [data_map1, compare_test]( diff --git a/src/ray/gcs/subscription_executor.cc b/src/ray/gcs/subscription_executor.cc new file mode 100644 index 000000000..c55660c3c --- /dev/null +++ b/src/ray/gcs/subscription_executor.cc @@ -0,0 +1,139 @@ +#include "ray/gcs/subscription_executor.h" + +namespace ray { + +namespace gcs { + +template +Status SubscriptionExecutor::AsyncSubscribe( + const ClientID &client_id, const SubscribeCallback &subscribe, + const StatusCallback &done) { + // TODO(micafan) Optimize the lock when necessary. + // Consider avoiding locking in single-threaded processes. + std::lock_guard lock(mutex_); + + if (subscribe_all_callback_ != nullptr) { + RAY_LOG(DEBUG) << "Duplicate subscription! Already subscribed to all elements."; + return Status::Invalid("Duplicate subscription!"); + } + + if (registered_) { + if (subscribe != nullptr) { + RAY_LOG(DEBUG) << "Duplicate subscription! Already subscribed to specific elements" + ", can't subscribe to all elements."; + return Status::Invalid("Duplicate subscription!"); + } + return Status::OK(); + } + + auto on_subscribe = [this](RedisGcsClient *client, const ID &id, + const std::vector &result) { + if (result.empty()) { + return; + } + + RAY_LOG(DEBUG) << "Subscribe received update of id " << id; + + SubscribeCallback sub_one_callback = nullptr; + SubscribeCallback sub_all_callback = nullptr; + { + std::lock_guard lock(mutex_); + const auto it = id_to_callback_map_.find(id); + if (it != id_to_callback_map_.end()) { + sub_one_callback = it->second; + } + sub_all_callback = subscribe_all_callback_; + } + if (sub_one_callback != nullptr) { + sub_one_callback(id, result.back()); + } + if (sub_all_callback != nullptr) { + RAY_CHECK(sub_one_callback == nullptr); + sub_all_callback(id, result.back()); + } + }; + + auto on_done = [done](RedisGcsClient *client) { + if (done != nullptr) { + done(Status::OK()); + } + }; + + Status status = table_.Subscribe(JobID::Nil(), client_id, on_subscribe, on_done); + if (status.ok()) { + registered_ = true; + subscribe_all_callback_ = subscribe; + } + + return status; +} + +template +Status SubscriptionExecutor::AsyncSubscribe( + const ClientID &client_id, const ID &id, const SubscribeCallback &subscribe, + const StatusCallback &done) { + Status status = AsyncSubscribe(client_id, nullptr, nullptr); + if (!status.ok()) { + return status; + } + + auto on_done = [this, done, id](Status status) { + if (!status.ok()) { + std::lock_guard lock(mutex_); + id_to_callback_map_.erase(id); + } + if (done != nullptr) { + done(status); + } + }; + + { + std::lock_guard lock(mutex_); + const auto it = id_to_callback_map_.find(id); + if (it != id_to_callback_map_.end()) { + RAY_LOG(DEBUG) << "Duplicate subscription to id " << id << " client_id " + << client_id; + return Status::Invalid("Duplicate subscription to element!"); + } + status = table_.RequestNotifications(JobID::Nil(), id, client_id, on_done); + if (status.ok()) { + id_to_callback_map_[id] = subscribe; + } + } + + return status; +} + +template +Status SubscriptionExecutor::AsyncUnsubscribe( + const ClientID &client_id, const ID &id, const StatusCallback &done) { + { + std::lock_guard lock(mutex_); + const auto it = id_to_callback_map_.find(id); + if (it == id_to_callback_map_.end()) { + RAY_LOG(DEBUG) << "Invalid Unsubscribe! id " << id << " client_id " << client_id; + return Status::Invalid("Invalid Unsubscribe, no existing subscription found."); + } + } + + auto on_done = [this, id, done](Status status) { + if (status.ok()) { + std::lock_guard lock(mutex_); + const auto it = id_to_callback_map_.find(id); + if (it != id_to_callback_map_.end()) { + id_to_callback_map_.erase(it); + } + } + if (done != nullptr) { + done(status); + } + }; + + return table_.CancelNotifications(JobID::Nil(), id, client_id, on_done); +} + +template class SubscriptionExecutor; + +} // namespace gcs + +} // namespace ray diff --git a/src/ray/gcs/subscription_executor.h b/src/ray/gcs/subscription_executor.h new file mode 100644 index 000000000..167e1f274 --- /dev/null +++ b/src/ray/gcs/subscription_executor.h @@ -0,0 +1,85 @@ +#ifndef RAY_GCS_SUBSCRIPTION_EXECUTOR_H +#define RAY_GCS_SUBSCRIPTION_EXECUTOR_H + +#include +#include +#include "ray/gcs/callback.h" +#include "ray/gcs/tables.h" + +namespace ray { + +namespace gcs { + +/// \class SubscriptionExecutor +/// SubscriptionExecutor class encapsulates the implementation details of +/// subscribe/unsubscribe to elements (e.g.: actors or tasks or objects or nodes). +/// Support subscribing to a specific element or subscribing to all elements. +template +class SubscriptionExecutor { + public: + SubscriptionExecutor(Table &table) : table_(table) {} + + ~SubscriptionExecutor() {} + + /// Subscribe to operations of all elements. + /// Repeated subscription will return a failure. + /// + /// \param client_id The type of update to listen to. If this is nil, then a + /// message for each update will be received. Else, only + /// messages for the given client will be received. + /// \param subscribe Callback that will be called each time when an element + /// is registered or updated. + /// \param done Callback that will be called when subscription is complete. + /// \return Status + Status AsyncSubscribe(const ClientID &client_id, + const SubscribeCallback &subscribe, + const StatusCallback &done); + + /// Subscribe to operations of an element. + /// Repeated subscription to an element will return a failure. + /// + /// \param client_id The type of update to listen to. If this is nil, then a + /// message for each update will be received. Else, only + /// messages for the given client will be received. + /// \param id The id of the element to be subscribe to. + /// \param subscribe Callback that will be called each time when the element + /// is registered or updated. + /// \param done Callback that will be called when subscription is complete. + /// \return Status + Status AsyncSubscribe(const ClientID &client_id, const ID &id, + const SubscribeCallback &subscribe, + const StatusCallback &done); + + /// Cancel subscription to an element. + /// Unsubscribing can only be called after the subscription request is completed. + /// + /// \param client_id The type of update to listen to. If this is nil, then a + /// message for each update will be received. Else, only + /// messages for the given client will be received. + /// \param id The id of the element to be unsubscribed to. + /// \param done Callback that will be called when cancel subscription is complete. + /// \return Status + Status AsyncUnsubscribe(const ClientID &client_id, const ID &id, + const StatusCallback &done); + + private: + Table &table_; + + std::mutex mutex_; + + /// Whether successfully registered subscription to GCS. + bool registered_{false}; + + /// Subscribe Callback of all elements. + SubscribeCallback subscribe_all_callback_{nullptr}; + + /// A mapping from element ID to subscription callback. + typedef std::unordered_map> IDToCallbackMap; + IDToCallbackMap id_to_callback_map_; +}; + +} // namespace gcs + +} // namespace ray + +#endif // RAY_GCS_SUBSCRIPTION_EXECUTOR_H diff --git a/src/ray/gcs/subscription_executor_test.cc b/src/ray/gcs/subscription_executor_test.cc new file mode 100644 index 000000000..6f477173e --- /dev/null +++ b/src/ray/gcs/subscription_executor_test.cc @@ -0,0 +1,201 @@ +#include "gtest/gtest.h" +#include "ray/gcs/accessor_test_base.h" +#include "ray/gcs/callback.h" +#include "ray/gcs/redis_gcs_client.h" + +namespace ray { + +namespace gcs { + +class SubscriptionExecutorTest : public AccessorTestBase { + public: + typedef SubscriptionExecutor ActorSubExecutor; + + virtual void SetUp() { + AccessorTestBase::SetUp(); + + actor_sub_executor_.reset(new ActorSubExecutor(gcs_client_->actor_table())); + + subscribe_ = [this](const ActorID &id, const ActorTableData &data) { + const auto it = id_to_data_.find(id); + ASSERT_TRUE(it != id_to_data_.end()); + --sub_pending_count_; + }; + + sub_done_ = [this](Status status) { + ASSERT_TRUE(status.ok()) << status; + --do_sub_pending_count_; + }; + + unsub_done_ = [this](Status status) { + ASSERT_TRUE(status.ok()) << status; + --do_unsub_pending_count_; + }; + } + + virtual void TearDown() { + AccessorTestBase::TearDown(); + ASSERT_EQ(sub_pending_count_, 0); + ASSERT_EQ(do_sub_pending_count_, 0); + ASSERT_EQ(do_unsub_pending_count_, 0); + } + + protected: + virtual void GenTestData() { + for (size_t i = 0; i < 2; ++i) { + std::shared_ptr actor = std::make_shared(); + actor->set_max_reconstructions(1); + actor->set_remaining_reconstructions(1); + JobID job_id = JobID::FromInt(i); + actor->set_job_id(job_id.Binary()); + actor->set_state(ActorTableData::ALIVE); + ActorID actor_id = ActorID::Of(job_id, RandomTaskId(), /*parent_task_counter=*/i); + actor->set_actor_id(actor_id.Binary()); + id_to_data_[actor_id] = actor; + } + } + + size_t AsyncRegisterActorToGcs() { + ActorStateAccessor &actor_accessor = gcs_client_->Actors(); + for (const auto &elem : id_to_data_) { + const auto &actor = elem.second; + auto done = [this](Status status) { + ASSERT_TRUE(status.ok()); + --pending_count_; + }; + ++pending_count_; + Status status = actor_accessor.AsyncRegister(actor, done); + RAY_CHECK_OK(status); + } + return id_to_data_.size(); + } + + protected: + std::unique_ptr actor_sub_executor_; + + std::atomic sub_pending_count_{0}; + std::atomic do_sub_pending_count_{0}; + std::atomic do_unsub_pending_count_{0}; + + SubscribeCallback subscribe_{nullptr}; + StatusCallback sub_done_{nullptr}; + StatusCallback unsub_done_{nullptr}; +}; + +TEST_F(SubscriptionExecutorTest, SubscribeAllTest) { + ++do_sub_pending_count_; + Status status = + actor_sub_executor_->AsyncSubscribe(ClientID::Nil(), subscribe_, sub_done_); + WaitPendingDone(do_sub_pending_count_, wait_pending_timeout_); + ASSERT_TRUE(status.ok()); + sub_pending_count_ = id_to_data_.size(); + AsyncRegisterActorToGcs(); + status = actor_sub_executor_->AsyncSubscribe(ClientID::Nil(), subscribe_, sub_done_); + ASSERT_TRUE(status.IsInvalid()); + WaitPendingDone(sub_pending_count_, wait_pending_timeout_); +} + +TEST_F(SubscriptionExecutorTest, SubscribeOneTest) { + Status status; + for (const auto &item : id_to_data_) { + ++do_sub_pending_count_; + status = actor_sub_executor_->AsyncSubscribe(ClientID::Nil(), item.first, subscribe_, + sub_done_); + ASSERT_TRUE(status.ok()); + } + WaitPendingDone(do_sub_pending_count_, wait_pending_timeout_); + sub_pending_count_ = id_to_data_.size(); + AsyncRegisterActorToGcs(); + for (const auto &item : id_to_data_) { + status = actor_sub_executor_->AsyncSubscribe(ClientID::Nil(), item.first, subscribe_, + sub_done_); + ASSERT_TRUE(status.IsInvalid()); + } + WaitPendingDone(sub_pending_count_, wait_pending_timeout_); +} + +TEST_F(SubscriptionExecutorTest, SubscribeOneWithClientIDTest) { + const auto &item = id_to_data_.begin(); + ++do_sub_pending_count_; + ++sub_pending_count_; + Status status = actor_sub_executor_->AsyncSubscribe(ClientID::FromRandom(), item->first, + subscribe_, sub_done_); + WaitPendingDone(do_sub_pending_count_, wait_pending_timeout_); + ASSERT_TRUE(status.ok()); + AsyncRegisterActorToGcs(); + WaitPendingDone(sub_pending_count_, wait_pending_timeout_); +} + +TEST_F(SubscriptionExecutorTest, SubscribeAllAndSubscribeOneTest) { + ++do_sub_pending_count_; + Status status = + actor_sub_executor_->AsyncSubscribe(ClientID::Nil(), subscribe_, sub_done_); + ASSERT_TRUE(status.ok()); + WaitPendingDone(do_sub_pending_count_, wait_pending_timeout_); + for (const auto &item : id_to_data_) { + status = actor_sub_executor_->AsyncSubscribe(ClientID::Nil(), item.first, subscribe_, + sub_done_); + ASSERT_FALSE(status.ok()); + } + sub_pending_count_ = id_to_data_.size(); + AsyncRegisterActorToGcs(); + WaitPendingDone(sub_pending_count_, wait_pending_timeout_); +} + +TEST_F(SubscriptionExecutorTest, UnsubscribeTest) { + Status status; + for (const auto &item : id_to_data_) { + status = + actor_sub_executor_->AsyncUnsubscribe(ClientID::Nil(), item.first, unsub_done_); + ASSERT_TRUE(status.IsInvalid()); + } + + for (const auto &item : id_to_data_) { + ++do_sub_pending_count_; + status = actor_sub_executor_->AsyncSubscribe(ClientID::Nil(), item.first, subscribe_, + sub_done_); + ASSERT_TRUE(status.ok()); + } + WaitPendingDone(do_sub_pending_count_, wait_pending_timeout_); + for (const auto &item : id_to_data_) { + ++do_unsub_pending_count_; + status = + actor_sub_executor_->AsyncUnsubscribe(ClientID::Nil(), item.first, unsub_done_); + ASSERT_TRUE(status.ok()); + } + WaitPendingDone(do_unsub_pending_count_, wait_pending_timeout_); + for (const auto &item : id_to_data_) { + status = + actor_sub_executor_->AsyncUnsubscribe(ClientID::Nil(), item.first, unsub_done_); + ASSERT_TRUE(!status.ok()); + } + + for (const auto &item : id_to_data_) { + ++do_sub_pending_count_; + status = actor_sub_executor_->AsyncSubscribe(ClientID::Nil(), item.first, subscribe_, + sub_done_); + ASSERT_TRUE(status.ok()); + } + WaitPendingDone(do_sub_pending_count_, wait_pending_timeout_); + for (const auto &item : id_to_data_) { + ++do_unsub_pending_count_; + status = + actor_sub_executor_->AsyncUnsubscribe(ClientID::Nil(), item.first, unsub_done_); + ASSERT_TRUE(status.ok()); + } + WaitPendingDone(do_unsub_pending_count_, wait_pending_timeout_); + for (const auto &item : id_to_data_) { + ++do_sub_pending_count_; + status = actor_sub_executor_->AsyncSubscribe(ClientID::Nil(), item.first, subscribe_, + sub_done_); + ASSERT_TRUE(status.ok()); + } + WaitPendingDone(do_sub_pending_count_, wait_pending_timeout_); + sub_pending_count_ = id_to_data_.size(); + AsyncRegisterActorToGcs(); + WaitPendingDone(sub_pending_count_, wait_pending_timeout_); +} + +} // namespace gcs + +} // namespace ray diff --git a/src/ray/gcs/tables.cc b/src/ray/gcs/tables.cc index 5c99aaac8..dc003a76a 100644 --- a/src/ray/gcs/tables.cc +++ b/src/ray/gcs/tables.cc @@ -162,22 +162,44 @@ Status Log::Subscribe(const JobID &job_id, const ClientID &client_id, template Status Log::RequestNotifications(const JobID &job_id, const ID &id, - const ClientID &client_id) { + const ClientID &client_id, + const StatusCallback &done) { RAY_CHECK(subscribe_callback_index_ >= 0) << "Client requested notifications on a key before Subscribe completed"; + + RedisCallback callback = nullptr; + if (done != nullptr) { + callback = [done](const CallbackReply &reply) { + const auto status = reply.IsNil() + ? Status::OK() + : Status::RedisError("request notifications failed."); + done(status); + }; + } + return GetRedisContext(id)->RunAsync("RAY.TABLE_REQUEST_NOTIFICATIONS", id, client_id.Data(), client_id.Size(), prefix_, - pubsub_channel_, nullptr); + pubsub_channel_, callback); } template Status Log::CancelNotifications(const JobID &job_id, const ID &id, - const ClientID &client_id) { + const ClientID &client_id, + const StatusCallback &done) { RAY_CHECK(subscribe_callback_index_ >= 0) << "Client canceled notifications on a key before Subscribe completed"; + + RedisCallback callback = nullptr; + if (done != nullptr) { + callback = [done](const CallbackReply &reply) { + const auto status = reply.ReadAsStatus(); + done(status); + }; + } + return GetRedisContext(id)->RunAsync("RAY.TABLE_CANCEL_NOTIFICATIONS", id, client_id.Data(), client_id.Size(), prefix_, - pubsub_channel_, nullptr); + pubsub_channel_, callback); } template @@ -621,7 +643,8 @@ Status ClientTable::Connect(const GcsNodeInfo &local_node_info) { // Callback to request notifications from the client table once we've // successfully subscribed. auto subscription_callback = [this](RedisGcsClient *c) { - RAY_CHECK_OK(RequestNotifications(JobID::Nil(), client_log_key_, node_id_)); + RAY_CHECK_OK(RequestNotifications(JobID::Nil(), client_log_key_, node_id_, + /*done*/ nullptr)); }; // Subscribe to the client table. RAY_CHECK_OK( @@ -636,7 +659,8 @@ Status ClientTable::Disconnect(const DisconnectCallback &callback) { auto add_callback = [this, callback](RedisGcsClient *client, const ClientID &id, const GcsNodeInfo &data) { HandleConnected(client, data); - RAY_CHECK_OK(CancelNotifications(JobID::Nil(), client_log_key_, id)); + RAY_CHECK_OK( + CancelNotifications(JobID::Nil(), client_log_key_, id, /*done*/ nullptr)); if (callback != nullptr) { callback(); } diff --git a/src/ray/gcs/tables.h b/src/ray/gcs/tables.h index 60d7a6dc6..919ff24ea 100644 --- a/src/ray/gcs/tables.h +++ b/src/ray/gcs/tables.h @@ -11,6 +11,7 @@ #include "ray/common/status.h" #include "ray/util/logging.h" +#include "ray/gcs/callback.h" #include "ray/gcs/redis_context.h" #include "ray/protobuf/gcs.pb.h" @@ -56,9 +57,11 @@ template class PubsubInterface { public: virtual Status RequestNotifications(const JobID &job_id, const ID &id, - const ClientID &client_id) = 0; + const ClientID &client_id, + const StatusCallback &done) = 0; virtual Status CancelNotifications(const JobID &job_id, const ID &id, - const ClientID &client_id) = 0; + const ClientID &client_id, + const StatusCallback &done) = 0; virtual ~PubsubInterface(){}; }; @@ -182,20 +185,22 @@ class Log : public LogInterface, virtual public PubsubInterface { /// \param job_id The ID of the job. /// \param id The ID of the key to request notifications for. /// \param client_id The client who is requesting notifications. Before + /// \param done Callback that is called when request notifications is complete. /// notifications can be requested, a call to `Subscribe` to this /// table with the same `client_id` must complete successfully. /// \return Status Status RequestNotifications(const JobID &job_id, const ID &id, - const ClientID &client_id); + const ClientID &client_id, const StatusCallback &done); /// Cancel notifications about a key in this table. /// /// \param job_id The ID of the job. /// \param id The ID of the key to request notifications for. /// \param client_id The client who originally requested notifications. + /// \param done Callback that is called when cancel notifications is complete. /// \return Status - Status CancelNotifications(const JobID &job_id, const ID &id, - const ClientID &client_id); + Status CancelNotifications(const JobID &job_id, const ID &id, const ClientID &client_id, + const StatusCallback &done); /// Delete an entire key from redis. /// diff --git a/src/ray/object_manager/object_directory.cc b/src/ray/object_manager/object_directory.cc index f7e468391..0a1376bc1 100644 --- a/src/ray/object_manager/object_directory.cc +++ b/src/ray/object_manager/object_directory.cc @@ -159,7 +159,8 @@ ray::Status ObjectDirectory::SubscribeObjectLocations(const UniqueID &callback_i if (it == listeners_.end()) { it = listeners_.emplace(object_id, LocationListenerState()).first; status = gcs_client_->object_table().RequestNotifications( - JobID::Nil(), object_id, gcs_client_->client_table().GetLocalClientId()); + JobID::Nil(), object_id, gcs_client_->client_table().GetLocalClientId(), + /*done*/ nullptr); } auto &listener_state = it->second; // TODO(hme): Make this fatal after implementing Pull suppression. @@ -187,7 +188,8 @@ ray::Status ObjectDirectory::UnsubscribeObjectLocations(const UniqueID &callback entry->second.callbacks.erase(callback_id); if (entry->second.callbacks.empty()) { status = gcs_client_->object_table().CancelNotifications( - JobID::Nil(), object_id, gcs_client_->client_table().GetLocalClientId()); + JobID::Nil(), object_id, gcs_client_->client_table().GetLocalClientId(), + /*done*/ nullptr); listeners_.erase(entry); } return status; diff --git a/src/ray/raylet/lineage_cache.cc b/src/ray/raylet/lineage_cache.cc index e000d7e55..30d0d2742 100644 --- a/src/ray/raylet/lineage_cache.cc +++ b/src/ray/raylet/lineage_cache.cc @@ -291,7 +291,8 @@ bool LineageCache::SubscribeTask(const TaskID &task_id) { if (unsubscribed) { // Request notifications for the task if we haven't already requested // notifications for it. - RAY_CHECK_OK(task_pubsub_.RequestNotifications(JobID::Nil(), task_id, client_id_)); + RAY_CHECK_OK(task_pubsub_.RequestNotifications(JobID::Nil(), task_id, client_id_, + /*done*/ nullptr)); } // Return whether we were previously unsubscribed to this task and are now // subscribed. @@ -304,7 +305,8 @@ bool LineageCache::UnsubscribeTask(const TaskID &task_id) { if (subscribed) { // Cancel notifications for the task if we previously requested // notifications for it. - RAY_CHECK_OK(task_pubsub_.CancelNotifications(JobID::Nil(), task_id, client_id_)); + RAY_CHECK_OK(task_pubsub_.CancelNotifications(JobID::Nil(), task_id, client_id_, + /*done*/ nullptr)); subscribed_tasks_.erase(it); } // Return whether we were previously subscribed to this task and are now diff --git a/src/ray/raylet/lineage_cache_test.cc b/src/ray/raylet/lineage_cache_test.cc index 1eb6c3cf4..e6760dd5f 100644 --- a/src/ray/raylet/lineage_cache_test.cc +++ b/src/ray/raylet/lineage_cache_test.cc @@ -7,8 +7,12 @@ #include "ray/common/task/task_execution_spec.h" #include "ray/common/task/task_spec.h" #include "ray/common/task/task_util.h" + +#include "ray/gcs/callback.h" + #include "ray/raylet/format/node_manager_generated.h" #include "ray/raylet/lineage_cache.h" + #include "ray/util/test_util.h" namespace ray { @@ -67,7 +71,8 @@ class MockGcs : public gcs::TableInterface, } Status RequestNotifications(const JobID &job_id, const TaskID &task_id, - const ClientID &client_id) { + const ClientID &client_id, + const gcs::StatusCallback &done) { subscribed_tasks_.insert(task_id); if (task_table_.count(task_id) == 1) { callbacks_.push_back({notification_callback_, task_id}); @@ -77,7 +82,7 @@ class MockGcs : public gcs::TableInterface, } Status CancelNotifications(const JobID &job_id, const TaskID &task_id, - const ClientID &client_id) { + const ClientID &client_id, const gcs::StatusCallback &done) { subscribed_tasks_.erase(task_id); return ray::Status::OK(); } diff --git a/src/ray/raylet/reconstruction_policy.cc b/src/ray/raylet/reconstruction_policy.cc index 23551b429..020f41437 100644 --- a/src/ray/raylet/reconstruction_policy.cc +++ b/src/ray/raylet/reconstruction_policy.cc @@ -53,7 +53,8 @@ void ReconstructionPolicy::SetTaskTimeout( // task is still required after this initial period, then we now // subscribe to task lease notifications. RAY_CHECK_OK(task_lease_pubsub_.RequestNotifications(JobID::Nil(), task_id, - client_id_)); + client_id_, + /*done*/ nullptr)); it->second.subscribed = true; } } else { @@ -200,8 +201,9 @@ void ReconstructionPolicy::Cancel(const ObjectID &object_id) { if (it->second.created_objects.empty()) { // Cancel notifications for the task lease if we were subscribed to them. if (it->second.subscribed) { - RAY_CHECK_OK( - task_lease_pubsub_.CancelNotifications(JobID::Nil(), task_id, client_id_)); + RAY_CHECK_OK(task_lease_pubsub_.CancelNotifications(JobID::Nil(), task_id, + client_id_, + /*done*/ nullptr)); } listening_tasks_.erase(it); } diff --git a/src/ray/raylet/reconstruction_policy_test.cc b/src/ray/raylet/reconstruction_policy_test.cc index 2052868b3..9a8ee6759 100644 --- a/src/ray/raylet/reconstruction_policy_test.cc +++ b/src/ray/raylet/reconstruction_policy_test.cc @@ -5,6 +5,8 @@ #include +#include "ray/gcs/callback.h" + #include "ray/raylet/format/node_manager_generated.h" #include "ray/raylet/reconstruction_policy.h" @@ -102,7 +104,8 @@ class MockGcs : public gcs::PubsubInterface, } Status RequestNotifications(const JobID &job_id, const TaskID &task_id, - const ClientID &client_id) { + const ClientID &client_id, + const gcs::StatusCallback &done) { subscribed_tasks_.insert(task_id); auto entry = task_lease_table_.find(task_id); if (entry == task_lease_table_.end()) { @@ -114,7 +117,7 @@ class MockGcs : public gcs::PubsubInterface, } Status CancelNotifications(const JobID &job_id, const TaskID &task_id, - const ClientID &client_id) { + const ClientID &client_id, const gcs::StatusCallback &done) { subscribed_tasks_.erase(task_id); return ray::Status::OK(); } diff --git a/src/ray/test/run_gcs_tests.sh b/src/ray/test/run_gcs_tests.sh index 4c780a1e3..d4c86f90f 100644 --- a/src/ray/test/run_gcs_tests.sh +++ b/src/ray/test/run_gcs_tests.sh @@ -6,7 +6,7 @@ set -e set -x -bazel build "//:redis_gcs_client_test" "//:actor_state_accessor_test" "//:asio_test" "//:libray_redis_module.so" +bazel build "//:redis_gcs_client_test" "//:actor_state_accessor_test" "//:subscription_executor_test" "//:asio_test" "//:libray_redis_module.so" # Start Redis. if [[ "${RAY_USE_NEW_GCS}" = "on" ]]; then @@ -25,6 +25,7 @@ sleep 1s ./bazel-bin/redis_gcs_client_test ./bazel-bin/actor_state_accessor_test +./bazel-bin/subscription_executor_test ./bazel-bin/asio_test ./bazel-genfiles/redis-cli -p 6379 shutdown diff --git a/src/ray/util/test_util.h b/src/ray/util/test_util.h index 37bddcc88..937dc0f06 100644 --- a/src/ray/util/test_util.h +++ b/src/ray/util/test_util.h @@ -19,8 +19,8 @@ bool WaitForCondition(std::function condition, int timeout_ms) { return true; } - // sleep 100ms. - const int wait_interval_ms = 100; + // sleep 10ms. + const int wait_interval_ms = 10; usleep(wait_interval_ms * 1000); wait_time += wait_interval_ms; if (wait_time > timeout_ms) {