From c84a9b457c122e03b3d2fb0d58acedf96eae85f8 Mon Sep 17 00:00:00 2001 From: Lingxuan Zuo Date: Tue, 13 Oct 2020 09:55:55 +0800 Subject: [PATCH] [Streaming] add barrier helper tests (#11107) --- streaming/BUILD.bazel | 9 ++ streaming/src/reliability/barrier_helper.h | 111 +++++++++++---- streaming/src/test/barrier_helper_tests.cc | 157 +++++++++++++++++++++ 3 files changed, 249 insertions(+), 28 deletions(-) create mode 100644 streaming/src/test/barrier_helper_tests.cc diff --git a/streaming/BUILD.bazel b/streaming/BUILD.bazel index 175643069..97e73d788 100644 --- a/streaming/BUILD.bazel +++ b/streaming/BUILD.bazel @@ -278,6 +278,15 @@ cc_test( deps = test_common_deps, ) +cc_test( + name = "barrier_helper_tests", + srcs = [ + "src/test/barrier_helper_tests.cc", + ], + copts = COPTS, + deps = test_common_deps, +) + cc_test( name = "streaming_message_serialization_tests", srcs = [ diff --git a/streaming/src/reliability/barrier_helper.h b/streaming/src/reliability/barrier_helper.h index 1ceacc47a..23f0e11c8 100644 --- a/streaming/src/reliability/barrier_helper.h +++ b/streaming/src/reliability/barrier_helper.h @@ -7,11 +7,92 @@ namespace ray { namespace streaming { -class StreamingBarrierHelper { +class StreamingBarrierHelper final { using BarrierIdQueue = std::shared_ptr>; + public: + StreamingBarrierHelper() {} + /// No duplicated barrier helper should be loaded in data writer or data + /// reader, so we mark BarrierHelper as a nocopyable object. + StreamingBarrierHelper(const StreamingBarrierHelper &barrier_helper) = delete; + + StreamingBarrierHelper operator=(const StreamingBarrierHelper &barrier_helper) = delete; + + virtual ~StreamingBarrierHelper() = default; + + /// Get barrier id from queue-barrier map by given seq-id. + /// \param_in q_id, channel id + /// \param_in barrier_id, barrier or checkpoint of long runtime job + /// \param_out msg_id, message id of barrier + StreamingStatus GetMsgIdByBarrierId(const ObjectID &q_id, uint64_t barrier_id, + uint64_t &msg_id); + + /// Append new message id to queue-barrier map. + /// \param_in q_id, channel id + /// \param_in barrier_id, barrier or checkpoint of long running job + /// \param_in msg_id, message id of barrier + void SetMsgIdByBarrierId(const ObjectID &q_id, uint64_t barrier_id, uint64_t msg_id); + + /// Check whether barrier id in queue-barrier map. + /// \param_in barrier_id, barrier id or checkpoint id + bool Contains(uint64_t barrier_id); + + /// Remove barrier info from queue-barrier map by given seq id. + void ReleaseBarrierMapById(uint64_t barrier_id); + + /// Remove all barrier info from queue-barrier map. + void ReleaseAllBarrierMap(); + + /// Fetch barrier id list from queue-barrier map. + void GetAllBarrier(std::vector &barrier_id_vec); + + /// Get barrier map capacity of current version. + uint32_t GetBarrierMapSize(); + + /// We assume there are multiple barriers in one checkpoint, so barrier id + /// should belong to a checkpoint id. + /// \param_in barrier_id, barrier id + /// \param_in checkpoint_id, checkpoint id + void MapBarrierToCheckpoint(uint64_t barrier_id, uint64_t checkpoint_id); + + /// Get checkpoint id by given barrier id + /// \param_in barrier_id, barrier id + /// \param_out checkpoint_id, checkpoint id + StreamingStatus GetCheckpointIdByBarrierId(uint64_t barrier_id, + uint64_t &checkpoint_id); + + /// Clear barrier-checkpoint relation if elements of barrier id vector are + /// equal to or less than given barrier id. + /// \param_in barrier_id + void ReleaseBarrierMapCheckpointByBarrierId(const uint64_t barrier_id); + + /// Get barrier id by lastest message id and channel + /// \param_in q_id, channel id + /// \param_in message_id, lastest message id of barrier data + /// \param_out barrier_id, barrier id + /// \param_in is_pop, whether pop out from queue + StreamingStatus GetBarrierIdByLastMessageId(const ObjectID &q_id, uint64_t message_id, + uint64_t &barrier_id, bool is_pop = false); + + /// Put new barrier id in map by channel index and lastest message id. + /// \param_in q_id, channel id + /// \param_in message_id, lastest message id of barrier data + /// \param_in barrier_id, barrier id + void SetBarrierIdByLastMessageId(const ObjectID &q_id, uint64_t message_id, + uint64_t barrier_id); + + /// \param_in q_id, channel id + /// \param_in checkpoint_id, checkpoint id of long running job + void GetCurrentMaxCheckpointIdInQueue(const ObjectID &q_id, + uint64_t &checkpoint_id) const; + + /// \param_in q_id, channel id + /// \param_in checkpoint_id, checkpoint id of long running job + void SetCurrentMaxCheckpointIdInQueue(const ObjectID &q_id, + const uint64_t checkpoint_id); + private: - // Global barrier map set (global barrier id -> (channel id -> msg id)) + // Global barrier map set (global barrier id -> (channel id -> seq id)) std::unordered_map> global_barrier_map_; @@ -34,32 +115,6 @@ class StreamingBarrierHelper { std::mutex global_barrier_mutex_; std::mutex barrier_map_checkpoint_mutex_; - - public: - StreamingStatus GetMsgIdByBarrierId(const ObjectID &q_id, uint64_t barrier_id, - uint64_t &msg_id); - void SetMsgIdByBarrierId(const ObjectID &q_id, uint64_t barrier_id, uint64_t seq_id); - bool Contains(uint64_t barrier_id); - void ReleaseBarrierMapById(uint64_t barrier_id); - void ReleaseAllBarrierMap(); - void GetAllBarrier(std::vector &barrier_id_vec); - uint32_t GetBarrierMapSize(); - - void MapBarrierToCheckpoint(uint64_t barrier_id, uint64_t checkpoint); - StreamingStatus GetCheckpointIdByBarrierId(uint64_t barrier_id, - uint64_t &checkpoint_id); - void ReleaseBarrierMapCheckpointByBarrierId(const uint64_t barrier_id); - - StreamingStatus GetBarrierIdByLastMessageId(const ObjectID &q_id, uint64_t message_id, - uint64_t &barrier_id, bool is_pop = false); - void SetBarrierIdByLastMessageId(const ObjectID &q_id, uint64_t message_id, - uint64_t barrier_id); - - void GetCurrentMaxCheckpointIdInQueue(const ObjectID &q_id, - uint64_t &checkpoint_id) const; - - void SetCurrentMaxCheckpointIdInQueue(const ObjectID &q_id, - const uint64_t checkpoint_id); }; } // namespace streaming } // namespace ray diff --git a/streaming/src/test/barrier_helper_tests.cc b/streaming/src/test/barrier_helper_tests.cc new file mode 100644 index 000000000..35fa5434d --- /dev/null +++ b/streaming/src/test/barrier_helper_tests.cc @@ -0,0 +1,157 @@ +#include "gtest/gtest.h" +#include "reliability/barrier_helper.h" + +using namespace ray::streaming; +using namespace ray; + +class StreamingBarrierHelperTest : public ::testing::Test { + public: + void SetUp() { barrier_helper_.reset(new StreamingBarrierHelper()); } + void TearDown() { barrier_helper_.release(); } + + protected: + std::unique_ptr barrier_helper_; + const ObjectID random_id = ray::ObjectID::FromRandom(); + const ObjectID another_random_id = ray::ObjectID::FromRandom(); +}; + +TEST_F(StreamingBarrierHelperTest, MsgIdByBarrierId) { + ASSERT_EQ(barrier_helper_->GetBarrierMapSize(), 0); + uint64_t msg_id = 0; + uint64_t init_msg_id = 10; + ASSERT_EQ(StreamingStatus::NoSuchItem, + barrier_helper_->GetMsgIdByBarrierId(random_id, 1, msg_id)); + + barrier_helper_->SetMsgIdByBarrierId(random_id, 1, init_msg_id); + + ASSERT_EQ(StreamingStatus::QueueIdNotFound, + barrier_helper_->GetMsgIdByBarrierId(another_random_id, 1, msg_id)); + + ASSERT_EQ(StreamingStatus::OK, + barrier_helper_->GetMsgIdByBarrierId(random_id, 1, msg_id)); + ASSERT_EQ(init_msg_id, msg_id); + + barrier_helper_->SetMsgIdByBarrierId(random_id, 2, init_msg_id + 1); + + ASSERT_EQ(StreamingStatus::OK, + barrier_helper_->GetMsgIdByBarrierId(random_id, 2, msg_id)); + ASSERT_EQ(init_msg_id + 1, msg_id); + + ASSERT_EQ(barrier_helper_->GetBarrierMapSize(), 2); + barrier_helper_->ReleaseBarrierMapById(1); + ASSERT_EQ(barrier_helper_->GetBarrierMapSize(), 1); + barrier_helper_->ReleaseAllBarrierMap(); + ASSERT_EQ(barrier_helper_->GetBarrierMapSize(), 0); +} + +TEST_F(StreamingBarrierHelperTest, BarrierIdByLastMessageId) { + uint64_t barrier_id = 0; + ASSERT_EQ(StreamingStatus::NoSuchItem, + barrier_helper_->GetBarrierIdByLastMessageId(random_id, 1, barrier_id)); + + barrier_helper_->SetBarrierIdByLastMessageId(random_id, 1, 10); + + ASSERT_EQ( + StreamingStatus::QueueIdNotFound, + barrier_helper_->GetBarrierIdByLastMessageId(another_random_id, 1, barrier_id)); + + ASSERT_EQ(StreamingStatus::OK, + barrier_helper_->GetBarrierIdByLastMessageId(random_id, 1, barrier_id)); + ASSERT_EQ(barrier_id, 10); + + barrier_helper_->SetBarrierIdByLastMessageId(random_id, 1, 11); + ASSERT_EQ(StreamingStatus::OK, + barrier_helper_->GetBarrierIdByLastMessageId(random_id, 1, barrier_id, true)); + ASSERT_EQ(barrier_id, 10); + ASSERT_EQ(StreamingStatus::OK, + barrier_helper_->GetBarrierIdByLastMessageId(random_id, 1, barrier_id, true)); + ASSERT_EQ(barrier_id, 11); + ASSERT_EQ(StreamingStatus::NoSuchItem, + barrier_helper_->GetBarrierIdByLastMessageId(random_id, 1, barrier_id, true)); +} + +TEST_F(StreamingBarrierHelperTest, CheckpointId) { + uint64_t checkpoint_id = static_cast(-1); + barrier_helper_->GetCurrentMaxCheckpointIdInQueue(random_id, checkpoint_id); + ASSERT_EQ(checkpoint_id, 0); + barrier_helper_->SetCurrentMaxCheckpointIdInQueue(random_id, 2); + barrier_helper_->GetCurrentMaxCheckpointIdInQueue(random_id, checkpoint_id); + ASSERT_EQ(checkpoint_id, 2); + barrier_helper_->SetCurrentMaxCheckpointIdInQueue(random_id, 3); + barrier_helper_->GetCurrentMaxCheckpointIdInQueue(random_id, checkpoint_id); + ASSERT_EQ(checkpoint_id, 3); +} + +TEST(BarrierHelper, barrier_map_get_set) { + StreamingBarrierHelper barrier_helper; + ray::ObjectID channel_id = ray::ObjectID::FromRandom(); + uint64_t msg_id; + auto status = barrier_helper.GetMsgIdByBarrierId(channel_id, 0, msg_id); + EXPECT_TRUE(status == StreamingStatus::NoSuchItem); + EXPECT_TRUE(barrier_helper.GetBarrierMapSize() == 0); + + msg_id = 1; + barrier_helper.SetMsgIdByBarrierId(channel_id, 0, msg_id); + + uint64_t fetched_msg_id; + status = barrier_helper.GetMsgIdByBarrierId(channel_id, 0, fetched_msg_id); + EXPECT_TRUE(status == StreamingStatus::OK); + EXPECT_TRUE(fetched_msg_id == msg_id); + EXPECT_TRUE(barrier_helper.GetBarrierMapSize() == 1); + + uint64_t fetched_no_barrier_id; + status = barrier_helper.GetMsgIdByBarrierId(channel_id, 1, fetched_no_barrier_id); + EXPECT_TRUE(status == StreamingStatus::NoSuchItem); + + ray::ObjectID other_channel_id = ray::ObjectID::FromRandom(); + status = barrier_helper.GetMsgIdByBarrierId(other_channel_id, 0, fetched_msg_id); + EXPECT_TRUE(status == StreamingStatus::QueueIdNotFound); + + EXPECT_TRUE(barrier_helper.Contains(0)); + EXPECT_TRUE(!barrier_helper.Contains(1)); + + msg_id = 10; + barrier_helper.SetMsgIdByBarrierId(channel_id, 1, msg_id); + EXPECT_TRUE(barrier_helper.Contains(1)); + EXPECT_TRUE(barrier_helper.GetBarrierMapSize() == 2); + + barrier_helper.ReleaseBarrierMapById(0); + EXPECT_TRUE(!barrier_helper.Contains(0)); + EXPECT_TRUE(barrier_helper.GetBarrierMapSize() == 1); + + msg_id = 20; + barrier_helper.SetMsgIdByBarrierId(channel_id, 2, msg_id); + std::vector barrier_id_vec; + barrier_helper.GetAllBarrier(barrier_id_vec); + EXPECT_TRUE(barrier_id_vec.size() == 2); + barrier_helper.ReleaseAllBarrierMap(); + EXPECT_TRUE(barrier_helper.GetBarrierMapSize() == 0); +} + +TEST(BarrierHelper, barrier_checkpoint_mapping) { + StreamingBarrierHelper barrier_helper; + ray::ObjectID channel_id = ray::ObjectID::FromRandom(); + uint64_t msg_id = 1; + uint64_t barrier_id = 0; + barrier_helper.SetMsgIdByBarrierId(channel_id, barrier_id, msg_id); + uint64_t checkpoint_id = 100; + barrier_helper.MapBarrierToCheckpoint(barrier_id, checkpoint_id); + uint64_t fetched_checkpoint_id; + barrier_helper.GetCheckpointIdByBarrierId(barrier_id, fetched_checkpoint_id); + EXPECT_TRUE(fetched_checkpoint_id == checkpoint_id); + + barrier_id = 2; + barrier_helper.MapBarrierToCheckpoint(barrier_id, checkpoint_id); + barrier_helper.GetCheckpointIdByBarrierId(barrier_id, fetched_checkpoint_id); + EXPECT_TRUE(fetched_checkpoint_id == checkpoint_id); + barrier_helper.ReleaseBarrierMapCheckpointByBarrierId(barrier_id); + + auto status1 = barrier_helper.GetCheckpointIdByBarrierId(0, fetched_checkpoint_id); + auto status2 = barrier_helper.GetCheckpointIdByBarrierId(2, fetched_checkpoint_id); + EXPECT_TRUE(status1 == status2 && status1 == StreamingStatus::NoSuchItem); +} + +int main(int argc, char **argv) { + ::testing::InitGoogleTest(&argc, argv); + return RUN_ALL_TESTS(); +}