Add common preprocessing for each request in node manager. (#5296)

This commit is contained in:
Joey Jiang
2019-08-06 20:48:58 +08:00
committed by Hao Chen
parent 0a3ff489fa
commit 02c5d2be20
8 changed files with 208 additions and 112 deletions
+46 -34
View File
@@ -5,12 +5,15 @@ package ray.rpc;
import "src/ray/protobuf/common.proto";
import "src/ray/protobuf/gcs.proto";
/// NOTE(Joey Jiang) Every request defined in this file should have a `worker_id` field,
/// which will be used in `NodeManager::PreprocessRequest`.
/// Service request and reply messages.
message RegisterClientRequest {
// Indicates the client is a worker or a driver.
bool is_worker = 1;
// The worker id.
bytes worker_id = 2;
bytes worker_id = 1;
// Indicates the client is a worker or a driver.
bool is_worker = 2;
// The process ID of this worker.
uint32 worker_pid = 3;
// The job ID.
@@ -27,7 +30,8 @@ message RegisterClientReply {
}
message SubmitTaskRequest {
TaskSpec task_spec = 1;
bytes worker_id = 1;
TaskSpec task_spec = 2;
}
message SubmitTaskReply {
}
@@ -55,14 +59,14 @@ message TaskDoneReply {
}
message FetchOrReconstructRequest {
// List of object IDs of the objects that we want to reconstruct or fetch.
repeated bytes object_ids = 1;
// Indicates that we only want to fetch objects, not reconstruct them.
bool fetch_only = 2;
// The current task ID. If fetch_only is false, then this task is blocked.
bytes task_id = 3;
// The worker ID.
bytes worker_id = 4;
bytes worker_id = 1;
// List of object IDs of the objects that we want to reconstruct or fetch.
repeated bytes object_ids = 2;
// Indicates that we only want to fetch objects, not reconstruct them.
bool fetch_only = 3;
// The current task ID. If fetch_only is false, then this task is blocked.
bytes task_id = 4;
}
message FetchOrReconstructReply {
}
@@ -76,19 +80,19 @@ message NotifyUnblockedReply {
}
message WaitRequest {
// The worker ID.
bytes worker_id = 1;
// List of object ids we'll be waiting on.
repeated bytes object_ids = 1;
repeated bytes object_ids = 2;
// Number of objects expected to be returned, if available.
uint64 num_ready_objects = 2;
uint64 num_ready_objects = 3;
// Timeout in milliseconds.
int64 timeout = 3;
int64 timeout = 4;
// Whether to wait until objects appear locally.
bool wait_local = 4;
bool wait_local = 5;
// The current task ID. If there are less than num_ready_objects local, then
// this task is blocked.
bytes task_id = 5;
// The worker ID.
bytes worker_id = 6;
bytes task_id = 6;
}
message WaitReply {
// List of object ids found.
@@ -98,60 +102,68 @@ message WaitReply {
}
message PushErrorRequest {
// The worker ID.
bytes worker_id = 1;
// The job id that the error is for.
bytes job_id = 1;
bytes job_id = 2;
// The type of the error.
bytes type = 2;
bytes type = 3;
// The error message.
bytes error_message = 3;
bytes error_message = 4;
// The timestamp of the error message.
double timestamp = 4;
double timestamp = 5;
}
message PushErrorReply {
}
message PushProfileEventsRequest {
ProfileTableData profile_table_data = 1;
bytes worker_id = 1;
ProfileTableData profile_table_data = 2;
}
message PushProfileEventsReply {
}
message FreeObjectsInStoreRequest {
// The worker ID.
bytes worker_id = 1;
// Whether keep this request within the local object store
// or send it to all of the object stores.
bool local_only = 1;
bool local_only = 2;
// Whether also delete objects' creating tasks from GCS.
bool delete_creating_tasks = 2;
bool delete_creating_tasks = 3;
// List of object ids to delete from the object store.
repeated bytes object_ids = 3;
repeated bytes object_ids = 4;
}
message FreeObjectsInStoreReply {
}
message PrepareActorCheckpointRequest {
bytes actor_id = 1;
bytes worker_id = 2;
bytes worker_id = 1;
bytes actor_id = 2;
}
message PrepareActorCheckpointReply {
bytes checkpoint_id = 1;
bytes worker_id = 1;
bytes checkpoint_id = 2;
}
message NotifyActorResumedFromCheckpointRequest {
bytes worker_id = 1;
// ID of the actor that resumed.
bytes actor_id = 1;
bytes actor_id = 2;
// ID of the checkpoint from which the actor was resumed.
bytes checkpoint_id = 2;
bytes checkpoint_id = 3;
}
message NotifyActorResumedFromCheckpointReply {
}
message SetResourceRequest {
bytes worker_id = 1;
// Name of the resource to be set.
bytes resource_name = 1;
bytes resource_name = 2;
// Capacity of the resource to be set.
double capacity = 2;
double capacity = 3;
// Client ID where this resource will be set.
bytes client_id = 3;
bytes client_id = 4;
}
message SetResourceReply {
}
+95 -62
View File
@@ -1,6 +1,7 @@
#include "ray/raylet/node_manager.h"
#include <fstream>
#include <sstream>
#include "ray/common/status.h"
@@ -13,6 +14,20 @@ namespace {
#define RAY_CHECK_ENUM(x, y) \
static_assert(static_cast<int>(x) == static_cast<int>(y), "protocol mismatch")
/// Macro to handle early return for preprocessing.
/// An early return will take place if the worker is being killed due to the exiting of
/// driver, or the worker is not registered yet.
#define PREPROCESS_WORKER_REQUEST(REQUEST_TYPE, REQUEST, SEND_REPLY) \
do { \
WorkerID worker_id = WorkerID::FromBinary(REQUEST.worker_id()); \
if (!PreprocessRequest(worker_id, #REQUEST_TYPE)) { \
SEND_REPLY( \
Status::Invalid("Discard this request due to failure of preprocessing."), \
nullptr, nullptr); \
return; \
} \
} while (0)
/// A helper function to return the expected actor counter for a given actor
/// and actor handle, according to the given actor registry. If a task's
/// counter is less than the returned value, then the task is a duplicate. If
@@ -75,7 +90,7 @@ NodeManager::NodeManager(boost::asio::io_service &io_service,
gcs_client_(std::move(gcs_client)),
object_directory_(std::move(object_directory)),
heartbeat_timer_(io_service),
heartbeat_period_(std::chrono::milliseconds(config.heartbeat_period_ms)),
heartbeat_period_(config.heartbeat_period_ms),
debug_dump_period_(config.debug_dump_period_ms),
temp_dir_(config.temp_dir),
object_manager_profile_timer_(io_service),
@@ -352,6 +367,9 @@ void NodeManager::Heartbeat() {
// Reset the timer.
heartbeat_timer_.expires_from_now(heartbeat_period_);
heartbeat_timer_.async_wait([this](const boost::system::error_code &error) {
if (error == boost::asio::error::operation_aborted) {
return;
}
RAY_CHECK(!error);
Heartbeat();
});
@@ -735,20 +753,51 @@ void NodeManager::DispatchTasks(
local_queues_.MoveTasks(assigned_task_ids, TaskState::READY, TaskState::RUNNING);
}
bool NodeManager::PreprocessRequest(const WorkerID &worker_id,
const std::string &request_name) {
std::ostringstream oss;
if (RAY_LOG_ENABLED(DEBUG)) {
oss << "Received a " << request_name << " request. Worker id " << worker_id << ".";
}
auto worker = worker_pool_.GetWorker(worker_id);
// Worker process has been killed, we should discard this request.
if (!worker) {
RAY_LOG(WARNING) << "Worker " << worker_id << " is not found in worker pool, request "
<< request_name << " will be discarded.";
return false;
}
if (RAY_LOG_ENABLED(DEBUG)) {
oss << " Is worker: " << (worker->IsWorker() ? "true" : "false") << ". Worker pid "
<< std::to_string(worker->Pid()) << ".";
RAY_LOG(DEBUG) << oss.str();
}
// The worker process is being killing, we should discard this request.
if (worker->IsBeingKilled()) {
RAY_LOG(INFO) << "Worker " << worker_id << " is being killed, request "
<< request_name << " will be discarded.";
return false;
}
return true;
}
void NodeManager::HandleRegisterClientRequest(
const rpc::RegisterClientRequest &request, rpc::RegisterClientReply *reply,
rpc::SendReplyCallback send_reply_callback) {
// Client id in register client is treated as worker id.
const WorkerID worker_id = WorkerID::FromBinary(request.worker_id());
bool is_worker = request.is_worker();
auto worker =
std::make_shared<Worker>(worker_id, request.worker_pid(), request.language(),
request.port(), client_call_manager_);
request.port(), client_call_manager_, is_worker);
RAY_LOG(DEBUG) << "Received a RegisterClientRequest, worker id: " << worker_id
<< ", is worker: " << request.is_worker()
<< ", pid: " << request.worker_pid();
RAY_LOG(DEBUG) << "Received a RegisterClientRequest. Worker id: " << worker_id
<< ". Is worker: " << is_worker << ". Worker pid "
<< request.worker_pid();
if (request.is_worker()) {
if (is_worker) {
// Register the new worker.
bool use_push_task = worker->UsePush();
worker_pool_.RegisterWorker(worker_id, std::move(worker));
@@ -818,16 +867,11 @@ void NodeManager::HandleDisconnectedActor(const ActorID &actor_id, bool was_loca
void NodeManager::HandleGetTaskRequest(const rpc::GetTaskRequest &request,
rpc::GetTaskReply *reply,
rpc::SendReplyCallback send_reply_callback) {
const WorkerID worker_id = WorkerID::FromBinary(request.worker_id());
PREPROCESS_WORKER_REQUEST(GetTaskRequest, request, send_reply_callback);
WorkerID worker_id = WorkerID::FromBinary(request.worker_id());
std::shared_ptr<Worker> worker = worker_pool_.GetRegisteredWorker(worker_id);
RAY_LOG(DEBUG) << "Received a GetTaskRequest, worker id " << worker_id << " pid "
<< worker->Pid();
if (!worker || worker->IsBeingKilled()) {
send_reply_callback(Status::Invalid("WorkerBeingKilled"), nullptr, nullptr);
return;
}
RAY_CHECK(!worker->UsePush());
RAY_CHECK(!worker->UsePush());
// Reply would be sent when assigned a task to the worker successfully.
worker->SetGetTaskReplyAndCallback(reply, std::move(send_reply_callback));
HandleWorkerAvailable(worker_id);
@@ -836,9 +880,8 @@ void NodeManager::HandleGetTaskRequest(const rpc::GetTaskRequest &request,
void NodeManager::HandleTaskDoneRequest(const rpc::TaskDoneRequest &request,
rpc::TaskDoneReply *reply,
rpc::SendReplyCallback send_reply_callback) {
const WorkerID worker_id = WorkerID::FromBinary(request.worker_id());
RAY_LOG(DEBUG) << "Received a TaskDoneRequest from worker " << worker_id;
PREPROCESS_WORKER_REQUEST(TaskDoneRequest, request, send_reply_callback);
WorkerID worker_id = WorkerID::FromBinary(request.worker_id());
auto worker = worker_pool_.GetRegisteredWorker(worker_id);
RAY_CHECK(worker && worker->UsePush());
HandleWorkerAvailable(worker_id);
@@ -849,50 +892,36 @@ void NodeManager::HandleDisconnectClientRequest(
const rpc::DisconnectClientRequest &request, rpc::DisconnectClientReply *reply,
rpc::SendReplyCallback send_reply_callback) {
const WorkerID worker_id = WorkerID::FromBinary(request.worker_id());
RAY_LOG(DEBUG) << "Received a DisconnectClientRequest from worker " << worker_id;
ProcessDisconnectClientMessage(worker_id, true);
send_reply_callback(Status::OK(), nullptr, nullptr);
}
void NodeManager::ProcessDisconnectClientMessage(const WorkerID &worker_id,
bool intentional_disconnect) {
std::shared_ptr<Worker> worker = worker_pool_.GetRegisteredWorker(worker_id);
bool is_worker = false, is_driver = false;
if (worker) {
// The client is a worker.
is_worker = true;
} else {
worker = worker_pool_.GetRegisteredDriver(worker_id);
if (worker) {
// The client is a driver.
is_driver = true;
} else {
RAY_LOG(INFO) << "Ignoring client disconnect because the client has already "
<< "been disconnected.";
return;
}
auto worker = worker_pool_.GetWorker(worker_id);
if (!worker) {
RAY_LOG(INFO) << "Ignoring client disconnect because the client has already "
<< "been disconnected.";
return;
}
RAY_CHECK(!(is_worker && is_driver));
bool is_worker = worker->IsWorker();
// If the client has any blocked tasks, mark them as unblocked. In
// particular, we are no longer waiting for their dependencies.
if (worker) {
if (is_worker && worker->IsBeingKilled()) {
// Don't need to unblock the client if it's a worker and have sent kill signal to
// it. Because in this case, its task is already cleaned up.
RAY_LOG(DEBUG) << "Skip unblocking worker because it's already dead.";
} else {
// Clean up any open ray.get calls that the worker made.
while (!worker->GetBlockedTaskIds().empty()) {
// NOTE(swang): HandleTaskUnblocked will modify the worker, so it is
// not safe to pass in the iterator directly.
const TaskID task_id = *worker->GetBlockedTaskIds().begin();
HandleTaskUnblocked(worker_id, task_id);
}
// Clean up any open ray.wait calls that the worker made.
task_dependency_manager_.UnsubscribeWaitDependencies(worker->WorkerId());
if (is_worker && worker->IsBeingKilled()) {
// Don't need to unblock the client if it's a worker and have sent kill signal to
// it. Because in this case, its task is already cleaned up.
RAY_LOG(DEBUG) << "Skip unblocking worker because it's already dead.";
} else {
// Clean up any open ray.get calls that the worker made.
while (!worker->GetBlockedTaskIds().empty()) {
// NOTE(swang): HandleTaskUnblocked will modify the worker, so it is
// not safe to pass in the iterator directly.
const TaskID task_id = *worker->GetBlockedTaskIds().begin();
HandleTaskUnblocked(worker_id, task_id);
}
// Clean up any open ray.wait calls that the worker made.
task_dependency_manager_.UnsubscribeWaitDependencies(worker->WorkerId());
}
if (is_worker) {
@@ -952,7 +981,7 @@ void NodeManager::ProcessDisconnectClientMessage(const WorkerID &worker_id,
// Since some resources may have been released, we can try to dispatch more tasks.
DispatchTasks(local_queues_.GetReadyTasksWithResources());
} else if (is_driver) {
} else {
// The client is a driver.
const auto job_id = worker->GetAssignedJobId();
const auto driver_id = ComputeDriverIdFromJob(job_id);
@@ -975,8 +1004,7 @@ void NodeManager::ProcessDisconnectClientMessage(const WorkerID &worker_id,
void NodeManager::HandleSubmitTaskRequest(const rpc::SubmitTaskRequest &request,
rpc::SubmitTaskReply *reply,
rpc::SendReplyCallback send_reply_callback) {
RAY_LOG(DEBUG) << "Received a SubmitTaskRequest.";
PREPROCESS_WORKER_REQUEST(SubmitTaskRequest, request, send_reply_callback);
rpc::Task task;
task.mutable_task_spec()->CopyFrom(request.task_spec());
@@ -989,7 +1017,8 @@ void NodeManager::HandleSubmitTaskRequest(const rpc::SubmitTaskRequest &request,
void NodeManager::HandleFetchOrReconstructRequest(
const rpc::FetchOrReconstructRequest &request, rpc::FetchOrReconstructReply *reply,
rpc::SendReplyCallback send_reply_callback) {
RAY_LOG(DEBUG) << "Received a FetchOrReconstructRequest.";
PREPROCESS_WORKER_REQUEST(FetchOrReconstructRequest, request, send_reply_callback);
WorkerID worker_id = WorkerID::FromBinary(request.worker_id());
const auto &object_ids = request.object_ids();
std::vector<ObjectID> required_object_ids;
for (size_t i = 0; i < object_ids.size(); ++i) {
@@ -1012,7 +1041,6 @@ void NodeManager::HandleFetchOrReconstructRequest(
if (!required_object_ids.empty()) {
const TaskID task_id = TaskID::FromBinary(request.task_id());
const WorkerID &worker_id = WorkerID::FromBinary(request.worker_id());
HandleTaskBlocked(worker_id, required_object_ids, task_id, /*ray_get=*/true);
}
send_reply_callback(Status::OK(), nullptr, nullptr);
@@ -1021,7 +1049,8 @@ void NodeManager::HandleFetchOrReconstructRequest(
void NodeManager::HandleWaitRequest(const rpc::WaitRequest &request,
rpc::WaitReply *reply,
rpc::SendReplyCallback send_reply_callback) {
RAY_LOG(DEBUG) << "Received a WaitRequest.";
PREPROCESS_WORKER_REQUEST(WaitRequest, request, send_reply_callback);
WorkerID worker_id = WorkerID::FromBinary(request.worker_id());
// Read the data.
std::vector<ObjectID> object_ids = IdVectorFromProtobuf<ObjectID>(request.object_ids());
int64_t wait_ms = request.timeout();
@@ -1039,7 +1068,6 @@ void NodeManager::HandleWaitRequest(const rpc::WaitRequest &request,
}
const TaskID &current_task_id = TaskID::FromBinary(request.task_id());
const WorkerID &worker_id = WorkerID::FromBinary(request.worker_id());
bool client_blocked = !required_object_ids.empty();
if (client_blocked) {
HandleTaskBlocked(worker_id, required_object_ids, current_task_id, /*ray_get=*/false);
@@ -1066,6 +1094,7 @@ void NodeManager::HandleWaitRequest(const rpc::WaitRequest &request,
void NodeManager::HandlePushErrorRequest(const rpc::PushErrorRequest &request,
rpc::PushErrorReply *reply,
rpc::SendReplyCallback send_reply_callback) {
PREPROCESS_WORKER_REQUEST(PushErrorRequest, request, send_reply_callback);
JobID job_id = JobID::FromBinary(request.job_id());
const auto &type = request.type();
const auto &error_message = request.error_message();
@@ -1081,12 +1110,13 @@ void NodeManager::HandlePushErrorRequest(const rpc::PushErrorRequest &request,
void NodeManager::HandlePrepareActorCheckpointRequest(
const rpc::PrepareActorCheckpointRequest &request,
rpc::PrepareActorCheckpointReply *reply, rpc::SendReplyCallback send_reply_callback) {
PREPROCESS_WORKER_REQUEST(PrepareActorCheckpointRequest, request, send_reply_callback);
WorkerID worker_id = WorkerID::FromBinary(request.worker_id());
ActorID actor_id = ActorID::FromBinary(request.actor_id());
RAY_LOG(DEBUG) << "Preparing checkpoint for actor " << actor_id;
const auto &actor_entry = actor_registry_.find(actor_id);
RAY_CHECK(actor_entry != actor_registry_.end());
WorkerID worker_id = WorkerID::FromBinary(request.worker_id());
std::shared_ptr<Worker> worker = worker_pool_.GetRegisteredWorker(worker_id);
RAY_CHECK(worker && worker->GetActorId() == actor_id);
@@ -1122,6 +1152,8 @@ void NodeManager::HandleNotifyActorResumedFromCheckpointRequest(
const rpc::NotifyActorResumedFromCheckpointRequest &request,
rpc::NotifyActorResumedFromCheckpointReply *reply,
rpc::SendReplyCallback send_reply_callback) {
PREPROCESS_WORKER_REQUEST(NotifyActorResumedFromCheckpointRequest, request,
send_reply_callback);
ActorID actor_id = ActorID::FromBinary(request.actor_id());
ActorCheckpointID checkpoint_id =
ActorCheckpointID::FromBinary(request.checkpoint_id());
@@ -1152,6 +1184,7 @@ void NodeManager::HandleForwardTask(const rpc::ForwardTaskRequest &request,
void NodeManager::HandleSetResourceRequest(const rpc::SetResourceRequest &request,
rpc::SetResourceReply *reply,
rpc::SendReplyCallback send_reply_callback) {
PREPROCESS_WORKER_REQUEST(SetResourceRequest, request, send_reply_callback);
auto const &resource_name = request.resource_name();
double const capacity = request.capacity();
bool is_deletion = capacity <= 0;
@@ -1192,9 +1225,9 @@ void NodeManager::HandleSetResourceRequest(const rpc::SetResourceRequest &reques
void NodeManager::HandleNotifyUnblockedRequest(
const rpc::NotifyUnblockedRequest &request, rpc::NotifyUnblockedReply *reply,
rpc::SendReplyCallback send_reply_callback) {
RAY_LOG(DEBUG) << "Received a NotifyUnblockedRequest.";
PREPROCESS_WORKER_REQUEST(NotifyUnblockedRequest, request, send_reply_callback);
WorkerID worker_id = WorkerID::FromBinary(request.worker_id());
const TaskID current_task_id = TaskID::FromBinary(request.task_id());
const WorkerID worker_id = WorkerID::FromBinary(request.worker_id());
HandleTaskUnblocked(worker_id, current_task_id);
send_reply_callback(Status::OK(), nullptr, nullptr);
@@ -1203,7 +1236,7 @@ void NodeManager::HandleNotifyUnblockedRequest(
void NodeManager::HandlePushProfileEventsRequest(
const rpc::PushProfileEventsRequest &request, rpc::PushProfileEventsReply *reply,
rpc::SendReplyCallback send_reply_callback) {
RAY_LOG(DEBUG) << "Received a PushProfileEventsRequest.";
PREPROCESS_WORKER_REQUEST(PushProfileEventsRequest, request, send_reply_callback);
const auto &profile_table_data = request.profile_table_data();
RAY_CHECK_OK(gcs_client_->profile_table().AddProfileEventBatch(profile_table_data));
send_reply_callback(Status::OK(), nullptr, nullptr);
@@ -1212,7 +1245,7 @@ void NodeManager::HandlePushProfileEventsRequest(
void NodeManager::HandleFreeObjectsInStoreRequest(
const rpc::FreeObjectsInStoreRequest &request, rpc::FreeObjectsInStoreReply *reply,
rpc::SendReplyCallback send_reply_callback) {
RAY_LOG(DEBUG) << "Received a FreeObjectsInStoreRequest.";
PREPROCESS_WORKER_REQUEST(FreeObjectsInStoreRequest, request, send_reply_callback);
std::vector<ObjectID> object_ids = IdVectorFromProtobuf<ObjectID>(request.object_ids());
object_manager_.FreeObjects(object_ids, request.local_only());
if (request.delete_creating_tasks()) {
+8
View File
@@ -95,6 +95,14 @@ class NodeManager : public rpc::NodeManagerServiceHandler,
/// Get the port of the node manager rpc server.
int GetServerPort() const { return node_manager_server_.GetPort(); }
/// Preprocess request from raylet client. We will check whether the worker is being
/// killed due to job finishing.
///
/// \param worker_id The worker id.
/// \param request_name The request name.
/// \return False if there is no need to process this request.
bool PreprocessRequest(const WorkerID &worker_id, const std::string &request_name);
/// Implementation of node manager grpc service.
/// Handle a `ForwardTask` request.
+5 -2
View File
@@ -10,7 +10,7 @@ namespace raylet {
/// A constructor responsible for initializing the state of a worker.
Worker::Worker(const WorkerID &worker_id, pid_t pid, const Language &language, int port,
rpc::ClientCallManager &client_call_manager)
rpc::ClientCallManager &client_call_manager, bool is_worker)
: worker_id_(worker_id),
pid_(pid),
port_(port),
@@ -18,7 +18,8 @@ Worker::Worker(const WorkerID &worker_id, pid_t pid, const Language &language, i
blocked_(false),
num_missed_heartbeats_(0),
is_being_killed_(false),
client_call_manager_(client_call_manager) {
client_call_manager_(client_call_manager),
is_worker_(is_worker) {
if (port_ > 0) {
rpc_client_ = std::unique_ptr<rpc::WorkerTaskClient>(
new rpc::WorkerTaskClient("127.0.0.1", port_, client_call_manager_));
@@ -29,6 +30,8 @@ void Worker::MarkAsBeingKilled() { is_being_killed_ = true; }
bool Worker::IsBeingKilled() const { return is_being_killed_; }
bool Worker::IsWorker() const { return is_worker_; }
void Worker::MarkBlocked() { blocked_ = true; }
void Worker::MarkUnblocked() { blocked_ = false; }
+4 -1
View File
@@ -24,11 +24,12 @@ class Worker {
public:
/// A constructor that initializes a worker object.
Worker(const WorkerID &worker_id, pid_t pid, const Language &language, int port,
rpc::ClientCallManager &client_call_manager);
rpc::ClientCallManager &client_call_manager, bool is_worker = true);
/// A destructor responsible for freeing all worker state.
~Worker() {}
void MarkAsBeingKilled();
bool IsBeingKilled() const;
bool IsWorker() const;
void MarkBlocked();
void MarkUnblocked();
bool IsBlocked() const;
@@ -105,6 +106,8 @@ class Worker {
/// The `ClientCallManager` object that is shared by `WorkerTaskClient` from all
/// workers.
rpc::ClientCallManager &client_call_manager_;
/// Indicates whether this is a worker or a driver.
bool is_worker_;
/// The rpc client to send tasks to this worker.
std::unique_ptr<rpc::WorkerTaskClient> rpc_client_;
/// Reply of the `GetTask` request.
+11
View File
@@ -416,6 +416,17 @@ void WorkerPool::TickHeartbeatTimer(int max_missed_heartbeats,
}
}
std::shared_ptr<Worker> WorkerPool::GetWorker(const WorkerID &worker_id) {
auto worker = GetRegisteredWorker(worker_id);
if (!worker) {
worker = GetRegisteredDriver(worker_id);
if (!worker) {
return nullptr;
}
}
return worker;
}
} // namespace raylet
} // namespace ray
+5
View File
@@ -138,6 +138,11 @@ class WorkerPool {
void TickHeartbeatTimer(int max_missed_heartbeats,
std::vector<std::shared_ptr<Worker>> *dead_workers);
/// Return the pointer to the worker according to the worker id.
///
/// \param worker_id The worker id.
std::shared_ptr<Worker> GetWorker(const WorkerID &worker_id);
protected:
/// Asynchronously start a new worker process. Once the worker process has
/// registered with an external server, the process should create and
+34 -13
View File
@@ -48,7 +48,8 @@ void RayletClient::TryRegisterClient(int retry_times) {
}
std::this_thread::sleep_for(std::chrono::milliseconds(500));
}
RAY_LOG(FATAL) << "Failed to register to raylet server, worker id: " << worker_id_
RAY_LOG(FATAL) << "Worker " << worker_id_
<< " failed to register to raylet server, worker id: " << worker_id_
<< ", pid: " << static_cast<int>(getpid())
<< ", is worker: " << is_worker_;
}
@@ -67,7 +68,9 @@ ray::Status RayletClient::Disconnect() {
grpc::ClientContext context;
auto status = stub_->DisconnectClient(&context, disconnect_client_request, &reply);
if (!status.ok()) {
RAY_LOG(ERROR) << "Failed to disconnect from raylet, msg: " << status.error_message();
RAY_LOG(ERROR) << "Worker " << worker_id_
<< " failed to disconnect from raylet, msg: "
<< status.error_message();
}
return GrpcStatusToRayStatus(status);
}
@@ -75,12 +78,14 @@ ray::Status RayletClient::Disconnect() {
ray::Status RayletClient::SubmitTask(const ray::TaskSpecification &task_spec) {
RETURN_IF_DISCONNECTED(is_connected_);
SubmitTaskRequest submit_task_request;
submit_task_request.set_worker_id(worker_id_.Binary());
submit_task_request.mutable_task_spec()->CopyFrom(task_spec.GetMessage());
auto callback = [this](const Status &status, const SubmitTaskReply &reply) {
if (!status.ok() && is_connected_) {
is_connected_ = false;
RAY_LOG(INFO) << "Failed to send SubmitTaskRequest, msg: " << status.message();
RAY_LOG(INFO) << "Worker " << worker_id_
<< " failed to send SubmitTaskRequest, msg: " << status.message();
}
};
@@ -128,7 +133,8 @@ ray::Status RayletClient::GetTask(std::unique_ptr<ray::TaskSpecification> *task_
task_spec->reset(new ray::TaskSpecification(reply.task_spec()));
} else {
*task_spec = nullptr;
RAY_LOG(INFO) << "Failed to get task, msg: " << status.error_message();
RAY_LOG(INFO) << "Worker " << worker_id_
<< " failed to get task, msg: " << status.error_message();
}
return GrpcStatusToRayStatus(status);
}
@@ -141,7 +147,8 @@ ray::Status RayletClient::TaskDone() {
auto callback = [this](const Status &status, const TaskDoneReply &reply) {
if (!status.ok() && is_connected_) {
is_connected_ = false;
RAY_LOG(INFO) << "Failed to send TaskDoneRequest, msg: " << status.message();
RAY_LOG(INFO) << "Worker " << worker_id_
<< " failed to send TaskDoneRequest, msg: " << status.message();
}
};
@@ -168,7 +175,8 @@ ray::Status RayletClient::FetchOrReconstruct(const std::vector<ObjectID> &object
auto callback = [this](const Status &status, const FetchOrReconstructReply &reply) {
if (!status.ok() && is_connected_) {
is_connected_ = false;
RAY_LOG(INFO) << "Failed to send FetchOrReconstructRequest, msg: "
RAY_LOG(INFO) << "Worker " << worker_id_
<< " failed to send FetchOrReconstructRequest, msg: "
<< status.message();
}
};
@@ -190,7 +198,9 @@ ray::Status RayletClient::NotifyUnblocked(const TaskID &current_task_id) {
auto callback = [this](const Status &status, const NotifyUnblockedReply &reply) {
if (!status.ok() && is_connected_) {
is_connected_ = false;
RAY_LOG(INFO) << "Failed to send NotifyUnblockedRequest, msg: " << status.message();
RAY_LOG(INFO) << "Worker " << worker_id_
<< " failed to send NotifyUnblockedRequest, msg: "
<< status.message();
}
};
@@ -223,7 +233,8 @@ ray::Status RayletClient::Wait(const std::vector<ObjectID> &object_ids, int num_
result->first = IdVectorFromProtobuf<ObjectID>(reply.found());
result->second = IdVectorFromProtobuf<ObjectID>(reply.remaining());
} else {
RAY_LOG(INFO) << "Failed to send WaitRequest, msg: " << status.error_message();
RAY_LOG(INFO) << "Worker " << worker_id_
<< " failed to send WaitRequest, msg: " << status.error_message();
}
return GrpcStatusToRayStatus(status);
@@ -237,11 +248,13 @@ ray::Status RayletClient::PushError(const ray::JobID &job_id, const std::string
push_error_request.set_type(type);
push_error_request.set_error_message(error_message);
push_error_request.set_timestamp(timestamp);
push_error_request.set_worker_id(worker_id_.Binary());
auto callback = [this](const Status &status, const PushErrorReply &reply) {
if (!status.ok() && is_connected_) {
is_connected_ = false;
RAY_LOG(INFO) << "Failed to send PushErrorRequest, msg: " << status.message();
RAY_LOG(INFO) << "Worker " << worker_id_
<< " failed to send PushErrorRequest, msg: " << status.message();
}
};
@@ -256,11 +269,13 @@ ray::Status RayletClient::PushProfileEvents(const ProfileTableData &profile_even
RETURN_IF_DISCONNECTED(is_connected_);
PushProfileEventsRequest push_profile_events_request;
push_profile_events_request.mutable_profile_table_data()->CopyFrom(profile_events);
push_profile_events_request.set_worker_id(worker_id_.Binary());
auto callback = [this](const Status &status, const PushProfileEventsReply &reply) {
if (!status.ok() && is_connected_) {
is_connected_ = false;
RAY_LOG(INFO) << "Failed to send PushProfileEventsRequest, msg: "
RAY_LOG(INFO) << "Worker " << worker_id_
<< " failed to send PushProfileEventsRequest, msg: "
<< status.message();
}
};
@@ -279,13 +294,15 @@ ray::Status RayletClient::FreeObjects(const std::vector<ray::ObjectID> &object_i
FreeObjectsInStoreRequest free_objects_request;
free_objects_request.set_local_only(local_only);
free_objects_request.set_delete_creating_tasks(delete_creating_tasks);
free_objects_request.set_worker_id(worker_id_.Binary());
IdVectorToProtobuf<ray::ObjectID, FreeObjectsInStoreRequest>(
object_ids, free_objects_request, &FreeObjectsInStoreRequest::add_object_ids);
auto callback = [this](const Status &status, const FreeObjectsInStoreReply &reply) {
if (!status.ok() && is_connected_) {
is_connected_ = false;
RAY_LOG(INFO) << "Failed to send FreeObjectsInStoreRequest, msg: "
RAY_LOG(INFO) << "Worker " << worker_id_
<< " failed to send FreeObjectsInStoreRequest, msg: "
<< status.message();
}
};
@@ -313,7 +330,8 @@ ray::Status RayletClient::PrepareActorCheckpoint(const ActorID &actor_id,
if (status.ok()) {
checkpoint_id = ActorCheckpointID::FromBinary(reply.checkpoint_id());
} else {
RAY_LOG(INFO) << "Failed to send PrepareActorCheckpointRequest, msg: "
RAY_LOG(INFO) << "Worker " << worker_id_
<< " failed to send PrepareActorCheckpointRequest, msg: "
<< status.error_message();
}
@@ -326,6 +344,7 @@ ray::Status RayletClient::NotifyActorResumedFromCheckpoint(
NotifyActorResumedFromCheckpointRequest notify_actor_resumed_from_checkpoint_request;
notify_actor_resumed_from_checkpoint_request.set_actor_id(actor_id.Binary());
notify_actor_resumed_from_checkpoint_request.set_checkpoint_id(checkpoint_id.Binary());
notify_actor_resumed_from_checkpoint_request.set_worker_id(worker_id_.Binary());
auto callback = [this](const Status &status,
const NotifyActorResumedFromCheckpointReply &reply) {
@@ -353,6 +372,7 @@ ray::Status RayletClient::SetResource(const std::string &resource_name,
set_resource_request.set_resource_name(resource_name);
set_resource_request.set_capacity(capacity);
set_resource_request.set_client_id(client_id.Binary());
set_resource_request.set_worker_id(worker_id_.Binary());
auto callback = [this](const Status &status, const SetResourceReply &reply) {
if (!status.ok() && is_connected_) {
@@ -382,7 +402,8 @@ ray::Status RayletClient::RegisterClient() {
auto status = stub_->RegisterClient(&context, register_client_request, &reply);
if (!status.ok()) {
RAY_LOG(DEBUG) << "Failed to register client, msg: " << status.error_message();
RAY_LOG(DEBUG) << "Worker " << worker_id_
<< " failed to register client, msg: " << status.error_message();
}
return GrpcStatusToRayStatus(status);