[Streaming] add barrier helper tests (#11107)

This commit is contained in:
Lingxuan Zuo
2020-10-13 09:55:55 +08:00
committed by GitHub
parent 6426fb3fff
commit c84a9b457c
3 changed files with 249 additions and 28 deletions
+9
View File
@@ -278,6 +278,15 @@ cc_test(
deps = test_common_deps,
)
cc_test(
name = "barrier_helper_tests",
srcs = [
"src/test/barrier_helper_tests.cc",
],
copts = COPTS,
deps = test_common_deps,
)
cc_test(
name = "streaming_message_serialization_tests",
srcs = [
+83 -28
View File
@@ -7,11 +7,92 @@
namespace ray {
namespace streaming {
class StreamingBarrierHelper {
class StreamingBarrierHelper final {
using BarrierIdQueue = std::shared_ptr<std::queue<uint64_t>>;
public:
StreamingBarrierHelper() {}
/// No duplicated barrier helper should be loaded in data writer or data
/// reader, so we mark BarrierHelper as a nocopyable object.
StreamingBarrierHelper(const StreamingBarrierHelper &barrier_helper) = delete;
StreamingBarrierHelper operator=(const StreamingBarrierHelper &barrier_helper) = delete;
virtual ~StreamingBarrierHelper() = default;
/// Get barrier id from queue-barrier map by given seq-id.
/// \param_in q_id, channel id
/// \param_in barrier_id, barrier or checkpoint of long runtime job
/// \param_out msg_id, message id of barrier
StreamingStatus GetMsgIdByBarrierId(const ObjectID &q_id, uint64_t barrier_id,
uint64_t &msg_id);
/// Append new message id to queue-barrier map.
/// \param_in q_id, channel id
/// \param_in barrier_id, barrier or checkpoint of long running job
/// \param_in msg_id, message id of barrier
void SetMsgIdByBarrierId(const ObjectID &q_id, uint64_t barrier_id, uint64_t msg_id);
/// Check whether barrier id in queue-barrier map.
/// \param_in barrier_id, barrier id or checkpoint id
bool Contains(uint64_t barrier_id);
/// Remove barrier info from queue-barrier map by given seq id.
void ReleaseBarrierMapById(uint64_t barrier_id);
/// Remove all barrier info from queue-barrier map.
void ReleaseAllBarrierMap();
/// Fetch barrier id list from queue-barrier map.
void GetAllBarrier(std::vector<uint64_t> &barrier_id_vec);
/// Get barrier map capacity of current version.
uint32_t GetBarrierMapSize();
/// We assume there are multiple barriers in one checkpoint, so barrier id
/// should belong to a checkpoint id.
/// \param_in barrier_id, barrier id
/// \param_in checkpoint_id, checkpoint id
void MapBarrierToCheckpoint(uint64_t barrier_id, uint64_t checkpoint_id);
/// Get checkpoint id by given barrier id
/// \param_in barrier_id, barrier id
/// \param_out checkpoint_id, checkpoint id
StreamingStatus GetCheckpointIdByBarrierId(uint64_t barrier_id,
uint64_t &checkpoint_id);
/// Clear barrier-checkpoint relation if elements of barrier id vector are
/// equal to or less than given barrier id.
/// \param_in barrier_id
void ReleaseBarrierMapCheckpointByBarrierId(const uint64_t barrier_id);
/// Get barrier id by lastest message id and channel
/// \param_in q_id, channel id
/// \param_in message_id, lastest message id of barrier data
/// \param_out barrier_id, barrier id
/// \param_in is_pop, whether pop out from queue
StreamingStatus GetBarrierIdByLastMessageId(const ObjectID &q_id, uint64_t message_id,
uint64_t &barrier_id, bool is_pop = false);
/// Put new barrier id in map by channel index and lastest message id.
/// \param_in q_id, channel id
/// \param_in message_id, lastest message id of barrier data
/// \param_in barrier_id, barrier id
void SetBarrierIdByLastMessageId(const ObjectID &q_id, uint64_t message_id,
uint64_t barrier_id);
/// \param_in q_id, channel id
/// \param_in checkpoint_id, checkpoint id of long running job
void GetCurrentMaxCheckpointIdInQueue(const ObjectID &q_id,
uint64_t &checkpoint_id) const;
/// \param_in q_id, channel id
/// \param_in checkpoint_id, checkpoint id of long running job
void SetCurrentMaxCheckpointIdInQueue(const ObjectID &q_id,
const uint64_t checkpoint_id);
private:
// Global barrier map set (global barrier id -> (channel id -> msg id))
// Global barrier map set (global barrier id -> (channel id -> seq id))
std::unordered_map<uint64_t, std::unordered_map<ObjectID, uint64_t>>
global_barrier_map_;
@@ -34,32 +115,6 @@ class StreamingBarrierHelper {
std::mutex global_barrier_mutex_;
std::mutex barrier_map_checkpoint_mutex_;
public:
StreamingStatus GetMsgIdByBarrierId(const ObjectID &q_id, uint64_t barrier_id,
uint64_t &msg_id);
void SetMsgIdByBarrierId(const ObjectID &q_id, uint64_t barrier_id, uint64_t seq_id);
bool Contains(uint64_t barrier_id);
void ReleaseBarrierMapById(uint64_t barrier_id);
void ReleaseAllBarrierMap();
void GetAllBarrier(std::vector<uint64_t> &barrier_id_vec);
uint32_t GetBarrierMapSize();
void MapBarrierToCheckpoint(uint64_t barrier_id, uint64_t checkpoint);
StreamingStatus GetCheckpointIdByBarrierId(uint64_t barrier_id,
uint64_t &checkpoint_id);
void ReleaseBarrierMapCheckpointByBarrierId(const uint64_t barrier_id);
StreamingStatus GetBarrierIdByLastMessageId(const ObjectID &q_id, uint64_t message_id,
uint64_t &barrier_id, bool is_pop = false);
void SetBarrierIdByLastMessageId(const ObjectID &q_id, uint64_t message_id,
uint64_t barrier_id);
void GetCurrentMaxCheckpointIdInQueue(const ObjectID &q_id,
uint64_t &checkpoint_id) const;
void SetCurrentMaxCheckpointIdInQueue(const ObjectID &q_id,
const uint64_t checkpoint_id);
};
} // namespace streaming
} // namespace ray
+157
View File
@@ -0,0 +1,157 @@
#include "gtest/gtest.h"
#include "reliability/barrier_helper.h"
using namespace ray::streaming;
using namespace ray;
class StreamingBarrierHelperTest : public ::testing::Test {
public:
void SetUp() { barrier_helper_.reset(new StreamingBarrierHelper()); }
void TearDown() { barrier_helper_.release(); }
protected:
std::unique_ptr<StreamingBarrierHelper> barrier_helper_;
const ObjectID random_id = ray::ObjectID::FromRandom();
const ObjectID another_random_id = ray::ObjectID::FromRandom();
};
TEST_F(StreamingBarrierHelperTest, MsgIdByBarrierId) {
ASSERT_EQ(barrier_helper_->GetBarrierMapSize(), 0);
uint64_t msg_id = 0;
uint64_t init_msg_id = 10;
ASSERT_EQ(StreamingStatus::NoSuchItem,
barrier_helper_->GetMsgIdByBarrierId(random_id, 1, msg_id));
barrier_helper_->SetMsgIdByBarrierId(random_id, 1, init_msg_id);
ASSERT_EQ(StreamingStatus::QueueIdNotFound,
barrier_helper_->GetMsgIdByBarrierId(another_random_id, 1, msg_id));
ASSERT_EQ(StreamingStatus::OK,
barrier_helper_->GetMsgIdByBarrierId(random_id, 1, msg_id));
ASSERT_EQ(init_msg_id, msg_id);
barrier_helper_->SetMsgIdByBarrierId(random_id, 2, init_msg_id + 1);
ASSERT_EQ(StreamingStatus::OK,
barrier_helper_->GetMsgIdByBarrierId(random_id, 2, msg_id));
ASSERT_EQ(init_msg_id + 1, msg_id);
ASSERT_EQ(barrier_helper_->GetBarrierMapSize(), 2);
barrier_helper_->ReleaseBarrierMapById(1);
ASSERT_EQ(barrier_helper_->GetBarrierMapSize(), 1);
barrier_helper_->ReleaseAllBarrierMap();
ASSERT_EQ(barrier_helper_->GetBarrierMapSize(), 0);
}
TEST_F(StreamingBarrierHelperTest, BarrierIdByLastMessageId) {
uint64_t barrier_id = 0;
ASSERT_EQ(StreamingStatus::NoSuchItem,
barrier_helper_->GetBarrierIdByLastMessageId(random_id, 1, barrier_id));
barrier_helper_->SetBarrierIdByLastMessageId(random_id, 1, 10);
ASSERT_EQ(
StreamingStatus::QueueIdNotFound,
barrier_helper_->GetBarrierIdByLastMessageId(another_random_id, 1, barrier_id));
ASSERT_EQ(StreamingStatus::OK,
barrier_helper_->GetBarrierIdByLastMessageId(random_id, 1, barrier_id));
ASSERT_EQ(barrier_id, 10);
barrier_helper_->SetBarrierIdByLastMessageId(random_id, 1, 11);
ASSERT_EQ(StreamingStatus::OK,
barrier_helper_->GetBarrierIdByLastMessageId(random_id, 1, barrier_id, true));
ASSERT_EQ(barrier_id, 10);
ASSERT_EQ(StreamingStatus::OK,
barrier_helper_->GetBarrierIdByLastMessageId(random_id, 1, barrier_id, true));
ASSERT_EQ(barrier_id, 11);
ASSERT_EQ(StreamingStatus::NoSuchItem,
barrier_helper_->GetBarrierIdByLastMessageId(random_id, 1, barrier_id, true));
}
TEST_F(StreamingBarrierHelperTest, CheckpointId) {
uint64_t checkpoint_id = static_cast<uint64_t>(-1);
barrier_helper_->GetCurrentMaxCheckpointIdInQueue(random_id, checkpoint_id);
ASSERT_EQ(checkpoint_id, 0);
barrier_helper_->SetCurrentMaxCheckpointIdInQueue(random_id, 2);
barrier_helper_->GetCurrentMaxCheckpointIdInQueue(random_id, checkpoint_id);
ASSERT_EQ(checkpoint_id, 2);
barrier_helper_->SetCurrentMaxCheckpointIdInQueue(random_id, 3);
barrier_helper_->GetCurrentMaxCheckpointIdInQueue(random_id, checkpoint_id);
ASSERT_EQ(checkpoint_id, 3);
}
TEST(BarrierHelper, barrier_map_get_set) {
StreamingBarrierHelper barrier_helper;
ray::ObjectID channel_id = ray::ObjectID::FromRandom();
uint64_t msg_id;
auto status = barrier_helper.GetMsgIdByBarrierId(channel_id, 0, msg_id);
EXPECT_TRUE(status == StreamingStatus::NoSuchItem);
EXPECT_TRUE(barrier_helper.GetBarrierMapSize() == 0);
msg_id = 1;
barrier_helper.SetMsgIdByBarrierId(channel_id, 0, msg_id);
uint64_t fetched_msg_id;
status = barrier_helper.GetMsgIdByBarrierId(channel_id, 0, fetched_msg_id);
EXPECT_TRUE(status == StreamingStatus::OK);
EXPECT_TRUE(fetched_msg_id == msg_id);
EXPECT_TRUE(barrier_helper.GetBarrierMapSize() == 1);
uint64_t fetched_no_barrier_id;
status = barrier_helper.GetMsgIdByBarrierId(channel_id, 1, fetched_no_barrier_id);
EXPECT_TRUE(status == StreamingStatus::NoSuchItem);
ray::ObjectID other_channel_id = ray::ObjectID::FromRandom();
status = barrier_helper.GetMsgIdByBarrierId(other_channel_id, 0, fetched_msg_id);
EXPECT_TRUE(status == StreamingStatus::QueueIdNotFound);
EXPECT_TRUE(barrier_helper.Contains(0));
EXPECT_TRUE(!barrier_helper.Contains(1));
msg_id = 10;
barrier_helper.SetMsgIdByBarrierId(channel_id, 1, msg_id);
EXPECT_TRUE(barrier_helper.Contains(1));
EXPECT_TRUE(barrier_helper.GetBarrierMapSize() == 2);
barrier_helper.ReleaseBarrierMapById(0);
EXPECT_TRUE(!barrier_helper.Contains(0));
EXPECT_TRUE(barrier_helper.GetBarrierMapSize() == 1);
msg_id = 20;
barrier_helper.SetMsgIdByBarrierId(channel_id, 2, msg_id);
std::vector<uint64_t> barrier_id_vec;
barrier_helper.GetAllBarrier(barrier_id_vec);
EXPECT_TRUE(barrier_id_vec.size() == 2);
barrier_helper.ReleaseAllBarrierMap();
EXPECT_TRUE(barrier_helper.GetBarrierMapSize() == 0);
}
TEST(BarrierHelper, barrier_checkpoint_mapping) {
StreamingBarrierHelper barrier_helper;
ray::ObjectID channel_id = ray::ObjectID::FromRandom();
uint64_t msg_id = 1;
uint64_t barrier_id = 0;
barrier_helper.SetMsgIdByBarrierId(channel_id, barrier_id, msg_id);
uint64_t checkpoint_id = 100;
barrier_helper.MapBarrierToCheckpoint(barrier_id, checkpoint_id);
uint64_t fetched_checkpoint_id;
barrier_helper.GetCheckpointIdByBarrierId(barrier_id, fetched_checkpoint_id);
EXPECT_TRUE(fetched_checkpoint_id == checkpoint_id);
barrier_id = 2;
barrier_helper.MapBarrierToCheckpoint(barrier_id, checkpoint_id);
barrier_helper.GetCheckpointIdByBarrierId(barrier_id, fetched_checkpoint_id);
EXPECT_TRUE(fetched_checkpoint_id == checkpoint_id);
barrier_helper.ReleaseBarrierMapCheckpointByBarrierId(barrier_id);
auto status1 = barrier_helper.GetCheckpointIdByBarrierId(0, fetched_checkpoint_id);
auto status2 = barrier_helper.GetCheckpointIdByBarrierId(2, fetched_checkpoint_id);
EXPECT_TRUE(status1 == status2 && status1 == StreamingStatus::NoSuchItem);
}
int main(int argc, char **argv) {
::testing::InitGoogleTest(&argc, argv);
return RUN_ALL_TESTS();
}