diff --git a/src/ray/gcs/gcs_server/gcs_server.cc b/src/ray/gcs/gcs_server/gcs_server.cc index 57a37afd6..a9e9a144e 100644 --- a/src/ray/gcs/gcs_server/gcs_server.cc +++ b/src/ray/gcs/gcs_server/gcs_server.cc @@ -6,6 +6,7 @@ #include "object_info_handler_impl.h" #include "stats_handler_impl.h" #include "task_info_handler_impl.h" +#include "worker_info_handler_impl.h" namespace ray { namespace gcs { @@ -55,6 +56,11 @@ void GcsServer::Start() { new rpc::ErrorInfoGrpcService(main_service_, *error_info_handler_)); rpc_server_.RegisterService(*error_info_service_); + worker_info_handler_ = InitWorkerInfoHandler(); + worker_info_service_.reset( + new rpc::WorkerInfoGrpcService(main_service_, *worker_info_handler_)); + rpc_server_.RegisterService(*worker_info_service_); + // Run rpc server. rpc_server_.Run(); @@ -116,5 +122,10 @@ std::unique_ptr GcsServer::InitErrorInfoHandler() { new rpc::DefaultErrorInfoHandler(*redis_gcs_client_)); } +std::unique_ptr GcsServer::InitWorkerInfoHandler() { + return std::unique_ptr( + new rpc::DefaultWorkerInfoHandler(*redis_gcs_client_)); +} + } // namespace gcs } // namespace ray diff --git a/src/ray/gcs/gcs_server/gcs_server.h b/src/ray/gcs/gcs_server/gcs_server.h index 1aa1d05a0..3a0887401 100644 --- a/src/ray/gcs/gcs_server/gcs_server.h +++ b/src/ray/gcs/gcs_server/gcs_server.h @@ -65,6 +65,9 @@ class GcsServer { /// The error info handler virtual std::unique_ptr InitErrorInfoHandler(); + /// The worker info handler + virtual std::unique_ptr InitWorkerInfoHandler(); + private: /// Gcs server configuration GcsServerConfig config_; @@ -93,6 +96,9 @@ class GcsServer { /// Error info handler and service std::unique_ptr error_info_handler_; std::unique_ptr error_info_service_; + /// Worker info handler and service + std::unique_ptr worker_info_handler_; + std::unique_ptr worker_info_service_; /// Backend client std::shared_ptr redis_gcs_client_; }; diff --git a/src/ray/gcs/gcs_server/test/gcs_server_rpc_test.cc b/src/ray/gcs/gcs_server/test/gcs_server_rpc_test.cc index 6911f7cd9..0b15c2749 100644 --- a/src/ray/gcs/gcs_server/test/gcs_server_rpc_test.cc +++ b/src/ray/gcs/gcs_server/test/gcs_server_rpc_test.cc @@ -365,6 +365,17 @@ class GcsServerTest : public RedisServiceManagerForTest { return WaitReady(promise.get_future(), timeout_ms_); } + bool ReportWorkerFailure(const rpc::ReportWorkerFailureRequest &request) { + std::promise promise; + client_->ReportWorkerFailure( + request, + [&promise](const Status &status, const rpc::ReportWorkerFailureReply &reply) { + RAY_CHECK_OK(status); + promise.set_value(true); + }); + return WaitReady(promise.get_future(), timeout_ms_); + } + bool WaitReady(const std::future &future, uint64_t timeout_ms) { auto status = future.wait_for(std::chrono::milliseconds(timeout_ms)); return status == std::future_status::ready; @@ -633,6 +644,15 @@ TEST_F(GcsServerTest, TestErrorInfo) { ASSERT_TRUE(ReportJobError(report_error_request)); } +TEST_F(GcsServerTest, TestWorkerInfo) { + rpc::WorkerFailureData worker_failure_data; + worker_failure_data.mutable_worker_address()->set_ip_address("127.0.0.1"); + worker_failure_data.mutable_worker_address()->set_port(5566); + rpc::ReportWorkerFailureRequest report_worker_failure_request; + report_worker_failure_request.mutable_worker_failure()->CopyFrom(worker_failure_data); + ASSERT_TRUE(ReportWorkerFailure(report_worker_failure_request)); +} + } // namespace ray int main(int argc, char **argv) { diff --git a/src/ray/gcs/gcs_server/worker_info_handler_impl.cc b/src/ray/gcs/gcs_server/worker_info_handler_impl.cc new file mode 100644 index 000000000..6dd63d58f --- /dev/null +++ b/src/ray/gcs/gcs_server/worker_info_handler_impl.cc @@ -0,0 +1,30 @@ +#include "worker_info_handler_impl.h" + +namespace ray { +namespace rpc { + +void DefaultWorkerInfoHandler::HandleReportWorkerFailure( + const ReportWorkerFailureRequest &request, ReportWorkerFailureReply *reply, + SendReplyCallback send_reply_callback) { + Address worker_address = request.worker_failure().worker_address(); + RAY_LOG(DEBUG) << "Reporting worker failure, " << worker_address.DebugString(); + auto worker_failure_data = std::make_shared(); + worker_failure_data->CopyFrom(request.worker_failure()); + auto on_done = [worker_address, send_reply_callback](Status status) { + if (!status.ok()) { + RAY_LOG(ERROR) << "Failed to report worker failure, " + << worker_address.DebugString(); + } + send_reply_callback(status, nullptr, nullptr); + }; + + Status status = + gcs_client_.Workers().AsyncReportWorkerFailure(worker_failure_data, on_done); + if (!status.ok()) { + on_done(status); + } + RAY_LOG(DEBUG) << "Finished reporting worker failure, " << worker_address.DebugString(); +} + +} // namespace rpc +} // namespace ray diff --git a/src/ray/gcs/gcs_server/worker_info_handler_impl.h b/src/ray/gcs/gcs_server/worker_info_handler_impl.h new file mode 100644 index 000000000..56ec5a473 --- /dev/null +++ b/src/ray/gcs/gcs_server/worker_info_handler_impl.h @@ -0,0 +1,27 @@ +#ifndef RAY_GCS_WORKER_INFO_HANDLER_IMPL_H +#define RAY_GCS_WORKER_INFO_HANDLER_IMPL_H + +#include "ray/gcs/redis_gcs_client.h" +#include "ray/rpc/gcs_server/gcs_rpc_server.h" + +namespace ray { +namespace rpc { + +/// This implementation class of `WorkerInfoHandler`. +class DefaultWorkerInfoHandler : public rpc::WorkerInfoHandler { + public: + explicit DefaultWorkerInfoHandler(gcs::RedisGcsClient &gcs_client) + : gcs_client_(gcs_client) {} + + void HandleReportWorkerFailure(const ReportWorkerFailureRequest &request, + ReportWorkerFailureReply *reply, + SendReplyCallback send_reply_callback) override; + + private: + gcs::RedisGcsClient &gcs_client_; +}; + +} // namespace rpc +} // namespace ray + +#endif // RAY_GCS_WORKER_INFO_HANDLER_IMPL_H diff --git a/src/ray/protobuf/gcs_service.proto b/src/ray/protobuf/gcs_service.proto index 4adcca742..a64fce67f 100644 --- a/src/ray/protobuf/gcs_service.proto +++ b/src/ray/protobuf/gcs_service.proto @@ -295,3 +295,16 @@ service ErrorInfoGcsService { // Report a job error to GCS Service. rpc ReportJobError(ReportJobErrorRequest) returns (ReportJobErrorReply); } + +message ReportWorkerFailureRequest { + WorkerFailureData worker_failure = 1; +} + +message ReportWorkerFailureReply { +} + +// Service for worker info access. +service WorkerInfoGcsService { + // Report a worker failure to GCS Service. + rpc ReportWorkerFailure(ReportWorkerFailureRequest) returns (ReportWorkerFailureReply); +} diff --git a/src/ray/rpc/gcs_server/gcs_rpc_client.h b/src/ray/rpc/gcs_server/gcs_rpc_client.h index b72a32835..551e868cc 100644 --- a/src/ray/rpc/gcs_server/gcs_rpc_client.h +++ b/src/ray/rpc/gcs_server/gcs_rpc_client.h @@ -30,6 +30,8 @@ class GcsRpcClient { new GrpcClient(address, port, client_call_manager)); error_info_grpc_client_ = std::unique_ptr>( new GrpcClient(address, port, client_call_manager)); + worker_info_grpc_client_ = std::unique_ptr>( + new GrpcClient(address, port, client_call_manager)); }; /// Add job info to gcs server. @@ -119,6 +121,10 @@ class GcsRpcClient { /// Report a job error to GCS Service. VOID_RPC_CLIENT_METHOD(ErrorInfoGcsService, ReportJobError, error_info_grpc_client_, ) + /// Report a worker failure to GCS Service. + VOID_RPC_CLIENT_METHOD(WorkerInfoGcsService, ReportWorkerFailure, + worker_info_grpc_client_, ) + private: /// The gRPC-generated stub. std::unique_ptr> job_info_grpc_client_; @@ -128,6 +134,7 @@ class GcsRpcClient { std::unique_ptr> task_info_grpc_client_; std::unique_ptr> stats_grpc_client_; std::unique_ptr> error_info_grpc_client_; + std::unique_ptr> worker_info_grpc_client_; }; } // namespace rpc diff --git a/src/ray/rpc/gcs_server/gcs_rpc_server.h b/src/ray/rpc/gcs_server/gcs_rpc_server.h index b735c7acd..c92260c0d 100644 --- a/src/ray/rpc/gcs_server/gcs_rpc_server.h +++ b/src/ray/rpc/gcs_server/gcs_rpc_server.h @@ -30,6 +30,9 @@ namespace rpc { #define ERROR_INFO_SERVICE_RPC_HANDLER(HANDLER, CONCURRENCY) \ RPC_SERVICE_HANDLER(ErrorInfoGcsService, HANDLER, CONCURRENCY) +#define WORKER_INFO_SERVICE_RPC_HANDLER(HANDLER, CONCURRENCY) \ + RPC_SERVICE_HANDLER(WorkerInfoGcsService, HANDLER, CONCURRENCY) + class JobInfoGcsServiceHandler { public: virtual ~JobInfoGcsServiceHandler() = default; @@ -374,6 +377,42 @@ class ErrorInfoGrpcService : public GrpcService { ErrorInfoGcsServiceHandler &service_handler_; }; +class WorkerInfoGcsServiceHandler { + public: + virtual ~WorkerInfoGcsServiceHandler() = default; + + virtual void HandleReportWorkerFailure(const ReportWorkerFailureRequest &request, + ReportWorkerFailureReply *reply, + SendReplyCallback send_reply_callback) = 0; +}; + +/// The `GrpcService` for `WorkerInfoGcsService`. +class WorkerInfoGrpcService : public GrpcService { + public: + /// Constructor. + /// + /// \param[in] handler The service handler that actually handle the requests. + explicit WorkerInfoGrpcService(boost::asio::io_service &io_service, + WorkerInfoGcsServiceHandler &handler) + : GrpcService(io_service), service_handler_(handler){}; + + protected: + grpc::Service &GetGrpcService() override { return service_; } + + void InitServerCallFactories( + const std::unique_ptr &cq, + std::vector, int>> + *server_call_factories_and_concurrencies) override { + WORKER_INFO_SERVICE_RPC_HANDLER(ReportWorkerFailure, 1); + } + + private: + /// The grpc async service object. + WorkerInfoGcsService::AsyncService service_; + /// The service handler that actually handle the requests. + WorkerInfoGcsServiceHandler &service_handler_; +}; + using JobInfoHandler = JobInfoGcsServiceHandler; using ActorInfoHandler = ActorInfoGcsServiceHandler; using NodeInfoHandler = NodeInfoGcsServiceHandler; @@ -381,6 +420,7 @@ using ObjectInfoHandler = ObjectInfoGcsServiceHandler; using TaskInfoHandler = TaskInfoGcsServiceHandler; using StatsHandler = StatsGcsServiceHandler; using ErrorInfoHandler = ErrorInfoGcsServiceHandler; +using WorkerInfoHandler = WorkerInfoGcsServiceHandler; } // namespace rpc } // namespace ray