diff --git a/streaming/BUILD.bazel b/streaming/BUILD.bazel index c7744528d..269e10433 100644 --- a/streaming/BUILD.bazel +++ b/streaming/BUILD.bazel @@ -239,6 +239,14 @@ cc_test( deps = test_common_deps, ) +cc_test( + name = "queue_protobuf_tests", + srcs = [ + "src/test/queue_protobuf_tests.cc", + ], + deps = test_common_deps, +) + python_proto_compile( name = "streaming_py_proto", deps = [":streaming_proto"], diff --git a/streaming/python/runtime/transfer.py b/streaming/python/runtime/transfer.py index 4aac89569..5a19bec9a 100644 --- a/streaming/python/runtime/transfer.py +++ b/streaming/python/runtime/transfer.py @@ -174,8 +174,8 @@ class ChannelCreationParametersBuilder: def __init__(self): self._parameters = [] - def build_input_queue_parameters(self, queue_ids_dict): - self.build_parameters(queue_ids_dict, + def build_input_queue_parameters(self, from_actors): + self.build_parameters(from_actors, self._java_writer_async_function_descriptor, self._java_writer_sync_function_descriptor, self._python_writer_async_function_descriptor, diff --git a/streaming/src/channel.cc b/streaming/src/channel.cc index a5652a7c7..6816bf972 100644 --- a/streaming/src/channel.cc +++ b/streaming/src/channel.cc @@ -57,7 +57,7 @@ StreamingStatus StreamingQueueProducer::CreateQueue() { << " data_size: " << channel_info_.queue_size; auto upstream_handler = ray::streaming::UpstreamQueueMessageHandler::GetService(); if (upstream_handler->UpstreamQueueExists(channel_info_.channel_id)) { - RAY_LOG(INFO) << "StreamingQueueWriter::CreateQueue duplicate!!!"; + STREAMING_LOG(INFO) << "StreamingQueueProducer CreateQueue duplicate."; return StreamingStatus::OK; } @@ -100,8 +100,9 @@ StreamingStatus StreamingQueueProducer::NotifyChannelConsumed(uint64_t channel_o StreamingStatus StreamingQueueProducer::ProduceItemToChannel(uint8_t *data, uint32_t data_size) { - Status status = - PushQueueItem(channel_info_.current_seq_id + 1, data, data_size, current_time_ms()); + /// TODO: Fix msg_id_start and msg_id_end + Status status = PushQueueItem(channel_info_.current_seq_id + 1, data, data_size, + current_time_ms(), 0, 0); if (status.code() != StatusCode::OK) { STREAMING_LOG(DEBUG) << channel_info_.channel_id << " => Queue is full" @@ -120,11 +121,13 @@ StreamingStatus StreamingQueueProducer::ProduceItemToChannel(uint8_t *data, } Status StreamingQueueProducer::PushQueueItem(uint64_t seq_id, uint8_t *data, - uint32_t data_size, uint64_t timestamp) { - STREAMING_LOG(INFO) << "StreamingQueueProducer::PushQueueItem:" - << " qid: " << channel_info_.channel_id << " seq_id: " << seq_id - << " data_size: " << data_size; - Status status = queue_->Push(seq_id, data, data_size, timestamp, false); + uint32_t data_size, uint64_t timestamp, + uint64_t msg_id_start, uint64_t msg_id_end) { + STREAMING_LOG(DEBUG) << "StreamingQueueProducer::PushQueueItem:" + << " qid: " << channel_info_.channel_id << " seq_id: " << seq_id + << " data_size: " << data_size; + Status status = + queue_->Push(seq_id, data, data_size, timestamp, msg_id_start, msg_id_end, false); if (status.IsOutOfMemory()) { status = queue_->TryEvictItems(); if (!status.ok()) { @@ -132,7 +135,8 @@ Status StreamingQueueProducer::PushQueueItem(uint64_t seq_id, uint8_t *data, return status; } - status = queue_->Push(seq_id, data, data_size, timestamp, false); + status = + queue_->Push(seq_id, data, data_size, timestamp, msg_id_start, msg_id_end, false); } queue_->Send(); @@ -149,24 +153,45 @@ StreamingQueueConsumer::~StreamingQueueConsumer() { STREAMING_LOG(INFO) << "Consumer Destroy"; } -StreamingStatus StreamingQueueConsumer::CreateTransferChannel() { +StreamingQueueStatus StreamingQueueConsumer::GetQueue( + const ObjectID &queue_id, uint64_t start_msg_id, + const ChannelCreationParameter &init_param) { + STREAMING_LOG(INFO) << "GetQueue qid: " << queue_id << " start_msg_id: " << start_msg_id + << " actor_id: " << init_param.actor_id; auto downstream_handler = ray::streaming::DownstreamQueueMessageHandler::GetService(); - STREAMING_LOG(INFO) << "GetQueue qid: " << channel_info_.channel_id - << " start_seq_id: " << channel_info_.current_seq_id + 1; - if (downstream_handler->DownstreamQueueExists(channel_info_.channel_id)) { - RAY_LOG(INFO) << "StreamingQueueReader::GetQueue duplicate!!!"; - return StreamingStatus::OK; + if (downstream_handler->DownstreamQueueExists(queue_id)) { + STREAMING_LOG(INFO) << "StreamingQueueReader:: Already got this queue."; + return StreamingQueueStatus::OK; } - downstream_handler->SetPeerActorID( - channel_info_.channel_id, channel_info_.parameter.actor_id, - *channel_info_.parameter.async_function, *channel_info_.parameter.sync_function); - STREAMING_LOG(INFO) << "Create ReaderQueue " << channel_info_.channel_id - << " pull from start_seq_id: " << channel_info_.current_seq_id + 1; - queue_ = downstream_handler->CreateDownstreamQueue(channel_info_.channel_id, - channel_info_.parameter.actor_id); + downstream_handler->SetPeerActorID(queue_id, channel_info_.parameter.actor_id, + *init_param.async_function, + *init_param.sync_function); + STREAMING_LOG(INFO) << "Create ReaderQueue " << queue_id + << " pull from start_msg_id: " << start_msg_id; + queue_ = downstream_handler->CreateDownstreamQueue(queue_id, init_param.actor_id); + STREAMING_CHECK(queue_ != nullptr); - return StreamingStatus::OK; + bool is_first_pull; + return downstream_handler->PullQueue(queue_id, start_msg_id, is_first_pull); +} + +TransferCreationStatus StreamingQueueConsumer::CreateTransferChannel() { + StreamingQueueStatus status = + GetQueue(channel_info_.channel_id, channel_info_.current_seq_id + 1, + channel_info_.parameter); + + if (status == StreamingQueueStatus::OK) { + return TransferCreationStatus::PullOk; + } else if (status == StreamingQueueStatus::NoValidData) { + return TransferCreationStatus::FreshStarted; + } else if (status == StreamingQueueStatus::Timeout) { + return TransferCreationStatus::Timeout; + } else if (status == StreamingQueueStatus::DataLost) { + return TransferCreationStatus::DataLost; + } + STREAMING_LOG(FATAL) << "Invalid StreamingQueueStatus, status=" << status; + return TransferCreationStatus::Invalid; } StreamingStatus StreamingQueueConsumer::DestroyTransferChannel() { diff --git a/streaming/src/channel.h b/streaming/src/channel.h index 2588687a4..ee52c2a74 100644 --- a/streaming/src/channel.h +++ b/streaming/src/channel.h @@ -9,6 +9,14 @@ namespace ray { namespace streaming { +enum class TransferCreationStatus : uint32_t { + FreshStarted = 0, + PullOk = 1, + Timeout = 2, + DataLost = 3, + Invalid = 999, +}; + struct StreamingQueueInfo { uint64_t first_seq_id = 0; uint64_t last_seq_id = 0; @@ -98,7 +106,7 @@ class ConsumerChannel { explicit ConsumerChannel(std::shared_ptr &transfer_config, ConsumerChannelInfo &c_channel_info); virtual ~ConsumerChannel() = default; - virtual StreamingStatus CreateTransferChannel() = 0; + virtual TransferCreationStatus CreateTransferChannel() = 0; virtual StreamingStatus DestroyTransferChannel() = 0; virtual StreamingStatus ClearTransferCheckpoint(uint64_t checkpoint_id, uint64_t checkpoint_offset) = 0; @@ -129,7 +137,7 @@ class StreamingQueueProducer : public ProducerChannel { private: StreamingStatus CreateQueue(); Status PushQueueItem(uint64_t seq_id, uint8_t *data, uint32_t data_size, - uint64_t timestamp); + uint64_t timestamp, uint64_t msg_id_start, uint64_t msg_id_end); private: std::shared_ptr queue_; @@ -140,7 +148,7 @@ class StreamingQueueConsumer : public ConsumerChannel { explicit StreamingQueueConsumer(std::shared_ptr &transfer_config, ConsumerChannelInfo &c_channel_info); ~StreamingQueueConsumer() override; - StreamingStatus CreateTransferChannel() override; + TransferCreationStatus CreateTransferChannel() override; StreamingStatus DestroyTransferChannel() override; StreamingStatus ClearTransferCheckpoint(uint64_t checkpoint_id, uint64_t checkpoint_offset) override; @@ -149,6 +157,10 @@ class StreamingQueueConsumer : public ConsumerChannel { uint32_t &data_size, uint32_t timeout) override; StreamingStatus NotifyChannelConsumed(uint64_t offset_id) override; + private: + StreamingQueueStatus GetQueue(const ObjectID &queue_id, uint64_t start_msg_id, + const ChannelCreationParameter &init_param); + private: std::shared_ptr queue_; }; @@ -183,7 +195,7 @@ class MockConsumer : public ConsumerChannel { explicit MockConsumer(std::shared_ptr &transfer_config, ConsumerChannelInfo &c_channel_info) : ConsumerChannel(transfer_config, c_channel_info){}; - StreamingStatus CreateTransferChannel() override { return StreamingStatus::OK; } + TransferCreationStatus CreateTransferChannel() override { return TransferCreationStatus::PullOk; } StreamingStatus DestroyTransferChannel() override { return StreamingStatus::OK; } StreamingStatus ClearTransferCheckpoint(uint64_t checkpoint_id, uint64_t checkpoint_offset) override { diff --git a/streaming/src/data_reader.cc b/streaming/src/data_reader.cc index b1a98f8c0..ca59cf987 100644 --- a/streaming/src/data_reader.cc +++ b/streaming/src/data_reader.cc @@ -77,8 +77,8 @@ StreamingStatus DataReader::InitChannel() { } channel_map_.emplace(input_channel, channel); - StreamingStatus status = channel->CreateTransferChannel(); - if (StreamingStatus::OK != status) { + TransferCreationStatus status = channel->CreateTransferChannel(); + if (TransferCreationStatus::PullOk != status) { STREAMING_LOG(ERROR) << "Initialize queue failed, id => " << input_channel; } } diff --git a/streaming/src/protobuf/streaming_queue.proto b/streaming/src/protobuf/streaming_queue.proto index d0eea2c2c..0fb260001 100644 --- a/streaming/src/protobuf/streaming_queue.proto +++ b/streaming/src/protobuf/streaming_queue.proto @@ -9,41 +9,45 @@ enum StreamingQueueMessageType { StreamingQueueNotificationMsgType = 3; StreamingQueueTestInitMsgType = 4; StreamingQueueTestCheckStatusRspMsgType = 5; + StreamingQueuePullRequestMsgType = 6; + StreamingQueuePullResponseMsgType = 7; + StreamingQueueResendDataMsgType = 8; } enum StreamingQueueError { OK = 0; QUEUE_NOT_EXIST = 1; - NO_VALID_DATA_TO_PULL = 2; + DATA_LOST = 2; + NO_VALID_DATA = 3; } -message StreamingQueueDataMsg { +message MessageCommon { bytes src_actor_id = 1; bytes dst_actor_id = 2; bytes queue_id = 3; - uint64 seq_id = 4; +} + +message StreamingQueueDataMsg { + MessageCommon common = 1; + uint64 seq_id = 2; + uint64 msg_id_start = 3; + uint64 msg_id_end = 4; uint64 length = 5; bool raw = 6; } message StreamingQueueCheckMsg { - bytes src_actor_id = 1; - bytes dst_actor_id = 2; - bytes queue_id = 3; + MessageCommon common = 1; } message StreamingQueueCheckRspMsg { - bytes src_actor_id = 1; - bytes dst_actor_id = 2; - bytes queue_id = 3; - StreamingQueueError err_code = 4; + MessageCommon common = 1; + StreamingQueueError err_code = 2; } message StreamingQueueNotificationMsg { - bytes src_actor_id = 1; - bytes dst_actor_id = 2; - bytes queue_id = 3; - uint64 seq_id = 4; + MessageCommon common = 1; + uint64 seq_id = 2; } // for test @@ -67,4 +71,28 @@ message StreamingQueueTestInitMsg { message StreamingQueueTestCheckStatusRspMsg { string test_name = 1; bool status = 2; -} \ No newline at end of file +} + +message StreamingQueuePullRequestMsg { + MessageCommon common = 1; + uint64 msg_id = 2; +} + +message StreamingQueuePullResponseMsg { + MessageCommon common = 1; + uint64 seq_id = 2; + uint64 msg_id = 3; + StreamingQueueError err_code = 4; + bool is_upstream_first_pull = 5; +} + +message StreamingQueueResendDataMsg { + MessageCommon common = 1; + uint64 first_seq_id = 2; + uint64 last_seq_id = 3; + uint64 seq_id = 4; + uint64 msg_id_start = 5; + uint64 msg_id_end = 6; + uint64 length = 7; + bool raw = 8; +} diff --git a/streaming/src/queue/message.cc b/streaming/src/queue/message.cc index 2cee9d203..f7d4d7874 100644 --- a/streaming/src/queue/message.cc +++ b/streaming/src/queue/message.cc @@ -13,7 +13,7 @@ std::unique_ptr Message::ToBytes() { queue::protobuf::StreamingQueueMessageType type = Type(); size_t total_len = - sizeof(Message::MagicNum) + sizeof(type) + sizeof(fbs_length) + fbs_length; + kItemHeaderSize + fbs_length; if (buffer_ != nullptr) { total_len += buffer_->Size(); } @@ -45,28 +45,35 @@ std::unique_ptr Message::ToBytes() { return buffer; } + +void Message::FillMessageCommon(queue::protobuf::MessageCommon *common) { + common->set_src_actor_id(actor_id_.Binary()); + common->set_dst_actor_id(peer_actor_id_.Binary()); + common->set_queue_id(queue_id_.Binary()); +} + void DataMessage::ToProtobuf(std::string *output) { queue::protobuf::StreamingQueueDataMsg msg; - msg.set_src_actor_id(actor_id_.Binary()); - msg.set_dst_actor_id(peer_actor_id_.Binary()); - msg.set_queue_id(queue_id_.Binary()); + FillMessageCommon(msg.mutable_common()); msg.set_seq_id(seq_id_); + msg.set_msg_id_start(msg_id_start_); + msg.set_msg_id_end(msg_id_end_); msg.set_length(buffer_->Size()); msg.set_raw(raw_); msg.SerializeToString(output); } std::shared_ptr DataMessage::FromBytes(uint8_t *bytes) { - bytes += sizeof(uint32_t) + sizeof(queue::protobuf::StreamingQueueMessageType); - uint64_t *fbs_length = (uint64_t *)bytes; - bytes += sizeof(uint64_t); - + uint64_t *fbs_length = (uint64_t *)(bytes + kItemMetaHeaderSize); + bytes += kItemHeaderSize; std::string inputpb(reinterpret_cast(bytes), *fbs_length); queue::protobuf::StreamingQueueDataMsg message; message.ParseFromString(inputpb); - ActorID src_actor_id = ActorID::FromBinary(message.src_actor_id()); - ActorID dst_actor_id = ActorID::FromBinary(message.dst_actor_id()); - ObjectID queue_id = ObjectID::FromBinary(message.queue_id()); + ActorID src_actor_id = ActorID::FromBinary(message.common().src_actor_id()); + ActorID dst_actor_id = ActorID::FromBinary(message.common().dst_actor_id()); + ObjectID queue_id = ObjectID::FromBinary(message.common().queue_id()); + uint64_t msg_id_start = message.msg_id_start(); + uint64_t msg_id_end = message.msg_id_end(); uint64_t seq_id = message.seq_id(); uint64_t length = message.length(); bool raw = message.raw(); @@ -76,32 +83,27 @@ std::shared_ptr DataMessage::FromBytes(uint8_t *bytes) { std::shared_ptr buffer = std::make_shared(bytes, (size_t)length, true); std::shared_ptr data_msg = std::make_shared( - src_actor_id, dst_actor_id, queue_id, seq_id, buffer, raw); + src_actor_id, dst_actor_id, queue_id, seq_id, msg_id_start, msg_id_end, buffer, raw); return data_msg; } void NotificationMessage::ToProtobuf(std::string *output) { queue::protobuf::StreamingQueueNotificationMsg msg; - msg.set_src_actor_id(actor_id_.Binary()); - msg.set_dst_actor_id(peer_actor_id_.Binary()); - msg.set_queue_id(queue_id_.Binary()); + FillMessageCommon(msg.mutable_common()); msg.set_seq_id(seq_id_); msg.SerializeToString(output); } std::shared_ptr NotificationMessage::FromBytes(uint8_t *bytes) { - bytes += sizeof(uint32_t) + sizeof(queue::protobuf::StreamingQueueMessageType); - uint64_t *length = (uint64_t *)bytes; - bytes += sizeof(uint64_t); - + uint64_t *length = (uint64_t *)(bytes + kItemMetaHeaderSize); + bytes += kItemHeaderSize; std::string inputpb(reinterpret_cast(bytes), *length); queue::protobuf::StreamingQueueNotificationMsg message; message.ParseFromString(inputpb); - STREAMING_LOG(INFO) << "message.src_actor_id: " << message.src_actor_id(); - ActorID src_actor_id = ActorID::FromBinary(message.src_actor_id()); - ActorID dst_actor_id = ActorID::FromBinary(message.dst_actor_id()); - ObjectID queue_id = ObjectID::FromBinary(message.queue_id()); + ActorID src_actor_id = ActorID::FromBinary(message.common().src_actor_id()); + ActorID dst_actor_id = ActorID::FromBinary(message.common().dst_actor_id()); + ObjectID queue_id = ObjectID::FromBinary(message.common().queue_id()); uint64_t seq_id = message.seq_id(); std::shared_ptr notify_msg = @@ -112,23 +114,19 @@ std::shared_ptr NotificationMessage::FromBytes(uint8_t *byt void CheckMessage::ToProtobuf(std::string *output) { queue::protobuf::StreamingQueueCheckMsg msg; - msg.set_src_actor_id(actor_id_.Binary()); - msg.set_dst_actor_id(peer_actor_id_.Binary()); - msg.set_queue_id(queue_id_.Binary()); + FillMessageCommon(msg.mutable_common()); msg.SerializeToString(output); } std::shared_ptr CheckMessage::FromBytes(uint8_t *bytes) { - bytes += sizeof(uint32_t) + sizeof(queue::protobuf::StreamingQueueMessageType); - uint64_t *length = (uint64_t *)bytes; - bytes += sizeof(uint64_t); - + uint64_t *length = (uint64_t *)(bytes + kItemMetaHeaderSize); + bytes += kItemHeaderSize; std::string inputpb(reinterpret_cast(bytes), *length); queue::protobuf::StreamingQueueCheckMsg message; message.ParseFromString(inputpb); - ActorID src_actor_id = ActorID::FromBinary(message.src_actor_id()); - ActorID dst_actor_id = ActorID::FromBinary(message.dst_actor_id()); - ObjectID queue_id = ObjectID::FromBinary(message.queue_id()); + ActorID src_actor_id = ActorID::FromBinary(message.common().src_actor_id()); + ActorID dst_actor_id = ActorID::FromBinary(message.common().dst_actor_id()); + ObjectID queue_id = ObjectID::FromBinary(message.common().queue_id()); std::shared_ptr check_msg = std::make_shared(src_actor_id, dst_actor_id, queue_id); @@ -138,24 +136,20 @@ std::shared_ptr CheckMessage::FromBytes(uint8_t *bytes) { void CheckRspMessage::ToProtobuf(std::string *output) { queue::protobuf::StreamingQueueCheckRspMsg msg; - msg.set_src_actor_id(actor_id_.Binary()); - msg.set_dst_actor_id(peer_actor_id_.Binary()); - msg.set_queue_id(queue_id_.Binary()); + FillMessageCommon(msg.mutable_common()); msg.set_err_code(err_code_); msg.SerializeToString(output); } std::shared_ptr CheckRspMessage::FromBytes(uint8_t *bytes) { - bytes += sizeof(uint32_t) + sizeof(queue::protobuf::StreamingQueueMessageType); - uint64_t *length = (uint64_t *)bytes; - bytes += sizeof(uint64_t); - + uint64_t *length = (uint64_t *)(bytes + kItemMetaHeaderSize); + bytes += kItemHeaderSize; std::string inputpb(reinterpret_cast(bytes), *length); queue::protobuf::StreamingQueueCheckRspMsg message; message.ParseFromString(inputpb); - ActorID src_actor_id = ActorID::FromBinary(message.src_actor_id()); - ActorID dst_actor_id = ActorID::FromBinary(message.dst_actor_id()); - ObjectID queue_id = ObjectID::FromBinary(message.queue_id()); + ActorID src_actor_id = ActorID::FromBinary(message.common().src_actor_id()); + ActorID dst_actor_id = ActorID::FromBinary(message.common().dst_actor_id()); + ObjectID queue_id = ObjectID::FromBinary(message.common().queue_id()); queue::protobuf::StreamingQueueError err_code = message.err_code(); std::shared_ptr check_rsp_msg = @@ -164,6 +158,117 @@ std::shared_ptr CheckRspMessage::FromBytes(uint8_t *bytes) { return check_rsp_msg; } +void PullRequestMessage::ToProtobuf(std::string *output) { + queue::protobuf::StreamingQueuePullRequestMsg msg; + FillMessageCommon(msg.mutable_common()); + msg.set_msg_id(msg_id_); + msg.SerializeToString(output); +} + +std::shared_ptr PullRequestMessage::FromBytes(uint8_t *bytes) { + uint64_t *length = (uint64_t *)(bytes + kItemMetaHeaderSize); + bytes += kItemHeaderSize; + std::string inputpb(reinterpret_cast(bytes), *length); + queue::protobuf::StreamingQueuePullRequestMsg message; + message.ParseFromString(inputpb); + ActorID src_actor_id = ActorID::FromBinary(message.common().src_actor_id()); + ActorID dst_actor_id = ActorID::FromBinary(message.common().dst_actor_id()); + ObjectID queue_id = ObjectID::FromBinary(message.common().queue_id()); + uint64_t msg_id = message.msg_id(); + STREAMING_LOG(DEBUG) << "src_actor_id:" << src_actor_id + << " dst_actor_id:" << dst_actor_id << " queue_id:" << queue_id + << " msg_id:" << msg_id; + + std::shared_ptr pull_msg = + std::make_shared(src_actor_id, dst_actor_id, queue_id, msg_id); + return pull_msg; +} + +void PullResponseMessage::ToProtobuf(std::string *output) { + queue::protobuf::StreamingQueuePullResponseMsg msg; + FillMessageCommon(msg.mutable_common()); + msg.set_seq_id(seq_id_); + msg.set_msg_id(msg_id_); + msg.set_err_code(err_code_); + msg.set_is_upstream_first_pull(is_upstream_first_pull_); + msg.SerializeToString(output); +} + +std::shared_ptr PullResponseMessage::FromBytes(uint8_t *bytes) { + uint64_t *length = (uint64_t *)(bytes + kItemMetaHeaderSize); + bytes += kItemHeaderSize; + std::string inputpb(reinterpret_cast(bytes), *length); + queue::protobuf::StreamingQueuePullResponseMsg message; + message.ParseFromString(inputpb); + ActorID src_actor_id = ActorID::FromBinary(message.common().src_actor_id()); + ActorID dst_actor_id = ActorID::FromBinary(message.common().dst_actor_id()); + ObjectID queue_id = ObjectID::FromBinary(message.common().queue_id()); + uint64_t seq_id = message.seq_id(); + uint64_t msg_id = message.msg_id(); + queue::protobuf::StreamingQueueError err_code = message.err_code(); + bool is_upstream_first_pull = message.is_upstream_first_pull(); + + STREAMING_LOG(INFO) << "src_actor_id:" << src_actor_id + << " dst_actor_id:" << dst_actor_id << " queue_id:" << queue_id + << " seq_id: " << seq_id << " msg_id: " << msg_id << " err_code:" + << queue::protobuf::StreamingQueueError_Name(err_code) + << " is_upstream_first_pull: " << is_upstream_first_pull; + + std::shared_ptr pull_rsp_msg = + std::make_shared(src_actor_id, dst_actor_id, queue_id, seq_id, + msg_id, err_code, is_upstream_first_pull); + + return pull_rsp_msg; +} + +void ResendDataMessage::ToProtobuf(std::string *output) { + queue::protobuf::StreamingQueueResendDataMsg msg; + FillMessageCommon(msg.mutable_common()); + msg.set_first_seq_id(first_seq_id_); + msg.set_last_seq_id(last_seq_id_); + msg.set_seq_id(seq_id_); + msg.set_msg_id_start(msg_id_start_); + msg.set_msg_id_end(msg_id_end_); + msg.set_length(buffer_->Size()); + msg.set_raw(raw_); + msg.SerializeToString(output); +} + +std::shared_ptr ResendDataMessage::FromBytes(uint8_t *bytes) { + uint64_t *fbs_length = (uint64_t *)(bytes + kItemMetaHeaderSize); + bytes += kItemHeaderSize; + std::string inputpb(reinterpret_cast(bytes), *fbs_length); + queue::protobuf::StreamingQueueResendDataMsg message; + message.ParseFromString(inputpb); + ActorID src_actor_id = ActorID::FromBinary(message.common().src_actor_id()); + ActorID dst_actor_id = ActorID::FromBinary(message.common().dst_actor_id()); + ObjectID queue_id = ObjectID::FromBinary(message.common().queue_id()); + uint64_t first_seq_id = message.first_seq_id(); + uint64_t last_seq_id = message.last_seq_id(); + uint64_t seq_id = message.seq_id(); + uint64_t msg_id_start = message.msg_id_start(); + uint64_t msg_id_end = message.msg_id_end(); + uint64_t length = message.length(); + bool raw = message.raw(); + + STREAMING_LOG(DEBUG) << "src_actor_id:" << src_actor_id + << " dst_actor_id:" << dst_actor_id + << " first_seq_id:" << first_seq_id << " seq_id:" << seq_id + << " msg_id_start: " << msg_id_start + << " msg_id_end: " << msg_id_end << " last_seq_id:" << last_seq_id + << " queue_id:" << queue_id << " length:" << length; + + bytes += *fbs_length; + /// COPY + std::shared_ptr buffer = + std::make_shared(bytes, (size_t)length, true); + std::shared_ptr pull_data_msg = std::make_shared( + src_actor_id, dst_actor_id, queue_id, first_seq_id, seq_id, msg_id_start, + msg_id_end, last_seq_id, buffer, raw); + + return pull_data_msg; +} + void TestInitMessage::ToProtobuf(std::string *output) { queue::protobuf::StreamingQueueTestInitMsg msg; msg.set_role(role_); @@ -183,10 +288,8 @@ void TestInitMessage::ToProtobuf(std::string *output) { } std::shared_ptr TestInitMessage::FromBytes(uint8_t *bytes) { - bytes += sizeof(uint32_t) + sizeof(queue::protobuf::StreamingQueueMessageType); - uint64_t *length = (uint64_t *)bytes; - bytes += sizeof(uint64_t); - + uint64_t *length = (uint64_t *)(bytes + kItemMetaHeaderSize); + bytes += kItemHeaderSize; std::string inputpb(reinterpret_cast(bytes), *length); queue::protobuf::StreamingQueueTestInitMsg message; message.ParseFromString(inputpb); @@ -221,10 +324,8 @@ void TestCheckStatusRspMsg::ToProtobuf(std::string *output) { } std::shared_ptr TestCheckStatusRspMsg::FromBytes(uint8_t *bytes) { - bytes += sizeof(uint32_t) + sizeof(queue::protobuf::StreamingQueueMessageType); - uint64_t *length = (uint64_t *)bytes; - bytes += sizeof(uint64_t); - + uint64_t *length = (uint64_t *)(bytes + kItemMetaHeaderSize); + bytes += kItemHeaderSize; std::string inputpb(reinterpret_cast(bytes), *length); queue::protobuf::StreamingQueueTestCheckStatusRspMsg message; message.ParseFromString(inputpb); @@ -237,4 +338,4 @@ std::shared_ptr TestCheckStatusRspMsg::FromBytes(uint8_t return test_check_msg; } } // namespace streaming -} // namespace ray \ No newline at end of file +} // namespace ray diff --git a/streaming/src/queue/message.h b/streaming/src/queue/message.h index 9f9205ae7..50c80a9fa 100644 --- a/streaming/src/queue/message.h +++ b/streaming/src/queue/message.h @@ -28,10 +28,10 @@ class Message { buffer_(buffer) {} Message() {} virtual ~Message() {} - ActorID ActorId() { return actor_id_; } - ActorID PeerActorId() { return peer_actor_id_; } - ObjectID QueueId() { return queue_id_; } - std::shared_ptr Buffer() { return buffer_; } + inline ActorID ActorId() { return actor_id_; } + inline ActorID PeerActorId() { return peer_actor_id_; } + inline ObjectID QueueId() { return queue_id_; } + inline std::shared_ptr Buffer() { return buffer_; } /// Serialize all meta data and data to a LocalMemoryBuffer, which can be sent through /// direct actor call. \return serialized buffer . @@ -44,6 +44,7 @@ class Message { /// All subclasses should implement `ToProtobuf` to serialize its own protobuf data. virtual void ToProtobuf(std::string *output) = 0; + void FillMessageCommon(queue::protobuf::MessageCommon *common); protected: ActorID actor_id_; ActorID peer_actor_id_; @@ -55,24 +56,36 @@ class Message { static const uint32_t MagicNum; }; +/// MagicNum + MessageType +constexpr uint32_t kItemMetaHeaderSize = + sizeof(Message::MagicNum) + sizeof(queue::protobuf::StreamingQueueMessageType); +/// kItemMetaHeaderSize + fbs length +constexpr uint32_t kItemHeaderSize = kItemMetaHeaderSize + sizeof(uint64_t); + /// Wrap StreamingQueueDataMsg in streaming_queue.proto. /// DataMessage encapsulates the memory buffer of QueueItem, a one-to-one relationship /// exists between DataMessage and QueueItem. class DataMessage : public Message { public: DataMessage(const ActorID &actor_id, const ActorID &peer_actor_id, ObjectID queue_id, - uint64_t seq_id, std::shared_ptr buffer, bool raw) - : Message(actor_id, peer_actor_id, queue_id, buffer), seq_id_(seq_id), raw_(raw) {} + uint64_t seq_id, uint64_t msg_id_start, uint64_t msg_id_end, std::shared_ptr buffer, bool raw) + : Message(actor_id, peer_actor_id, queue_id, buffer), seq_id_(seq_id), + msg_id_start_(msg_id_start), + msg_id_end_(msg_id_end),raw_(raw) {} virtual ~DataMessage() {} static std::shared_ptr FromBytes(uint8_t *bytes); virtual void ToProtobuf(std::string *output); - uint64_t SeqId() { return seq_id_; } - bool IsRaw() { return raw_; } - queue::protobuf::StreamingQueueMessageType Type() { return type_; } + inline uint64_t SeqId() { return seq_id_; } + inline uint64_t MsgIdStart() { return msg_id_start_; } + inline uint64_t MsgIdEnd() { return msg_id_end_; } + inline bool IsRaw() { return raw_; } + inline queue::protobuf::StreamingQueueMessageType Type() { return type_; } private: uint64_t seq_id_; + uint64_t msg_id_start_; + uint64_t msg_id_end_; bool raw_; const queue::protobuf::StreamingQueueMessageType type_ = @@ -93,8 +106,8 @@ class NotificationMessage : public Message { static std::shared_ptr FromBytes(uint8_t *bytes); virtual void ToProtobuf(std::string *output); - uint64_t SeqId() { return seq_id_; } - queue::protobuf::StreamingQueueMessageType Type() { return type_; } + inline uint64_t SeqId() { return seq_id_; } + inline queue::protobuf::StreamingQueueMessageType Type() { return type_; } private: uint64_t seq_id_; @@ -115,7 +128,7 @@ class CheckMessage : public Message { static std::shared_ptr FromBytes(uint8_t *bytes); virtual void ToProtobuf(std::string *output); - queue::protobuf::StreamingQueueMessageType Type() { return type_; } + inline queue::protobuf::StreamingQueueMessageType Type() { return type_; } private: const queue::protobuf::StreamingQueueMessageType type_ = @@ -134,8 +147,8 @@ class CheckRspMessage : public Message { static std::shared_ptr FromBytes(uint8_t *bytes); virtual void ToProtobuf(std::string *output); - queue::protobuf::StreamingQueueMessageType Type() { return type_; } - queue::protobuf::StreamingQueueError Error() { return err_code_; } + inline queue::protobuf::StreamingQueueMessageType Type() { return type_; } + inline queue::protobuf::StreamingQueueError Error() { return err_code_; } private: queue::protobuf::StreamingQueueError err_code_; @@ -143,6 +156,91 @@ class CheckRspMessage : public Message { queue::protobuf::StreamingQueueMessageType::StreamingQueueCheckRspMsgType; }; +class PullRequestMessage : public Message { + public: + PullRequestMessage(const ActorID &actor_id, const ActorID &peer_actor_id, + const ObjectID &queue_id, uint64_t msg_id) + : Message(actor_id, peer_actor_id, queue_id), msg_id_(msg_id) {} + virtual ~PullRequestMessage() {} + + static std::shared_ptr FromBytes(uint8_t *bytes); + virtual void ToProtobuf(std::string *output); + inline uint64_t MsgId() { return msg_id_; } + inline queue::protobuf::StreamingQueueMessageType Type() { return type_; } + + private: + uint64_t msg_id_; + const queue::protobuf::StreamingQueueMessageType type_ = + queue::protobuf::StreamingQueueMessageType::StreamingQueuePullRequestMsgType; +}; + +class PullResponseMessage : public Message { + public: + PullResponseMessage(const ActorID &actor_id, const ActorID &peer_actor_id, + const ObjectID &queue_id, uint64_t seq_id, uint64_t msg_id, + queue::protobuf::StreamingQueueError err_code, + bool is_upstream_first_pull) + : Message(actor_id, peer_actor_id, queue_id), + seq_id_(seq_id), + msg_id_(msg_id), + is_upstream_first_pull_(is_upstream_first_pull), + err_code_(err_code) {} + virtual ~PullResponseMessage() = default; + + static std::shared_ptr FromBytes(uint8_t *bytes); + virtual void ToProtobuf(std::string *output); + inline uint64_t SeqId() { return seq_id_; } + inline uint64_t MsgId() { return msg_id_; } + inline queue::protobuf::StreamingQueueMessageType Type() { return type_; } + inline queue::protobuf::StreamingQueueError Error() { return err_code_; } + inline bool IsUpstreamFirstPull() { return is_upstream_first_pull_; } + + private: + uint64_t seq_id_; + uint64_t msg_id_; + bool is_upstream_first_pull_; + queue::protobuf::StreamingQueueError err_code_; + const queue::protobuf::StreamingQueueMessageType type_ = + queue::protobuf::StreamingQueueMessageType::StreamingQueuePullResponseMsgType; +}; + +class ResendDataMessage : public Message { + public: + ResendDataMessage(const ActorID &actor_id, const ActorID &peer_actor_id, + ObjectID queue_id, uint64_t first_seq_id, uint64_t seq_id, + uint64_t msg_id_start, uint64_t msg_id_end, uint64_t last_seq_id, + std::shared_ptr buffer, bool raw) + : Message(actor_id, peer_actor_id, queue_id, buffer), + first_seq_id_(first_seq_id), + last_seq_id_(last_seq_id), + seq_id_(seq_id), + msg_id_start_(msg_id_start), + msg_id_end_(msg_id_end), + raw_(raw) {} + virtual ~ResendDataMessage() {} + + static std::shared_ptr FromBytes(uint8_t *bytes); + virtual void ToProtobuf(std::string *output); + inline uint64_t MsgIdStart() { return msg_id_start_; } + inline uint64_t MsgIdEnd() { return msg_id_end_; } + inline uint64_t SeqId() { return seq_id_; } + inline uint64_t FirstSeqId() { return first_seq_id_; } + inline uint64_t LastSeqId() { return last_seq_id_; } + inline bool IsRaw() { return raw_; } + inline queue::protobuf::StreamingQueueMessageType Type() { return type_; } + + private: + uint64_t first_seq_id_; + uint64_t last_seq_id_; + uint64_t seq_id_; + uint64_t msg_id_start_; + uint64_t msg_id_end_; + bool raw_; + + const queue::protobuf::StreamingQueueMessageType type_ = + queue::protobuf::StreamingQueueMessageType::StreamingQueueResendDataMsgType; +}; + /// Wrap StreamingQueueTestInitMsg in streaming_queue.proto. /// TestInitMessage, used for test, driver sends to test workers to init test suite. class TestInitMessage : public Message { @@ -165,14 +263,14 @@ class TestInitMessage : public Message { static std::shared_ptr FromBytes(uint8_t *bytes); virtual void ToProtobuf(std::string *output); - queue::protobuf::StreamingQueueMessageType Type() { return type_; } - std::string ActorHandleSerialized() { return actor_handle_serialized_; } - queue::protobuf::StreamingQueueTestRole Role() { return role_; } - std::vector QueueIds() { return queue_ids_; } - std::vector RescaleQueueIds() { return rescale_queue_ids_; } - std::string TestSuiteName() { return test_suite_name_; } - std::string TestName() { return test_name_; } - uint64_t Param() { return param_; } + inline queue::protobuf::StreamingQueueMessageType Type() { return type_; } + inline std::string ActorHandleSerialized() { return actor_handle_serialized_; } + inline queue::protobuf::StreamingQueueTestRole Role() { return role_; } + inline std::vector QueueIds() { return queue_ids_; } + inline std::vector RescaleQueueIds() { return rescale_queue_ids_; } + inline std::string TestSuiteName() { return test_suite_name_; } + inline std::string TestName() { return test_name_; } + inline uint64_t Param() { return param_; } std::string ToString() { std::ostringstream os; @@ -218,9 +316,9 @@ class TestCheckStatusRspMsg : public Message { static std::shared_ptr FromBytes(uint8_t *bytes); virtual void ToProtobuf(std::string *output); - queue::protobuf::StreamingQueueMessageType Type() { return type_; } - std::string TestName() { return test_name_; } - bool Status() { return status_; } + inline queue::protobuf::StreamingQueueMessageType Type() { return type_; } + inline std::string TestName() { return test_name_; } + inline bool Status() { return status_; } private: const queue::protobuf::StreamingQueueMessageType type_ = diff --git a/streaming/src/queue/queue.cc b/streaming/src/queue/queue.cc index 3f9e60585..308ac6b76 100644 --- a/streaming/src/queue/queue.cc +++ b/streaming/src/queue/queue.cc @@ -69,30 +69,18 @@ QueueItem Queue::PopPendingBlockTimeout(uint64_t timeout_us) { return item; } else { - uint8_t data[1]; - return QueueItem(QUEUE_INVALID_SEQ_ID, data, 1, 0, true); + return InvalidQueueItem(); } } QueueItem Queue::BackPending() { std::unique_lock lock(mutex_); if (std::next(watershed_iter_) == buffer_queue_.end()) { - uint8_t data[1]; - return QueueItem(QUEUE_INVALID_SEQ_ID, data, 1, 0, true); + return InvalidQueueItem(); } return buffer_queue_.back(); } -bool Queue::IsPendingEmpty() { - std::unique_lock lock(mutex_); - return std::next(watershed_iter_) == buffer_queue_.end(); -} - -bool Queue::IsPendingFull(uint64_t data_size) { - std::unique_lock lock(mutex_); - return max_data_size_ < data_size + data_size_; -} - size_t Queue::ProcessedCount() { std::unique_lock lock(mutex_); if (watershed_iter_ == buffer_queue_.begin()) return 0; @@ -113,27 +101,28 @@ size_t Queue::PendingCount() { return begin->SeqId() - end->SeqId() + 1; } -Status WriterQueue::Push(uint64_t seq_id, uint8_t *data, uint32_t data_size, - uint64_t timestamp, bool raw) { - if (IsPendingFull(data_size)) { +Status WriterQueue::Push(uint64_t seq_id, uint8_t *buffer, uint32_t buffer_size, + uint64_t timestamp, uint64_t msg_id_start, uint64_t msg_id_end, bool raw) { + if (IsPendingFull(buffer_size)) { return Status::OutOfMemory("Queue Push OutOfMemory"); } - while (is_pulling_) { - STREAMING_LOG(INFO) << "This queue is sending pull data, wait."; + while (is_resending_) { + STREAMING_LOG(INFO) << "This queue is resending data, wait."; std::this_thread::sleep_for(std::chrono::milliseconds(10)); } - QueueItem item(seq_id, data, data_size, timestamp, raw); + QueueItem item(seq_id, buffer, buffer_size, timestamp, msg_id_start, msg_id_end, raw); Queue::Push(item); + STREAMING_LOG(DEBUG) << "WriterQueue::Push seq_id_: " << seq_id_; + seq_id_++; return Status::OK(); } void WriterQueue::Send() { while (!IsPendingEmpty()) { - // FIXME: front -> send -> pop QueueItem item = PopPending(); - DataMessage msg(actor_id_, peer_actor_id_, queue_id_, item.SeqId(), item.Buffer(), + DataMessage msg(actor_id_, peer_actor_id_, queue_id_, item.SeqId(), item.MsgIdStart(), item.MsgIdEnd(), item.Buffer(), item.IsRaw()); std::unique_ptr buffer = msg.ToBytes(); STREAMING_CHECK(transport_ != nullptr); @@ -171,6 +160,115 @@ void WriterQueue::OnNotify(std::shared_ptr notify_msg) { min_consumed_id_ = notify_msg->SeqId(); } +void WriterQueue::ResendItem(QueueItem &item, uint64_t first_seq_id, + uint64_t last_seq_id) { + ResendDataMessage msg(actor_id_, peer_actor_id_, queue_id_, first_seq_id, item.SeqId(), + item.MsgIdStart(), item.MsgIdEnd(), last_seq_id, item.Buffer(), + item.IsRaw()); + STREAMING_CHECK(item.Buffer()->Data() != nullptr); + std::unique_ptr buffer = msg.ToBytes(); + + transport_->Send(std::move(buffer)); +} + +int WriterQueue::ResendItems(std::list::iterator start_iter, + uint64_t first_seq_id, uint64_t last_seq_id) { + std::unique_lock lock(mutex_); + int count = 0; + auto it = start_iter; + for (; it != watershed_iter_; it++) { + if (it->SeqId() > last_seq_id) { + break; + } + STREAMING_LOG(INFO) << "ResendItems send seq_id " << it->SeqId() << " to peer."; + ResendItem(*it, first_seq_id, last_seq_id); + count++; + } + + STREAMING_LOG(INFO) << "ResendItems total count: " << count; + is_resending_ = false; + return count; +} + +void WriterQueue::FindItem( + uint64_t target_msg_id, std::function greater_callback, std::function less_callback, + std::function::iterator, uint64_t, uint64_t)> equal_callback) { + auto last_one = std::prev(watershed_iter_); + bool last_item_too_small = + last_one != buffer_queue_.end() && last_one->MsgIdEnd() < target_msg_id; + + if (QUEUE_INITIAL_SEQ_ID == seq_id_ || last_item_too_small) { + greater_callback(); + return; + } + + auto begin = buffer_queue_.begin(); + uint64_t first_seq_id = (*begin).SeqId(); + uint64_t last_seq_id = first_seq_id + std::distance(begin, watershed_iter_) - 1; + STREAMING_LOG(INFO) << "FindItem last_seq_id: " << last_seq_id + << " first_seq_id: " << first_seq_id; + + auto target_item = std::find_if( + begin, watershed_iter_, + [&target_msg_id](QueueItem &item) { return item.InItem(target_msg_id); }); + + if (target_item != watershed_iter_) { + equal_callback(target_item, first_seq_id, last_seq_id); + } else { + less_callback(); + } +} + +void WriterQueue::OnPull( + std::shared_ptr pull_msg, boost::asio::io_service &service, + std::function)> callback) { + std::unique_lock lock(mutex_); + STREAMING_CHECK(peer_actor_id_ == pull_msg->ActorId()) + << peer_actor_id_ << " " << pull_msg->ActorId(); + + FindItem(pull_msg->MsgId(), + /// target_msg_id is too large. + [this, &pull_msg, &callback]() { + STREAMING_LOG(WARNING) + << "No valid data to pull, the writer has not push data yet. "; + PullResponseMessage msg(pull_msg->PeerActorId(), pull_msg->ActorId(), + pull_msg->QueueId(), QUEUE_INVALID_SEQ_ID, + QUEUE_INVALID_SEQ_ID, + queue::protobuf::StreamingQueueError::NO_VALID_DATA, + is_upstream_first_pull_); + std::unique_ptr buffer = msg.ToBytes(); + is_upstream_first_pull_ = false; + callback(std::move(buffer)); + }, + /// target_msg_id is too small. + [this, &pull_msg, &callback]() { + STREAMING_LOG(WARNING) << "Data lost."; + PullResponseMessage msg( + pull_msg->PeerActorId(), pull_msg->ActorId(), pull_msg->QueueId(), + QUEUE_INVALID_SEQ_ID, QUEUE_INVALID_SEQ_ID, + queue::protobuf::StreamingQueueError::DATA_LOST, is_upstream_first_pull_); + std::unique_ptr buffer = msg.ToBytes(); + + callback(std::move(buffer)); + }, + /// target_msg_id found. + [this, &pull_msg, &callback, &service]( + std::list::iterator target_item, uint64_t first_seq_id, + uint64_t last_seq_id) { + is_resending_ = true; + STREAMING_LOG(INFO) << "OnPull return"; + service.post(std::bind(&WriterQueue::ResendItems, this, target_item, + first_seq_id, last_seq_id)); + PullResponseMessage msg( + pull_msg->PeerActorId(), pull_msg->ActorId(), pull_msg->QueueId(), + target_item->SeqId(), pull_msg->MsgId(), + queue::protobuf::StreamingQueueError::OK, is_upstream_first_pull_); + std::unique_ptr buffer = msg.ToBytes(); + is_upstream_first_pull_ = false; + callback(std::move(buffer)); + }); +} + void ReaderQueue::OnConsumed(uint64_t seq_id) { STREAMING_LOG(INFO) << "OnConsumed: " << seq_id; QueueItem item = FrontProcessed(); @@ -195,17 +293,27 @@ void ReaderQueue::Notify(uint64_t seq_id) { void ReaderQueue::CreateNotifyTask(uint64_t seq_id, std::vector &task_args) {} void ReaderQueue::OnData(QueueItem &item) { - if (item.SeqId() != expect_seq_id_) { - STREAMING_LOG(WARNING) << "OnData ignore seq_id: " << item.SeqId() - << " expect_seq_id_: " << expect_seq_id_; - return; - } - last_recv_seq_id_ = item.SeqId(); STREAMING_LOG(DEBUG) << "ReaderQueue::OnData seq_id: " << last_recv_seq_id_; Push(item); - expect_seq_id_++; +} + +void ReaderQueue::OnResendData(std::shared_ptr msg) { + STREAMING_LOG(INFO) << "OnResendData queue_id: " << queue_id_ << " recv seq_id " + << msg->SeqId() << "(" << msg->FirstSeqId() << "/" + << msg->LastSeqId() << ")"; + QueueItem item(msg->SeqId(), msg->Buffer(), 0, msg->MsgIdStart(), msg->MsgIdEnd(), + msg->IsRaw()); + STREAMING_CHECK(msg->Buffer()->Data() != nullptr); + + Push(item); + STREAMING_CHECK(msg->SeqId() >= msg->FirstSeqId() && msg->SeqId() <= msg->LastSeqId()) + << "(" << msg->FirstSeqId() << "/" << msg->SeqId() << "/" << msg->LastSeqId() + << ")"; + if (msg->SeqId() == msg->LastSeqId()) { + STREAMING_LOG(INFO) << "Resend DATA Done"; + } } } // namespace streaming diff --git a/streaming/src/queue/queue.h b/streaming/src/queue/queue.h index f5625d0e2..32c1355b5 100644 --- a/streaming/src/queue/queue.h +++ b/streaming/src/queue/queue.h @@ -28,11 +28,17 @@ enum QueueType { UPSTREAM = 0, DOWNSTREAM }; /// using a watershed iterator to divided. class Queue { public: - /// \param[in] queue_id the unique identification of a pair of queues (upstream and - /// downstream). \param[in] size max size of the queue in bytes. \param[in] transport + /// \param actor_id, the actor id of upstream worker + /// \param peer_actor_id, the actor id of downstream worker + /// \param queue_id the unique identification of a pair of queues (upstream and + /// downstream). + /// \param size max size of the queue in bytes. + /// \param transport /// transport to send items to peer. - Queue(ObjectID queue_id, uint64_t size, std::shared_ptr transport) - : queue_id_(queue_id), max_data_size_(size), data_size_(0), data_size_sent_(0) { + Queue(const ActorID &actor_id, const ActorID &peer_actor_id, ObjectID queue_id, uint64_t size, std::shared_ptr transport) + : actor_id_(actor_id), + peer_actor_id_(peer_actor_id), + queue_id_(queue_id), max_data_size_(size), data_size_(0), data_size_sent_(0) { buffer_queue_.push_back(InvalidQueueItem()); watershed_iter_ = buffer_queue_.begin(); } @@ -61,20 +67,27 @@ class Queue { /// Return the last item in pending state. QueueItem BackPending(); - bool IsPendingEmpty(); - bool IsPendingFull(uint64_t data_size = 0); + inline bool IsPendingEmpty() { + std::unique_lock lock(mutex_); + return std::next(watershed_iter_) == buffer_queue_.end(); + }; + + inline bool IsPendingFull(uint64_t data_size = 0) { + std::unique_lock lock(mutex_); + return max_data_size_ < data_size + data_size_; + } /// Return the size in bytes of all items in queue. - uint64_t QueueSize() { return data_size_; } + inline uint64_t QueueSize() { return data_size_; } /// Return the size in bytes of all items in pending state. - uint64_t PendingDataSize() { return data_size_ - data_size_sent_; } + inline uint64_t PendingDataSize() { return data_size_ - data_size_sent_; } /// Return the size in bytes of all items in processed state. - uint64_t ProcessedDataSize() { return data_size_sent_; } + inline uint64_t ProcessedDataSize() { return data_size_sent_; } /// Return item count of the queue. - size_t Count() { return buffer_queue_.size(); } + inline size_t Count() { return buffer_queue_.size(); } /// Return item count in pending state. size_t PendingCount(); @@ -82,11 +95,16 @@ class Queue { /// Return item count in processed state. size_t ProcessedCount(); + inline ActorID GetActorID() { return actor_id_; } + inline ActorID GetPeerActorID() { return peer_actor_id_; } + inline ObjectID GetQueueID() { return queue_id_; } protected: - ObjectID queue_id_; std::list buffer_queue_; std::list::iterator watershed_iter_; + ActorID actor_id_; + ActorID peer_actor_id_; + ObjectID queue_id_; /// max data size in bytes uint64_t max_data_size_; uint64_t data_size_; @@ -107,26 +125,40 @@ class WriterQueue : public Queue { WriterQueue(const ObjectID &queue_id, const ActorID &actor_id, const ActorID &peer_actor_id, uint64_t size, std::shared_ptr transport) - : Queue(queue_id, size, transport), + : Queue(actor_id, peer_actor_id, queue_id, size, transport), actor_id_(actor_id), peer_actor_id_(peer_actor_id), + seq_id_(QUEUE_INITIAL_SEQ_ID), eviction_limit_(QUEUE_INVALID_SEQ_ID), min_consumed_id_(QUEUE_INVALID_SEQ_ID), peer_last_msg_id_(0), peer_last_seq_id_(QUEUE_INVALID_SEQ_ID), transport_(transport), - is_pulling_(false) {} + is_resending_(false), + is_upstream_first_pull_(true) {} - /// Push a continuous buffer into queue. - /// NOTE: the buffer should be copied. - Status Push(uint64_t seq_id, uint8_t *data, uint32_t data_size, uint64_t timestamp, - bool raw = false); + /// Push a continuous buffer into queue, the buffer consists of some messages packed by DataWriter. + /// \param data, the buffer address + /// \param data_size, buffer size + /// \param timestamp, the timestamp when the buffer pushed in + /// \param msg_id_start, the message id of the first message in the buffer + /// \param msg_id_end, the message id of the last message in the buffer + /// \param raw, whether this buffer is raw data, be True only in test + Status Push(uint64_t seq_id, uint8_t *buffer, uint32_t buffer_size, uint64_t timestamp, + uint64_t msg_id_start, uint64_t msg_id_end, bool raw = false); /// Callback function, will be called when downstream queue notifies /// it has consumed some items. /// NOTE: this callback function is called in queue thread. void OnNotify(std::shared_ptr notify_msg); + /// Callback function, will be called when downstream queue receives + /// resend items form upstream queue. + /// NOTE: this callback function is called in queue thread. + void OnPull(std::shared_ptr pull_msg, + boost::asio::io_service &service, + std::function)> callback); + /// Send items through direct call. void Send(); @@ -151,16 +183,42 @@ class WriterQueue : public Queue { uint64_t GetPeerLastSeqId() { return peer_last_seq_id_; } + private: + /// Resend an item to peer. + /// \param item, the item object reference to ben resend. + /// \param first_seq_id, the seq id of the first item in this resend sequence. + /// \param last_seq_id, the seq id of the last item in this resend sequence. + void ResendItem(QueueItem &item, uint64_t first_seq_id, uint64_t last_seq_id); + /// Resend items to peer from start_iter iterator to watershed_iter_. + /// \param start_iter, the starting list iterator. + /// \param first_seq_id, the seq id of the first item in this resend sequence. + /// \param last_seq_id, the seq id of the last item in this resend sequence. + int ResendItems(std::list::iterator start_iter, uint64_t first_seq_id, + uint64_t last_seq_id); + /// Find the item which the message with `target_msg_id` in. If the `target_msg_id` + /// is larger than the largest message id in the queue, the `greater_callback` callback + /// will be called; If the `target_message_id` is smaller than the smallest message id + /// in the queue, the `less_callback` callback will be called; If the `target_msg_id` is + /// found in the queue, the `found_callback` callback willbe called. + /// \param target_msg_id, the target message id to be found. + void FindItem( + uint64_t target_msg_id, + std::function greater_callback, + std::function less_callback, + std::function::iterator, uint64_t, uint64_t)> equal_callback); + private: ActorID actor_id_; ActorID peer_actor_id_; + uint64_t seq_id_; uint64_t eviction_limit_; uint64_t min_consumed_id_; uint64_t peer_last_msg_id_; uint64_t peer_last_seq_id_; std::shared_ptr transport_; - std::atomic is_pulling_; + std::atomic is_resending_; + bool is_upstream_first_pull_; }; /// Queue in downstream. @@ -173,12 +231,11 @@ class ReaderQueue : public Queue { /// NOTE: we do not restrict queue size of ReaderQueue ReaderQueue(const ObjectID &queue_id, const ActorID &actor_id, const ActorID &peer_actor_id, std::shared_ptr transport) - : Queue(queue_id, std::numeric_limits::max(), transport), + : Queue(actor_id, peer_actor_id, queue_id, std::numeric_limits::max(), transport), actor_id_(actor_id), peer_actor_id_(peer_actor_id), min_consumed_id_(QUEUE_INVALID_SEQ_ID), last_recv_seq_id_(QUEUE_INVALID_SEQ_ID), - expect_seq_id_(1), transport_(transport) {} /// Delete processed items whose seq id <= seq_id, @@ -186,13 +243,15 @@ class ReaderQueue : public Queue { void OnConsumed(uint64_t seq_id); void OnData(QueueItem &item); + /// Callback function, will be called when PullPeer DATA comes. + /// TODO: can be combined with OnData + /// NOTE: this callback function is called in queue thread. + void OnResendData(std::shared_ptr msg); uint64_t GetMinConsumedSeqID() { return min_consumed_id_; } uint64_t GetLastRecvSeqId() { return last_recv_seq_id_; } - void SetExpectSeqId(uint64_t expect) { expect_seq_id_ = expect; } - private: void Notify(uint64_t seq_id); void CreateNotifyTask(uint64_t seq_id, std::vector &task_args); @@ -202,7 +261,6 @@ class ReaderQueue : public Queue { ActorID peer_actor_id_; uint64_t min_consumed_id_; uint64_t last_recv_seq_id_; - uint64_t expect_seq_id_; std::shared_ptr promise_for_pull_; std::shared_ptr transport_; }; diff --git a/streaming/src/queue/queue_handler.cc b/streaming/src/queue/queue_handler.cc index 012442566..52cf48d78 100644 --- a/streaming/src/queue/queue_handler.cc +++ b/streaming/src/queue/queue_handler.cc @@ -39,6 +39,15 @@ std::shared_ptr QueueMessageHandler::ParseMessage( case queue::protobuf::StreamingQueueMessageType::StreamingQueueCheckRspMsgType: message = CheckRspMessage::FromBytes(bytes); break; + case queue::protobuf::StreamingQueuePullRequestMsgType: + message = PullRequestMessage::FromBytes(bytes); + break; + case queue::protobuf::StreamingQueuePullResponseMsgType: + message = PullResponseMessage::FromBytes(bytes); + break; + case queue::protobuf::StreamingQueueResendDataMsgType: + message = ResendDataMessage::FromBytes(bytes); + break; default: STREAMING_CHECK(false) << "nonsupport message type: " << queue::protobuf::StreamingQueueMessageType_Name(*type); @@ -109,6 +118,21 @@ void QueueMessageHandler::Stop() { } } +void UpstreamQueueMessageHandler::Start() { + STREAMING_LOG(INFO) << "UpstreamQueueMessageHandler::Start"; + QueueMessageHandler::Start(); + handle_service_thread_ = std::thread([this] { handler_service_.run(); }); +} + +void UpstreamQueueMessageHandler::Stop() { + STREAMING_LOG(INFO) << "UpstreamQueueMessageHandler::Stop"; + handler_service_.stop(); + if (handle_service_thread_.joinable()) { + handle_service_thread_.join(); + } + QueueMessageHandler::Stop(); +} + std::shared_ptr UpstreamQueueMessageHandler::CreateService( const ActorID &actor_id) { if (nullptr == upstream_handler_) { @@ -203,7 +227,7 @@ void UpstreamQueueMessageHandler::DispatchMessageInternal( std::shared_ptr buffer, std::function)> callback) { std::shared_ptr msg = ParseMessage(buffer); - STREAMING_LOG(DEBUG) << "QueueMessageHandler::DispatchMessageInternal: " + STREAMING_LOG(DEBUG) << "UpstreamQueueMessageHandler::DispatchMessageInternal: " << " qid: " << msg->QueueId() << " actorid " << msg->ActorId() << " peer actorid: " << msg->PeerActorId() << " type: " << queue::protobuf::StreamingQueueMessageType_Name(msg->Type()); @@ -214,6 +238,13 @@ void UpstreamQueueMessageHandler::DispatchMessageInternal( } else if (msg->Type() == queue::protobuf::StreamingQueueMessageType::StreamingQueueCheckRspMsgType) { STREAMING_CHECK(false) << "Should not receive StreamingQueueCheckRspMsg"; + } else if (msg->Type() == + queue::protobuf::StreamingQueueMessageType::StreamingQueuePullRequestMsgType) { + STREAMING_CHECK(callback) << "StreamingQueuePullRequestMsg " + << " qid: " << msg->QueueId() << " actorid " + << msg->ActorId() + << " peer actorid: " << msg->PeerActorId(); + OnPullRequest(std::dynamic_pointer_cast(msg), callback); } else { STREAMING_CHECK(false) << "message type should be added: " << queue::protobuf::StreamingQueueMessageType_Name( @@ -235,6 +266,25 @@ void UpstreamQueueMessageHandler::OnNotify( queue->OnNotify(notify_msg); } +void UpstreamQueueMessageHandler::OnPullRequest( + std::shared_ptr pull_msg, + std::function)> callback) { + STREAMING_LOG(INFO) << "OnPullRequest"; + auto queue = upstream_queues_.find(pull_msg->QueueId()); + if (queue == upstream_queues_.end()) { + STREAMING_LOG(INFO) << "Can not find queue " << pull_msg->QueueId(); + PullResponseMessage msg(pull_msg->PeerActorId(), pull_msg->ActorId(), + pull_msg->QueueId(), QUEUE_INVALID_SEQ_ID, + QUEUE_INVALID_SEQ_ID, + queue::protobuf::StreamingQueueError::QUEUE_NOT_EXIST, false); + std::unique_ptr buffer = msg.ToBytes(); + callback(std::move(buffer)); + return; + } + + queue->second->OnPull(pull_msg, handler_service_, callback); +} + void UpstreamQueueMessageHandler::ReleaseAllUpQueues() { STREAMING_LOG(INFO) << "ReleaseAllUpQueues"; upstream_queues_.clear(); @@ -244,6 +294,8 @@ void UpstreamQueueMessageHandler::ReleaseAllUpQueues() { std::shared_ptr DownstreamQueueMessageHandler::CreateService(const ActorID &actor_id) { if (nullptr == downstream_handler_) { + STREAMING_LOG(INFO) << "DownstreamQueueMessageHandler::CreateService " + << " actorid: " << actor_id; downstream_handler_ = std::make_shared(actor_id); } return downstream_handler_; @@ -275,6 +327,25 @@ std::shared_ptr DownstreamQueueMessageHandler::CreateDownstreamQueu return queue; } +StreamingQueueStatus DownstreamQueueMessageHandler::PullQueue( + const ObjectID &queue_id, uint64_t start_msg_id, bool &is_upstream_first_pull, + uint64_t timeout_ms) { + STREAMING_LOG(INFO) << "PullQueue queue_id: " + << queue_id + << " start_msg_id: " << start_msg_id + << " is_upstream_first_pull: " << is_upstream_first_pull; + uint64_t start_time = current_time_ms(); + uint64_t current_time = start_time; + StreamingQueueStatus st = StreamingQueueStatus::OK; + while (current_time < start_time + timeout_ms && + (st = PullPeerAsync(queue_id, start_msg_id, is_upstream_first_pull, + timeout_ms)) == StreamingQueueStatus::Timeout) { + std::this_thread::sleep_for(std::chrono::milliseconds(200)); + current_time = current_time_ms(); + } + return st; +} + std::shared_ptr DownstreamQueueMessageHandler::GetDownQueue( const ObjectID &queue_id) { auto it = downstream_queues_.find(queue_id); @@ -311,7 +382,7 @@ void DownstreamQueueMessageHandler::DispatchMessageInternal( std::shared_ptr buffer, std::function)> callback) { std::shared_ptr msg = ParseMessage(buffer); - STREAMING_LOG(DEBUG) << "QueueMessageHandler::DispatchMessageInternal: " + STREAMING_LOG(DEBUG) << "DownstreamQueueMessageHandler::DispatchMessageInternal: " << " qid: " << msg->QueueId() << " actorid " << msg->ActorId() << " peer actorid: " << msg->PeerActorId() << " type: " << queue::protobuf::StreamingQueueMessageType_Name(msg->Type()); @@ -326,6 +397,22 @@ void DownstreamQueueMessageHandler::DispatchMessageInternal( if (callback != nullptr) { callback(check_result); } + } else if (msg->Type() == + queue::protobuf::StreamingQueueMessageType::StreamingQueueResendDataMsgType) { + auto queue = downstream_queues_.find(msg->QueueId()); + if (queue == downstream_queues_.end()) { + std::shared_ptr data_msg = + std::dynamic_pointer_cast(msg); + STREAMING_LOG(DEBUG) << "Can not find queue for " + << queue::protobuf::StreamingQueueMessageType_Name(msg->Type()) + << ", maybe queue has been destroyed, ignore it." + << " seq id: " << data_msg->SeqId(); + return; + } + std::shared_ptr resend_data_msg = + std::dynamic_pointer_cast(msg); + + queue->second->OnResendData(resend_data_msg); } else { STREAMING_CHECK(false) << "message type should be added: " << queue::protobuf::StreamingQueueMessageType_Name( @@ -347,5 +434,53 @@ void DownstreamQueueMessageHandler::OnData(std::shared_ptr msg) { queue->OnData(item); } +StreamingQueueStatus DownstreamQueueMessageHandler::PullPeerAsync( + const ObjectID &queue_id, uint64_t start_msg_id, bool &is_upstream_first_pull, + uint64_t timeout_ms) { + STREAMING_LOG(INFO) << "PullPeerAsync queue_id: " << queue_id + << " start_msg_id: " << start_msg_id; + auto queue = GetDownQueue(queue_id); + STREAMING_CHECK(queue != nullptr); + STREAMING_LOG(INFO) << "PullPeerAsync " + << " actorid: " << queue->GetActorID(); + PullRequestMessage msg(queue->GetActorID(), queue->GetPeerActorID(), queue_id, + start_msg_id); + std::unique_ptr buffer = msg.ToBytes(); + + auto transport_it = GetOutTransport(queue_id); + STREAMING_CHECK(transport_it != nullptr); + std::shared_ptr result_buffer = + transport_it->SendForResultWithRetry(std::move(buffer), 1, timeout_ms); + if (result_buffer == nullptr) { + return StreamingQueueStatus::Timeout; + } + + std::shared_ptr result_msg = ParseMessage(result_buffer); + STREAMING_CHECK(result_msg->Type() == + queue::protobuf::StreamingQueueMessageType::StreamingQueuePullResponseMsgType); + std::shared_ptr response_msg = + std::dynamic_pointer_cast(result_msg); + + STREAMING_LOG(INFO) << "PullPeerAsync error: " + << queue::protobuf::StreamingQueueError_Name( + response_msg->Error()) + << " start_msg_id: " << start_msg_id; + + is_upstream_first_pull = response_msg->IsUpstreamFirstPull(); + if (response_msg->Error() == queue::protobuf::StreamingQueueError::OK) { + STREAMING_LOG(INFO) << "Set queue " << queue_id << " expect_seq_id to " + << response_msg->SeqId(); + return StreamingQueueStatus::OK; + } else if (response_msg->Error() == + queue::protobuf::StreamingQueueError::DATA_LOST) { + return StreamingQueueStatus::DataLost; + } else if (response_msg->Error() == + queue::protobuf::StreamingQueueError::NO_VALID_DATA) { + return StreamingQueueStatus::NoValidData; + } else { // QUEUE_NOT_EXIST + return StreamingQueueStatus::Timeout; + } +} + } // namespace streaming } // namespace ray \ No newline at end of file diff --git a/streaming/src/queue/queue_handler.h b/streaming/src/queue/queue_handler.h index f05d71f05..b1d68ba18 100644 --- a/streaming/src/queue/queue_handler.h +++ b/streaming/src/queue/queue_handler.h @@ -11,6 +11,21 @@ namespace ray { namespace streaming { +enum class StreamingQueueStatus : uint32_t { + OK = 0, + Timeout = 1, + DataLost = 2, // The data in upstream has been evicted when downstream try to pull data + // from upstream. + NoValidData = 3, // There is no data written into queue, or start_msg_id is bigger than + // all items in queue now. +}; + +static inline std::ostream &operator<<(std::ostream &os, + const StreamingQueueStatus &status) { + os << static_cast::type>(status); + return os; +} + /// Base class of UpstreamQueueMessageHandler and DownstreamQueueMessageHandler. /// A queue service manages a group of queues, upstream queues or downstream queues of /// the current actor. Each queue service holds a boost.asio io_service, to handle @@ -26,7 +41,6 @@ class QueueMessageHandler { /// \param[in] actor_id actor id of current actor. QueueMessageHandler(const ActorID &actor_id) : actor_id_(actor_id), queue_dummy_work_(queue_service_) { - Start(); } virtual ~QueueMessageHandler() { Stop(); } @@ -71,11 +85,11 @@ class QueueMessageHandler { /// Release all queues in current queue service. void Release(); - private: + protected: /// Start asio service - void Start(); + virtual void Start(); /// Stop asio service - void Stop(); + virtual void Stop(); /// The callback function of internal thread. void QueueThreadCallback() { queue_service_.run(); } @@ -102,7 +116,10 @@ class QueueMessageHandler { class UpstreamQueueMessageHandler : public QueueMessageHandler { public: /// Construct a UpstreamQueueMessageHandler instance. - UpstreamQueueMessageHandler(const ActorID &actor_id) : QueueMessageHandler(actor_id) {} + UpstreamQueueMessageHandler(const ActorID &actor_id) : QueueMessageHandler(actor_id), + handler_service_dummy_worker_(handler_service_) { + Start(); + } /// Create a upstream queue. /// \param[in] queue_id queue id of the queue to be created. /// \param[in] peer_actor_id actor id of peer actor. @@ -120,6 +137,9 @@ class UpstreamQueueMessageHandler : public QueueMessageHandler { std::vector &failed_queues); /// Handle notify message from corresponded downstream queue. void OnNotify(std::shared_ptr notify_msg); + /// Handle pull request message from corresponded downstream queue. + void OnPullRequest(std::shared_ptr pull_msg, + std::function)> callback); /// Obtain upstream queue specified by queue_id. std::shared_ptr GetUpQueue(const ObjectID &queue_id); /// Release all upstream queues @@ -132,41 +152,58 @@ class UpstreamQueueMessageHandler : public QueueMessageHandler { static std::shared_ptr CreateService( const ActorID &actor_id); static std::shared_ptr GetService(); - + virtual void Start() override; private: bool CheckQueueSync(const ObjectID &queue_ids); + virtual void Stop() override; private: std::unordered_map> upstream_queues_; static std::shared_ptr upstream_handler_; + boost::asio::io_service handler_service_; + boost::asio::io_service::work handler_service_dummy_worker_; + std::thread handle_service_thread_; }; -/// UpstreamQueueMessageHandler holds and manages all downstream queues of current actor. +/// DownstreamQueueMessageHandler holds and manages all downstream queues of current actor. class DownstreamQueueMessageHandler : public QueueMessageHandler { public: DownstreamQueueMessageHandler(const ActorID &actor_id) - : QueueMessageHandler(actor_id) {} + : QueueMessageHandler(actor_id) { + Start(); + } + /// Create a downstream queue. + /// \param queue_id, queue id of the queue to be created. + /// \param peer_actor_id, actor id of peer actor. std::shared_ptr CreateDownstreamQueue(const ObjectID &queue_id, const ActorID &peer_actor_id); + /// Request to pull messages from corresponded upstream queue, whose message id + /// is larger than `start_msg_id`. Multiple attempts to pull until timeout. + /// \param queue_id, queue id of the queue to be pulled. + /// \param start_msg_id, the starting message id reqeust by downstream queue. + /// \param is_upstream_first_pull + /// \param timeout_ms, the maxmium timeout. + StreamingQueueStatus PullQueue(const ObjectID &queue_id, uint64_t start_msg_id, + bool &is_upstream_first_pull, + uint64_t timeout_ms = 2000); + /// Check whether the downstream queue specified by queue_id exists or not. bool DownstreamQueueExists(const ObjectID &queue_id); - - void UpdateDownActor(const ObjectID &queue_id, const ActorID &actor_id); - std::shared_ptr OnCheckQueue( std::shared_ptr check_msg); - + /// Obtain downstream queue specified by queue_id. std::shared_ptr GetDownQueue(const ObjectID &queue_id); - + /// Release all downstream queues void ReleaseAllDownQueues(); - + /// The callback function called when downstream queue receives a queue item. void OnData(std::shared_ptr msg); virtual void DispatchMessageInternal( std::shared_ptr buffer, std::function)> callback); - static std::shared_ptr CreateService( const ActorID &actor_id); static std::shared_ptr GetService(); + StreamingQueueStatus PullPeerAsync(const ObjectID &queue_id, uint64_t start_msg_id, + bool &is_upstream_first_pull, uint64_t timeout_ms); private: std::unordered_map> diff --git a/streaming/src/queue/queue_item.h b/streaming/src/queue/queue_item.h index 9265c5ef9..e01928442 100644 --- a/streaming/src/queue/queue_item.h +++ b/streaming/src/queue/queue_item.h @@ -15,6 +15,7 @@ namespace streaming { using ray::ObjectID; const uint64_t QUEUE_INVALID_SEQ_ID = std::numeric_limits::max(); +const uint64_t QUEUE_INITIAL_SEQ_ID = 1; /// QueueItem is the element stored in `Queue`. Actually, when DataWriter pushes a message /// bundle into a queue, the bundle is packed into one QueueItem, so a one-to-one @@ -31,24 +32,28 @@ class QueueItem { /// \param[in] timestamp the time when this QueueItem created. /// \param[in] raw whether the data content is raw bytes, only used in some tests. QueueItem(uint64_t seq_id, uint8_t *data, uint32_t data_size, uint64_t timestamp, - bool raw = false) - : seq_id_(seq_id), + uint64_t msg_id_start, uint64_t msg_id_end, bool raw = false) + : seq_id_(seq_id), msg_id_start_(msg_id_start), msg_id_end_(msg_id_end), timestamp_(timestamp), raw_(raw), /*COPY*/ buffer_(std::make_shared(data, data_size, true)) {} QueueItem(uint64_t seq_id, std::shared_ptr buffer, - uint64_t timestamp, bool raw = false) - : seq_id_(seq_id), timestamp_(timestamp), raw_(raw), buffer_(buffer) {} + uint64_t timestamp, uint64_t msg_id_start, uint64_t msg_id_end, bool raw = false) + : seq_id_(seq_id), msg_id_start_(msg_id_start), msg_id_end_(msg_id_end), timestamp_(timestamp), raw_(raw), buffer_(buffer) {} QueueItem(std::shared_ptr data_msg) : seq_id_(data_msg->SeqId()), + msg_id_start_(data_msg->MsgIdStart()), + msg_id_end_(data_msg->MsgIdEnd()), raw_(data_msg->IsRaw()), buffer_(data_msg->Buffer()) {} QueueItem(const QueueItem &&item) { buffer_ = item.buffer_; seq_id_ = item.seq_id_; + msg_id_start_ = item.msg_id_start_; + msg_id_end_ = item.msg_id_end_; timestamp_ = item.timestamp_; raw_ = item.raw_; } @@ -56,6 +61,8 @@ class QueueItem { QueueItem(const QueueItem &item) { buffer_ = item.buffer_; seq_id_ = item.seq_id_; + msg_id_start_ = item.msg_id_start_; + msg_id_end_ = item.msg_id_end_; timestamp_ = item.timestamp_; raw_ = item.raw_; } @@ -63,6 +70,8 @@ class QueueItem { QueueItem &operator=(const QueueItem &item) { buffer_ = item.buffer_; seq_id_ = item.seq_id_; + msg_id_start_ = item.msg_id_start_; + msg_id_end_ = item.msg_id_end_; timestamp_ = item.timestamp_; raw_ = item.raw_; return *this; @@ -70,11 +79,16 @@ class QueueItem { virtual ~QueueItem() = default; - uint64_t SeqId() { return seq_id_; } - bool IsRaw() { return raw_; } - uint64_t TimeStamp() { return timestamp_; } - size_t DataSize() { return buffer_->Size(); } - std::shared_ptr Buffer() { return buffer_; } + inline uint64_t SeqId() { return seq_id_; } + inline uint64_t MsgIdStart() { return msg_id_start_; } + inline uint64_t MsgIdEnd() { return msg_id_end_; } + inline bool InItem(uint64_t msg_id) { + return msg_id >= msg_id_start_ && msg_id <= msg_id_end_; + } + inline bool IsRaw() { return raw_; } + inline uint64_t TimeStamp() { return timestamp_; } + inline size_t DataSize() { return buffer_->Size(); } + inline std::shared_ptr Buffer() { return buffer_; } /// Get max message id in this item. /// \return max message id. @@ -88,6 +102,8 @@ class QueueItem { protected: uint64_t seq_id_; + uint64_t msg_id_start_; + uint64_t msg_id_end_; uint64_t timestamp_; bool raw_; @@ -96,7 +112,8 @@ class QueueItem { class InvalidQueueItem : public QueueItem { public: - InvalidQueueItem() : QueueItem(QUEUE_INVALID_SEQ_ID, data_, 1, 0) {} + InvalidQueueItem() : QueueItem(QUEUE_INVALID_SEQ_ID, data_, 1, 0, QUEUE_INVALID_SEQ_ID, + QUEUE_INVALID_SEQ_ID) {} private: uint8_t data_[1]; diff --git a/streaming/src/test/mock_actor.cc b/streaming/src/test/mock_actor.cc index 4055bb20f..ce9b1bc86 100644 --- a/streaming/src/test/mock_actor.cc +++ b/streaming/src/test/mock_actor.cc @@ -257,6 +257,165 @@ class StreamingQueueReaderTestSuite : public StreamingQueueTestSuite { } }; +class StreamingQueueUpStreamTestSuite : public StreamingQueueTestSuite { + public: + StreamingQueueUpStreamTestSuite(ActorID &peer_actor_id, std::vector queue_ids, + std::vector rescale_queue_ids) + : StreamingQueueTestSuite(peer_actor_id, queue_ids, rescale_queue_ids) { + test_func_map_ = { + {"pull_peer_async_test", + std::bind(&StreamingQueueUpStreamTestSuite::PullPeerAsyncTest, this)}, + {"get_queue_test", + std::bind(&StreamingQueueUpStreamTestSuite::GetQueueTest, this)}}; + } + + void GetQueueTest() { + // Sleep 2s, queue shoulde not exist when reader pull + std::this_thread::sleep_for(std::chrono::milliseconds(2000)); + auto upstream_handler = ray::streaming::UpstreamQueueMessageHandler::GetService(); + ObjectID &queue_id = queue_ids_[0]; + RayFunction async_call_func{ray::Language::PYTHON, + ray::FunctionDescriptorBuilder::FromVector(ray::Language::PYTHON, {"", "", "reader_async_call_func", ""})}; + RayFunction sync_call_func{ray::Language::PYTHON, + ray::FunctionDescriptorBuilder::FromVector(ray::Language::PYTHON, {"", "", "reader_sync_call_func", ""})}; + upstream_handler->SetPeerActorID(queue_id, peer_actor_id_, async_call_func, sync_call_func); + upstream_handler->CreateUpstreamQueue(queue_id, peer_actor_id_, 10240); + STREAMING_LOG(INFO) << "IsQueueExist: " + << upstream_handler->UpstreamQueueExists(queue_id); + + // Sleep 2s, No valid data when reader pull + std::this_thread::sleep_for(std::chrono::milliseconds(2000)); + + std::this_thread::sleep_for(std::chrono::milliseconds(10 * 1000)); + STREAMING_LOG(INFO) << "StreamingQueueUpStreamTestSuite::GetQueueTest done"; + status_ = true; + } + + void PullPeerAsyncTest() { + // Sleep 2s, queue should not exist when reader pull + std::this_thread::sleep_for(std::chrono::milliseconds(2000)); + auto upstream_handler = ray::streaming::UpstreamQueueMessageHandler::GetService(); + ObjectID &queue_id = queue_ids_[0]; + RayFunction async_call_func{ray::Language::PYTHON, + ray::FunctionDescriptorBuilder::FromVector(ray::Language::PYTHON, {"", "", "reader_async_call_func", ""})}; + RayFunction sync_call_func{ray::Language::PYTHON, + ray::FunctionDescriptorBuilder::FromVector(ray::Language::PYTHON, {"", "", "reader_sync_call_func", ""})}; + upstream_handler->SetPeerActorID(queue_id, peer_actor_id_, async_call_func, sync_call_func); + std::shared_ptr queue = + upstream_handler->CreateUpstreamQueue(queue_id, peer_actor_id_, 10240); + STREAMING_LOG(INFO) << "IsQueueExist: " + << upstream_handler->UpstreamQueueExists(queue_id); + + // Sleep 2s, No valid data when reader pull + std::this_thread::sleep_for(std::chrono::milliseconds(2000)); + // message id starts from 1 + for (int msg_id = 1; msg_id <= 80; msg_id++) { + uint8_t data[100]; + memset(data, msg_id, 100); + STREAMING_LOG(INFO) << "Writer User Push item msg_id: " << msg_id; + ASSERT_TRUE( + queue->Push(msg_id/*seqid*/, data, 100, current_sys_time_ms(), msg_id, msg_id, true).ok()); + queue->Send(); + } + + std::this_thread::sleep_for(std::chrono::milliseconds(5000)); + STREAMING_LOG(INFO) << "StreamingQueueUpStreamTestSuite::PullPeerAsyncTest done"; + status_ = true; + } +}; + +class StreamingQueueDownStreamTestSuite : public StreamingQueueTestSuite { + public: + StreamingQueueDownStreamTestSuite(ActorID peer_actor_id, std::vector queue_ids, + std::vector rescale_queue_ids) + : StreamingQueueTestSuite(peer_actor_id, queue_ids, rescale_queue_ids) { + test_func_map_ = { + {"pull_peer_async_test", + std::bind(&StreamingQueueDownStreamTestSuite::PullPeerAsyncTest, this)}, + {"get_queue_test", + std::bind(&StreamingQueueDownStreamTestSuite::GetQueueTest, this)}}; + }; + + void GetQueueTest() { + auto downstream_handler = + ray::streaming::DownstreamQueueMessageHandler::GetService(); + ObjectID &queue_id = queue_ids_[0]; + RayFunction async_call_func{ray::Language::PYTHON, + ray::FunctionDescriptorBuilder::FromVector(ray::Language::PYTHON, {"", "", "writer_async_call_func", ""})}; + RayFunction sync_call_func{ray::Language::PYTHON, + ray::FunctionDescriptorBuilder::FromVector(ray::Language::PYTHON, {"", "", "writer_sync_call_func", ""})}; + downstream_handler->SetPeerActorID(queue_id, peer_actor_id_, async_call_func, sync_call_func); + downstream_handler->CreateDownstreamQueue(queue_id, peer_actor_id_); + + bool is_upstream_first_pull_ = false; + downstream_handler->PullQueue(queue_id, 1, is_upstream_first_pull_, 10 * 1000); + ASSERT_TRUE(is_upstream_first_pull_); + downstream_handler->PullQueue(queue_id, 1, is_upstream_first_pull_, 10 * 1000); + ASSERT_FALSE(is_upstream_first_pull_); + STREAMING_LOG(INFO) << "StreamingQueueDownStreamTestSuite::GetQueueTest done"; + status_ = true; + } + + void PullPeerAsyncTest() { + auto downstream_handler = + ray::streaming::DownstreamQueueMessageHandler::GetService(); + ObjectID &queue_id = queue_ids_[0]; + RayFunction async_call_func{ray::Language::PYTHON, + ray::FunctionDescriptorBuilder::FromVector(ray::Language::PYTHON, {"", "", "writer_async_call_func", ""})}; + RayFunction sync_call_func{ray::Language::PYTHON, + ray::FunctionDescriptorBuilder::FromVector(ray::Language::PYTHON, {"", "", "writer_sync_call_func", ""})}; + downstream_handler->SetPeerActorID(queue_id, peer_actor_id_, async_call_func, sync_call_func); + std::shared_ptr queue = + downstream_handler->CreateDownstreamQueue(queue_id, peer_actor_id_); + + bool is_first_pull; + downstream_handler->PullQueue(queue_id, 1, is_first_pull, 10 * 1000); + uint64_t count = 0; + uint8_t msg_id = 1; + while (true) { + uint8_t *data = nullptr; + uint32_t data_size = 0; + uint64_t timeout_ms = 1000; + QueueItem item = queue->PopPendingBlockTimeout(timeout_ms * 1000); + if (item.SeqId() == QUEUE_INVALID_SEQ_ID) { + STREAMING_LOG(INFO) << "PopPendingBlockTimeout timeout."; + data = nullptr; + data_size = 0; + } else { + data = item.Buffer()->Data(); + data_size = item.Buffer()->Size(); + } + + STREAMING_LOG(INFO) << "[Reader] count: " << count; + if (data == nullptr) { + STREAMING_LOG(INFO) << "[Reader] data null"; + continue; + } + + for (uint32_t i = 0; i < data_size; i++) { + ASSERT_EQ(data[i], msg_id); + } + + count++; + if (count == 80) { + bool is_upstream_first_pull; + msg_id = 50; + downstream_handler->PullPeerAsync(queue_id, 50, is_upstream_first_pull, 1000); + continue; + } + + msg_id++; + STREAMING_LOG(INFO) << "[Reader] count: " << count; + if (count == 110) { + break; + } + } + + STREAMING_LOG(INFO) << "StreamingQueueDownStreamTestSuite::PullPeerAsyncTest done"; + status_ = true; + } +}; + class TestSuiteFactory { public: static std::shared_ptr CreateTestSuite( @@ -272,6 +431,9 @@ class TestSuiteFactory { if (suite_name == "StreamingWriterTest") { test_suite = std::make_shared( peer_actor_id, queue_ids, rescale_queue_ids); + } else if (suite_name == "StreamingQueueTest") { + test_suite = std::make_shared( + peer_actor_id, queue_ids, rescale_queue_ids); } else { STREAMING_CHECK(false) << "unsurported suite_name: " << suite_name; } @@ -279,6 +441,9 @@ class TestSuiteFactory { if (suite_name == "StreamingWriterTest") { test_suite = std::make_shared( peer_actor_id, queue_ids, rescale_queue_ids); + } else if (suite_name == "StreamingQueueTest") { + test_suite = std::make_shared( + peer_actor_id, queue_ids, rescale_queue_ids); } else { STREAMING_CHECK(false) << "unsupported suite_name: " << suite_name; } @@ -323,9 +488,6 @@ class StreamingWorker { -1, // metrics_agent_port }; CoreWorkerProcess::Initialize(options); - - reader_client_ = std::make_shared(); - writer_client_ = std::make_shared(); STREAMING_LOG(INFO) << "StreamingWorker constructor"; } @@ -348,7 +510,7 @@ class StreamingWorker { RAY_CHECK(function_descriptor->Type() == ray::FunctionDescriptorType::kPythonFunctionDescriptor); auto typed_descriptor = function_descriptor->As(); - STREAMING_LOG(INFO) << "StreamingWorker::ExecuteTask " + STREAMING_LOG(DEBUG) << "StreamingWorker::ExecuteTask " << typed_descriptor->ToString(); std::string func_name = typed_descriptor->FunctionName(); @@ -412,6 +574,8 @@ class StreamingWorker { private: void HandleInitTask(std::shared_ptr buffer) { + reader_client_ = std::make_shared(); + writer_client_ = std::make_shared(); uint8_t *bytes = buffer->Data(); uint8_t *p_cur = bytes; uint32_t *magic_num = (uint32_t *)p_cur; @@ -425,17 +589,12 @@ class StreamingWorker { queue::protobuf::StreamingQueueMessageType::StreamingQueueTestInitMsgType); std::shared_ptr message = TestInitMessage::FromBytes(bytes); - STREAMING_LOG(INFO) << "Init message: " << message->ToString(); std::string actor_handle_serialized = message->ActorHandleSerialized(); CoreWorkerProcess::GetCoreWorker().DeserializeAndRegisterActorHandle( actor_handle_serialized, ObjectID::Nil()); std::shared_ptr actor_handle(new ActorHandle(actor_handle_serialized)); STREAMING_CHECK(actor_handle != nullptr); - STREAMING_LOG(INFO) << " actor id from handle: " << actor_handle->GetActorID(); - - // STREAMING_LOG(INFO) << "actor_handle_serialized: " << actor_handle_serialized; - // peer_actor_handle_ = - // std::make_shared(actor_handle_serialized); + STREAMING_LOG(INFO) << "Actor id from handle: " << actor_handle->GetActorID(); STREAMING_LOG(INFO) << "HandleInitTask queues:"; for (auto qid : message->QueueIds()) { diff --git a/streaming/src/test/queue_protobuf_tests.cc b/streaming/src/test/queue_protobuf_tests.cc new file mode 100644 index 000000000..2b9bcadb0 --- /dev/null +++ b/streaming/src/test/queue_protobuf_tests.cc @@ -0,0 +1,29 @@ +#include +#include +#include "gtest/gtest.h" + +#include "queue/message.h" +using namespace ray; +using namespace ray::streaming; + +TEST(ProtoBufTest, MessageCommonTest) { + JobID job_id = JobID::FromInt(0); + TaskID task_id = TaskID::ForDriverTask(job_id); + ray::ActorID actor_id = ray::ActorID::Of(job_id, task_id, 0); + ray::ActorID peer_actor_id = ray::ActorID::Of(job_id, task_id, 1); + ObjectID queue_id = ray::ObjectID::FromRandom(); + + uint8_t data[128]; + std::shared_ptr buffer = std::make_shared(data, 128, true); + DataMessage msg(actor_id, peer_actor_id, queue_id, 100, 1000, 2000, buffer, true); + std::unique_ptr serilized_buffer = msg.ToBytes(); + std::shared_ptr msg2 = DataMessage::FromBytes(serilized_buffer->Data()); + EXPECT_EQ(msg.ActorId(), msg2->ActorId()); + EXPECT_EQ(msg.PeerActorId(), msg2->PeerActorId()); + EXPECT_EQ(msg.QueueId(), msg2->QueueId()); +} + +int main(int argc, char **argv) { + ::testing::InitGoogleTest(&argc, argv); + return RUN_ALL_TESTS(); +} diff --git a/streaming/src/test/streaming_queue_tests.cc b/streaming/src/test/streaming_queue_tests.cc index 1339d45f5..6b95b1495 100644 --- a/streaming/src/test/streaming_queue_tests.cc +++ b/streaming/src/test/streaming_queue_tests.cc @@ -16,6 +16,12 @@ namespace streaming { static int node_manager_port; +class StreamingQueueTest : public StreamingQueueTestBase { + public: + StreamingQueueTest() + : StreamingQueueTestBase(1, node_manager_port) {} +}; + class StreamingWriterTest : public StreamingQueueTestBase { public: StreamingWriterTest() : StreamingQueueTestBase(1, node_manager_port) {} @@ -26,6 +32,20 @@ class StreamingExactlySameTest : public StreamingQueueTestBase { StreamingExactlySameTest() : StreamingQueueTestBase(1, node_manager_port) {} }; +TEST_P(StreamingQueueTest, PullPeerAsyncTest) { + STREAMING_LOG(INFO) << "StreamingQueueTest.pull_peer_async_test"; + + uint32_t queue_num = 1; + SubmitTest(queue_num, "StreamingQueueTest", "pull_peer_async_test", 60 * 1000); +} + +TEST_P(StreamingQueueTest, GetQueueTest) { + STREAMING_LOG(INFO) << "StreamingQueueTest.get_queue_test"; + + uint32_t queue_num = 1; + SubmitTest(queue_num, "StreamingQueueTest", "get_queue_test", 60 * 1000); +} + TEST_P(StreamingWriterTest, streaming_writer_exactly_once_test) { STREAMING_LOG(INFO) << "StreamingWriterTest.streaming_writer_exactly_once_test"; @@ -36,6 +56,8 @@ TEST_P(StreamingWriterTest, streaming_writer_exactly_once_test) { 60 * 1000); } +INSTANTIATE_TEST_CASE_P(StreamingTest, StreamingQueueTest, testing::Values(0)); + INSTANTIATE_TEST_CASE_P(StreamingTest, StreamingWriterTest, testing::Values(0)); INSTANTIATE_TEST_CASE_P(StreamingTest, StreamingExactlySameTest,