mirror of
https://github.com/wassname/ray.git
synced 2026-06-27 17:49:47 +08:00
[Streaming] add barrier helper tests (#11107)
This commit is contained in:
@@ -278,6 +278,15 @@ cc_test(
|
||||
deps = test_common_deps,
|
||||
)
|
||||
|
||||
cc_test(
|
||||
name = "barrier_helper_tests",
|
||||
srcs = [
|
||||
"src/test/barrier_helper_tests.cc",
|
||||
],
|
||||
copts = COPTS,
|
||||
deps = test_common_deps,
|
||||
)
|
||||
|
||||
cc_test(
|
||||
name = "streaming_message_serialization_tests",
|
||||
srcs = [
|
||||
|
||||
@@ -7,11 +7,92 @@
|
||||
|
||||
namespace ray {
|
||||
namespace streaming {
|
||||
class StreamingBarrierHelper {
|
||||
class StreamingBarrierHelper final {
|
||||
using BarrierIdQueue = std::shared_ptr<std::queue<uint64_t>>;
|
||||
|
||||
public:
|
||||
StreamingBarrierHelper() {}
|
||||
/// No duplicated barrier helper should be loaded in data writer or data
|
||||
/// reader, so we mark BarrierHelper as a nocopyable object.
|
||||
StreamingBarrierHelper(const StreamingBarrierHelper &barrier_helper) = delete;
|
||||
|
||||
StreamingBarrierHelper operator=(const StreamingBarrierHelper &barrier_helper) = delete;
|
||||
|
||||
virtual ~StreamingBarrierHelper() = default;
|
||||
|
||||
/// Get barrier id from queue-barrier map by given seq-id.
|
||||
/// \param_in q_id, channel id
|
||||
/// \param_in barrier_id, barrier or checkpoint of long runtime job
|
||||
/// \param_out msg_id, message id of barrier
|
||||
StreamingStatus GetMsgIdByBarrierId(const ObjectID &q_id, uint64_t barrier_id,
|
||||
uint64_t &msg_id);
|
||||
|
||||
/// Append new message id to queue-barrier map.
|
||||
/// \param_in q_id, channel id
|
||||
/// \param_in barrier_id, barrier or checkpoint of long running job
|
||||
/// \param_in msg_id, message id of barrier
|
||||
void SetMsgIdByBarrierId(const ObjectID &q_id, uint64_t barrier_id, uint64_t msg_id);
|
||||
|
||||
/// Check whether barrier id in queue-barrier map.
|
||||
/// \param_in barrier_id, barrier id or checkpoint id
|
||||
bool Contains(uint64_t barrier_id);
|
||||
|
||||
/// Remove barrier info from queue-barrier map by given seq id.
|
||||
void ReleaseBarrierMapById(uint64_t barrier_id);
|
||||
|
||||
/// Remove all barrier info from queue-barrier map.
|
||||
void ReleaseAllBarrierMap();
|
||||
|
||||
/// Fetch barrier id list from queue-barrier map.
|
||||
void GetAllBarrier(std::vector<uint64_t> &barrier_id_vec);
|
||||
|
||||
/// Get barrier map capacity of current version.
|
||||
uint32_t GetBarrierMapSize();
|
||||
|
||||
/// We assume there are multiple barriers in one checkpoint, so barrier id
|
||||
/// should belong to a checkpoint id.
|
||||
/// \param_in barrier_id, barrier id
|
||||
/// \param_in checkpoint_id, checkpoint id
|
||||
void MapBarrierToCheckpoint(uint64_t barrier_id, uint64_t checkpoint_id);
|
||||
|
||||
/// Get checkpoint id by given barrier id
|
||||
/// \param_in barrier_id, barrier id
|
||||
/// \param_out checkpoint_id, checkpoint id
|
||||
StreamingStatus GetCheckpointIdByBarrierId(uint64_t barrier_id,
|
||||
uint64_t &checkpoint_id);
|
||||
|
||||
/// Clear barrier-checkpoint relation if elements of barrier id vector are
|
||||
/// equal to or less than given barrier id.
|
||||
/// \param_in barrier_id
|
||||
void ReleaseBarrierMapCheckpointByBarrierId(const uint64_t barrier_id);
|
||||
|
||||
/// Get barrier id by lastest message id and channel
|
||||
/// \param_in q_id, channel id
|
||||
/// \param_in message_id, lastest message id of barrier data
|
||||
/// \param_out barrier_id, barrier id
|
||||
/// \param_in is_pop, whether pop out from queue
|
||||
StreamingStatus GetBarrierIdByLastMessageId(const ObjectID &q_id, uint64_t message_id,
|
||||
uint64_t &barrier_id, bool is_pop = false);
|
||||
|
||||
/// Put new barrier id in map by channel index and lastest message id.
|
||||
/// \param_in q_id, channel id
|
||||
/// \param_in message_id, lastest message id of barrier data
|
||||
/// \param_in barrier_id, barrier id
|
||||
void SetBarrierIdByLastMessageId(const ObjectID &q_id, uint64_t message_id,
|
||||
uint64_t barrier_id);
|
||||
|
||||
/// \param_in q_id, channel id
|
||||
/// \param_in checkpoint_id, checkpoint id of long running job
|
||||
void GetCurrentMaxCheckpointIdInQueue(const ObjectID &q_id,
|
||||
uint64_t &checkpoint_id) const;
|
||||
|
||||
/// \param_in q_id, channel id
|
||||
/// \param_in checkpoint_id, checkpoint id of long running job
|
||||
void SetCurrentMaxCheckpointIdInQueue(const ObjectID &q_id,
|
||||
const uint64_t checkpoint_id);
|
||||
|
||||
private:
|
||||
// Global barrier map set (global barrier id -> (channel id -> msg id))
|
||||
// Global barrier map set (global barrier id -> (channel id -> seq id))
|
||||
std::unordered_map<uint64_t, std::unordered_map<ObjectID, uint64_t>>
|
||||
global_barrier_map_;
|
||||
|
||||
@@ -34,32 +115,6 @@ class StreamingBarrierHelper {
|
||||
std::mutex global_barrier_mutex_;
|
||||
|
||||
std::mutex barrier_map_checkpoint_mutex_;
|
||||
|
||||
public:
|
||||
StreamingStatus GetMsgIdByBarrierId(const ObjectID &q_id, uint64_t barrier_id,
|
||||
uint64_t &msg_id);
|
||||
void SetMsgIdByBarrierId(const ObjectID &q_id, uint64_t barrier_id, uint64_t seq_id);
|
||||
bool Contains(uint64_t barrier_id);
|
||||
void ReleaseBarrierMapById(uint64_t barrier_id);
|
||||
void ReleaseAllBarrierMap();
|
||||
void GetAllBarrier(std::vector<uint64_t> &barrier_id_vec);
|
||||
uint32_t GetBarrierMapSize();
|
||||
|
||||
void MapBarrierToCheckpoint(uint64_t barrier_id, uint64_t checkpoint);
|
||||
StreamingStatus GetCheckpointIdByBarrierId(uint64_t barrier_id,
|
||||
uint64_t &checkpoint_id);
|
||||
void ReleaseBarrierMapCheckpointByBarrierId(const uint64_t barrier_id);
|
||||
|
||||
StreamingStatus GetBarrierIdByLastMessageId(const ObjectID &q_id, uint64_t message_id,
|
||||
uint64_t &barrier_id, bool is_pop = false);
|
||||
void SetBarrierIdByLastMessageId(const ObjectID &q_id, uint64_t message_id,
|
||||
uint64_t barrier_id);
|
||||
|
||||
void GetCurrentMaxCheckpointIdInQueue(const ObjectID &q_id,
|
||||
uint64_t &checkpoint_id) const;
|
||||
|
||||
void SetCurrentMaxCheckpointIdInQueue(const ObjectID &q_id,
|
||||
const uint64_t checkpoint_id);
|
||||
};
|
||||
} // namespace streaming
|
||||
} // namespace ray
|
||||
|
||||
@@ -0,0 +1,157 @@
|
||||
#include "gtest/gtest.h"
|
||||
#include "reliability/barrier_helper.h"
|
||||
|
||||
using namespace ray::streaming;
|
||||
using namespace ray;
|
||||
|
||||
class StreamingBarrierHelperTest : public ::testing::Test {
|
||||
public:
|
||||
void SetUp() { barrier_helper_.reset(new StreamingBarrierHelper()); }
|
||||
void TearDown() { barrier_helper_.release(); }
|
||||
|
||||
protected:
|
||||
std::unique_ptr<StreamingBarrierHelper> barrier_helper_;
|
||||
const ObjectID random_id = ray::ObjectID::FromRandom();
|
||||
const ObjectID another_random_id = ray::ObjectID::FromRandom();
|
||||
};
|
||||
|
||||
TEST_F(StreamingBarrierHelperTest, MsgIdByBarrierId) {
|
||||
ASSERT_EQ(barrier_helper_->GetBarrierMapSize(), 0);
|
||||
uint64_t msg_id = 0;
|
||||
uint64_t init_msg_id = 10;
|
||||
ASSERT_EQ(StreamingStatus::NoSuchItem,
|
||||
barrier_helper_->GetMsgIdByBarrierId(random_id, 1, msg_id));
|
||||
|
||||
barrier_helper_->SetMsgIdByBarrierId(random_id, 1, init_msg_id);
|
||||
|
||||
ASSERT_EQ(StreamingStatus::QueueIdNotFound,
|
||||
barrier_helper_->GetMsgIdByBarrierId(another_random_id, 1, msg_id));
|
||||
|
||||
ASSERT_EQ(StreamingStatus::OK,
|
||||
barrier_helper_->GetMsgIdByBarrierId(random_id, 1, msg_id));
|
||||
ASSERT_EQ(init_msg_id, msg_id);
|
||||
|
||||
barrier_helper_->SetMsgIdByBarrierId(random_id, 2, init_msg_id + 1);
|
||||
|
||||
ASSERT_EQ(StreamingStatus::OK,
|
||||
barrier_helper_->GetMsgIdByBarrierId(random_id, 2, msg_id));
|
||||
ASSERT_EQ(init_msg_id + 1, msg_id);
|
||||
|
||||
ASSERT_EQ(barrier_helper_->GetBarrierMapSize(), 2);
|
||||
barrier_helper_->ReleaseBarrierMapById(1);
|
||||
ASSERT_EQ(barrier_helper_->GetBarrierMapSize(), 1);
|
||||
barrier_helper_->ReleaseAllBarrierMap();
|
||||
ASSERT_EQ(barrier_helper_->GetBarrierMapSize(), 0);
|
||||
}
|
||||
|
||||
TEST_F(StreamingBarrierHelperTest, BarrierIdByLastMessageId) {
|
||||
uint64_t barrier_id = 0;
|
||||
ASSERT_EQ(StreamingStatus::NoSuchItem,
|
||||
barrier_helper_->GetBarrierIdByLastMessageId(random_id, 1, barrier_id));
|
||||
|
||||
barrier_helper_->SetBarrierIdByLastMessageId(random_id, 1, 10);
|
||||
|
||||
ASSERT_EQ(
|
||||
StreamingStatus::QueueIdNotFound,
|
||||
barrier_helper_->GetBarrierIdByLastMessageId(another_random_id, 1, barrier_id));
|
||||
|
||||
ASSERT_EQ(StreamingStatus::OK,
|
||||
barrier_helper_->GetBarrierIdByLastMessageId(random_id, 1, barrier_id));
|
||||
ASSERT_EQ(barrier_id, 10);
|
||||
|
||||
barrier_helper_->SetBarrierIdByLastMessageId(random_id, 1, 11);
|
||||
ASSERT_EQ(StreamingStatus::OK,
|
||||
barrier_helper_->GetBarrierIdByLastMessageId(random_id, 1, barrier_id, true));
|
||||
ASSERT_EQ(barrier_id, 10);
|
||||
ASSERT_EQ(StreamingStatus::OK,
|
||||
barrier_helper_->GetBarrierIdByLastMessageId(random_id, 1, barrier_id, true));
|
||||
ASSERT_EQ(barrier_id, 11);
|
||||
ASSERT_EQ(StreamingStatus::NoSuchItem,
|
||||
barrier_helper_->GetBarrierIdByLastMessageId(random_id, 1, barrier_id, true));
|
||||
}
|
||||
|
||||
TEST_F(StreamingBarrierHelperTest, CheckpointId) {
|
||||
uint64_t checkpoint_id = static_cast<uint64_t>(-1);
|
||||
barrier_helper_->GetCurrentMaxCheckpointIdInQueue(random_id, checkpoint_id);
|
||||
ASSERT_EQ(checkpoint_id, 0);
|
||||
barrier_helper_->SetCurrentMaxCheckpointIdInQueue(random_id, 2);
|
||||
barrier_helper_->GetCurrentMaxCheckpointIdInQueue(random_id, checkpoint_id);
|
||||
ASSERT_EQ(checkpoint_id, 2);
|
||||
barrier_helper_->SetCurrentMaxCheckpointIdInQueue(random_id, 3);
|
||||
barrier_helper_->GetCurrentMaxCheckpointIdInQueue(random_id, checkpoint_id);
|
||||
ASSERT_EQ(checkpoint_id, 3);
|
||||
}
|
||||
|
||||
TEST(BarrierHelper, barrier_map_get_set) {
|
||||
StreamingBarrierHelper barrier_helper;
|
||||
ray::ObjectID channel_id = ray::ObjectID::FromRandom();
|
||||
uint64_t msg_id;
|
||||
auto status = barrier_helper.GetMsgIdByBarrierId(channel_id, 0, msg_id);
|
||||
EXPECT_TRUE(status == StreamingStatus::NoSuchItem);
|
||||
EXPECT_TRUE(barrier_helper.GetBarrierMapSize() == 0);
|
||||
|
||||
msg_id = 1;
|
||||
barrier_helper.SetMsgIdByBarrierId(channel_id, 0, msg_id);
|
||||
|
||||
uint64_t fetched_msg_id;
|
||||
status = barrier_helper.GetMsgIdByBarrierId(channel_id, 0, fetched_msg_id);
|
||||
EXPECT_TRUE(status == StreamingStatus::OK);
|
||||
EXPECT_TRUE(fetched_msg_id == msg_id);
|
||||
EXPECT_TRUE(barrier_helper.GetBarrierMapSize() == 1);
|
||||
|
||||
uint64_t fetched_no_barrier_id;
|
||||
status = barrier_helper.GetMsgIdByBarrierId(channel_id, 1, fetched_no_barrier_id);
|
||||
EXPECT_TRUE(status == StreamingStatus::NoSuchItem);
|
||||
|
||||
ray::ObjectID other_channel_id = ray::ObjectID::FromRandom();
|
||||
status = barrier_helper.GetMsgIdByBarrierId(other_channel_id, 0, fetched_msg_id);
|
||||
EXPECT_TRUE(status == StreamingStatus::QueueIdNotFound);
|
||||
|
||||
EXPECT_TRUE(barrier_helper.Contains(0));
|
||||
EXPECT_TRUE(!barrier_helper.Contains(1));
|
||||
|
||||
msg_id = 10;
|
||||
barrier_helper.SetMsgIdByBarrierId(channel_id, 1, msg_id);
|
||||
EXPECT_TRUE(barrier_helper.Contains(1));
|
||||
EXPECT_TRUE(barrier_helper.GetBarrierMapSize() == 2);
|
||||
|
||||
barrier_helper.ReleaseBarrierMapById(0);
|
||||
EXPECT_TRUE(!barrier_helper.Contains(0));
|
||||
EXPECT_TRUE(barrier_helper.GetBarrierMapSize() == 1);
|
||||
|
||||
msg_id = 20;
|
||||
barrier_helper.SetMsgIdByBarrierId(channel_id, 2, msg_id);
|
||||
std::vector<uint64_t> barrier_id_vec;
|
||||
barrier_helper.GetAllBarrier(barrier_id_vec);
|
||||
EXPECT_TRUE(barrier_id_vec.size() == 2);
|
||||
barrier_helper.ReleaseAllBarrierMap();
|
||||
EXPECT_TRUE(barrier_helper.GetBarrierMapSize() == 0);
|
||||
}
|
||||
|
||||
TEST(BarrierHelper, barrier_checkpoint_mapping) {
|
||||
StreamingBarrierHelper barrier_helper;
|
||||
ray::ObjectID channel_id = ray::ObjectID::FromRandom();
|
||||
uint64_t msg_id = 1;
|
||||
uint64_t barrier_id = 0;
|
||||
barrier_helper.SetMsgIdByBarrierId(channel_id, barrier_id, msg_id);
|
||||
uint64_t checkpoint_id = 100;
|
||||
barrier_helper.MapBarrierToCheckpoint(barrier_id, checkpoint_id);
|
||||
uint64_t fetched_checkpoint_id;
|
||||
barrier_helper.GetCheckpointIdByBarrierId(barrier_id, fetched_checkpoint_id);
|
||||
EXPECT_TRUE(fetched_checkpoint_id == checkpoint_id);
|
||||
|
||||
barrier_id = 2;
|
||||
barrier_helper.MapBarrierToCheckpoint(barrier_id, checkpoint_id);
|
||||
barrier_helper.GetCheckpointIdByBarrierId(barrier_id, fetched_checkpoint_id);
|
||||
EXPECT_TRUE(fetched_checkpoint_id == checkpoint_id);
|
||||
barrier_helper.ReleaseBarrierMapCheckpointByBarrierId(barrier_id);
|
||||
|
||||
auto status1 = barrier_helper.GetCheckpointIdByBarrierId(0, fetched_checkpoint_id);
|
||||
auto status2 = barrier_helper.GetCheckpointIdByBarrierId(2, fetched_checkpoint_id);
|
||||
EXPECT_TRUE(status1 == status2 && status1 == StreamingStatus::NoSuchItem);
|
||||
}
|
||||
|
||||
int main(int argc, char **argv) {
|
||||
::testing::InitGoogleTest(&argc, argv);
|
||||
return RUN_ALL_TESTS();
|
||||
}
|
||||
Reference in New Issue
Block a user