From 2790818c53b9b63d4e9a87b9c8d27ab00cd39def Mon Sep 17 00:00:00 2001 From: fangfengbin <869218239a@zju.edu.cn> Date: Mon, 27 Jul 2020 13:58:39 +0800 Subject: [PATCH] [GCS]GCS client support multi-thread subscribe&resubscribe&unsubscribe (#9718) --- src/ray/core_worker/actor_manager.cc | 7 +-- src/ray/core_worker/actor_manager.h | 5 +- .../gcs/gcs_client/service_based_accessor.cc | 36 ++++++++++++--- .../gcs/gcs_client/service_based_accessor.h | 19 ++++++-- .../test/service_based_gcs_client_test.cc | 46 +++++++++++++++++++ 5 files changed, 94 insertions(+), 19 deletions(-) diff --git a/src/ray/core_worker/actor_manager.cc b/src/ray/core_worker/actor_manager.cc index 12876968d..3ca57895a 100644 --- a/src/ray/core_worker/actor_manager.cc +++ b/src/ray/core_worker/actor_manager.cc @@ -89,11 +89,8 @@ bool ActorManager::AddActorHandle(std::unique_ptr actor_handle, auto actor_notification_callback = std::bind(&ActorManager::HandleActorStateNotification, this, std::placeholders::_1, std::placeholders::_2); - { - absl::MutexLock lock(&gcs_client_mutex_); - RAY_CHECK_OK(gcs_client_->Actors().AsyncSubscribe( - actor_id, actor_notification_callback, nullptr)); - } + RAY_CHECK_OK(gcs_client_->Actors().AsyncSubscribe( + actor_id, actor_notification_callback, nullptr)); if (!RayConfig::instance().gcs_actor_service_enabled()) { RAY_CHECK(reference_counter_->SetDeleteCallback( diff --git a/src/ray/core_worker/actor_manager.h b/src/ray/core_worker/actor_manager.h index 37ef89590..24f66d8f5 100644 --- a/src/ray/core_worker/actor_manager.h +++ b/src/ray/core_worker/actor_manager.h @@ -174,11 +174,8 @@ class ActorManager { void HandleActorStateNotification(const ActorID &actor_id, const gcs::ActorTableData &actor_data); - /// Mutex to protect the gcs_client_ field. - /// NOTE: Now gcs client is not thread safe, so we add lock protection. - mutable absl::Mutex gcs_client_mutex_; /// GCS client. - std::shared_ptr gcs_client_ GUARDED_BY(gcs_client_mutex_); + std::shared_ptr gcs_client_; /// Interface to submit tasks directly to other actors. std::shared_ptr direct_actor_submitter_; diff --git a/src/ray/gcs/gcs_client/service_based_accessor.cc b/src/ray/gcs/gcs_client/service_based_accessor.cc index 289a2dcee..3ad8dc78c 100644 --- a/src/ray/gcs/gcs_client/service_based_accessor.cc +++ b/src/ray/gcs/gcs_client/service_based_accessor.cc @@ -326,8 +326,11 @@ Status ServiceBasedActorInfoAccessor::AsyncSubscribe( on_subscribe, subscribe_done); }; - subscribe_operations_[actor_id] = subscribe_operation; - fetch_data_operations_[actor_id] = fetch_data_operation; + { + absl::MutexLock lock(&mutex_); + subscribe_operations_[actor_id] = subscribe_operation; + fetch_data_operations_[actor_id] = fetch_data_operation; + } return subscribe_operation( [fetch_data_operation, done](const Status &status) { fetch_data_operation(done); }); } @@ -335,6 +338,7 @@ Status ServiceBasedActorInfoAccessor::AsyncSubscribe( Status ServiceBasedActorInfoAccessor::AsyncUnsubscribe(const ActorID &actor_id) { RAY_LOG(DEBUG) << "Cancelling subscription to an actor, actor id = " << actor_id; auto status = client_impl_->GetGcsPubSub().Unsubscribe(ACTOR_CHANNEL, actor_id.Hex()); + absl::MutexLock lock(&mutex_); subscribe_operations_.erase(actor_id); fetch_data_operations_.erase(actor_id); RAY_LOG(DEBUG) << "Finished cancelling subscription to an actor, actor id = " @@ -418,6 +422,7 @@ void ServiceBasedActorInfoAccessor::AsyncResubscribe(bool is_pubsub_server_resta // If only the GCS sever has restarted, we only need to fetch data from the GCS server. // If the pub-sub server has also restarted, we need to resubscribe to the pub-sub // server first, then fetch data from the GCS server. + absl::MutexLock lock(&mutex_); if (is_pubsub_server_restarted) { if (subscribe_all_operation_ != nullptr) { RAY_CHECK_OK(subscribe_all_operation_( @@ -426,7 +431,14 @@ void ServiceBasedActorInfoAccessor::AsyncResubscribe(bool is_pubsub_server_resta for (auto &item : subscribe_operations_) { auto &actor_id = item.first; RAY_CHECK_OK(item.second([this, actor_id](const Status &status) { - fetch_data_operations_[actor_id](nullptr); + absl::MutexLock lock(&mutex_); + auto fetch_data_operation = fetch_data_operations_[actor_id]; + // `fetch_data_operation` is called in the callback function of subscribe. + // Before that, if the user calls `AsyncUnsubscribe` function, the corresponding + // fetch function will be deleted, so we need to check if it's null. + if (fetch_data_operation != nullptr) { + fetch_data_operation(nullptr); + } })); } } else { @@ -1218,8 +1230,11 @@ Status ServiceBasedObjectInfoAccessor::AsyncSubscribeToLocations( on_subscribe, subscribe_done); }; - subscribe_object_operations_[object_id] = subscribe_operation; - fetch_object_data_operations_[object_id] = fetch_data_operation; + { + absl::MutexLock lock(&mutex_); + subscribe_object_operations_[object_id] = subscribe_operation; + fetch_object_data_operations_[object_id] = fetch_data_operation; + } return subscribe_operation( [fetch_data_operation, done](const Status &status) { fetch_data_operation(done); }); } @@ -1229,10 +1244,18 @@ void ServiceBasedObjectInfoAccessor::AsyncResubscribe(bool is_pubsub_server_rest // If only the GCS sever has restarted, we only need to fetch data from the GCS server. // If the pub-sub server has also restarted, we need to resubscribe to the pub-sub // server first, then fetch data from the GCS server. + absl::MutexLock lock(&mutex_); if (is_pubsub_server_restarted) { for (auto &item : subscribe_object_operations_) { RAY_CHECK_OK(item.second([this, item](const Status &status) { - fetch_object_data_operations_[item.first](nullptr); + absl::MutexLock lock(&mutex_); + auto fetch_object_data_operation = fetch_object_data_operations_[item.first]; + // `fetch_object_data_operation` is called in the callback function of subscribe. + // Before that, if the user calls `AsyncUnsubscribeToLocations` function, the + // corresponding fetch function will be deleted, so we need to check if it's null. + if (fetch_object_data_operation != nullptr) { + fetch_object_data_operation(nullptr); + } })); } } else { @@ -1246,6 +1269,7 @@ Status ServiceBasedObjectInfoAccessor::AsyncUnsubscribeToLocations( const ObjectID &object_id) { RAY_LOG(DEBUG) << "Unsubscribing object location, object id = " << object_id; auto status = client_impl_->GetGcsPubSub().Unsubscribe(OBJECT_CHANNEL, object_id.Hex()); + absl::MutexLock lock(&mutex_); subscribe_object_operations_.erase(object_id); fetch_object_data_operations_.erase(object_id); RAY_LOG(DEBUG) << "Finished unsubscribing object location, object id = " << object_id; diff --git a/src/ray/gcs/gcs_client/service_based_accessor.h b/src/ray/gcs/gcs_client/service_based_accessor.h index d820aee88..32568ad3c 100644 --- a/src/ray/gcs/gcs_client/service_based_accessor.h +++ b/src/ray/gcs/gcs_client/service_based_accessor.h @@ -126,11 +126,16 @@ class ServiceBasedActorInfoAccessor : public ActorInfoAccessor { /// server restarts from a failure. FetchDataOperation fetch_all_data_operation_; + // Mutex to protect the subscribe_operations_ field and fetch_data_operations_ field. + absl::Mutex mutex_; + /// Save the subscribe operation of actors. - std::unordered_map subscribe_operations_; + std::unordered_map subscribe_operations_ + GUARDED_BY(mutex_); /// Save the fetch data operation of actors. - std::unordered_map fetch_data_operations_; + std::unordered_map fetch_data_operations_ + GUARDED_BY(mutex_); ServiceBasedGcsClient *client_impl_; @@ -330,13 +335,19 @@ class ServiceBasedObjectInfoAccessor : public ObjectInfoAccessor { void AsyncResubscribe(bool is_pubsub_server_restarted) override; private: + // Mutex to protect the subscribe_object_operations_ field and + // fetch_object_data_operations_ field. + absl::Mutex mutex_; + /// Save the subscribe operations, so we can call them again when PubSub /// server restarts from a failure. - std::unordered_map subscribe_object_operations_; + std::unordered_map subscribe_object_operations_ + GUARDED_BY(mutex_); /// Save the fetch data operation in this function, so we can call it again when GCS /// server restarts from a failure. - std::unordered_map fetch_object_data_operations_; + std::unordered_map fetch_object_data_operations_ + GUARDED_BY(mutex_); ServiceBasedGcsClient *client_impl_; diff --git a/src/ray/gcs/gcs_client/test/service_based_gcs_client_test.cc b/src/ray/gcs/gcs_client/test/service_based_gcs_client_test.cc index d0901b597..d1aa16154 100644 --- a/src/ray/gcs/gcs_client/test/service_based_gcs_client_test.cc +++ b/src/ray/gcs/gcs_client/test/service_based_gcs_client_test.cc @@ -1133,6 +1133,52 @@ TEST_F(ServiceBasedGcsClientTest, TestGcsRedisFailureDetector) { RAY_CHECK(gcs_server_->IsStopped()); } +TEST_F(ServiceBasedGcsClientTest, TestMultiThreadSubAndUnsub) { + auto sub_finished_count = std::make_shared>(0); + int size = 5; + std::vector> threads; + threads.resize(size); + + // The number of times each thread executes subscribe & resubscribe & unsubscribe. + const int sub_and_unsub_loop_count = 20; + + // Multithreading subscribe/resubscribe/unsubscribe actors. + auto job_id = JobID::FromInt(1); + for (int index = 0; index < size; ++index) { + threads[index].reset(new std::thread([this, job_id] { + for (int index = 0; index < sub_and_unsub_loop_count; ++index) { + auto actor_id = ActorID::Of(job_id, RandomTaskId(), 0); + ASSERT_TRUE(SubscribeActor( + actor_id, [](const ActorID &id, const rpc::ActorTableData &result) {})); + gcs_client_->Actors().AsyncResubscribe(false); + UnsubscribeActor(actor_id); + } + })); + } + for (auto &thread : threads) { + thread->join(); + thread.reset(); + } + + // Multithreading subscribe/resubscribe/unsubscribe objects. + for (int index = 0; index < size; ++index) { + threads[index].reset(new std::thread([this] { + for (int index = 0; index < sub_and_unsub_loop_count; ++index) { + auto object_id = ObjectID::FromRandom(); + ASSERT_TRUE(SubscribeToLocations( + object_id, + [](const ObjectID &id, const gcs::ObjectChangeNotification &result) {})); + gcs_client_->Objects().AsyncResubscribe(false); + UnsubscribeToLocations(object_id); + } + })); + } + for (auto &thread : threads) { + thread->join(); + thread.reset(); + } +} + } // namespace ray int main(int argc, char **argv) {