diff --git a/src/ray/gcs/accessor.h b/src/ray/gcs/accessor.h index 7c730035d..bfe1bf655 100644 --- a/src/ray/gcs/accessor.h +++ b/src/ray/gcs/accessor.h @@ -650,6 +650,12 @@ class WorkerInfoAccessor { const std::unordered_map &worker_info, const StatusCallback &callback) = 0; + /// Reestablish subscription. + /// This should be called when GCS server restarts from a failure. + /// + /// \return Status + virtual Status AsyncReSubscribe() = 0; + protected: WorkerInfoAccessor() = default; }; diff --git a/src/ray/gcs/gcs_client/service_based_accessor.cc b/src/ray/gcs/gcs_client/service_based_accessor.cc index f00b10a9e..c547fefe7 100644 --- a/src/ray/gcs/gcs_client/service_based_accessor.cc +++ b/src/ray/gcs/gcs_client/service_based_accessor.cc @@ -1212,15 +1212,26 @@ Status ServiceBasedWorkerInfoAccessor::AsyncSubscribeToWorkerFailures( const StatusCallback &done) { RAY_LOG(DEBUG) << "Subscribing worker failures."; RAY_CHECK(subscribe != nullptr); - auto on_subscribe = [subscribe](const std::string &id, const std::string &data) { - rpc::WorkerFailureData worker_failure_data; - worker_failure_data.ParseFromString(data); - subscribe(WorkerID::FromBinary(id), worker_failure_data); + subscribe_operation_ = [this, subscribe](const StatusCallback &done) { + auto on_subscribe = [subscribe](const std::string &id, const std::string &data) { + rpc::WorkerFailureData worker_failure_data; + worker_failure_data.ParseFromString(data); + subscribe(WorkerID::FromBinary(id), worker_failure_data); + }; + auto status = client_impl_->GetGcsPubSub().SubscribeAll(WORKER_FAILURE_CHANNEL, + on_subscribe, done); + RAY_LOG(DEBUG) << "Finished subscribing worker failures."; + return status; }; - auto status = client_impl_->GetGcsPubSub().SubscribeAll(WORKER_FAILURE_CHANNEL, - on_subscribe, done); - RAY_LOG(DEBUG) << "Finished subscribing worker failures."; - return status; + return subscribe_operation_(done); +} + +Status ServiceBasedWorkerInfoAccessor::AsyncReSubscribe() { + RAY_LOG(INFO) << "Reestablishing subscription for worker failures."; + if (subscribe_operation_ != nullptr) { + return subscribe_operation_(nullptr); + } + return Status::OK(); } Status ServiceBasedWorkerInfoAccessor::AsyncReportWorkerFailure( diff --git a/src/ray/gcs/gcs_client/service_based_accessor.h b/src/ray/gcs/gcs_client/service_based_accessor.h index 0115b17c5..2c1077f23 100644 --- a/src/ray/gcs/gcs_client/service_based_accessor.h +++ b/src/ray/gcs/gcs_client/service_based_accessor.h @@ -357,7 +357,13 @@ class ServiceBasedWorkerInfoAccessor : public WorkerInfoAccessor { const std::unordered_map &worker_info, const StatusCallback &callback) override; + Status AsyncReSubscribe() override; + private: + /// Save the subscribe operation in this function, so we can call it again when GCS + /// restarts from a failure. + SubscribeOperation subscribe_operation_; + ServiceBasedGcsClient *client_impl_; }; diff --git a/src/ray/gcs/gcs_client/service_based_gcs_client.cc b/src/ray/gcs/gcs_client/service_based_gcs_client.cc index 7b12db7d3..005b8c376 100644 --- a/src/ray/gcs/gcs_client/service_based_gcs_client.cc +++ b/src/ray/gcs/gcs_client/service_based_gcs_client.cc @@ -52,6 +52,7 @@ Status ServiceBasedGcsClient::Connect(boost::asio::io_service &io_service) { RAY_CHECK_OK(actor_accessor_->AsyncReSubscribe()); RAY_CHECK_OK(node_accessor_->AsyncReSubscribe()); RAY_CHECK_OK(task_accessor_->AsyncReSubscribe()); + RAY_CHECK_OK(worker_accessor_->AsyncReSubscribe()); }; // Connect to gcs service. diff --git a/src/ray/gcs/gcs_client/test/service_based_gcs_client_test.cc b/src/ray/gcs/gcs_client/test/service_based_gcs_client_test.cc index ff0bad23b..7f0a3d212 100644 --- a/src/ray/gcs/gcs_client/test/service_based_gcs_client_test.cc +++ b/src/ray/gcs/gcs_client/test/service_based_gcs_client_test.cc @@ -961,6 +961,24 @@ TEST_F(ServiceBasedGcsClientTest, TestTaskTableReSubscribe) { WaitPendingDone(task_count, 1); } +TEST_F(ServiceBasedGcsClientTest, TestWorkerTableReSubscribe) { + // Subscribe to all unexpected failure of workers from GCS. + std::atomic worker_failure_count(0); + auto on_subscribe = [&worker_failure_count](const WorkerID &worker_id, + const rpc::WorkerFailureData &result) { + ++worker_failure_count; + }; + ASSERT_TRUE(SubscribeToWorkerFailures(on_subscribe)); + + // Restart GCS + RestartGcsServer(); + + // Report a worker failure to GCS and check if resubscribe works. + auto worker_failure_data = Mocker::GenWorkerFailureData(); + ASSERT_TRUE(ReportWorkerFailure(worker_failure_data)); + WaitPendingDone(worker_failure_count, 1); +} + TEST_F(ServiceBasedGcsClientTest, TestGcsRedisFailureDetector) { // Stop redis. TestSetupUtil::ShutDownRedisServers(); diff --git a/src/ray/gcs/redis_accessor.h b/src/ray/gcs/redis_accessor.h index 6a1bbc4be..42ea0f5f7 100644 --- a/src/ray/gcs/redis_accessor.h +++ b/src/ray/gcs/redis_accessor.h @@ -452,6 +452,8 @@ class RedisWorkerInfoAccessor : public WorkerInfoAccessor { const std::unordered_map &worker_info, const StatusCallback &callback) override; + Status AsyncReSubscribe() override { return Status::NotImplemented(""); } + private: RedisGcsClient *client_impl_{nullptr};