[GCS]GCS adapts to object table pub sub (#8180)

This commit is contained in:
fangfengbin
2020-05-03 21:44:33 +08:00
committed by GitHub
parent 166bb5d690
commit b7bbc3bc83
17 changed files with 144 additions and 61 deletions
+1 -3
View File
@@ -325,10 +325,8 @@ class ObjectInfoAccessor {
/// Cancel subscription to any update of an object's location.
///
/// \param object_id The ID of the object to be unsubscribed to.
/// \param done Callback that will be called when unsubscription is complete.
/// \return Status
virtual Status AsyncUnsubscribeToLocations(const ObjectID &object_id,
const StatusCallback &done) = 0;
virtual Status AsyncUnsubscribeToLocations(const ObjectID &object_id) = 0;
protected:
ObjectInfoAccessor() = default;
@@ -715,9 +715,7 @@ Status ServiceBasedTaskInfoAccessor::AttemptTaskReconstruction(
ServiceBasedObjectInfoAccessor::ServiceBasedObjectInfoAccessor(
ServiceBasedGcsClient *client_impl)
: client_impl_(client_impl),
subscribe_id_(ClientID::FromRandom()),
object_sub_executor_(client_impl->GetRedisGcsClient().object_table()) {}
: client_impl_(client_impl) {}
Status ServiceBasedObjectInfoAccessor::AsyncGetLocations(
const ObjectID &object_id, const MultiItemCallback<rpc::ObjectTableData> &callback) {
@@ -749,7 +747,7 @@ Status ServiceBasedObjectInfoAccessor::AsyncAddLocation(const ObjectID &object_i
request.set_node_id(node_id.Binary());
auto operation = [this, request, object_id, node_id,
callback](SequencerDoneCallback done_callback) {
callback](const SequencerDoneCallback &done_callback) {
client_impl_->GetGcsRpcClient().AddObjectLocation(
request, [object_id, node_id, callback, done_callback](
const Status &status, const rpc::AddObjectLocationReply &reply) {
@@ -776,7 +774,7 @@ Status ServiceBasedObjectInfoAccessor::AsyncRemoveLocation(
request.set_node_id(node_id.Binary());
auto operation = [this, request, object_id, node_id,
callback](SequencerDoneCallback done_callback) {
callback](const SequencerDoneCallback &done_callback) {
client_impl_->GetGcsRpcClient().RemoveObjectLocation(
request, [object_id, node_id, callback, done_callback](
const Status &status, const rpc::RemoveObjectLocationReply &reply) {
@@ -800,16 +798,46 @@ Status ServiceBasedObjectInfoAccessor::AsyncSubscribeToLocations(
RAY_LOG(DEBUG) << "Subscribing object location, object id = " << object_id;
RAY_CHECK(subscribe != nullptr)
<< "Failed to subscribe object location, object id = " << object_id;
auto status =
object_sub_executor_.AsyncSubscribe(subscribe_id_, object_id, subscribe, done);
auto on_subscribe = [object_id, subscribe](const std::string &id,
const std::string &data) {
rpc::ObjectLocationChange object_location_change;
object_location_change.ParseFromString(data);
std::vector<rpc::ObjectTableData> object_data_vector;
object_data_vector.emplace_back(object_location_change.data());
auto change_mode = object_location_change.is_add() ? rpc::GcsChangeMode::APPEND_OR_ADD
: rpc::GcsChangeMode::REMOVE;
gcs::ObjectChangeNotification notification(change_mode, object_data_vector);
subscribe(object_id, notification);
};
auto on_done = [this, object_id, subscribe, done](const Status &status) {
if (status.ok()) {
auto callback = [object_id, subscribe, done](
const Status &status,
const std::vector<rpc::ObjectTableData> &result) {
if (status.ok()) {
gcs::ObjectChangeNotification notification(rpc::GcsChangeMode::APPEND_OR_ADD,
result);
subscribe(object_id, notification);
}
if (done) {
done(status);
}
};
RAY_CHECK_OK(AsyncGetLocations(object_id, callback));
} else if (done) {
done(status);
}
};
auto status = client_impl_->GetGcsPubSub().Subscribe(OBJECT_CHANNEL, object_id.Hex(),
on_subscribe, on_done);
RAY_LOG(DEBUG) << "Finished subscribing object location, object id = " << object_id;
return status;
}
Status ServiceBasedObjectInfoAccessor::AsyncUnsubscribeToLocations(
const ObjectID &object_id, const StatusCallback &done) {
const ObjectID &object_id) {
RAY_LOG(DEBUG) << "Unsubscribing object location, object id = " << object_id;
auto status = object_sub_executor_.AsyncUnsubscribe(subscribe_id_, object_id, done);
auto status = client_impl_->GetGcsPubSub().Unsubscribe(OBJECT_CHANNEL, object_id.Hex());
RAY_LOG(DEBUG) << "Finished unsubscribing object location, object id = " << object_id;
return status;
}
@@ -267,18 +267,11 @@ class ServiceBasedObjectInfoAccessor : public ObjectInfoAccessor {
const SubscribeCallback<ObjectID, ObjectChangeNotification> &subscribe,
const StatusCallback &done) override;
Status AsyncUnsubscribeToLocations(const ObjectID &object_id,
const StatusCallback &done) override;
Status AsyncUnsubscribeToLocations(const ObjectID &object_id) override;
private:
ServiceBasedGcsClient *client_impl_;
ClientID subscribe_id_;
typedef SubscriptionExecutor<ObjectID, ObjectChangeNotification, ObjectTable>
ObjectSubscriptionExecutor;
ObjectSubscriptionExecutor object_sub_executor_;
Sequencer<ObjectID> sequencer_;
};
@@ -373,11 +373,9 @@ class ServiceBasedGcsClientTest : public RedisServiceManagerForTest {
return WaitReady(promise.get_future(), timeout_ms_);
}
bool UnsubscribeToLocations(const ObjectID &object_id) {
void UnsubscribeToLocations(const ObjectID &object_id) {
std::promise<bool> promise;
RAY_CHECK_OK(gcs_client_->Objects().AsyncUnsubscribeToLocations(
object_id, [&promise](Status status) { promise.set_value(status.ok()); }));
return WaitReady(promise.get_future(), timeout_ms_);
RAY_CHECK_OK(gcs_client_->Objects().AsyncUnsubscribeToLocations(object_id));
}
bool AddLocation(const ObjectID &object_id, const ClientID &node_id) {
@@ -770,7 +768,7 @@ TEST_F(ServiceBasedGcsClientTest, TestObjectInfo) {
ASSERT_TRUE(GetLocations(object_id).empty());
// Cancel subscription to any update of an object's location.
ASSERT_TRUE(UnsubscribeToLocations(object_id));
UnsubscribeToLocations(object_id);
// Add location of object to GCS again.
ASSERT_TRUE(AddLocation(object_id, node_id));
+1 -1
View File
@@ -149,7 +149,7 @@ std::unique_ptr<rpc::ActorInfoHandler> GcsServer::InitActorInfoHandler() {
std::unique_ptr<rpc::ObjectInfoHandler> GcsServer::InitObjectInfoHandler() {
return std::unique_ptr<rpc::DefaultObjectInfoHandler>(
new rpc::DefaultObjectInfoHandler(*redis_gcs_client_));
new rpc::DefaultObjectInfoHandler(*redis_gcs_client_, gcs_pub_sub_));
}
void GcsServer::StoreGcsServerAddressInRedis() {
@@ -13,6 +13,7 @@
// limitations under the License.
#include "object_info_handler_impl.h"
#include "ray/gcs/pb_util.h"
#include "ray/util/logging.h"
namespace ray {
@@ -26,11 +27,14 @@ void DefaultObjectInfoHandler::HandleGetObjectLocations(
<< ", object id = " << object_id;
auto on_done = [reply, object_id, send_reply_callback](
Status status, const std::vector<rpc::ObjectTableData> &result) {
const Status &status,
const std::vector<rpc::ObjectTableData> &result) {
if (status.ok()) {
for (const rpc::ObjectTableData &object_table_data : result) {
reply->add_object_table_data_list()->CopyFrom(object_table_data);
}
RAY_LOG(DEBUG) << "Finished getting object locations, job id = "
<< object_id.TaskId().JobId() << ", object id = " << object_id;
} else {
RAY_LOG(ERROR) << "Failed to get object locations: " << status.ToString()
<< ", job id = " << object_id.TaskId().JobId()
@@ -43,9 +47,6 @@ void DefaultObjectInfoHandler::HandleGetObjectLocations(
if (!status.ok()) {
on_done(status, std::vector<rpc::ObjectTableData>());
}
RAY_LOG(DEBUG) << "Finished getting object locations, job id = "
<< object_id.TaskId().JobId() << ", object id = " << object_id;
}
void DefaultObjectInfoHandler::HandleAddObjectLocation(
@@ -56,8 +57,16 @@ void DefaultObjectInfoHandler::HandleAddObjectLocation(
RAY_LOG(DEBUG) << "Adding object location, job id = " << object_id.TaskId().JobId()
<< ", object id = " << object_id << ", node id = " << node_id;
auto on_done = [object_id, node_id, reply, send_reply_callback](Status status) {
if (!status.ok()) {
auto on_done = [this, object_id, node_id, reply,
send_reply_callback](const Status &status) {
if (status.ok()) {
RAY_CHECK_OK(gcs_pub_sub_->Publish(
OBJECT_CHANNEL, object_id.Hex(),
gcs::CreateObjectLocationChange(node_id, true)->SerializeAsString(), nullptr));
RAY_LOG(DEBUG) << "Finished adding object location, job id = "
<< object_id.TaskId().JobId() << ", object id = " << object_id
<< ", node id = " << node_id << ", task id = " << object_id.TaskId();
} else {
RAY_LOG(ERROR) << "Failed to add object location: " << status.ToString()
<< ", job id = " << object_id.TaskId().JobId()
<< ", object id = " << object_id << ", node id = " << node_id;
@@ -69,10 +78,6 @@ void DefaultObjectInfoHandler::HandleAddObjectLocation(
if (!status.ok()) {
on_done(status);
}
RAY_LOG(DEBUG) << "Finished adding object location, job id = "
<< object_id.TaskId().JobId() << ", object id = " << object_id
<< ", node id = " << node_id;
}
void DefaultObjectInfoHandler::HandleRemoveObjectLocation(
@@ -83,8 +88,16 @@ void DefaultObjectInfoHandler::HandleRemoveObjectLocation(
RAY_LOG(DEBUG) << "Removing object location, job id = " << object_id.TaskId().JobId()
<< ", object id = " << object_id << ", node id = " << node_id;
auto on_done = [object_id, node_id, reply, send_reply_callback](Status status) {
if (!status.ok()) {
auto on_done = [this, object_id, node_id, reply,
send_reply_callback](const Status &status) {
if (status.ok()) {
RAY_CHECK_OK(gcs_pub_sub_->Publish(
OBJECT_CHANNEL, object_id.Hex(),
gcs::CreateObjectLocationChange(node_id, false)->SerializeAsString(), nullptr));
RAY_LOG(DEBUG) << "Finished removing object location, job id = "
<< object_id.TaskId().JobId() << ", object id = " << object_id
<< ", node id = " << node_id;
} else {
RAY_LOG(ERROR) << "Failed to remove object location: " << status.ToString()
<< ", job id = " << object_id.TaskId().JobId()
<< ", object id = " << object_id << ", node id = " << node_id;
@@ -96,10 +109,6 @@ void DefaultObjectInfoHandler::HandleRemoveObjectLocation(
if (!status.ok()) {
on_done(status);
}
RAY_LOG(DEBUG) << "Finished removing object location, job id = "
<< object_id.TaskId().JobId() << ", object id = " << object_id
<< ", node id = " << node_id;
}
} // namespace rpc
@@ -15,6 +15,7 @@
#ifndef RAY_GCS_OBJECT_INFO_HANDLER_IMPL_H
#define RAY_GCS_OBJECT_INFO_HANDLER_IMPL_H
#include "ray/gcs/pubsub/gcs_pub_sub.h"
#include "ray/gcs/redis_gcs_client.h"
#include "ray/rpc/gcs_server/gcs_rpc_server.h"
@@ -24,8 +25,9 @@ namespace rpc {
/// This implementation class of `ObjectInfoHandler`.
class DefaultObjectInfoHandler : public rpc::ObjectInfoHandler {
public:
explicit DefaultObjectInfoHandler(gcs::RedisGcsClient &gcs_client)
: gcs_client_(gcs_client) {}
explicit DefaultObjectInfoHandler(gcs::RedisGcsClient &gcs_client,
std::shared_ptr<gcs::GcsPubSub> &gcs_pub_sub)
: gcs_client_(gcs_client), gcs_pub_sub_(gcs_pub_sub) {}
void HandleGetObjectLocations(const GetObjectLocationsRequest &request,
GetObjectLocationsReply *reply,
@@ -41,6 +43,7 @@ class DefaultObjectInfoHandler : public rpc::ObjectInfoHandler {
private:
gcs::RedisGcsClient &gcs_client_;
std::shared_ptr<gcs::GcsPubSub> gcs_pub_sub_;
};
} // namespace rpc
+15
View File
@@ -97,6 +97,21 @@ inline std::shared_ptr<ray::rpc::WorkerFailureData> CreateWorkerFailureData(
return worker_failure_info_ptr;
}
/// Helper function to produce object location change.
///
/// \param node_id The node ID that this object appeared on or was evicted by.
/// \param is_add Whether the object is appeared on the node.
/// \return The object location change created by this method.
inline std::shared_ptr<ray::rpc::ObjectLocationChange> CreateObjectLocationChange(
const ClientID &node_id, bool is_add) {
ray::rpc::ObjectTableData object_table_data;
object_table_data.set_manager(node_id.Binary());
auto object_location_change = std::make_shared<ray::rpc::ObjectLocationChange>();
object_location_change->set_is_add(is_add);
object_location_change->mutable_data()->CopyFrom(object_table_data);
return object_location_change;
}
} // namespace gcs
} // namespace ray
+5 -4
View File
@@ -53,13 +53,17 @@ Status GcsPubSub::SubscribeInternal(const std::string &channel, const Callback &
const StatusCallback &done,
const boost::optional<std::string> &id) {
std::string pattern = GenChannelPattern(channel, id);
auto callback = [this, pattern, subscribe](std::shared_ptr<CallbackReply> reply) {
auto callback = [this, pattern, done, subscribe](std::shared_ptr<CallbackReply> reply) {
if (!reply->IsNil()) {
if (reply->IsUnsubscribeCallback()) {
absl::MutexLock lock(&mutex_);
ray::gcs::RedisCallbackManager::instance().remove(
subscribe_callback_index_[pattern]);
subscribe_callback_index_.erase(pattern);
} else if (reply->IsSubscribeCallback()) {
if (done) {
done(Status::OK());
}
} else {
const auto reply_data = reply->ReadAsPubsubData();
if (!reply_data.empty()) {
@@ -78,9 +82,6 @@ Status GcsPubSub::SubscribeInternal(const std::string &channel, const Callback &
absl::MutexLock lock(&mutex_);
subscribe_callback_index_[pattern] = out_callback_index;
}
if (done) {
done(status);
}
return status;
}
+1
View File
@@ -27,6 +27,7 @@ namespace gcs {
#define JOB_CHANNEL "JOB"
#define WORKER_FAILURE_CHANNEL "WORKER_FAILURE"
#define OBJECT_CHANNEL "OBJECT"
/// \class GcsPubSub
///
@@ -177,6 +177,29 @@ TEST_F(GcsPubSubTest, TestMultithreading) {
}
}
TEST_F(GcsPubSubTest, TestPubSubWithTableData) {
std::string channel("channel");
std::string data("data");
std::vector<std::string> result;
int size = 1000;
for (int index = 0; index < size; ++index) {
ObjectID object_id = ObjectID::FromRandom();
std::promise<bool> promise;
auto done = [&promise](const Status &status) { promise.set_value(status.ok()); };
auto subscribe = [this, channel, &result](const std::string &id,
const std::string &data) {
result.push_back(data);
RAY_CHECK_OK(pub_sub_->Unsubscribe(channel, id));
};
RAY_CHECK_OK((pub_sub_->Subscribe(channel, object_id.Hex(), subscribe, done)));
WaitReady(promise.get_future(), timeout_ms_);
RAY_CHECK_OK((pub_sub_->Publish(channel, object_id.Hex(), data, nullptr)));
}
WaitPendingDone(result, size);
}
} // namespace ray
int main(int argc, char **argv) {
+2 -3
View File
@@ -502,9 +502,8 @@ Status RedisObjectInfoAccessor::AsyncSubscribeToLocations(
return object_sub_executor_.AsyncSubscribe(subscribe_id_, object_id, subscribe, done);
}
Status RedisObjectInfoAccessor::AsyncUnsubscribeToLocations(const ObjectID &object_id,
const StatusCallback &done) {
return object_sub_executor_.AsyncUnsubscribe(subscribe_id_, object_id, done);
Status RedisObjectInfoAccessor::AsyncUnsubscribeToLocations(const ObjectID &object_id) {
return object_sub_executor_.AsyncUnsubscribe(subscribe_id_, object_id, nullptr);
}
RedisNodeInfoAccessor::RedisNodeInfoAccessor(RedisGcsClient *client_impl)
+1 -2
View File
@@ -254,8 +254,7 @@ class RedisObjectInfoAccessor : public ObjectInfoAccessor {
const SubscribeCallback<ObjectID, ObjectChangeNotification> &subscribe,
const StatusCallback &done) override;
Status AsyncUnsubscribeToLocations(const ObjectID &object_id,
const StatusCallback &done) override;
Status AsyncUnsubscribeToLocations(const ObjectID &object_id) override;
private:
RedisGcsClient *client_impl_{nullptr};
+1
View File
@@ -91,6 +91,7 @@ CallbackReply::CallbackReply(redisReply *redis_reply) : reply_type_(redis_reply-
redisReply *message_type = redis_reply->element[0];
if (strcmp(message_type->str, "subscribe") == 0 ||
strcmp(message_type->str, "psubscribe") == 0) {
is_subscribe_callback_ = true;
// If the message is for the initial subscription call, return the empty
// string as a response to signify that subscription was successful.
} else if (strcmp(message_type->str, "punsubscribe") == 0) {
+3
View File
@@ -76,6 +76,8 @@ class CallbackReply {
/// \return size_t The next cursor for scan.
size_t ReadAsScanArray(std::vector<std::string> *array) const;
bool IsSubscribeCallback() const { return is_subscribe_callback_; }
bool IsUnsubscribeCallback() const { return is_unsubscribe_callback_; }
private:
@@ -101,6 +103,7 @@ class CallbackReply {
/// Represent the reply of StringArray or ScanArray.
std::vector<std::string> string_array_reply_;
bool is_subscribe_callback_ = false;
bool is_unsubscribe_callback_ = false;
/// Represent the reply of SCanArray, means the next scan cursor for scan request.
+14 -8
View File
@@ -29,19 +29,22 @@ using ray::rpc::ObjectTableData;
/// Process a notification of the object table entries and store the result in
/// node_ids. This assumes that node_ids already contains the result of the
/// object table entries up to but not including this notification.
void UpdateObjectLocations(bool is_added,
bool UpdateObjectLocations(bool is_added,
const std::vector<ObjectTableData> &location_updates,
std::shared_ptr<gcs::GcsClient> gcs_client,
std::unordered_set<ClientID> *node_ids) {
// location_updates contains the updates of locations of the object.
// with GcsChangeMode, we can determine whether the update mode is
// addition or deletion.
bool isUpdated = false;
for (const auto &object_table_data : location_updates) {
ClientID node_id = ClientID::FromBinary(object_table_data.manager());
if (is_added) {
if (is_added && 0 == node_ids->count(node_id)) {
node_ids->insert(node_id);
} else {
isUpdated = true;
} else if (!is_added && 1 == node_ids->count(node_id)) {
node_ids->erase(node_id);
isUpdated = true;
}
}
// Filter out the removed clients from the object locations.
@@ -52,6 +55,8 @@ void UpdateObjectLocations(bool is_added,
it++;
}
}
return isUpdated;
}
} // namespace
@@ -141,9 +146,11 @@ ray::Status ObjectDirectory::SubscribeObjectLocations(const UniqueID &callback_i
it->second.subscribed = true;
// Update entries for this object.
UpdateObjectLocations(object_notification.IsAdded(),
object_notification.GetData(), gcs_client_,
&it->second.current_object_locations);
if (!UpdateObjectLocations(object_notification.IsAdded(),
object_notification.GetData(), gcs_client_,
&it->second.current_object_locations)) {
return;
}
// Copy the callbacks so that the callbacks can unsubscribe without interrupting
// looping over the callbacks.
auto callbacks = it->second.callbacks;
@@ -186,8 +193,7 @@ ray::Status ObjectDirectory::UnsubscribeObjectLocations(const UniqueID &callback
}
entry->second.callbacks.erase(callback_id);
if (entry->second.callbacks.empty()) {
status =
gcs_client_->Objects().AsyncUnsubscribeToLocations(object_id, /*done*/ nullptr);
status = gcs_client_->Objects().AsyncUnsubscribeToLocations(object_id);
listeners_.erase(entry);
}
return status;
+6
View File
@@ -311,6 +311,12 @@ message ObjectTableDataList {
repeated ObjectTableData items = 1;
}
// A notification message about one object's locations being changed.
message ObjectLocationChange {
bool is_add = 1;
ObjectTableData data = 2;
}
message PubSubMessage {
bytes id = 1;
bytes data = 2;