diff --git a/src/ray/protobuf/raylet.proto b/src/ray/protobuf/raylet.proto index 0cff32d1f..ee71ebcec 100644 --- a/src/ray/protobuf/raylet.proto +++ b/src/ray/protobuf/raylet.proto @@ -5,12 +5,15 @@ package ray.rpc; import "src/ray/protobuf/common.proto"; import "src/ray/protobuf/gcs.proto"; +/// NOTE(Joey Jiang) Every request defined in this file should have a `worker_id` field, +/// which will be used in `NodeManager::PreprocessRequest`. + /// Service request and reply messages. message RegisterClientRequest { - // Indicates the client is a worker or a driver. - bool is_worker = 1; // The worker id. - bytes worker_id = 2; + bytes worker_id = 1; + // Indicates the client is a worker or a driver. + bool is_worker = 2; // The process ID of this worker. uint32 worker_pid = 3; // The job ID. @@ -27,7 +30,8 @@ message RegisterClientReply { } message SubmitTaskRequest { - TaskSpec task_spec = 1; + bytes worker_id = 1; + TaskSpec task_spec = 2; } message SubmitTaskReply { } @@ -55,14 +59,14 @@ message TaskDoneReply { } message FetchOrReconstructRequest { - // List of object IDs of the objects that we want to reconstruct or fetch. - repeated bytes object_ids = 1; - // Indicates that we only want to fetch objects, not reconstruct them. - bool fetch_only = 2; - // The current task ID. If fetch_only is false, then this task is blocked. - bytes task_id = 3; // The worker ID. - bytes worker_id = 4; + bytes worker_id = 1; + // List of object IDs of the objects that we want to reconstruct or fetch. + repeated bytes object_ids = 2; + // Indicates that we only want to fetch objects, not reconstruct them. + bool fetch_only = 3; + // The current task ID. If fetch_only is false, then this task is blocked. + bytes task_id = 4; } message FetchOrReconstructReply { } @@ -76,19 +80,19 @@ message NotifyUnblockedReply { } message WaitRequest { + // The worker ID. + bytes worker_id = 1; // List of object ids we'll be waiting on. - repeated bytes object_ids = 1; + repeated bytes object_ids = 2; // Number of objects expected to be returned, if available. - uint64 num_ready_objects = 2; + uint64 num_ready_objects = 3; // Timeout in milliseconds. - int64 timeout = 3; + int64 timeout = 4; // Whether to wait until objects appear locally. - bool wait_local = 4; + bool wait_local = 5; // The current task ID. If there are less than num_ready_objects local, then // this task is blocked. - bytes task_id = 5; - // The worker ID. - bytes worker_id = 6; + bytes task_id = 6; } message WaitReply { // List of object ids found. @@ -98,60 +102,68 @@ message WaitReply { } message PushErrorRequest { + // The worker ID. + bytes worker_id = 1; // The job id that the error is for. - bytes job_id = 1; + bytes job_id = 2; // The type of the error. - bytes type = 2; + bytes type = 3; // The error message. - bytes error_message = 3; + bytes error_message = 4; // The timestamp of the error message. - double timestamp = 4; + double timestamp = 5; } message PushErrorReply { } message PushProfileEventsRequest { - ProfileTableData profile_table_data = 1; + bytes worker_id = 1; + ProfileTableData profile_table_data = 2; } message PushProfileEventsReply { } message FreeObjectsInStoreRequest { + // The worker ID. + bytes worker_id = 1; // Whether keep this request within the local object store // or send it to all of the object stores. - bool local_only = 1; + bool local_only = 2; // Whether also delete objects' creating tasks from GCS. - bool delete_creating_tasks = 2; + bool delete_creating_tasks = 3; // List of object ids to delete from the object store. - repeated bytes object_ids = 3; + repeated bytes object_ids = 4; } message FreeObjectsInStoreReply { } message PrepareActorCheckpointRequest { - bytes actor_id = 1; - bytes worker_id = 2; + bytes worker_id = 1; + bytes actor_id = 2; } message PrepareActorCheckpointReply { - bytes checkpoint_id = 1; + bytes worker_id = 1; + bytes checkpoint_id = 2; } message NotifyActorResumedFromCheckpointRequest { + bytes worker_id = 1; // ID of the actor that resumed. - bytes actor_id = 1; + bytes actor_id = 2; // ID of the checkpoint from which the actor was resumed. - bytes checkpoint_id = 2; + bytes checkpoint_id = 3; } message NotifyActorResumedFromCheckpointReply { } message SetResourceRequest { + bytes worker_id = 1; // Name of the resource to be set. - bytes resource_name = 1; + bytes resource_name = 2; // Capacity of the resource to be set. - double capacity = 2; + double capacity = 3; // Client ID where this resource will be set. - bytes client_id = 3; + bytes client_id = 4; } message SetResourceReply { } diff --git a/src/ray/raylet/node_manager.cc b/src/ray/raylet/node_manager.cc index 20efb4985..d7615621b 100644 --- a/src/ray/raylet/node_manager.cc +++ b/src/ray/raylet/node_manager.cc @@ -1,6 +1,7 @@ #include "ray/raylet/node_manager.h" #include +#include #include "ray/common/status.h" @@ -13,6 +14,20 @@ namespace { #define RAY_CHECK_ENUM(x, y) \ static_assert(static_cast(x) == static_cast(y), "protocol mismatch") +/// Macro to handle early return for preprocessing. +/// An early return will take place if the worker is being killed due to the exiting of +/// driver, or the worker is not registered yet. +#define PREPROCESS_WORKER_REQUEST(REQUEST_TYPE, REQUEST, SEND_REPLY) \ + do { \ + WorkerID worker_id = WorkerID::FromBinary(REQUEST.worker_id()); \ + if (!PreprocessRequest(worker_id, #REQUEST_TYPE)) { \ + SEND_REPLY( \ + Status::Invalid("Discard this request due to failure of preprocessing."), \ + nullptr, nullptr); \ + return; \ + } \ + } while (0) + /// A helper function to return the expected actor counter for a given actor /// and actor handle, according to the given actor registry. If a task's /// counter is less than the returned value, then the task is a duplicate. If @@ -75,7 +90,7 @@ NodeManager::NodeManager(boost::asio::io_service &io_service, gcs_client_(std::move(gcs_client)), object_directory_(std::move(object_directory)), heartbeat_timer_(io_service), - heartbeat_period_(std::chrono::milliseconds(config.heartbeat_period_ms)), + heartbeat_period_(config.heartbeat_period_ms), debug_dump_period_(config.debug_dump_period_ms), temp_dir_(config.temp_dir), object_manager_profile_timer_(io_service), @@ -352,6 +367,9 @@ void NodeManager::Heartbeat() { // Reset the timer. heartbeat_timer_.expires_from_now(heartbeat_period_); heartbeat_timer_.async_wait([this](const boost::system::error_code &error) { + if (error == boost::asio::error::operation_aborted) { + return; + } RAY_CHECK(!error); Heartbeat(); }); @@ -735,20 +753,51 @@ void NodeManager::DispatchTasks( local_queues_.MoveTasks(assigned_task_ids, TaskState::READY, TaskState::RUNNING); } +bool NodeManager::PreprocessRequest(const WorkerID &worker_id, + const std::string &request_name) { + std::ostringstream oss; + if (RAY_LOG_ENABLED(DEBUG)) { + oss << "Received a " << request_name << " request. Worker id " << worker_id << "."; + } + + auto worker = worker_pool_.GetWorker(worker_id); + // Worker process has been killed, we should discard this request. + if (!worker) { + RAY_LOG(WARNING) << "Worker " << worker_id << " is not found in worker pool, request " + << request_name << " will be discarded."; + return false; + } + if (RAY_LOG_ENABLED(DEBUG)) { + oss << " Is worker: " << (worker->IsWorker() ? "true" : "false") << ". Worker pid " + << std::to_string(worker->Pid()) << "."; + RAY_LOG(DEBUG) << oss.str(); + } + + // The worker process is being killing, we should discard this request. + if (worker->IsBeingKilled()) { + RAY_LOG(INFO) << "Worker " << worker_id << " is being killed, request " + << request_name << " will be discarded."; + return false; + } + + return true; +} + void NodeManager::HandleRegisterClientRequest( const rpc::RegisterClientRequest &request, rpc::RegisterClientReply *reply, rpc::SendReplyCallback send_reply_callback) { // Client id in register client is treated as worker id. const WorkerID worker_id = WorkerID::FromBinary(request.worker_id()); + bool is_worker = request.is_worker(); auto worker = std::make_shared(worker_id, request.worker_pid(), request.language(), - request.port(), client_call_manager_); + request.port(), client_call_manager_, is_worker); - RAY_LOG(DEBUG) << "Received a RegisterClientRequest, worker id: " << worker_id - << ", is worker: " << request.is_worker() - << ", pid: " << request.worker_pid(); + RAY_LOG(DEBUG) << "Received a RegisterClientRequest. Worker id: " << worker_id + << ". Is worker: " << is_worker << ". Worker pid " + << request.worker_pid(); - if (request.is_worker()) { + if (is_worker) { // Register the new worker. bool use_push_task = worker->UsePush(); worker_pool_.RegisterWorker(worker_id, std::move(worker)); @@ -818,16 +867,11 @@ void NodeManager::HandleDisconnectedActor(const ActorID &actor_id, bool was_loca void NodeManager::HandleGetTaskRequest(const rpc::GetTaskRequest &request, rpc::GetTaskReply *reply, rpc::SendReplyCallback send_reply_callback) { - const WorkerID worker_id = WorkerID::FromBinary(request.worker_id()); + PREPROCESS_WORKER_REQUEST(GetTaskRequest, request, send_reply_callback); + WorkerID worker_id = WorkerID::FromBinary(request.worker_id()); std::shared_ptr worker = worker_pool_.GetRegisteredWorker(worker_id); - RAY_LOG(DEBUG) << "Received a GetTaskRequest, worker id " << worker_id << " pid " - << worker->Pid(); - if (!worker || worker->IsBeingKilled()) { - send_reply_callback(Status::Invalid("WorkerBeingKilled"), nullptr, nullptr); - return; - } - RAY_CHECK(!worker->UsePush()); + RAY_CHECK(!worker->UsePush()); // Reply would be sent when assigned a task to the worker successfully. worker->SetGetTaskReplyAndCallback(reply, std::move(send_reply_callback)); HandleWorkerAvailable(worker_id); @@ -836,9 +880,8 @@ void NodeManager::HandleGetTaskRequest(const rpc::GetTaskRequest &request, void NodeManager::HandleTaskDoneRequest(const rpc::TaskDoneRequest &request, rpc::TaskDoneReply *reply, rpc::SendReplyCallback send_reply_callback) { - const WorkerID worker_id = WorkerID::FromBinary(request.worker_id()); - RAY_LOG(DEBUG) << "Received a TaskDoneRequest from worker " << worker_id; - + PREPROCESS_WORKER_REQUEST(TaskDoneRequest, request, send_reply_callback); + WorkerID worker_id = WorkerID::FromBinary(request.worker_id()); auto worker = worker_pool_.GetRegisteredWorker(worker_id); RAY_CHECK(worker && worker->UsePush()); HandleWorkerAvailable(worker_id); @@ -849,50 +892,36 @@ void NodeManager::HandleDisconnectClientRequest( const rpc::DisconnectClientRequest &request, rpc::DisconnectClientReply *reply, rpc::SendReplyCallback send_reply_callback) { const WorkerID worker_id = WorkerID::FromBinary(request.worker_id()); - RAY_LOG(DEBUG) << "Received a DisconnectClientRequest from worker " << worker_id; - ProcessDisconnectClientMessage(worker_id, true); send_reply_callback(Status::OK(), nullptr, nullptr); } void NodeManager::ProcessDisconnectClientMessage(const WorkerID &worker_id, bool intentional_disconnect) { - std::shared_ptr worker = worker_pool_.GetRegisteredWorker(worker_id); - bool is_worker = false, is_driver = false; - if (worker) { - // The client is a worker. - is_worker = true; - } else { - worker = worker_pool_.GetRegisteredDriver(worker_id); - if (worker) { - // The client is a driver. - is_driver = true; - } else { - RAY_LOG(INFO) << "Ignoring client disconnect because the client has already " - << "been disconnected."; - return; - } + auto worker = worker_pool_.GetWorker(worker_id); + if (!worker) { + RAY_LOG(INFO) << "Ignoring client disconnect because the client has already " + << "been disconnected."; + return; } - RAY_CHECK(!(is_worker && is_driver)); + bool is_worker = worker->IsWorker(); // If the client has any blocked tasks, mark them as unblocked. In // particular, we are no longer waiting for their dependencies. - if (worker) { - if (is_worker && worker->IsBeingKilled()) { - // Don't need to unblock the client if it's a worker and have sent kill signal to - // it. Because in this case, its task is already cleaned up. - RAY_LOG(DEBUG) << "Skip unblocking worker because it's already dead."; - } else { - // Clean up any open ray.get calls that the worker made. - while (!worker->GetBlockedTaskIds().empty()) { - // NOTE(swang): HandleTaskUnblocked will modify the worker, so it is - // not safe to pass in the iterator directly. - const TaskID task_id = *worker->GetBlockedTaskIds().begin(); - HandleTaskUnblocked(worker_id, task_id); - } - // Clean up any open ray.wait calls that the worker made. - task_dependency_manager_.UnsubscribeWaitDependencies(worker->WorkerId()); + if (is_worker && worker->IsBeingKilled()) { + // Don't need to unblock the client if it's a worker and have sent kill signal to + // it. Because in this case, its task is already cleaned up. + RAY_LOG(DEBUG) << "Skip unblocking worker because it's already dead."; + } else { + // Clean up any open ray.get calls that the worker made. + while (!worker->GetBlockedTaskIds().empty()) { + // NOTE(swang): HandleTaskUnblocked will modify the worker, so it is + // not safe to pass in the iterator directly. + const TaskID task_id = *worker->GetBlockedTaskIds().begin(); + HandleTaskUnblocked(worker_id, task_id); } + // Clean up any open ray.wait calls that the worker made. + task_dependency_manager_.UnsubscribeWaitDependencies(worker->WorkerId()); } if (is_worker) { @@ -952,7 +981,7 @@ void NodeManager::ProcessDisconnectClientMessage(const WorkerID &worker_id, // Since some resources may have been released, we can try to dispatch more tasks. DispatchTasks(local_queues_.GetReadyTasksWithResources()); - } else if (is_driver) { + } else { // The client is a driver. const auto job_id = worker->GetAssignedJobId(); const auto driver_id = ComputeDriverIdFromJob(job_id); @@ -975,8 +1004,7 @@ void NodeManager::ProcessDisconnectClientMessage(const WorkerID &worker_id, void NodeManager::HandleSubmitTaskRequest(const rpc::SubmitTaskRequest &request, rpc::SubmitTaskReply *reply, rpc::SendReplyCallback send_reply_callback) { - RAY_LOG(DEBUG) << "Received a SubmitTaskRequest."; - + PREPROCESS_WORKER_REQUEST(SubmitTaskRequest, request, send_reply_callback); rpc::Task task; task.mutable_task_spec()->CopyFrom(request.task_spec()); @@ -989,7 +1017,8 @@ void NodeManager::HandleSubmitTaskRequest(const rpc::SubmitTaskRequest &request, void NodeManager::HandleFetchOrReconstructRequest( const rpc::FetchOrReconstructRequest &request, rpc::FetchOrReconstructReply *reply, rpc::SendReplyCallback send_reply_callback) { - RAY_LOG(DEBUG) << "Received a FetchOrReconstructRequest."; + PREPROCESS_WORKER_REQUEST(FetchOrReconstructRequest, request, send_reply_callback); + WorkerID worker_id = WorkerID::FromBinary(request.worker_id()); const auto &object_ids = request.object_ids(); std::vector required_object_ids; for (size_t i = 0; i < object_ids.size(); ++i) { @@ -1012,7 +1041,6 @@ void NodeManager::HandleFetchOrReconstructRequest( if (!required_object_ids.empty()) { const TaskID task_id = TaskID::FromBinary(request.task_id()); - const WorkerID &worker_id = WorkerID::FromBinary(request.worker_id()); HandleTaskBlocked(worker_id, required_object_ids, task_id, /*ray_get=*/true); } send_reply_callback(Status::OK(), nullptr, nullptr); @@ -1021,7 +1049,8 @@ void NodeManager::HandleFetchOrReconstructRequest( void NodeManager::HandleWaitRequest(const rpc::WaitRequest &request, rpc::WaitReply *reply, rpc::SendReplyCallback send_reply_callback) { - RAY_LOG(DEBUG) << "Received a WaitRequest."; + PREPROCESS_WORKER_REQUEST(WaitRequest, request, send_reply_callback); + WorkerID worker_id = WorkerID::FromBinary(request.worker_id()); // Read the data. std::vector object_ids = IdVectorFromProtobuf(request.object_ids()); int64_t wait_ms = request.timeout(); @@ -1039,7 +1068,6 @@ void NodeManager::HandleWaitRequest(const rpc::WaitRequest &request, } const TaskID ¤t_task_id = TaskID::FromBinary(request.task_id()); - const WorkerID &worker_id = WorkerID::FromBinary(request.worker_id()); bool client_blocked = !required_object_ids.empty(); if (client_blocked) { HandleTaskBlocked(worker_id, required_object_ids, current_task_id, /*ray_get=*/false); @@ -1066,6 +1094,7 @@ void NodeManager::HandleWaitRequest(const rpc::WaitRequest &request, void NodeManager::HandlePushErrorRequest(const rpc::PushErrorRequest &request, rpc::PushErrorReply *reply, rpc::SendReplyCallback send_reply_callback) { + PREPROCESS_WORKER_REQUEST(PushErrorRequest, request, send_reply_callback); JobID job_id = JobID::FromBinary(request.job_id()); const auto &type = request.type(); const auto &error_message = request.error_message(); @@ -1081,12 +1110,13 @@ void NodeManager::HandlePushErrorRequest(const rpc::PushErrorRequest &request, void NodeManager::HandlePrepareActorCheckpointRequest( const rpc::PrepareActorCheckpointRequest &request, rpc::PrepareActorCheckpointReply *reply, rpc::SendReplyCallback send_reply_callback) { + PREPROCESS_WORKER_REQUEST(PrepareActorCheckpointRequest, request, send_reply_callback); + WorkerID worker_id = WorkerID::FromBinary(request.worker_id()); ActorID actor_id = ActorID::FromBinary(request.actor_id()); RAY_LOG(DEBUG) << "Preparing checkpoint for actor " << actor_id; const auto &actor_entry = actor_registry_.find(actor_id); RAY_CHECK(actor_entry != actor_registry_.end()); - WorkerID worker_id = WorkerID::FromBinary(request.worker_id()); std::shared_ptr worker = worker_pool_.GetRegisteredWorker(worker_id); RAY_CHECK(worker && worker->GetActorId() == actor_id); @@ -1122,6 +1152,8 @@ void NodeManager::HandleNotifyActorResumedFromCheckpointRequest( const rpc::NotifyActorResumedFromCheckpointRequest &request, rpc::NotifyActorResumedFromCheckpointReply *reply, rpc::SendReplyCallback send_reply_callback) { + PREPROCESS_WORKER_REQUEST(NotifyActorResumedFromCheckpointRequest, request, + send_reply_callback); ActorID actor_id = ActorID::FromBinary(request.actor_id()); ActorCheckpointID checkpoint_id = ActorCheckpointID::FromBinary(request.checkpoint_id()); @@ -1152,6 +1184,7 @@ void NodeManager::HandleForwardTask(const rpc::ForwardTaskRequest &request, void NodeManager::HandleSetResourceRequest(const rpc::SetResourceRequest &request, rpc::SetResourceReply *reply, rpc::SendReplyCallback send_reply_callback) { + PREPROCESS_WORKER_REQUEST(SetResourceRequest, request, send_reply_callback); auto const &resource_name = request.resource_name(); double const capacity = request.capacity(); bool is_deletion = capacity <= 0; @@ -1192,9 +1225,9 @@ void NodeManager::HandleSetResourceRequest(const rpc::SetResourceRequest &reques void NodeManager::HandleNotifyUnblockedRequest( const rpc::NotifyUnblockedRequest &request, rpc::NotifyUnblockedReply *reply, rpc::SendReplyCallback send_reply_callback) { - RAY_LOG(DEBUG) << "Received a NotifyUnblockedRequest."; + PREPROCESS_WORKER_REQUEST(NotifyUnblockedRequest, request, send_reply_callback); + WorkerID worker_id = WorkerID::FromBinary(request.worker_id()); const TaskID current_task_id = TaskID::FromBinary(request.task_id()); - const WorkerID worker_id = WorkerID::FromBinary(request.worker_id()); HandleTaskUnblocked(worker_id, current_task_id); send_reply_callback(Status::OK(), nullptr, nullptr); @@ -1203,7 +1236,7 @@ void NodeManager::HandleNotifyUnblockedRequest( void NodeManager::HandlePushProfileEventsRequest( const rpc::PushProfileEventsRequest &request, rpc::PushProfileEventsReply *reply, rpc::SendReplyCallback send_reply_callback) { - RAY_LOG(DEBUG) << "Received a PushProfileEventsRequest."; + PREPROCESS_WORKER_REQUEST(PushProfileEventsRequest, request, send_reply_callback); const auto &profile_table_data = request.profile_table_data(); RAY_CHECK_OK(gcs_client_->profile_table().AddProfileEventBatch(profile_table_data)); send_reply_callback(Status::OK(), nullptr, nullptr); @@ -1212,7 +1245,7 @@ void NodeManager::HandlePushProfileEventsRequest( void NodeManager::HandleFreeObjectsInStoreRequest( const rpc::FreeObjectsInStoreRequest &request, rpc::FreeObjectsInStoreReply *reply, rpc::SendReplyCallback send_reply_callback) { - RAY_LOG(DEBUG) << "Received a FreeObjectsInStoreRequest."; + PREPROCESS_WORKER_REQUEST(FreeObjectsInStoreRequest, request, send_reply_callback); std::vector object_ids = IdVectorFromProtobuf(request.object_ids()); object_manager_.FreeObjects(object_ids, request.local_only()); if (request.delete_creating_tasks()) { diff --git a/src/ray/raylet/node_manager.h b/src/ray/raylet/node_manager.h index 224a701ba..a6093eadd 100644 --- a/src/ray/raylet/node_manager.h +++ b/src/ray/raylet/node_manager.h @@ -95,6 +95,14 @@ class NodeManager : public rpc::NodeManagerServiceHandler, /// Get the port of the node manager rpc server. int GetServerPort() const { return node_manager_server_.GetPort(); } + /// Preprocess request from raylet client. We will check whether the worker is being + /// killed due to job finishing. + /// + /// \param worker_id The worker id. + /// \param request_name The request name. + /// \return False if there is no need to process this request. + bool PreprocessRequest(const WorkerID &worker_id, const std::string &request_name); + /// Implementation of node manager grpc service. /// Handle a `ForwardTask` request. diff --git a/src/ray/raylet/worker.cc b/src/ray/raylet/worker.cc index cf92b21dd..d1d215db6 100644 --- a/src/ray/raylet/worker.cc +++ b/src/ray/raylet/worker.cc @@ -10,7 +10,7 @@ namespace raylet { /// A constructor responsible for initializing the state of a worker. Worker::Worker(const WorkerID &worker_id, pid_t pid, const Language &language, int port, - rpc::ClientCallManager &client_call_manager) + rpc::ClientCallManager &client_call_manager, bool is_worker) : worker_id_(worker_id), pid_(pid), port_(port), @@ -18,7 +18,8 @@ Worker::Worker(const WorkerID &worker_id, pid_t pid, const Language &language, i blocked_(false), num_missed_heartbeats_(0), is_being_killed_(false), - client_call_manager_(client_call_manager) { + client_call_manager_(client_call_manager), + is_worker_(is_worker) { if (port_ > 0) { rpc_client_ = std::unique_ptr( new rpc::WorkerTaskClient("127.0.0.1", port_, client_call_manager_)); @@ -29,6 +30,8 @@ void Worker::MarkAsBeingKilled() { is_being_killed_ = true; } bool Worker::IsBeingKilled() const { return is_being_killed_; } +bool Worker::IsWorker() const { return is_worker_; } + void Worker::MarkBlocked() { blocked_ = true; } void Worker::MarkUnblocked() { blocked_ = false; } diff --git a/src/ray/raylet/worker.h b/src/ray/raylet/worker.h index d46209c9f..3d77da15b 100644 --- a/src/ray/raylet/worker.h +++ b/src/ray/raylet/worker.h @@ -24,11 +24,12 @@ class Worker { public: /// A constructor that initializes a worker object. Worker(const WorkerID &worker_id, pid_t pid, const Language &language, int port, - rpc::ClientCallManager &client_call_manager); + rpc::ClientCallManager &client_call_manager, bool is_worker = true); /// A destructor responsible for freeing all worker state. ~Worker() {} void MarkAsBeingKilled(); bool IsBeingKilled() const; + bool IsWorker() const; void MarkBlocked(); void MarkUnblocked(); bool IsBlocked() const; @@ -105,6 +106,8 @@ class Worker { /// The `ClientCallManager` object that is shared by `WorkerTaskClient` from all /// workers. rpc::ClientCallManager &client_call_manager_; + /// Indicates whether this is a worker or a driver. + bool is_worker_; /// The rpc client to send tasks to this worker. std::unique_ptr rpc_client_; /// Reply of the `GetTask` request. diff --git a/src/ray/raylet/worker_pool.cc b/src/ray/raylet/worker_pool.cc index 6b5422a3e..3d6e4cec8 100644 --- a/src/ray/raylet/worker_pool.cc +++ b/src/ray/raylet/worker_pool.cc @@ -416,6 +416,17 @@ void WorkerPool::TickHeartbeatTimer(int max_missed_heartbeats, } } +std::shared_ptr WorkerPool::GetWorker(const WorkerID &worker_id) { + auto worker = GetRegisteredWorker(worker_id); + if (!worker) { + worker = GetRegisteredDriver(worker_id); + if (!worker) { + return nullptr; + } + } + return worker; +} + } // namespace raylet } // namespace ray diff --git a/src/ray/raylet/worker_pool.h b/src/ray/raylet/worker_pool.h index fc42f0545..0d5d094a9 100644 --- a/src/ray/raylet/worker_pool.h +++ b/src/ray/raylet/worker_pool.h @@ -138,6 +138,11 @@ class WorkerPool { void TickHeartbeatTimer(int max_missed_heartbeats, std::vector> *dead_workers); + /// Return the pointer to the worker according to the worker id. + /// + /// \param worker_id The worker id. + std::shared_ptr GetWorker(const WorkerID &worker_id); + protected: /// Asynchronously start a new worker process. Once the worker process has /// registered with an external server, the process should create and diff --git a/src/ray/rpc/raylet/raylet_client.cc b/src/ray/rpc/raylet/raylet_client.cc index 83272fae1..7f43bf9fa 100644 --- a/src/ray/rpc/raylet/raylet_client.cc +++ b/src/ray/rpc/raylet/raylet_client.cc @@ -48,7 +48,8 @@ void RayletClient::TryRegisterClient(int retry_times) { } std::this_thread::sleep_for(std::chrono::milliseconds(500)); } - RAY_LOG(FATAL) << "Failed to register to raylet server, worker id: " << worker_id_ + RAY_LOG(FATAL) << "Worker " << worker_id_ + << " failed to register to raylet server, worker id: " << worker_id_ << ", pid: " << static_cast(getpid()) << ", is worker: " << is_worker_; } @@ -67,7 +68,9 @@ ray::Status RayletClient::Disconnect() { grpc::ClientContext context; auto status = stub_->DisconnectClient(&context, disconnect_client_request, &reply); if (!status.ok()) { - RAY_LOG(ERROR) << "Failed to disconnect from raylet, msg: " << status.error_message(); + RAY_LOG(ERROR) << "Worker " << worker_id_ + << " failed to disconnect from raylet, msg: " + << status.error_message(); } return GrpcStatusToRayStatus(status); } @@ -75,12 +78,14 @@ ray::Status RayletClient::Disconnect() { ray::Status RayletClient::SubmitTask(const ray::TaskSpecification &task_spec) { RETURN_IF_DISCONNECTED(is_connected_); SubmitTaskRequest submit_task_request; + submit_task_request.set_worker_id(worker_id_.Binary()); submit_task_request.mutable_task_spec()->CopyFrom(task_spec.GetMessage()); auto callback = [this](const Status &status, const SubmitTaskReply &reply) { if (!status.ok() && is_connected_) { is_connected_ = false; - RAY_LOG(INFO) << "Failed to send SubmitTaskRequest, msg: " << status.message(); + RAY_LOG(INFO) << "Worker " << worker_id_ + << " failed to send SubmitTaskRequest, msg: " << status.message(); } }; @@ -128,7 +133,8 @@ ray::Status RayletClient::GetTask(std::unique_ptr *task_ task_spec->reset(new ray::TaskSpecification(reply.task_spec())); } else { *task_spec = nullptr; - RAY_LOG(INFO) << "Failed to get task, msg: " << status.error_message(); + RAY_LOG(INFO) << "Worker " << worker_id_ + << " failed to get task, msg: " << status.error_message(); } return GrpcStatusToRayStatus(status); } @@ -141,7 +147,8 @@ ray::Status RayletClient::TaskDone() { auto callback = [this](const Status &status, const TaskDoneReply &reply) { if (!status.ok() && is_connected_) { is_connected_ = false; - RAY_LOG(INFO) << "Failed to send TaskDoneRequest, msg: " << status.message(); + RAY_LOG(INFO) << "Worker " << worker_id_ + << " failed to send TaskDoneRequest, msg: " << status.message(); } }; @@ -168,7 +175,8 @@ ray::Status RayletClient::FetchOrReconstruct(const std::vector &object auto callback = [this](const Status &status, const FetchOrReconstructReply &reply) { if (!status.ok() && is_connected_) { is_connected_ = false; - RAY_LOG(INFO) << "Failed to send FetchOrReconstructRequest, msg: " + RAY_LOG(INFO) << "Worker " << worker_id_ + << " failed to send FetchOrReconstructRequest, msg: " << status.message(); } }; @@ -190,7 +198,9 @@ ray::Status RayletClient::NotifyUnblocked(const TaskID ¤t_task_id) { auto callback = [this](const Status &status, const NotifyUnblockedReply &reply) { if (!status.ok() && is_connected_) { is_connected_ = false; - RAY_LOG(INFO) << "Failed to send NotifyUnblockedRequest, msg: " << status.message(); + RAY_LOG(INFO) << "Worker " << worker_id_ + << " failed to send NotifyUnblockedRequest, msg: " + << status.message(); } }; @@ -223,7 +233,8 @@ ray::Status RayletClient::Wait(const std::vector &object_ids, int num_ result->first = IdVectorFromProtobuf(reply.found()); result->second = IdVectorFromProtobuf(reply.remaining()); } else { - RAY_LOG(INFO) << "Failed to send WaitRequest, msg: " << status.error_message(); + RAY_LOG(INFO) << "Worker " << worker_id_ + << " failed to send WaitRequest, msg: " << status.error_message(); } return GrpcStatusToRayStatus(status); @@ -237,11 +248,13 @@ ray::Status RayletClient::PushError(const ray::JobID &job_id, const std::string push_error_request.set_type(type); push_error_request.set_error_message(error_message); push_error_request.set_timestamp(timestamp); + push_error_request.set_worker_id(worker_id_.Binary()); auto callback = [this](const Status &status, const PushErrorReply &reply) { if (!status.ok() && is_connected_) { is_connected_ = false; - RAY_LOG(INFO) << "Failed to send PushErrorRequest, msg: " << status.message(); + RAY_LOG(INFO) << "Worker " << worker_id_ + << " failed to send PushErrorRequest, msg: " << status.message(); } }; @@ -256,11 +269,13 @@ ray::Status RayletClient::PushProfileEvents(const ProfileTableData &profile_even RETURN_IF_DISCONNECTED(is_connected_); PushProfileEventsRequest push_profile_events_request; push_profile_events_request.mutable_profile_table_data()->CopyFrom(profile_events); + push_profile_events_request.set_worker_id(worker_id_.Binary()); auto callback = [this](const Status &status, const PushProfileEventsReply &reply) { if (!status.ok() && is_connected_) { is_connected_ = false; - RAY_LOG(INFO) << "Failed to send PushProfileEventsRequest, msg: " + RAY_LOG(INFO) << "Worker " << worker_id_ + << " failed to send PushProfileEventsRequest, msg: " << status.message(); } }; @@ -279,13 +294,15 @@ ray::Status RayletClient::FreeObjects(const std::vector &object_i FreeObjectsInStoreRequest free_objects_request; free_objects_request.set_local_only(local_only); free_objects_request.set_delete_creating_tasks(delete_creating_tasks); + free_objects_request.set_worker_id(worker_id_.Binary()); IdVectorToProtobuf( object_ids, free_objects_request, &FreeObjectsInStoreRequest::add_object_ids); auto callback = [this](const Status &status, const FreeObjectsInStoreReply &reply) { if (!status.ok() && is_connected_) { is_connected_ = false; - RAY_LOG(INFO) << "Failed to send FreeObjectsInStoreRequest, msg: " + RAY_LOG(INFO) << "Worker " << worker_id_ + << " failed to send FreeObjectsInStoreRequest, msg: " << status.message(); } }; @@ -313,7 +330,8 @@ ray::Status RayletClient::PrepareActorCheckpoint(const ActorID &actor_id, if (status.ok()) { checkpoint_id = ActorCheckpointID::FromBinary(reply.checkpoint_id()); } else { - RAY_LOG(INFO) << "Failed to send PrepareActorCheckpointRequest, msg: " + RAY_LOG(INFO) << "Worker " << worker_id_ + << " failed to send PrepareActorCheckpointRequest, msg: " << status.error_message(); } @@ -326,6 +344,7 @@ ray::Status RayletClient::NotifyActorResumedFromCheckpoint( NotifyActorResumedFromCheckpointRequest notify_actor_resumed_from_checkpoint_request; notify_actor_resumed_from_checkpoint_request.set_actor_id(actor_id.Binary()); notify_actor_resumed_from_checkpoint_request.set_checkpoint_id(checkpoint_id.Binary()); + notify_actor_resumed_from_checkpoint_request.set_worker_id(worker_id_.Binary()); auto callback = [this](const Status &status, const NotifyActorResumedFromCheckpointReply &reply) { @@ -353,6 +372,7 @@ ray::Status RayletClient::SetResource(const std::string &resource_name, set_resource_request.set_resource_name(resource_name); set_resource_request.set_capacity(capacity); set_resource_request.set_client_id(client_id.Binary()); + set_resource_request.set_worker_id(worker_id_.Binary()); auto callback = [this](const Status &status, const SetResourceReply &reply) { if (!status.ok() && is_connected_) { @@ -382,7 +402,8 @@ ray::Status RayletClient::RegisterClient() { auto status = stub_->RegisterClient(&context, register_client_request, &reply); if (!status.ok()) { - RAY_LOG(DEBUG) << "Failed to register client, msg: " << status.error_message(); + RAY_LOG(DEBUG) << "Worker " << worker_id_ + << " failed to register client, msg: " << status.error_message(); } return GrpcStatusToRayStatus(status);