diff --git a/python/ray/gcs_utils.py b/python/ray/gcs_utils.py index 43051dfcc..2429b9c39 100644 --- a/python/ray/gcs_utils.py +++ b/python/ray/gcs_utils.py @@ -17,6 +17,7 @@ from ray.core.generated.gcs_pb2 import ( ResourceTableData, ObjectLocationInfo, PubSubMessage, + WorkerTableData, ) __all__ = [ @@ -39,6 +40,7 @@ __all__ = [ "construct_error_message", "ObjectLocationInfo", "PubSubMessage", + "WorkerTableData", ] FUNCTION_PREFIX = "RemoteFunction:" @@ -69,6 +71,9 @@ TablePrefix_PROFILE_string = "PROFILE" TablePrefix_JOB_string = "JOB" TablePrefix_ACTOR_string = "ACTOR" +WORKER = 0 +DRIVER = 1 + def construct_error_message(job_id, error_type, message, timestamp): """Construct a serialized ErrorTableData object. diff --git a/python/ray/includes/global_state_accessor.pxd b/python/ray/includes/global_state_accessor.pxd index 5a3760459..4652aaedb 100644 --- a/python/ray/includes/global_state_accessor.pxd +++ b/python/ray/includes/global_state_accessor.pxd @@ -6,6 +6,7 @@ from ray.includes.unique_ids cimport ( CActorID, CClientID, CObjectID, + CWorkerID, ) cdef extern from "ray/gcs/gcs_client/global_state_accessor.h" nogil: @@ -23,3 +24,6 @@ cdef extern from "ray/gcs/gcs_client/global_state_accessor.h" nogil: c_vector[c_string] GetAllActorInfo() unique_ptr[c_string] GetActorInfo(const CActorID &actor_id) c_string GetNodeResourceInfo(const CClientID &node_id) + unique_ptr[c_string] GetWorkerInfo(const CWorkerID &worker_id) + c_vector[c_string] GetAllWorkerInfo() + c_bool AddWorkerInfo(const c_string &serialized_string) diff --git a/python/ray/includes/global_state_accessor.pxi b/python/ray/includes/global_state_accessor.pxi index 4a5f3fb5f..78d1f47ef 100644 --- a/python/ray/includes/global_state_accessor.pxi +++ b/python/ray/includes/global_state_accessor.pxi @@ -2,6 +2,7 @@ from ray.includes.unique_ids cimport ( CActorID, CClientID, CObjectID, + CWorkerID, ) from ray.includes.global_state_accessor cimport ( @@ -57,3 +58,15 @@ cdef class GlobalStateAccessor: def get_node_resource_info(self, node_id): return self.inner.get().GetNodeResourceInfo(CClientID.FromBinary(node_id.binary())) + + def get_worker_table(self): + return self.inner.get().GetAllWorkerInfo() + + def get_worker_info(self, worker_id): + worker_info = self.inner.get().GetWorkerInfo(CWorkerID.FromBinary(worker_id.binary())) + if worker_info: + return c_string(worker_info.get().data(), worker_info.get().size()) + return None + + def add_worker_info(self, serialized_string): + return self.inner.get().AddWorkerInfo(serialized_string) diff --git a/python/ray/state.py b/python/ray/state.py index a095edb47..649c668b5 100644 --- a/python/ray/state.py +++ b/python/ray/state.py @@ -602,26 +602,52 @@ class GlobalState: """Get a dictionary mapping worker ID to worker information.""" self._check_connected() - worker_keys = self.redis_client.keys("Worker*") + # Get all data in worker table + worker_table = self.global_state_accessor.get_worker_table() workers_data = {} + for i in range(len(worker_table)): + worker_table_data = gcs_utils.WorkerTableData.FromString( + worker_table[i]) + if worker_table_data.is_alive and \ + worker_table_data.worker_type == gcs_utils.WORKER: + worker_id = binary_to_hex( + worker_table_data.worker_address.worker_id) + worker_info = worker_table_data.worker_info - for worker_key in worker_keys: - worker_info = self.redis_client.hgetall(worker_key) - worker_id = binary_to_hex(worker_key[len("Workers:"):]) - - workers_data[worker_id] = { - "node_ip_address": decode(worker_info[b"node_ip_address"]), - "plasma_store_socket": decode( - worker_info[b"plasma_store_socket"]) - } - if b"stderr_file" in worker_info: - workers_data[worker_id]["stderr_file"] = decode( - worker_info[b"stderr_file"]) - if b"stdout_file" in worker_info: - workers_data[worker_id]["stdout_file"] = decode( - worker_info[b"stdout_file"]) + workers_data[worker_id] = { + "node_ip_address": decode(worker_info[b"node_ip_address"]), + "plasma_store_socket": decode( + worker_info[b"plasma_store_socket"]) + } + if b"stderr_file" in worker_info: + workers_data[worker_id]["stderr_file"] = decode( + worker_info[b"stderr_file"]) + if b"stdout_file" in worker_info: + workers_data[worker_id]["stdout_file"] = decode( + worker_info[b"stdout_file"]) return workers_data + def add_worker(self, worker_id, worker_type, worker_info): + """ Add a worker to the cluster. + + Args: + worker_id: ID of this worker. Type is bytes. + worker_type: Type of this worker. Value is ray.gcs_utils.DRIVER or + ray.gcs_utils.WORKER. + worker_info: Info of this worker. Type is dict{str: str}. + + Returns: + Is operation success + """ + worker_data = ray.gcs_utils.WorkerTableData() + worker_data.is_alive = True + worker_data.worker_address.worker_id = worker_id + worker_data.worker_type = worker_type + for k, v in worker_info.items(): + worker_data.worker_info[k] = bytes(v, encoding="utf-8") + return self.global_state_accessor.add_worker_info( + worker_data.SerializeToString()) + def _job_length(self): event_log_sets = self.redis_client.keys("event_log*") overall_smallest = sys.maxsize diff --git a/python/ray/worker.py b/python/ray/worker.py index 6d3b195dc..03db98fa5 100644 --- a/python/ray/worker.py +++ b/python/ray/worker.py @@ -872,15 +872,14 @@ normal_excepthook = sys.excepthook def custom_excepthook(type, value, tb): - # If this is a driver, push the exception to redis. + # If this is a driver, push the exception to GCS worker table. if global_worker.mode == SCRIPT_MODE: error_message = "".join(traceback.format_tb(tb)) - try: - global_worker.redis_client.hmset( - b"Drivers:" + global_worker.worker_id, - {"exception": error_message}) - except (ConnectionRefusedError, redis.exceptions.ConnectionError): - logger.warning("Could not push exception to redis.") + worker_id = global_worker.worker_id + worker_type = ray.gcs_utils.DRIVER + worker_info = {"exception": error_message} + + ray.state.state.add_worker(worker_id, worker_type, worker_info) # Call the normal excepthook. normal_excepthook(type, value, tb) diff --git a/src/ray/common/scheduling/cluster_resource_scheduler.cc b/src/ray/common/scheduling/cluster_resource_scheduler.cc index 241bbe838..b97a09081 100644 --- a/src/ray/common/scheduling/cluster_resource_scheduler.cc +++ b/src/ray/common/scheduling/cluster_resource_scheduler.cc @@ -485,7 +485,7 @@ int64_t ClusterResourceScheduler::IsSchedulable(const TaskRequest &task_req, } } - // No check custom resources. + // Now check custom resources. for (const auto task_req_custom_resource : task_req.custom_resources) { auto it = resources.custom_resources.find(task_req_custom_resource.id); diff --git a/src/ray/common/scheduling/cluster_resource_scheduler.h b/src/ray/common/scheduling/cluster_resource_scheduler.h index 00a5e90c6..a3128374c 100644 --- a/src/ray/common/scheduling/cluster_resource_scheduler.h +++ b/src/ray/common/scheduling/cluster_resource_scheduler.h @@ -138,7 +138,7 @@ class TaskRequest { public: /// List of predefined resources required by the task. std::vector predefined_resources; - /// List of custom resources required by the tasl. + /// List of custom resources required by the task. std::vector custom_resources; /// List of placement hints. A placement hint is a node on which /// we desire to run this task. This is a soft constraint in that diff --git a/src/ray/core_worker/core_worker.cc b/src/ray/core_worker/core_worker.cc index 8b6a00ceb..11a7a42e9 100644 --- a/src/ray/core_worker/core_worker.cc +++ b/src/ray/core_worker/core_worker.cc @@ -646,8 +646,12 @@ void CoreWorker::RegisterToGcs() { worker_info.emplace("stderr_file", options_.stderr_file); } - RAY_CHECK_OK(gcs_client_->Workers().AsyncRegisterWorker(options_.worker_type, worker_id, - worker_info, nullptr)); + auto worker_data = std::make_shared(); + worker_data->mutable_worker_address()->set_worker_id(worker_id.Binary()); + worker_data->set_worker_type(options_.worker_type); + worker_data->mutable_worker_info()->insert(worker_info.begin(), worker_info.end()); + + RAY_CHECK_OK(gcs_client_->Workers().AsyncAdd(worker_data, nullptr)); } void CoreWorker::CheckForRayletFailure(const boost::system::error_code &error) { if (error == boost::asio::error::operation_aborted) { diff --git a/src/ray/gcs/accessor.h b/src/ray/gcs/accessor.h index df3bb201c..9f30606bd 100644 --- a/src/ray/gcs/accessor.h +++ b/src/ray/gcs/accessor.h @@ -647,7 +647,7 @@ class WorkerInfoAccessor { /// \param done Callback that will be called when subscription is complete. /// \return Status virtual Status AsyncSubscribeToWorkerFailures( - const SubscribeCallback &subscribe, + const SubscribeCallback &subscribe, const StatusCallback &done) = 0; /// Report a worker failure to GCS asynchronously. @@ -656,19 +656,31 @@ class WorkerInfoAccessor { /// \param callback Callback that will be called when report is complate. /// \param Status virtual Status AsyncReportWorkerFailure( - const std::shared_ptr &data_ptr, + const std::shared_ptr &data_ptr, const StatusCallback &callback) = 0; - /// Register a worker to GCS asynchronously. + /// Get worker specification from GCS asynchronously. /// - /// \param worker_type The type of the worker. - /// \param worker_id The ID of the worker. - /// \param worker_info The information of the worker. - /// \return Status. - virtual Status AsyncRegisterWorker( - rpc::WorkerType worker_type, const WorkerID &worker_id, - const std::unordered_map &worker_info, - const StatusCallback &callback) = 0; + /// \param worker_id The ID of worker to look up in the GCS. + /// \param callback Callback that will be called after lookup finishes. + /// \return Status + virtual Status AsyncGet(const WorkerID &worker_id, + const OptionalItemCallback &callback) = 0; + + /// Get all worker info from GCS asynchronously. + /// + /// \param callback Callback that will be called after lookup finished. + /// \return Status + virtual Status AsyncGetAll(const MultiItemCallback &callback) = 0; + + /// Add worker information to GCS asynchronously. + /// + /// \param data_ptr The worker that will be add to GCS. + /// \param callback Callback that will be called after worker information has been added + /// to GCS. + /// \return Status + virtual Status AsyncAdd(const std::shared_ptr &data_ptr, + const StatusCallback &callback) = 0; /// Reestablish subscription. /// This should be called when GCS server restarts from a failure. diff --git a/src/ray/gcs/gcs_client/global_state_accessor.cc b/src/ray/gcs/gcs_client/global_state_accessor.cc index d9b3e3da1..f95d6febb 100644 --- a/src/ray/gcs/gcs_client/global_state_accessor.cc +++ b/src/ray/gcs/gcs_client/global_state_accessor.cc @@ -178,5 +178,38 @@ std::unique_ptr GlobalStateAccessor::GetActorCheckpointId( return actor_checkpoint_id_data; } +std::unique_ptr GlobalStateAccessor::GetWorkerInfo( + const WorkerID &worker_id) { + std::unique_ptr worker_table_data; + std::promise promise; + RAY_CHECK_OK(gcs_client_->Workers().AsyncGet( + worker_id, TransformForOptionalItemCallback(worker_table_data, + promise))); + promise.get_future().get(); + return worker_table_data; +} + +std::vector GlobalStateAccessor::GetAllWorkerInfo() { + std::vector worker_table_data; + std::promise promise; + RAY_CHECK_OK(gcs_client_->Workers().AsyncGetAll( + TransformForMultiItemCallback(worker_table_data, promise))); + promise.get_future().get(); + return worker_table_data; +} + +bool GlobalStateAccessor::AddWorkerInfo(const std::string &serialized_string) { + auto data_ptr = std::make_shared(); + data_ptr->ParseFromString(serialized_string); + std::promise promise; + RAY_CHECK_OK( + gcs_client_->Workers().AsyncAdd(data_ptr, [&promise](const Status &status) { + RAY_CHECK_OK(status); + promise.set_value(true); + })); + promise.get_future().get(); + return true; +} + } // namespace gcs } // namespace ray diff --git a/src/ray/gcs/gcs_client/global_state_accessor.h b/src/ray/gcs/gcs_client/global_state_accessor.h index eb63e3c56..977a87537 100644 --- a/src/ray/gcs/gcs_client/global_state_accessor.h +++ b/src/ray/gcs/gcs_client/global_state_accessor.h @@ -109,6 +109,28 @@ class GlobalStateAccessor { /// deserialized with protobuf function. std::unique_ptr GetActorCheckpointId(const ActorID &actor_id); + /// Get information of a worker from GCS Service. + /// + /// \param worker_id The ID of worker to look up in the GCS Service. + /// \return Worker info. To support multi-language, we serialize each WorkerTableData + /// and return the serialized string. Where used, it needs to be deserialized with + /// protobuf function. + std::unique_ptr GetWorkerInfo(const WorkerID &worker_id); + + /// Get information of all workers from GCS Service. + /// + /// \return All worker info. To support multi-language, we serialize each + /// WorkerTableData and return the serialized string. Where used, it needs to be + /// deserialized with protobuf function. + std::vector GetAllWorkerInfo(); + + /// Add information of a worker to GCS Service. + /// + /// \param serialized_string The serialized data of worker to be added in the GCS + /// Service, use string is convenient for python to use. + /// \return Is operation success. + bool AddWorkerInfo(const std::string &serialized_string); + private: /// MultiItem transformation helper in template style. /// diff --git a/src/ray/gcs/gcs_client/service_based_accessor.cc b/src/ray/gcs/gcs_client/service_based_accessor.cc index 5997f42c7..d39963986 100644 --- a/src/ray/gcs/gcs_client/service_based_accessor.cc +++ b/src/ray/gcs/gcs_client/service_based_accessor.cc @@ -1252,17 +1252,16 @@ ServiceBasedWorkerInfoAccessor::ServiceBasedWorkerInfoAccessor( : client_impl_(client_impl) {} Status ServiceBasedWorkerInfoAccessor::AsyncSubscribeToWorkerFailures( - const SubscribeCallback &subscribe, + const SubscribeCallback &subscribe, const StatusCallback &done) { RAY_CHECK(subscribe != nullptr); subscribe_operation_ = [this, subscribe](const StatusCallback &done) { auto on_subscribe = [subscribe](const std::string &id, const std::string &data) { - rpc::WorkerFailureData worker_failure_data; + rpc::WorkerTableData worker_failure_data; worker_failure_data.ParseFromString(data); subscribe(WorkerID::FromBinary(id), worker_failure_data); }; - return client_impl_->GetGcsPubSub().SubscribeAll(WORKER_FAILURE_CHANNEL, on_subscribe, - done); + return client_impl_->GetGcsPubSub().SubscribeAll(WORKER_CHANNEL, on_subscribe, done); }; return subscribe_operation_(done); } @@ -1276,7 +1275,7 @@ void ServiceBasedWorkerInfoAccessor::AsyncResubscribe(bool is_pubsub_server_rest } Status ServiceBasedWorkerInfoAccessor::AsyncReportWorkerFailure( - const std::shared_ptr &data_ptr, + const std::shared_ptr &data_ptr, const StatusCallback &callback) { rpc::Address worker_address = data_ptr->worker_address(); RAY_LOG(DEBUG) << "Reporting worker failure, " << worker_address.DebugString(); @@ -1294,22 +1293,48 @@ Status ServiceBasedWorkerInfoAccessor::AsyncReportWorkerFailure( return Status::OK(); } -Status ServiceBasedWorkerInfoAccessor::AsyncRegisterWorker( - rpc::WorkerType worker_type, const WorkerID &worker_id, - const std::unordered_map &worker_info, - const StatusCallback &callback) { - RAY_LOG(DEBUG) << "Registering the worker. worker id = " << worker_id; - rpc::RegisterWorkerRequest request; - request.set_worker_type(worker_type); +Status ServiceBasedWorkerInfoAccessor::AsyncGet( + const WorkerID &worker_id, + const OptionalItemCallback &callback) { + RAY_LOG(DEBUG) << "Getting worker info, worker id = " << worker_id; + rpc::GetWorkerInfoRequest request; request.set_worker_id(worker_id.Binary()); - request.mutable_worker_info()->insert(worker_info.begin(), worker_info.end()); - client_impl_->GetGcsRpcClient().RegisterWorker( + client_impl_->GetGcsRpcClient().GetWorkerInfo( request, - [worker_id, callback](const Status &status, const rpc::RegisterWorkerReply &reply) { + [worker_id, callback](const Status &status, const rpc::GetWorkerInfoReply &reply) { + if (reply.has_worker_table_data()) { + callback(status, reply.worker_table_data()); + } else { + callback(status, boost::none); + } + RAY_LOG(DEBUG) << "Finished getting worker info, worker id = " << worker_id; + }); + return Status::OK(); +} + +Status ServiceBasedWorkerInfoAccessor::AsyncGetAll( + const MultiItemCallback &callback) { + RAY_LOG(DEBUG) << "Getting all worker info."; + rpc::GetAllWorkerInfoRequest request; + client_impl_->GetGcsRpcClient().GetAllWorkerInfo( + request, [callback](const Status &status, const rpc::GetAllWorkerInfoReply &reply) { + auto result = VectorFromProtobuf(reply.worker_table_data()); + callback(status, result); + RAY_LOG(DEBUG) << "Finished getting all worker info, status = " << status; + }); + return Status::OK(); +} + +Status ServiceBasedWorkerInfoAccessor::AsyncAdd( + const std::shared_ptr &data_ptr, + const StatusCallback &callback) { + rpc::AddWorkerInfoRequest request; + request.mutable_worker_data()->CopyFrom(*data_ptr); + client_impl_->GetGcsRpcClient().AddWorkerInfo( + request, [callback](const Status &status, const rpc::AddWorkerInfoReply &reply) { if (callback) { callback(status); } - RAY_LOG(DEBUG) << "Finished registering worker. worker id = " << worker_id; }); return Status::OK(); } diff --git a/src/ray/gcs/gcs_client/service_based_accessor.h b/src/ray/gcs/gcs_client/service_based_accessor.h index a38125293..3d7eda914 100644 --- a/src/ray/gcs/gcs_client/service_based_accessor.h +++ b/src/ray/gcs/gcs_client/service_based_accessor.h @@ -374,16 +374,19 @@ class ServiceBasedWorkerInfoAccessor : public WorkerInfoAccessor { virtual ~ServiceBasedWorkerInfoAccessor() = default; Status AsyncSubscribeToWorkerFailures( - const SubscribeCallback &subscribe, + const SubscribeCallback &subscribe, const StatusCallback &done) override; - Status AsyncReportWorkerFailure(const std::shared_ptr &data_ptr, + Status AsyncReportWorkerFailure(const std::shared_ptr &data_ptr, const StatusCallback &callback) override; - Status AsyncRegisterWorker( - rpc::WorkerType worker_type, const WorkerID &worker_id, - const std::unordered_map &worker_info, - const StatusCallback &callback) override; + Status AsyncGet(const WorkerID &worker_id, + const OptionalItemCallback &callback) override; + + Status AsyncGetAll(const MultiItemCallback &callback) override; + + Status AsyncAdd(const std::shared_ptr &data_ptr, + const StatusCallback &callback) override; void AsyncResubscribe(bool is_pubsub_server_restarted) override; diff --git a/src/ray/gcs/gcs_client/test/global_state_accessor_test.cc b/src/ray/gcs/gcs_client/test/global_state_accessor_test.cc index 71309f030..38e6e894f 100644 --- a/src/ray/gcs/gcs_client/test/global_state_accessor_test.cc +++ b/src/ray/gcs/gcs_client/test/global_state_accessor_test.cc @@ -226,6 +226,26 @@ TEST_F(GlobalStateAccessorTest, TestActorTable) { } } +TEST_F(GlobalStateAccessorTest, TestWorkerTable) { + ASSERT_EQ(global_state_->GetAllWorkerInfo().size(), 0); + // Add worker info + auto worker_table_data = Mocker::GenWorkerTableData(); + worker_table_data->mutable_worker_address()->set_worker_id( + WorkerID::FromRandom().Binary()); + ASSERT_TRUE(global_state_->AddWorkerInfo(worker_table_data->SerializeAsString())); + + // Get worker info + auto worker_id = WorkerID::FromBinary(worker_table_data->worker_address().worker_id()); + ASSERT_TRUE(global_state_->GetWorkerInfo(worker_id)); + + // Add another worker info + auto another_worker_data = Mocker::GenWorkerTableData(); + another_worker_data->mutable_worker_address()->set_worker_id( + WorkerID::FromRandom().Binary()); + ASSERT_TRUE(global_state_->AddWorkerInfo(another_worker_data->SerializeAsString())); + ASSERT_EQ(global_state_->GetAllWorkerInfo().size(), 2); +} + } // namespace ray int main(int argc, char **argv) { diff --git a/src/ray/gcs/gcs_client/test/service_based_gcs_client_test.cc b/src/ray/gcs/gcs_client/test/service_based_gcs_client_test.cc index a625848fd..ce15ffc7e 100644 --- a/src/ray/gcs/gcs_client/test/service_based_gcs_client_test.cc +++ b/src/ray/gcs/gcs_client/test/service_based_gcs_client_test.cc @@ -455,7 +455,7 @@ class ServiceBasedGcsClientTest : public ::testing::Test { } bool SubscribeToWorkerFailures( - const gcs::SubscribeCallback &subscribe) { + const gcs::SubscribeCallback &subscribe) { std::promise promise; RAY_CHECK_OK(gcs_client_->Workers().AsyncSubscribeToWorkerFailures( subscribe, [&promise](Status status) { promise.set_value(status.ok()); })); @@ -463,7 +463,7 @@ class ServiceBasedGcsClientTest : public ::testing::Test { } bool ReportWorkerFailure( - const std::shared_ptr &worker_failure_data) { + const std::shared_ptr &worker_failure_data) { std::promise promise; RAY_CHECK_OK(gcs_client_->Workers().AsyncReportWorkerFailure( worker_failure_data, @@ -471,6 +471,13 @@ class ServiceBasedGcsClientTest : public ::testing::Test { return WaitReady(promise.get_future(), timeout_ms_); } + bool AddWorker(const std::shared_ptr &worker_data) { + std::promise promise; + RAY_CHECK_OK(gcs_client_->Workers().AsyncAdd( + worker_data, [&promise](Status status) { promise.set_value(status.ok()); })); + return WaitReady(promise.get_future(), timeout_ms_); + } + bool WaitReady(std::future future, const std::chrono::milliseconds &timeout_ms) { auto status = future.wait_for(timeout_ms); return status == std::future_status::ready && future.get(); @@ -824,14 +831,22 @@ TEST_F(ServiceBasedGcsClientTest, TestWorkerInfo) { // Subscribe to all unexpected failure of workers from GCS. std::atomic worker_failure_count(0); auto on_subscribe = [&worker_failure_count](const WorkerID &worker_id, - const rpc::WorkerFailureData &result) { + const rpc::WorkerTableData &result) { ++worker_failure_count; }; ASSERT_TRUE(SubscribeToWorkerFailures(on_subscribe)); - // Report a worker failure to GCS. - auto worker_failure_data = Mocker::GenWorkerFailureData(); - ASSERT_TRUE(ReportWorkerFailure(worker_failure_data)); + // Report a worker failure to GCS when this worker doesn't exist. + auto worker_data = Mocker::GenWorkerTableData(); + worker_data->mutable_worker_address()->set_worker_id(WorkerID::FromRandom().Binary()); + ASSERT_TRUE(ReportWorkerFailure(worker_data)); + WaitPendingDone(worker_failure_count, 0); + + // Add a worker to GCS. + ASSERT_TRUE(AddWorker(worker_data)); + + // Report a worker failure to GCS when this worker is actually exist. + ASSERT_TRUE(ReportWorkerFailure(worker_data)); WaitPendingDone(worker_failure_count, 1); } @@ -1065,7 +1080,7 @@ TEST_F(ServiceBasedGcsClientTest, TestWorkerTableResubscribe) { // Subscribe to all unexpected failure of workers from GCS. std::atomic worker_failure_count(0); auto on_subscribe = [&worker_failure_count](const WorkerID &worker_id, - const rpc::WorkerFailureData &result) { + const rpc::WorkerTableData &result) { ++worker_failure_count; }; ASSERT_TRUE(SubscribeToWorkerFailures(on_subscribe)); @@ -1073,9 +1088,13 @@ TEST_F(ServiceBasedGcsClientTest, TestWorkerTableResubscribe) { // Restart GCS RestartGcsServer(); + // Add a worker before report worker failure to GCS. + auto worker_data = Mocker::GenWorkerTableData(); + worker_data->mutable_worker_address()->set_worker_id(WorkerID::FromRandom().Binary()); + ASSERT_TRUE(AddWorker(worker_data)); + // Report a worker failure to GCS and check if resubscribe works. - auto worker_failure_data = Mocker::GenWorkerFailureData(); - ASSERT_TRUE(ReportWorkerFailure(worker_failure_data)); + ASSERT_TRUE(ReportWorkerFailure(worker_data)); WaitPendingDone(worker_failure_count, 1); } diff --git a/src/ray/gcs/gcs_server/gcs_server.cc b/src/ray/gcs/gcs_server/gcs_server.cc index eb4ca29a4..0617881a5 100644 --- a/src/ray/gcs/gcs_server/gcs_server.cc +++ b/src/ray/gcs/gcs_server/gcs_server.cc @@ -19,11 +19,11 @@ #include "gcs_job_manager.h" #include "gcs_node_manager.h" #include "gcs_object_manager.h" +#include "gcs_worker_manager.h" #include "ray/common/network_util.h" #include "ray/common/ray_config.h" #include "stats_handler_impl.h" #include "task_info_handler_impl.h" -#include "worker_info_handler_impl.h" namespace ray { namespace gcs { @@ -92,9 +92,9 @@ void GcsServer::Start() { new rpc::ErrorInfoGrpcService(main_service_, *error_info_handler_)); rpc_server_.RegisterService(*error_info_service_); - worker_info_handler_ = InitWorkerInfoHandler(); + gcs_worker_manager_ = InitGcsWorkerManager(); worker_info_service_.reset( - new rpc::WorkerInfoGrpcService(main_service_, *worker_info_handler_)); + new rpc::WorkerInfoGrpcService(main_service_, *gcs_worker_manager_)); rpc_server_.RegisterService(*worker_info_service_); auto load_completed_count = std::make_shared(0); @@ -191,7 +191,7 @@ void GcsServer::InitGcsActorManager() { }); auto on_subscribe = [this](const std::string &id, const std::string &data) { - rpc::WorkerFailureData worker_failure_data; + rpc::WorkerTableData worker_failure_data; worker_failure_data.ParseFromString(data); auto &worker_address = worker_failure_data.worker_address(); WorkerID worker_id = WorkerID::FromBinary(id); @@ -199,7 +199,7 @@ void GcsServer::InitGcsActorManager() { gcs_actor_manager_->OnWorkerDead(node_id, worker_id, worker_failure_data.intentional_disconnect()); }; - RAY_CHECK_OK(gcs_pub_sub_->SubscribeAll(WORKER_FAILURE_CHANNEL, on_subscribe, nullptr)); + RAY_CHECK_OK(gcs_pub_sub_->SubscribeAll(WORKER_CHANNEL, on_subscribe, nullptr)); } void GcsServer::InitGcsJobManager() { @@ -243,9 +243,9 @@ std::unique_ptr GcsServer::InitErrorInfoHandler() { new rpc::DefaultErrorInfoHandler(*redis_gcs_client_)); } -std::unique_ptr GcsServer::InitWorkerInfoHandler() { - return std::unique_ptr(new rpc::DefaultWorkerInfoHandler( - *redis_gcs_client_, gcs_table_storage_, gcs_pub_sub_)); +std::unique_ptr GcsServer::InitGcsWorkerManager() { + return std::unique_ptr( + new GcsWorkerManager(gcs_table_storage_, gcs_pub_sub_)); } } // namespace gcs diff --git a/src/ray/gcs/gcs_server/gcs_server.h b/src/ray/gcs/gcs_server/gcs_server.h index 00d2ba197..2eecc9276 100644 --- a/src/ray/gcs/gcs_server/gcs_server.h +++ b/src/ray/gcs/gcs_server/gcs_server.h @@ -39,6 +39,7 @@ struct GcsServerConfig { class GcsNodeManager; class GcsActorManager; class GcsJobManager; +class GcsWorkerManager; /// The GcsServer will take over all requests from ServiceBasedGcsClient and transparent /// transmit the command to the backend reliable storage for the time being. @@ -96,8 +97,8 @@ class GcsServer { /// The error info handler virtual std::unique_ptr InitErrorInfoHandler(); - /// The worker info handler - virtual std::unique_ptr InitWorkerInfoHandler(); + /// The worker manager + virtual std::unique_ptr InitGcsWorkerManager(); private: /// Store the address of GCS server in Redis. @@ -140,8 +141,9 @@ class GcsServer { /// Error info handler and service std::unique_ptr error_info_handler_; std::unique_ptr error_info_service_; - /// Worker info handler and service - std::unique_ptr worker_info_handler_; + /// The gcs worker manager + std::unique_ptr gcs_worker_manager_; + /// Worker info service std::unique_ptr worker_info_service_; /// Backend client std::shared_ptr redis_gcs_client_; diff --git a/src/ray/gcs/gcs_server/gcs_table_storage.cc b/src/ray/gcs/gcs_server/gcs_table_storage.cc index 627605a6c..fd4d4e780 100644 --- a/src/ray/gcs/gcs_server/gcs_table_storage.cc +++ b/src/ray/gcs/gcs_server/gcs_table_storage.cc @@ -112,7 +112,7 @@ template class GcsTable; template class GcsTable; template class GcsTable; template class GcsTable; -template class GcsTable; +template class GcsTable; template class GcsTable; template class GcsTable; template class GcsTable; diff --git a/src/ray/gcs/gcs_server/gcs_table_storage.h b/src/ray/gcs/gcs_server/gcs_table_storage.h index 591a49d43..494b09727 100644 --- a/src/ray/gcs/gcs_server/gcs_table_storage.h +++ b/src/ray/gcs/gcs_server/gcs_table_storage.h @@ -39,7 +39,7 @@ using rpc::ResourceTableData; using rpc::TaskLeaseData; using rpc::TaskReconstructionData; using rpc::TaskTableData; -using rpc::WorkerFailureData; +using rpc::WorkerTableData; /// \class GcsTable /// @@ -265,11 +265,11 @@ class GcsProfileTable : public GcsTable { } }; -class GcsWorkerFailureTable : public GcsTable { +class GcsWorkerTable : public GcsTable { public: - explicit GcsWorkerFailureTable(std::shared_ptr &store_client) + explicit GcsWorkerTable(std::shared_ptr &store_client) : GcsTable(store_client) { - table_name_ = TablePrefix_Name(TablePrefix::WORKER_FAILURE); + table_name_ = TablePrefix_Name(TablePrefix::WORKERS); } }; @@ -349,9 +349,9 @@ class GcsTableStorage { return *profile_table_; } - GcsWorkerFailureTable &WorkerFailureTable() { - RAY_CHECK(worker_failure_table_ != nullptr); - return *worker_failure_table_; + GcsWorkerTable &WorkerTable() { + RAY_CHECK(worker_table_ != nullptr); + return *worker_table_; } protected: @@ -370,7 +370,7 @@ class GcsTableStorage { std::unique_ptr heartbeat_batch_table_; std::unique_ptr error_info_table_; std::unique_ptr profile_table_; - std::unique_ptr worker_failure_table_; + std::unique_ptr worker_table_; }; /// \class RedisGcsTableStorage @@ -394,7 +394,7 @@ class RedisGcsTableStorage : public GcsTableStorage { heartbeat_batch_table_.reset(new GcsHeartbeatBatchTable(store_client_)); error_info_table_.reset(new GcsErrorInfoTable(store_client_)); profile_table_.reset(new GcsProfileTable(store_client_)); - worker_failure_table_.reset(new GcsWorkerFailureTable(store_client_)); + worker_table_.reset(new GcsWorkerTable(store_client_)); } }; @@ -419,7 +419,7 @@ class InMemoryGcsTableStorage : public GcsTableStorage { heartbeat_batch_table_.reset(new GcsHeartbeatBatchTable(store_client_)); error_info_table_.reset(new GcsErrorInfoTable(store_client_)); profile_table_.reset(new GcsProfileTable(store_client_)); - worker_failure_table_.reset(new GcsWorkerFailureTable(store_client_)); + worker_table_.reset(new GcsWorkerTable(store_client_)); } }; diff --git a/src/ray/gcs/gcs_server/gcs_worker_manager.cc b/src/ray/gcs/gcs_server/gcs_worker_manager.cc new file mode 100644 index 000000000..827976f00 --- /dev/null +++ b/src/ray/gcs/gcs_server/gcs_worker_manager.cc @@ -0,0 +1,133 @@ +// Copyright 2017 The Ray Authors. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "gcs_worker_manager.h" + +namespace ray { +namespace gcs { + +void GcsWorkerManager::HandleReportWorkerFailure( + const rpc::ReportWorkerFailureRequest &request, rpc::ReportWorkerFailureReply *reply, + rpc::SendReplyCallback send_reply_callback) { + const rpc::Address worker_address = request.worker_failure().worker_address(); + RAY_LOG(DEBUG) << "Reporting worker failure, " << worker_address.DebugString(); + auto worker_failure_data = std::make_shared(); + worker_failure_data->CopyFrom(request.worker_failure()); + worker_failure_data->set_is_alive(false); + const auto worker_id = WorkerID::FromBinary(worker_address.worker_id()); + + // Before handle ReportWorkerFailureRequest, you should check if the worker is exists. + auto on_get_done = + [this, worker_address, worker_id, worker_failure_data, reply, send_reply_callback]( + const Status &status, const boost::optional &result) { + if (result) { + auto on_put_done = [this, worker_address, worker_id, worker_failure_data, reply, + send_reply_callback](const Status &status) { + if (!status.ok()) { + RAY_LOG(ERROR) << "Failed to report worker failure, " + << worker_address.DebugString(); + } else { + RAY_CHECK_OK(gcs_pub_sub_->Publish(WORKER_CHANNEL, worker_id.Binary(), + worker_failure_data->SerializeAsString(), + nullptr)); + } + GCS_RPC_SEND_REPLY(send_reply_callback, reply, status); + }; + + // The worker exists in worker table, you can update the info of this worker. + Status report_status = gcs_table_storage_->WorkerTable().Put( + worker_id, *worker_failure_data, on_put_done); + if (!report_status.ok()) { + on_put_done(report_status); + } + } else { + // The worker doesn't exists in worker table. + RAY_LOG(WARNING) << "Failed to report worker failure, the worker doesn't " + "exist, " + << worker_address.DebugString(); + GCS_RPC_SEND_REPLY(send_reply_callback, reply, status); + } + }; + Status status = gcs_table_storage_->WorkerTable().Get(worker_id, on_get_done); + if (!status.ok()) { + on_get_done(status, boost::none); + } +} + +void GcsWorkerManager::HandleGetWorkerInfo(const rpc::GetWorkerInfoRequest &request, + rpc::GetWorkerInfoReply *reply, + rpc::SendReplyCallback send_reply_callback) { + WorkerID worker_id = WorkerID::FromBinary(request.worker_id()); + RAY_LOG(DEBUG) << "Getting worker info, worker id = " << worker_id; + + auto on_done = [worker_id, reply, send_reply_callback]( + const Status &status, + const boost::optional &result) { + if (result) { + reply->mutable_worker_table_data()->CopyFrom(*result); + } + RAY_LOG(DEBUG) << "Finished getting worker info, worker id = " << worker_id; + GCS_RPC_SEND_REPLY(send_reply_callback, reply, Status::OK()); + }; + + Status status = gcs_table_storage_->WorkerTable().Get(worker_id, on_done); + if (!status.ok()) { + on_done(status, boost::none); + } +} + +void GcsWorkerManager::HandleGetAllWorkerInfo( + const rpc::GetAllWorkerInfoRequest &request, rpc::GetAllWorkerInfoReply *reply, + rpc::SendReplyCallback send_reply_callback) { + RAY_LOG(DEBUG) << "Getting all worker info."; + auto on_done = [reply, send_reply_callback]( + const std::unordered_map &result) { + for (auto &data : result) { + reply->add_worker_table_data()->CopyFrom(data.second); + } + RAY_LOG(DEBUG) << "Finished getting all worker info."; + GCS_RPC_SEND_REPLY(send_reply_callback, reply, Status::OK()); + }; + Status status = gcs_table_storage_->WorkerTable().GetAll(on_done); + if (!status.ok()) { + on_done(std::unordered_map()); + } +} + +void GcsWorkerManager::HandleAddWorkerInfo(const rpc::AddWorkerInfoRequest &request, + rpc::AddWorkerInfoReply *reply, + rpc::SendReplyCallback send_reply_callback) { + auto worker_data = std::make_shared(); + worker_data->CopyFrom(request.worker_data()); + auto worker_id = WorkerID::FromBinary(worker_data->worker_address().worker_id()); + RAY_LOG(DEBUG) << "Adding worker " << worker_id; + + auto on_done = [worker_id, worker_data, reply, + send_reply_callback](const Status &status) { + if (!status.ok()) { + RAY_LOG(ERROR) << "Failed to add worker information, " + << worker_data->DebugString(); + } + RAY_LOG(DEBUG) << "Finished adding worker " << worker_id; + GCS_RPC_SEND_REPLY(send_reply_callback, reply, status); + }; + + Status status = gcs_table_storage_->WorkerTable().Put(worker_id, *worker_data, on_done); + if (!status.ok()) { + on_done(status); + } +} + +} // namespace gcs +} // namespace ray diff --git a/src/ray/gcs/gcs_server/gcs_worker_manager.h b/src/ray/gcs/gcs_server/gcs_worker_manager.h new file mode 100644 index 000000000..5013dd0c2 --- /dev/null +++ b/src/ray/gcs/gcs_server/gcs_worker_manager.h @@ -0,0 +1,54 @@ +// Copyright 2017 The Ray Authors. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#pragma once + +#include "gcs_table_storage.h" +#include "ray/gcs/pubsub/gcs_pub_sub.h" +#include "ray/gcs/redis_gcs_client.h" +#include "ray/rpc/gcs_server/gcs_rpc_server.h" + +namespace ray { +namespace gcs { + +/// This implementation class of `WorkerInfoHandler`. +class GcsWorkerManager : public rpc::WorkerInfoHandler { + public: + explicit GcsWorkerManager(std::shared_ptr gcs_table_storage, + std::shared_ptr &gcs_pub_sub) + : gcs_table_storage_(gcs_table_storage), gcs_pub_sub_(gcs_pub_sub) {} + + void HandleReportWorkerFailure(const rpc::ReportWorkerFailureRequest &request, + rpc::ReportWorkerFailureReply *reply, + rpc::SendReplyCallback send_reply_callback) override; + + void HandleGetWorkerInfo(const rpc::GetWorkerInfoRequest &request, + rpc::GetWorkerInfoReply *reply, + rpc::SendReplyCallback send_reply_callback) override; + + void HandleGetAllWorkerInfo(const rpc::GetAllWorkerInfoRequest &request, + rpc::GetAllWorkerInfoReply *reply, + rpc::SendReplyCallback send_reply_callback) override; + + void HandleAddWorkerInfo(const rpc::AddWorkerInfoRequest &request, + rpc::AddWorkerInfoReply *reply, + rpc::SendReplyCallback send_reply_callback) override; + + private: + std::shared_ptr gcs_table_storage_; + std::shared_ptr gcs_pub_sub_; +}; + +} // namespace gcs +} // namespace ray diff --git a/src/ray/gcs/gcs_server/test/gcs_server_rpc_test.cc b/src/ray/gcs/gcs_server/test/gcs_server_rpc_test.cc index 2faf95b96..898e36844 100644 --- a/src/ray/gcs/gcs_server/test/gcs_server_rpc_test.cc +++ b/src/ray/gcs/gcs_server/test/gcs_server_rpc_test.cc @@ -390,6 +390,53 @@ class GcsServerTest : public ::testing::Test { client_->ReportWorkerFailure( request, [&promise](const Status &status, const rpc::ReportWorkerFailureReply &reply) { + RAY_CHECK_OK(status); + promise.set_value(status.ok()); + }); + return WaitReady(promise.get_future(), timeout_ms_); + } + + boost::optional GetWorkerInfo(const std::string &worker_id) { + rpc::GetWorkerInfoRequest request; + request.set_worker_id(worker_id); + boost::optional worker_table_data_opt; + std::promise promise; + client_->GetWorkerInfo( + request, [&worker_table_data_opt, &promise]( + const Status &status, const rpc::GetWorkerInfoReply &reply) { + RAY_CHECK_OK(status); + if (reply.has_worker_table_data()) { + worker_table_data_opt = reply.worker_table_data(); + } else { + worker_table_data_opt = boost::none; + } + promise.set_value(true); + }); + EXPECT_TRUE(WaitReady(promise.get_future(), timeout_ms_)); + return worker_table_data_opt; + } + + std::vector GetAllWorkerInfo() { + std::vector worker_table_data; + rpc::GetAllWorkerInfoRequest request; + std::promise promise; + client_->GetAllWorkerInfo( + request, [&worker_table_data, &promise](const Status &status, + const rpc::GetAllWorkerInfoReply &reply) { + RAY_CHECK_OK(status); + for (int index = 0; index < reply.worker_table_data_size(); ++index) { + worker_table_data.push_back(reply.worker_table_data(index)); + } + promise.set_value(true); + }); + EXPECT_TRUE(WaitReady(promise.get_future(), timeout_ms_)); + return worker_table_data; + } + + bool AddWorkerInfo(const rpc::AddWorkerInfoRequest &request) { + std::promise promise; + client_->AddWorkerInfo( + request, [&promise](const Status &status, const rpc::AddWorkerInfoReply &reply) { RAY_CHECK_OK(status); promise.set_value(true); }); @@ -725,12 +772,29 @@ TEST_F(GcsServerTest, TestErrorInfo) { } TEST_F(GcsServerTest, TestWorkerInfo) { - rpc::WorkerFailureData worker_failure_data; - worker_failure_data.mutable_worker_address()->set_ip_address("127.0.0.1"); - worker_failure_data.mutable_worker_address()->set_port(5566); + // Report worker failure + auto worker_failure_data = Mocker::GenWorkerTableData(); + worker_failure_data->mutable_worker_address()->set_ip_address("127.0.0.1"); + worker_failure_data->mutable_worker_address()->set_port(5566); rpc::ReportWorkerFailureRequest report_worker_failure_request; - report_worker_failure_request.mutable_worker_failure()->CopyFrom(worker_failure_data); + report_worker_failure_request.mutable_worker_failure()->CopyFrom(*worker_failure_data); ASSERT_TRUE(ReportWorkerFailure(report_worker_failure_request)); + std::vector worker_table_data = GetAllWorkerInfo(); + ASSERT_TRUE(worker_table_data.size() == 0); + + // Add worker info + auto worker_data = Mocker::GenWorkerTableData(); + worker_data->mutable_worker_address()->set_worker_id(WorkerID::FromRandom().Binary()); + rpc::AddWorkerInfoRequest add_worker_request; + add_worker_request.mutable_worker_data()->CopyFrom(*worker_data); + ASSERT_TRUE(AddWorkerInfo(add_worker_request)); + ASSERT_TRUE(GetAllWorkerInfo().size() == 1); + + // Get worker info + boost::optional result = + GetWorkerInfo(worker_data->worker_address().worker_id()); + ASSERT_TRUE(result->worker_address().worker_id() == + worker_data->worker_address().worker_id()); } } // namespace ray diff --git a/src/ray/gcs/gcs_server/worker_info_handler_impl.cc b/src/ray/gcs/gcs_server/worker_info_handler_impl.cc deleted file mode 100644 index 5bd7745fb..000000000 --- a/src/ray/gcs/gcs_server/worker_info_handler_impl.cc +++ /dev/null @@ -1,72 +0,0 @@ -// Copyright 2017 The Ray Authors. -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -#include "worker_info_handler_impl.h" - -namespace ray { -namespace rpc { - -void DefaultWorkerInfoHandler::HandleReportWorkerFailure( - const ReportWorkerFailureRequest &request, ReportWorkerFailureReply *reply, - SendReplyCallback send_reply_callback) { - const Address worker_address = request.worker_failure().worker_address(); - RAY_LOG(DEBUG) << "Reporting worker failure, " << worker_address.DebugString(); - auto worker_failure_data = std::make_shared(); - worker_failure_data->CopyFrom(request.worker_failure()); - const auto worker_id = WorkerID::FromBinary(worker_address.worker_id()); - auto on_done = [this, worker_address, worker_id, worker_failure_data, reply, - send_reply_callback](const Status &status) { - if (!status.ok()) { - RAY_LOG(ERROR) << "Failed to report worker failure, " - << worker_address.DebugString(); - } else { - RAY_CHECK_OK(gcs_pub_sub_->Publish(WORKER_FAILURE_CHANNEL, worker_id.Binary(), - worker_failure_data->SerializeAsString(), - nullptr)); - } - GCS_RPC_SEND_REPLY(send_reply_callback, reply, status); - }; - - Status status = gcs_table_storage_->WorkerFailureTable().Put( - worker_id, *worker_failure_data, on_done); - if (!status.ok()) { - on_done(status); - } -} - -void DefaultWorkerInfoHandler::HandleRegisterWorker( - const RegisterWorkerRequest &request, RegisterWorkerReply *reply, - SendReplyCallback send_reply_callback) { - auto worker_type = request.worker_type(); - auto worker_id = WorkerID::FromBinary(request.worker_id()); - auto worker_info = MapFromProtobuf(request.worker_info()); - - auto on_done = [worker_id, reply, send_reply_callback](const Status &status) { - if (!status.ok()) { - RAY_LOG(ERROR) << "Failed to register worker " << worker_id; - } else { - RAY_LOG(DEBUG) << "Finished registering worker " << worker_id; - } - GCS_RPC_SEND_REPLY(send_reply_callback, reply, Status::OK()); - }; - - Status status = gcs_client_.Workers().AsyncRegisterWorker(worker_type, worker_id, - worker_info, on_done); - if (!status.ok()) { - on_done(status); - } -} - -} // namespace rpc -} // namespace ray diff --git a/src/ray/gcs/gcs_server/worker_info_handler_impl.h b/src/ray/gcs/gcs_server/worker_info_handler_impl.h deleted file mode 100644 index 1b556b14b..000000000 --- a/src/ray/gcs/gcs_server/worker_info_handler_impl.h +++ /dev/null @@ -1,51 +0,0 @@ -// Copyright 2017 The Ray Authors. -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -#pragma once - -#include "gcs_table_storage.h" -#include "ray/gcs/pubsub/gcs_pub_sub.h" -#include "ray/gcs/redis_gcs_client.h" -#include "ray/rpc/gcs_server/gcs_rpc_server.h" - -namespace ray { -namespace rpc { - -/// This implementation class of `WorkerInfoHandler`. -class DefaultWorkerInfoHandler : public rpc::WorkerInfoHandler { - public: - explicit DefaultWorkerInfoHandler( - gcs::RedisGcsClient &gcs_client, - std::shared_ptr gcs_table_storage, - std::shared_ptr &gcs_pub_sub) - : gcs_client_(gcs_client), - gcs_table_storage_(gcs_table_storage), - gcs_pub_sub_(gcs_pub_sub) {} - - void HandleReportWorkerFailure(const ReportWorkerFailureRequest &request, - ReportWorkerFailureReply *reply, - SendReplyCallback send_reply_callback) override; - - void HandleRegisterWorker(const RegisterWorkerRequest &request, - RegisterWorkerReply *reply, - SendReplyCallback send_reply_callback) override; - - private: - gcs::RedisGcsClient &gcs_client_; - std::shared_ptr gcs_table_storage_; - std::shared_ptr gcs_pub_sub_; -}; - -} // namespace rpc -} // namespace ray diff --git a/src/ray/gcs/pb_util.h b/src/ray/gcs/pb_util.h index bc40c981e..564623fa5 100644 --- a/src/ray/gcs/pb_util.h +++ b/src/ray/gcs/pb_util.h @@ -82,11 +82,11 @@ inline std::shared_ptr CreateActorTableData( } /// Helper function to produce worker failure data. -inline std::shared_ptr CreateWorkerFailureData( +inline std::shared_ptr CreateWorkerFailureData( const ClientID &raylet_id, const WorkerID &worker_id, const std::string &address, int32_t port, int64_t timestamp = std::time(nullptr), bool intentional_disconnect = false) { - auto worker_failure_info_ptr = std::make_shared(); + auto worker_failure_info_ptr = std::make_shared(); worker_failure_info_ptr->mutable_worker_address()->set_raylet_id(raylet_id.Binary()); worker_failure_info_ptr->mutable_worker_address()->set_worker_id(worker_id.Binary()); worker_failure_info_ptr->mutable_worker_address()->set_ip_address(address); diff --git a/src/ray/gcs/pubsub/gcs_pub_sub.h b/src/ray/gcs/pubsub/gcs_pub_sub.h index 19cfaee6e..43e327281 100644 --- a/src/ray/gcs/pubsub/gcs_pub_sub.h +++ b/src/ray/gcs/pubsub/gcs_pub_sub.h @@ -29,7 +29,7 @@ namespace gcs { #define NODE_CHANNEL "NODE" #define NODE_RESOURCE_CHANNEL "NODE_RESOURCE" #define ACTOR_CHANNEL "ACTOR" -#define WORKER_FAILURE_CHANNEL "WORKER_FAILURE" +#define WORKER_CHANNEL "WORKER" #define OBJECT_CHANNEL "OBJECT" #define TASK_CHANNEL "TASK" #define TASK_LEASE_CHANNEL "TASK_LEASE" diff --git a/src/ray/gcs/redis_accessor.cc b/src/ray/gcs/redis_accessor.cc index 179b7d468..800f90e0f 100644 --- a/src/ray/gcs/redis_accessor.cc +++ b/src/ray/gcs/redis_accessor.cc @@ -792,50 +792,43 @@ Status RedisStatsInfoAccessor::AsyncAddProfileData( RedisWorkerInfoAccessor::RedisWorkerInfoAccessor(RedisGcsClient *client_impl) : client_impl_(client_impl), - worker_failure_sub_executor_(client_impl->worker_failure_table()) {} + worker_failure_sub_executor_(client_impl->worker_table()) {} Status RedisWorkerInfoAccessor::AsyncSubscribeToWorkerFailures( - const SubscribeCallback &subscribe, + const SubscribeCallback &subscribe, const StatusCallback &done) { RAY_CHECK(subscribe != nullptr); return worker_failure_sub_executor_.AsyncSubscribeAll(ClientID::Nil(), subscribe, done); } Status RedisWorkerInfoAccessor::AsyncReportWorkerFailure( - const std::shared_ptr &data_ptr, const StatusCallback &callback) { - WorkerFailureTable::WriteCallback on_done = nullptr; + const std::shared_ptr &data_ptr, const StatusCallback &callback) { + WorkerTable::WriteCallback on_done = nullptr; if (callback != nullptr) { on_done = [callback](RedisGcsClient *client, const WorkerID &id, - const WorkerFailureData &data) { callback(Status::OK()); }; + const WorkerTableData &data) { callback(Status::OK()); }; } WorkerID worker_id = WorkerID::FromBinary(data_ptr->worker_address().worker_id()); - WorkerFailureTable &worker_failure_table = client_impl_->worker_failure_table(); + WorkerTable &worker_failure_table = client_impl_->worker_table(); return worker_failure_table.Add(JobID::Nil(), worker_id, data_ptr, on_done); } -Status RedisWorkerInfoAccessor::AsyncRegisterWorker( - rpc::WorkerType worker_type, const WorkerID &worker_id, - const std::unordered_map &worker_info, - const StatusCallback &callback) { - std::vector args; - args.emplace_back("HMSET"); - if (worker_type == rpc::WorkerType::DRIVER) { - args.emplace_back("Drivers:" + worker_id.Binary()); - } else { - args.emplace_back("Workers:" + worker_id.Binary()); - } - for (const auto &entry : worker_info) { - args.push_back(entry.first); - args.push_back(entry.second); - } +Status RedisWorkerInfoAccessor::AsyncGet( + const WorkerID &worker_id, + const OptionalItemCallback &callback) { + return Status::Invalid("Not implemented"); +} - auto status = client_impl_->primary_context()->RunArgvAsync(args); - if (callback) { - // TODO (kfstorm): Invoke the callback asynchronously. - callback(status); - } - return status; +Status RedisWorkerInfoAccessor::AsyncGetAll( + const MultiItemCallback &callback) { + return Status::Invalid("Not implemented"); +} + +Status RedisWorkerInfoAccessor::AsyncAdd( + const std::shared_ptr &data_ptr, + const StatusCallback &callback) { + return Status::Invalid("Not implemented"); } } // namespace gcs diff --git a/src/ray/gcs/redis_accessor.h b/src/ray/gcs/redis_accessor.h index 02b061996..c5c220fcf 100644 --- a/src/ray/gcs/redis_accessor.h +++ b/src/ray/gcs/redis_accessor.h @@ -436,23 +436,26 @@ class RedisWorkerInfoAccessor : public WorkerInfoAccessor { virtual ~RedisWorkerInfoAccessor() = default; Status AsyncSubscribeToWorkerFailures( - const SubscribeCallback &subscribe, + const SubscribeCallback &subscribe, const StatusCallback &done) override; - Status AsyncReportWorkerFailure(const std::shared_ptr &data_ptr, + Status AsyncReportWorkerFailure(const std::shared_ptr &data_ptr, const StatusCallback &callback) override; - Status AsyncRegisterWorker( - rpc::WorkerType worker_type, const WorkerID &worker_id, - const std::unordered_map &worker_info, - const StatusCallback &callback) override; + Status AsyncGet(const WorkerID &worker_id, + const OptionalItemCallback &callback) override; + + Status AsyncGetAll(const MultiItemCallback &callback) override; + + Status AsyncAdd(const std::shared_ptr &data_ptr, + const StatusCallback &callback) override; void AsyncResubscribe(bool is_pubsub_server_restarted) override {} private: RedisGcsClient *client_impl_{nullptr}; - typedef SubscriptionExecutor + typedef SubscriptionExecutor WorkerFailureSubscriptionExecutor; WorkerFailureSubscriptionExecutor worker_failure_sub_executor_; }; diff --git a/src/ray/gcs/redis_gcs_client.cc b/src/ray/gcs/redis_gcs_client.cc index 46e2f7c79..07ca06745 100644 --- a/src/ray/gcs/redis_gcs_client.cc +++ b/src/ray/gcs/redis_gcs_client.cc @@ -68,13 +68,14 @@ Status RedisGcsClient::Connect(boost::asio::io_service &io_service) { actor_checkpoint_table_.reset(new ActorCheckpointTable(shard_contexts, this)); actor_checkpoint_id_table_.reset(new ActorCheckpointIdTable(shard_contexts, this)); resource_table_.reset(new DynamicResourceTable({primary_context}, this)); - worker_failure_table_.reset(new WorkerFailureTable(shard_contexts, this)); + worker_table_.reset(new WorkerTable(shard_contexts, this)); if (RayConfig::instance().gcs_actor_service_enabled()) { actor_accessor_.reset(new RedisActorInfoAccessor(this)); } else { actor_accessor_.reset(new RedisLogBasedActorInfoAccessor(this)); } + job_accessor_.reset(new RedisJobInfoAccessor(this)); object_accessor_.reset(new RedisObjectInfoAccessor(this)); node_accessor_.reset(new RedisNodeInfoAccessor(this)); @@ -123,9 +124,7 @@ LogBasedActorTable &RedisGcsClient::log_based_actor_table() { ActorTable &RedisGcsClient::actor_table() { return *actor_table_; } -WorkerFailureTable &RedisGcsClient::worker_failure_table() { - return *worker_failure_table_; -} +WorkerTable &RedisGcsClient::worker_table() { return *worker_table_; } TaskReconstructionLog &RedisGcsClient::task_reconstruction_log() { return *task_reconstruction_log_; diff --git a/src/ray/gcs/redis_gcs_client.h b/src/ray/gcs/redis_gcs_client.h index cccecc0aa..7b01c0ba1 100644 --- a/src/ray/gcs/redis_gcs_client.h +++ b/src/ray/gcs/redis_gcs_client.h @@ -107,7 +107,7 @@ class RAY_EXPORT RedisGcsClient : public GcsClient { /// Implements the Stats() interface. ProfileTable &profile_table(); /// Implements the Workers() interface. - WorkerFailureTable &worker_failure_table(); + WorkerTable &worker_table(); private: // GCS command type. If CommandType::kChain, chain-replicated versions of the tables @@ -130,7 +130,7 @@ class RAY_EXPORT RedisGcsClient : public GcsClient { std::unique_ptr actor_checkpoint_table_; std::unique_ptr actor_checkpoint_id_table_; std::unique_ptr resource_table_; - std::unique_ptr worker_failure_table_; + std::unique_ptr worker_table_; std::unique_ptr job_table_; }; diff --git a/src/ray/gcs/subscription_executor.cc b/src/ray/gcs/subscription_executor.cc index 85a294ede..2ba7f8094 100644 --- a/src/ray/gcs/subscription_executor.cc +++ b/src/ray/gcs/subscription_executor.cc @@ -210,7 +210,7 @@ template class SubscriptionExecutor; template class SubscriptionExecutor; -template class SubscriptionExecutor; +template class SubscriptionExecutor; } // namespace gcs diff --git a/src/ray/gcs/tables.cc b/src/ray/gcs/tables.cc index 2a8079e18..8d35b1911 100644 --- a/src/ray/gcs/tables.cc +++ b/src/ray/gcs/tables.cc @@ -872,10 +872,10 @@ template class Log; template class Log; template class Log; template class Log; -template class Log; +template class Log; template class Table; template class Table; -template class Table; +template class Table; template class Table; template class Log; diff --git a/src/ray/gcs/tables.h b/src/ray/gcs/tables.h index 8433e508a..90cb45fbf 100644 --- a/src/ray/gcs/tables.h +++ b/src/ray/gcs/tables.h @@ -52,7 +52,7 @@ using rpc::TablePubsub; using rpc::TaskLeaseData; using rpc::TaskReconstructionData; using rpc::TaskTableData; -using rpc::WorkerFailureData; +using rpc::WorkerTableData; class RedisContext; @@ -754,15 +754,15 @@ class ActorTable : public Table { Status Get(const ActorID &actor_id, ActorTableData *actor_table_data); }; -class WorkerFailureTable : public Table { +class WorkerTable : public Table { public: - WorkerFailureTable(const std::vector> &contexts, - RedisGcsClient *client) + WorkerTable(const std::vector> &contexts, + RedisGcsClient *client) : Table(contexts, client) { pubsub_channel_ = TablePubsub::WORKER_FAILURE_PUBSUB; - prefix_ = TablePrefix::WORKER_FAILURE; + prefix_ = TablePrefix::WORKERS; } - virtual ~WorkerFailureTable() {} + virtual ~WorkerTable() {} }; class TaskReconstructionLog : public Log { diff --git a/src/ray/gcs/test/gcs_test_util.h b/src/ray/gcs/test/gcs_test_util.h index 1c5ce09bc..006815415 100644 --- a/src/ray/gcs/test/gcs_test_util.h +++ b/src/ray/gcs/test/gcs_test_util.h @@ -127,10 +127,10 @@ struct Mocker { return error_table_data; } - static std::shared_ptr GenWorkerFailureData() { - auto worker_failure_data = std::make_shared(); - worker_failure_data->set_timestamp(std::time(nullptr)); - return worker_failure_data; + static std::shared_ptr GenWorkerTableData() { + auto worker_table_data = std::make_shared(); + worker_table_data->set_timestamp(std::time(nullptr)); + return worker_table_data; } }; diff --git a/src/ray/protobuf/gcs.proto b/src/ray/protobuf/gcs.proto index 467bf895d..bb092e980 100644 --- a/src/ray/protobuf/gcs.proto +++ b/src/ray/protobuf/gcs.proto @@ -41,7 +41,8 @@ enum TablePrefix { ACTOR_CHECKPOINT_ID = 16; NODE_RESOURCE = 17; DIRECT_ACTOR = 18; - WORKER_FAILURE = 19; + // WORKER is already used in WorkerType, so use WORKERS here. + WORKERS = 19; TABLE_PREFIX_MAX = 20; } @@ -296,13 +297,19 @@ message ActorCheckpointIdData { repeated uint64 timestamps = 3; } -message WorkerFailureData { +message WorkerTableData { + // Is this worker alive. + bool is_alive = 1; // Address of the worker that failed. - Address worker_address = 1; - // The UNIX timestamp at which the worker failed. + Address worker_address = 2; + // The UNIX timestamp at which this worker's state was updated. int64 timestamp = 3; - // Is intentional disconnect + // Whether it's an intentional disconnect, only applies then `is_alive` is false. bool intentional_disconnect = 4; + // Type of this worker. + WorkerType worker_type = 5; + // This is for AddWorker. + map worker_info = 6; } message ResourceMap { diff --git a/src/ray/protobuf/gcs_service.proto b/src/ray/protobuf/gcs_service.proto index 4ffe4ee65..ecbd3c661 100644 --- a/src/ray/protobuf/gcs_service.proto +++ b/src/ray/protobuf/gcs_service.proto @@ -411,23 +411,38 @@ service ErrorInfoGcsService { } message ReportWorkerFailureRequest { - WorkerFailureData worker_failure = 1; + WorkerTableData worker_failure = 1; } message ReportWorkerFailureReply { GcsStatus status = 1; } -message RegisterWorkerRequest { - /// The type of the worker. - WorkerType worker_type = 1; - /// The ID of the worker. - bytes worker_id = 2; - /// The information of the worker in a dictionary. - map worker_info = 3; +message GetWorkerInfoRequest { + // ID of this worker. + bytes worker_id = 1; } -message RegisterWorkerReply { +message GetWorkerInfoReply { + GcsStatus status = 1; + // Data of worker. + WorkerTableData worker_table_data = 2; +} + +message GetAllWorkerInfoRequest { +} + +message GetAllWorkerInfoReply { + GcsStatus status = 1; + // Data of worker + repeated WorkerTableData worker_table_data = 2; +} + +message AddWorkerInfoRequest { + WorkerTableData worker_data = 1; +} + +message AddWorkerInfoReply { GcsStatus status = 1; } @@ -435,8 +450,12 @@ message RegisterWorkerReply { service WorkerInfoGcsService { // Report a worker failure to GCS Service. rpc ReportWorkerFailure(ReportWorkerFailureRequest) returns (ReportWorkerFailureReply); - // Register a worker to GCS Service. - rpc RegisterWorker(RegisterWorkerRequest) returns (RegisterWorkerReply); + // Get worker information from GCS Service by worker id. + rpc GetWorkerInfo(GetWorkerInfoRequest) returns (GetWorkerInfoReply); + // Get information of all workers from GCS Service. + rpc GetAllWorkerInfo(GetAllWorkerInfoRequest) returns (GetAllWorkerInfoReply); + // Add worker information to GCS Service. + rpc AddWorkerInfo(AddWorkerInfoRequest) returns (AddWorkerInfoReply); } message CreateActorRequest { diff --git a/src/ray/raylet/node_manager.cc b/src/ray/raylet/node_manager.cc index cbbefc267..414fef441 100644 --- a/src/ray/raylet/node_manager.cc +++ b/src/ray/raylet/node_manager.cc @@ -254,7 +254,7 @@ ray::Status NodeManager::RegisterGcs() { // node failure. These workers can be identified by comparing the raylet_id // in their rpc::Address to the ID of a failed raylet. const auto &worker_failure_handler = - [this](const WorkerID &id, const gcs::WorkerFailureData &worker_failure_data) { + [this](const WorkerID &id, const gcs::WorkerTableData &worker_failure_data) { HandleUnexpectedWorkerFailure(worker_failure_data.worker_address()); }; RAY_CHECK_OK(gcs_client_->Workers().AsyncSubscribeToWorkerFailures( diff --git a/src/ray/raylet/node_manager.h b/src/ray/raylet/node_manager.h index b02e13a10..ffd4253ec 100644 --- a/src/ray/raylet/node_manager.h +++ b/src/ray/raylet/node_manager.h @@ -771,7 +771,7 @@ class NodeManager : public rpc::NodeManagerServiceHandler { void WaitForTaskArgsRequests(std::pair &work); // TODO(swang): Evict entries from these caches. - /// Cache for the WorkerFailureTable in the GCS. + /// Cache for the WorkerTable in the GCS. absl::flat_hash_set failed_workers_cache_; /// Cache for the ClientTable in the GCS. absl::flat_hash_set failed_nodes_cache_; diff --git a/src/ray/rpc/gcs_server/gcs_rpc_client.h b/src/ray/rpc/gcs_server/gcs_rpc_client.h index 2507d63dc..c5f3798b8 100644 --- a/src/ray/rpc/gcs_server/gcs_rpc_client.h +++ b/src/ray/rpc/gcs_server/gcs_rpc_client.h @@ -227,8 +227,16 @@ class GcsRpcClient { VOID_GCS_RPC_CLIENT_METHOD(WorkerInfoGcsService, ReportWorkerFailure, worker_info_grpc_client_, ) - /// Register a worker to GCS Service. - VOID_GCS_RPC_CLIENT_METHOD(WorkerInfoGcsService, RegisterWorker, + /// Get worker information from GCS Service. + VOID_GCS_RPC_CLIENT_METHOD(WorkerInfoGcsService, GetWorkerInfo, + worker_info_grpc_client_, ) + + /// Get information of all workers from GCS Service. + VOID_GCS_RPC_CLIENT_METHOD(WorkerInfoGcsService, GetAllWorkerInfo, + worker_info_grpc_client_, ) + + /// Add worker information to GCS Service. + VOID_GCS_RPC_CLIENT_METHOD(WorkerInfoGcsService, AddWorkerInfo, worker_info_grpc_client_, ) private: diff --git a/src/ray/rpc/gcs_server/gcs_rpc_server.h b/src/ray/rpc/gcs_server/gcs_rpc_server.h index 1807f1981..bb0a3eaa6 100644 --- a/src/ray/rpc/gcs_server/gcs_rpc_server.h +++ b/src/ray/rpc/gcs_server/gcs_rpc_server.h @@ -427,9 +427,17 @@ class WorkerInfoGcsServiceHandler { ReportWorkerFailureReply *reply, SendReplyCallback send_reply_callback) = 0; - virtual void HandleRegisterWorker(const RegisterWorkerRequest &request, - RegisterWorkerReply *reply, - SendReplyCallback send_reply_callback) = 0; + virtual void HandleGetWorkerInfo(const GetWorkerInfoRequest &request, + GetWorkerInfoReply *reply, + SendReplyCallback send_reply_callback) = 0; + + virtual void HandleGetAllWorkerInfo(const GetAllWorkerInfoRequest &request, + GetAllWorkerInfoReply *reply, + SendReplyCallback send_reply_callback) = 0; + + virtual void HandleAddWorkerInfo(const AddWorkerInfoRequest &request, + AddWorkerInfoReply *reply, + SendReplyCallback send_reply_callback) = 0; }; /// The `GrpcService` for `WorkerInfoGcsService`. @@ -449,7 +457,9 @@ class WorkerInfoGrpcService : public GrpcService { const std::unique_ptr &cq, std::vector> *server_call_factories) override { WORKER_INFO_SERVICE_RPC_HANDLER(ReportWorkerFailure); - WORKER_INFO_SERVICE_RPC_HANDLER(RegisterWorker); + WORKER_INFO_SERVICE_RPC_HANDLER(GetWorkerInfo); + WORKER_INFO_SERVICE_RPC_HANDLER(GetAllWorkerInfo); + WORKER_INFO_SERVICE_RPC_HANDLER(AddWorkerInfo); } private: