From 2f3ec4ef758ad482e46b3f4c72be9d10f46d7ae2 Mon Sep 17 00:00:00 2001 From: Lingxuan Zuo Date: Fri, 16 Oct 2020 14:42:44 +0800 Subject: [PATCH] add streaming data writer unit tests (#11387) --- streaming/BUILD.bazel | 9 ++ streaming/src/data_writer.h | 3 + streaming/src/test/data_writer_tests.cc | 138 ++++++++++++++++++++++++ 3 files changed, 150 insertions(+) create mode 100644 streaming/src/test/data_writer_tests.cc diff --git a/streaming/BUILD.bazel b/streaming/BUILD.bazel index 97e73d788..0cb747f80 100644 --- a/streaming/BUILD.bazel +++ b/streaming/BUILD.bazel @@ -331,6 +331,15 @@ cc_test( deps = test_common_deps, ) +cc_test( + name = "data_writer_tests", + srcs = [ + "src/test/data_writer_tests.cc", + ], + copts = COPTS, + deps = test_common_deps, +) + python_proto_compile( name = "streaming_py_proto", deps = [":streaming_proto"], diff --git a/streaming/src/data_writer.h b/streaming/src/data_writer.h index e2a18c334..07a7abb4d 100644 --- a/streaming/src/data_writer.h +++ b/streaming/src/data_writer.h @@ -34,6 +34,9 @@ class ReliabilityHelper; /// buffers have no data in that moment. class DataWriter { public: + // For mock writer accessing inner fields. + friend class MockWriter; + explicit DataWriter(std::shared_ptr &runtime_context); virtual ~DataWriter(); diff --git a/streaming/src/test/data_writer_tests.cc b/streaming/src/test/data_writer_tests.cc new file mode 100644 index 000000000..61594e98e --- /dev/null +++ b/streaming/src/test/data_writer_tests.cc @@ -0,0 +1,138 @@ +#include "data_writer.h" +#include "gtest/gtest.h" + +namespace ray { +namespace streaming { +void GenRandomChannelIdVector(std::vector &input_ids, int n) { + for (int i = 0; i < n; ++i) { + input_ids.push_back(ObjectID::FromRandom()); + } +} + +class MockWriter : public DataWriter { + public: + friend class MockWriterTest; + MockWriter(std::shared_ptr runtime_context) + : DataWriter(runtime_context) {} + void Init(const std::vector &input_channel_vec) { + output_queue_ids_ = input_channel_vec; + for (size_t i = 0; i < input_channel_vec.size(); ++i) { + const ChannelCreationParameter param; + InitChannel(input_channel_vec[i], param, 0, 0xfff); + } + reliability_helper_ = ReliabilityHelperFactory::CreateReliabilityHelper( + runtime_context_->GetConfig(), barrier_helper_, this, nullptr); + event_service_ = std::make_shared(); + runtime_context_->SetRuntimeStatus(RuntimeStatus::Running); + event_service_->Run(); + } + + void Destroy() { + event_service_->Stop(); + event_service_.reset(); + } + + bool IsMessageAvailableInBuffer(const ObjectID &id) { + return DataWriter::IsMessageAvailableInBuffer(channel_info_map_[id]); + } + + std::unordered_map &GetChannelInfoMap() { + return channel_info_map_; + }; + + bool CollectFromRingBuffer(const ObjectID &id, uint64_t &buffer_remain) { + return DataWriter::CollectFromRingBuffer(channel_info_map_[id], buffer_remain); + } + + StreamingStatus WriteBufferToChannel(const ObjectID &id, uint64_t &buffer_remain) { + return DataWriter::WriteBufferToChannel(channel_info_map_[id], buffer_remain); + } + + void BroadcastBarrier(uint64_t barrier_id) { + static const uint8_t barrier_data[] = {1, 2, 3, 4}; + DataWriter::BroadcastBarrier(barrier_id, barrier_data, 4); + } + + uint64_t WriteMessageToBufferRing(const ObjectID &channel_id, uint8_t *data, + uint32_t data_size) { + return DataWriter::WriteMessageToBufferRing(channel_id, data, data_size); + } +}; + +class MockWriterTest : public ::testing::Test { + protected: + virtual void SetUp() override { + runtime_context.reset(new RuntimeContext()); + runtime_context->SetConfig(config); + runtime_context->MarkMockTest(); + mock_writer.reset(new MockWriter(runtime_context)); + } + virtual void TearDown() override { mock_writer->Destroy(); } + + protected: + std::shared_ptr runtime_context; + StreamingConfig config; + std::shared_ptr mock_writer; + std::vector input_ids; +}; + +TEST_F(MockWriterTest, test_message_avaliablie_in_buffer) { + int channel_num = 5; + GenRandomChannelIdVector(input_ids, channel_num); + mock_writer->Init(input_ids); + for (const auto &id : input_ids) { + EXPECT_TRUE(!mock_writer->IsMessageAvailableInBuffer(id)); + } + mock_writer->BroadcastBarrier(0); + for (const auto &id : input_ids) { + EXPECT_TRUE(mock_writer->IsMessageAvailableInBuffer(id)); + } +} + +uint8_t data[] = {0x01, 0x02, 0x0f, 0xe, 0x00}; +uint32_t data_size = 5; + +TEST_F(MockWriterTest, test_write_message_to_buffer_ring) { + int channel_num = 2; + GenRandomChannelIdVector(input_ids, channel_num); + mock_writer->Init(input_ids); + for (const auto &id : input_ids) { + EXPECT_TRUE(!mock_writer->IsMessageAvailableInBuffer(id)); + } + mock_writer->WriteMessageToBufferRing(input_ids[0], data, data_size); + EXPECT_TRUE(mock_writer->IsMessageAvailableInBuffer(input_ids[0])); + EXPECT_TRUE(!mock_writer->IsMessageAvailableInBuffer(input_ids[1])); +} + +TEST_F(MockWriterTest, test_collecting_buffer) { + int channel_num = 1; + GenRandomChannelIdVector(input_ids, channel_num); + mock_writer->Init(input_ids); + mock_writer->WriteMessageToBufferRing(input_ids[0], data, data_size); + uint64_t buffer_remain; + mock_writer->CollectFromRingBuffer(input_ids[0], buffer_remain); + EXPECT_TRUE(buffer_remain == 0); + EXPECT_TRUE(mock_writer->IsMessageAvailableInBuffer(input_ids[0])); + EXPECT_TRUE(mock_writer->GetChannelInfoMap()[input_ids[0]] + .writer_ring_buffer->IsTransientAvaliable()); +} + +TEST_F(MockWriterTest, test_write_to_transfer) { + int channel_num = 1; + GenRandomChannelIdVector(input_ids, channel_num); + mock_writer->Init(input_ids); + mock_writer->WriteMessageToBufferRing(input_ids[0], data, data_size); + uint64_t buffer_remain; + EXPECT_EQ(mock_writer->WriteBufferToChannel(input_ids[0], buffer_remain), + StreamingStatus::OK); + EXPECT_TRUE(buffer_remain == 0); + EXPECT_TRUE(!mock_writer->IsMessageAvailableInBuffer(input_ids[0])); +} + +} // namespace streaming +} // namespace ray + +int main(int argc, char **argv) { + ::testing::InitGoogleTest(&argc, argv); + return RUN_ALL_TESTS(); +}