Modularize NodeManager::ProcessClientMessage (#2895)

Split NodeManager::ProcessClientMessage into a couple of smaller functions, each of which handles one type of message.
This commit is contained in:
Hao Chen
2018-09-19 05:18:34 +08:00
committed by Robert Nishihara
parent ea9d1cc887
commit 715ec1bca5
2 changed files with 286 additions and 204 deletions
+235 -202
View File
@@ -538,227 +538,34 @@ void NodeManager::ProcessClientMessage(
switch (static_cast<protocol::MessageType>(message_type)) {
case protocol::MessageType::RegisterClientRequest: {
auto message = flatbuffers::GetRoot<protocol::RegisterClientRequest>(message_data);
client->SetClientID(from_flatbuf(*message->client_id()));
auto worker =
std::make_shared<Worker>(message->worker_pid(), message->language(), client);
if (message->is_worker()) {
// Register the new worker.
worker_pool_.RegisterWorker(std::move(worker));
DispatchTasks();
} else {
// Register the new driver. Note that here the driver_id in RegisterClientRequest
// message is actually the ID of the driver task, while client_id represents the
// real driver ID, which can associate all the tasks/actors for a given driver,
// which is set to the worker ID.
const JobID driver_task_id = from_flatbuf(*message->driver_id());
worker->AssignTaskId(driver_task_id);
worker->AssignDriverId(from_flatbuf(*message->client_id()));
worker_pool_.RegisterDriver(std::move(worker));
local_queues_.AddDriverTaskId(driver_task_id);
}
ProcessRegisterClientRequestMessage(client, message_data);
} break;
case protocol::MessageType::GetTask: {
std::shared_ptr<Worker> worker = worker_pool_.GetRegisteredWorker(client);
RAY_CHECK(worker);
// If the worker was assigned a task, mark it as finished.
if (!worker->GetAssignedTaskId().is_nil()) {
FinishAssignedTask(*worker);
}
// Return the worker to the idle pool.
worker_pool_.PushWorker(std::move(worker));
// Local resource availability changed: invoke scheduling policy for local node.
const ClientID &local_client_id = gcs_client_->client_table().GetLocalClientId();
cluster_resource_map_[local_client_id].SetLoadResources(
local_queues_.GetResourceLoad());
// Call task dispatch to assign work to the new worker.
DispatchTasks();
ProcessGetTaskMessage(client);
} break;
case protocol::MessageType::DisconnectClient: {
// Remove the dead worker from the pool and stop listening for messages.
const std::shared_ptr<Worker> worker = worker_pool_.GetRegisteredWorker(client);
if (worker) {
// The client is a worker. Handle the case where the worker is killed
// while executing a task. Clean up the assigned task's resources, push
// an error to the driver.
// (See design_docs/task_states.rst for the state transition diagram.)
const TaskID &task_id = worker->GetAssignedTaskId();
if (!task_id.is_nil() && !worker->IsDead()) {
// If the worker was killed intentionally, e.g., when the driver that created
// the task that this worker is currently executing exits, the task for this
// worker has already been removed from queue, so the following are skipped.
auto const &running_tasks = local_queues_.GetRunningTasks();
// TODO(rkn): This is too heavyweight just to get the task's driver ID.
auto const it = std::find_if(
running_tasks.begin(), running_tasks.end(), [task_id](const Task &task) {
return task.GetTaskSpecification().TaskId() == task_id;
});
RAY_CHECK(running_tasks.size() != 0);
RAY_CHECK(it != running_tasks.end());
const TaskSpecification &spec = it->GetTaskSpecification();
const JobID job_id = spec.DriverId();
// TODO(rkn): Define this constant somewhere else.
std::string type = "worker_died";
std::ostringstream error_message;
error_message << "A worker died or was killed while executing task " << task_id
<< ".";
RAY_CHECK_OK(gcs_client_->error_table().PushErrorToDriver(
job_id, type, error_message.str(), current_time_ms()));
// Handle the task failure in order to raise an exception in the
// application.
TreatTaskAsFailed(spec);
task_dependency_manager_.TaskCanceled(spec.TaskId());
local_queues_.RemoveTask(spec.TaskId());
}
worker_pool_.DisconnectWorker(worker);
// If the worker was an actor, add it to the list of dead actors.
const ActorID actor_id = worker->GetActorId();
if (!actor_id.is_nil()) {
// TODO(rkn): Consider broadcasting a message to all of the other
// node managers so that they can mark the actor as dead.
RAY_LOG(DEBUG) << "The actor with ID " << actor_id << " died.";
auto actor_entry = actor_registry_.find(actor_id);
RAY_CHECK(actor_entry != actor_registry_.end());
actor_entry->second.MarkDead();
// For dead actors, if there are remaining tasks for this actor, we
// should handle them.
CleanUpTasksForDeadActor(actor_id);
}
const ClientID &client_id = gcs_client_->client_table().GetLocalClientId();
// Return the resources that were being used by this worker.
auto const &task_resources = worker->GetTaskResourceIds();
local_available_resources_.Release(task_resources);
cluster_resource_map_[client_id].Release(task_resources.ToResourceSet());
worker->ResetTaskResourceIds();
auto const &lifetime_resources = worker->GetLifetimeResourceIds();
local_available_resources_.Release(lifetime_resources);
cluster_resource_map_[client_id].Release(lifetime_resources.ToResourceSet());
worker->ResetLifetimeResourceIds();
RAY_LOG(DEBUG) << "Worker (pid=" << worker->Pid() << ") is disconnected. "
<< "driver_id: " << worker->GetAssignedDriverId();
// Since some resources may have been released, we can try to dispatch more tasks.
DispatchTasks();
} else {
// The client is a driver.
RAY_CHECK_OK(gcs_client_->driver_table().AppendDriverData(client->GetClientID(),
/*is_dead=*/true));
const std::shared_ptr<Worker> driver = worker_pool_.GetRegisteredDriver(client);
RAY_CHECK(driver);
auto driver_id = driver->GetAssignedTaskId();
RAY_CHECK(!driver_id.is_nil());
local_queues_.RemoveDriverTaskId(driver_id);
worker_pool_.DisconnectDriver(driver);
RAY_LOG(DEBUG) << "Driver (pid=" << driver->Pid() << ") is disconnected. "
<< "driver_id: " << driver->GetAssignedDriverId();
}
ProcessDisconnectClientMessage(client);
// We don't need to receive future messages from this client,
// because it's already disconnected.
return;
} break;
case protocol::MessageType::SubmitTask: {
// Read the task submitted by the client.
auto message = flatbuffers::GetRoot<protocol::SubmitTaskRequest>(message_data);
TaskExecutionSpecification task_execution_spec(
from_flatbuf(*message->execution_dependencies()));
TaskSpecification task_spec(*message->task_spec());
Task task(task_execution_spec, task_spec);
// Submit the task to the local scheduler. Since the task was submitted
// locally, there is no uncommitted lineage.
SubmitTask(task, Lineage());
ProcessSubmitTaskMessage(message_data);
} break;
case protocol::MessageType::ReconstructObjects: {
auto message = flatbuffers::GetRoot<protocol::ReconstructObjects>(message_data);
std::vector<ObjectID> required_object_ids;
for (size_t i = 0; i < message->object_ids()->size(); ++i) {
ObjectID object_id = from_flatbuf(*message->object_ids()->Get(i));
if (!task_dependency_manager_.CheckObjectLocal(object_id)) {
if (message->fetch_only()) {
// If only a fetch is required, then do not subscribe to the
// dependencies to the task dependency manager.
RAY_CHECK_OK(object_manager_.Pull(object_id));
} else {
// If reconstruction is also required, then add any missing objects
// to the list to subscribe to in the task dependency manager. These
// objects will be pulled from remote node managers and reconstructed
// if necessary.
required_object_ids.push_back(object_id);
}
}
}
if (!required_object_ids.empty()) {
HandleClientBlocked(client, required_object_ids);
}
ProcessReconstructObjectsMessage(client, message_data);
} break;
case protocol::MessageType::NotifyUnblocked: {
HandleClientUnblocked(client);
} break;
case protocol::MessageType::WaitRequest: {
// Read the data.
auto message = flatbuffers::GetRoot<protocol::WaitRequest>(message_data);
std::vector<ObjectID> object_ids = from_flatbuf(*message->object_ids());
int64_t wait_ms = message->timeout();
uint64_t num_required_objects = static_cast<uint64_t>(message->num_ready_objects());
bool wait_local = message->wait_local();
std::vector<ObjectID> required_object_ids;
for (auto const &object_id : object_ids) {
if (!task_dependency_manager_.CheckObjectLocal(object_id)) {
// Add any missing objects to the list to subscribe to in the task
// dependency manager. These objects will be pulled from remote node
// managers and reconstructed if necessary.
required_object_ids.push_back(object_id);
}
}
bool client_blocked = !required_object_ids.empty();
if (client_blocked) {
HandleClientBlocked(client, required_object_ids);
}
ray::Status status = object_manager_.Wait(
object_ids, wait_ms, num_required_objects, wait_local,
[this, client_blocked, client](std::vector<ObjectID> found,
std::vector<ObjectID> remaining) {
// Write the data.
flatbuffers::FlatBufferBuilder fbb;
flatbuffers::Offset<protocol::WaitReply> wait_reply = protocol::CreateWaitReply(
fbb, to_flatbuf(fbb, found), to_flatbuf(fbb, remaining));
fbb.Finish(wait_reply);
RAY_CHECK_OK(
client->WriteMessage(static_cast<int64_t>(protocol::MessageType::WaitReply),
fbb.GetSize(), fbb.GetBufferPointer()));
// The client is unblocked now because the wait call has returned.
if (client_blocked) {
HandleClientUnblocked(client);
}
});
RAY_CHECK_OK(status);
ProcessWaitRequestMessage(client, message_data);
} break;
case protocol::MessageType::PushErrorRequest: {
auto message = flatbuffers::GetRoot<protocol::PushErrorRequest>(message_data);
JobID job_id = from_flatbuf(*message->job_id());
auto const &type = string_from_flatbuf(*message->type());
auto const &error_message = string_from_flatbuf(*message->error_message());
double timestamp = message->timestamp();
RAY_CHECK_OK(gcs_client_->error_table().PushErrorToDriver(job_id, type, error_message,
timestamp));
ProcessPushErrorRequestMessage(message_data);
} break;
case protocol::MessageType::PushProfileEventsRequest: {
auto message = flatbuffers::GetRoot<ProfileTableData>(message_data);
RAY_CHECK_OK(gcs_client_->profile_table().AddProfileEventBatch(*message));
} break;
case protocol::MessageType::FreeObjectsInObjectStoreRequest: {
@@ -775,6 +582,232 @@ void NodeManager::ProcessClientMessage(
client->ProcessMessages();
}
void NodeManager::ProcessRegisterClientRequestMessage(
const std::shared_ptr<LocalClientConnection> &client, const uint8_t *message_data) {
auto message = flatbuffers::GetRoot<protocol::RegisterClientRequest>(message_data);
client->SetClientID(from_flatbuf(*message->client_id()));
auto worker =
std::make_shared<Worker>(message->worker_pid(), message->language(), client);
if (message->is_worker()) {
// Register the new worker.
worker_pool_.RegisterWorker(std::move(worker));
DispatchTasks();
} else {
// Register the new driver. Note that here the driver_id in RegisterClientRequest
// message is actually the ID of the driver task, while client_id represents the
// real driver ID, which can associate all the tasks/actors for a given driver,
// which is set to the worker ID.
const JobID driver_task_id = from_flatbuf(*message->driver_id());
worker->AssignTaskId(driver_task_id);
worker->AssignDriverId(from_flatbuf(*message->client_id()));
worker_pool_.RegisterDriver(std::move(worker));
local_queues_.AddDriverTaskId(driver_task_id);
}
}
void NodeManager::ProcessGetTaskMessage(
const std::shared_ptr<LocalClientConnection> &client) {
std::shared_ptr<Worker> worker = worker_pool_.GetRegisteredWorker(client);
RAY_CHECK(worker);
// If the worker was assigned a task, mark it as finished.
if (!worker->GetAssignedTaskId().is_nil()) {
FinishAssignedTask(*worker);
}
// Return the worker to the idle pool.
worker_pool_.PushWorker(std::move(worker));
// Local resource availability changed: invoke scheduling policy for local node.
const ClientID &local_client_id = gcs_client_->client_table().GetLocalClientId();
cluster_resource_map_[local_client_id].SetLoadResources(
local_queues_.GetResourceLoad());
// Call task dispatch to assign work to the new worker.
DispatchTasks();
}
void NodeManager::ProcessDisconnectClientMessage(
const std::shared_ptr<LocalClientConnection> &client) {
// Remove the dead worker from the pool and stop listening for messages.
const std::shared_ptr<Worker> worker = worker_pool_.GetRegisteredWorker(client);
if (worker) {
// The client is a worker. Handle the case where the worker is killed
// while executing a task. Clean up the assigned task's resources, push
// an error to the driver.
// (See design_docs/task_states.rst for the state transition diagram.)
const TaskID &task_id = worker->GetAssignedTaskId();
if (!task_id.is_nil() && !worker->IsDead()) {
// If the worker was killed intentionally, e.g., when the driver that created
// the task that this worker is currently executing exits, the task for this
// worker has already been removed from queue, so the following are skipped.
auto const &running_tasks = local_queues_.GetRunningTasks();
// TODO(rkn): This is too heavyweight just to get the task's driver ID.
auto const it = std::find_if(
running_tasks.begin(), running_tasks.end(), [task_id](const Task &task) {
return task.GetTaskSpecification().TaskId() == task_id;
});
RAY_CHECK(running_tasks.size() != 0);
RAY_CHECK(it != running_tasks.end());
const TaskSpecification &spec = it->GetTaskSpecification();
const JobID job_id = spec.DriverId();
// TODO(rkn): Define this constant somewhere else.
std::string type = "worker_died";
std::ostringstream error_message;
error_message << "A worker died or was killed while executing task " << task_id
<< ".";
RAY_CHECK_OK(gcs_client_->error_table().PushErrorToDriver(
job_id, type, error_message.str(), current_time_ms()));
// Handle the task failure in order to raise an exception in the
// application.
TreatTaskAsFailed(spec);
task_dependency_manager_.TaskCanceled(spec.TaskId());
local_queues_.RemoveTask(spec.TaskId());
}
worker_pool_.DisconnectWorker(worker);
// If the worker was an actor, add it to the list of dead actors.
const ActorID actor_id = worker->GetActorId();
if (!actor_id.is_nil()) {
// TODO(rkn): Consider broadcasting a message to all of the other
// node managers so that they can mark the actor as dead.
RAY_LOG(DEBUG) << "The actor with ID " << actor_id << " died.";
auto actor_entry = actor_registry_.find(actor_id);
RAY_CHECK(actor_entry != actor_registry_.end());
actor_entry->second.MarkDead();
// For dead actors, if there are remaining tasks for this actor, we
// should handle them.
CleanUpTasksForDeadActor(actor_id);
}
const ClientID &client_id = gcs_client_->client_table().GetLocalClientId();
// Return the resources that were being used by this worker.
auto const &task_resources = worker->GetTaskResourceIds();
local_available_resources_.Release(task_resources);
cluster_resource_map_[client_id].Release(task_resources.ToResourceSet());
worker->ResetTaskResourceIds();
auto const &lifetime_resources = worker->GetLifetimeResourceIds();
local_available_resources_.Release(lifetime_resources);
cluster_resource_map_[client_id].Release(lifetime_resources.ToResourceSet());
worker->ResetLifetimeResourceIds();
RAY_LOG(DEBUG) << "Worker (pid=" << worker->Pid() << ") is disconnected. "
<< "driver_id: " << worker->GetAssignedDriverId();
// Since some resources may have been released, we can try to dispatch more tasks.
DispatchTasks();
} else {
// The client is a driver.
RAY_CHECK_OK(gcs_client_->driver_table().AppendDriverData(client->GetClientID(),
/*is_dead=*/true));
const std::shared_ptr<Worker> driver = worker_pool_.GetRegisteredDriver(client);
RAY_CHECK(driver);
auto driver_id = driver->GetAssignedTaskId();
RAY_CHECK(!driver_id.is_nil());
local_queues_.RemoveDriverTaskId(driver_id);
worker_pool_.DisconnectDriver(driver);
RAY_LOG(DEBUG) << "Driver (pid=" << driver->Pid() << ") is disconnected. "
<< "driver_id: " << driver->GetAssignedDriverId();
}
}
void NodeManager::ProcessSubmitTaskMessage(const uint8_t *message_data) {
// Read the task submitted by the client.
auto message = flatbuffers::GetRoot<protocol::SubmitTaskRequest>(message_data);
TaskExecutionSpecification task_execution_spec(
from_flatbuf(*message->execution_dependencies()));
TaskSpecification task_spec(*message->task_spec());
Task task(task_execution_spec, task_spec);
// Submit the task to the local scheduler. Since the task was submitted
// locally, there is no uncommitted lineage.
SubmitTask(task, Lineage());
}
void NodeManager::ProcessReconstructObjectsMessage(
const std::shared_ptr<LocalClientConnection> &client, const uint8_t *message_data) {
auto message = flatbuffers::GetRoot<protocol::ReconstructObjects>(message_data);
std::vector<ObjectID> required_object_ids;
for (size_t i = 0; i < message->object_ids()->size(); ++i) {
ObjectID object_id = from_flatbuf(*message->object_ids()->Get(i));
if (!task_dependency_manager_.CheckObjectLocal(object_id)) {
if (message->fetch_only()) {
// If only a fetch is required, then do not subscribe to the
// dependencies to the task dependency manager.
RAY_CHECK_OK(object_manager_.Pull(object_id));
} else {
// If reconstruction is also required, then add any missing objects
// to the list to subscribe to in the task dependency manager. These
// objects will be pulled from remote node managers and reconstructed
// if necessary.
required_object_ids.push_back(object_id);
}
}
}
if (!required_object_ids.empty()) {
HandleClientBlocked(client, required_object_ids);
}
}
void NodeManager::ProcessWaitRequestMessage(
const std::shared_ptr<LocalClientConnection> &client, const uint8_t *message_data) {
// Read the data.
auto message = flatbuffers::GetRoot<protocol::WaitRequest>(message_data);
std::vector<ObjectID> object_ids = from_flatbuf(*message->object_ids());
int64_t wait_ms = message->timeout();
uint64_t num_required_objects = static_cast<uint64_t>(message->num_ready_objects());
bool wait_local = message->wait_local();
std::vector<ObjectID> required_object_ids;
for (auto const &object_id : object_ids) {
if (!task_dependency_manager_.CheckObjectLocal(object_id)) {
// Add any missing objects to the list to subscribe to in the task
// dependency manager. These objects will be pulled from remote node
// managers and reconstructed if necessary.
required_object_ids.push_back(object_id);
}
}
bool client_blocked = !required_object_ids.empty();
if (client_blocked) {
HandleClientBlocked(client, required_object_ids);
}
ray::Status status = object_manager_.Wait(
object_ids, wait_ms, num_required_objects, wait_local,
[this, client_blocked, client](std::vector<ObjectID> found,
std::vector<ObjectID> remaining) {
// Write the data.
flatbuffers::FlatBufferBuilder fbb;
flatbuffers::Offset<protocol::WaitReply> wait_reply = protocol::CreateWaitReply(
fbb, to_flatbuf(fbb, found), to_flatbuf(fbb, remaining));
fbb.Finish(wait_reply);
RAY_CHECK_OK(
client->WriteMessage(static_cast<int64_t>(protocol::MessageType::WaitReply),
fbb.GetSize(), fbb.GetBufferPointer()));
// The client is unblocked now because the wait call has returned.
if (client_blocked) {
HandleClientUnblocked(client);
}
});
RAY_CHECK_OK(status);
}
void NodeManager::ProcessPushErrorRequestMessage(const uint8_t *message_data) {
auto message = flatbuffers::GetRoot<protocol::PushErrorRequest>(message_data);
JobID job_id = from_flatbuf(*message->job_id());
auto const &type = string_from_flatbuf(*message->type());
auto const &error_message = string_from_flatbuf(*message->error_message());
double timestamp = message->timestamp();
RAY_CHECK_OK(gcs_client_->error_table().PushErrorToDriver(job_id, type, error_message,
timestamp));
}
void NodeManager::ProcessNewNodeManager(TcpClientConnection &node_manager_client) {
node_manager_client.ProcessMessages();
}
+51 -2
View File
@@ -59,10 +59,10 @@ class NodeManager {
///
/// \param client The client that sent the message.
/// \param message_type The message type (e.g., a flatbuffer enum).
/// \param message A pointer to the message data.
/// \param message_data A pointer to the message data.
/// \return Void.
void ProcessClientMessage(const std::shared_ptr<LocalClientConnection> &client,
int64_t message_type, const uint8_t *message);
int64_t message_type, const uint8_t *message_data);
/// Handle a new node manager connection.
///
@@ -266,6 +266,55 @@ class NodeManager {
/// \return True if the invariants are satisfied and false otherwise.
bool CheckDependencyManagerInvariant() const;
/// Process client message of RegisterClientRequest
//
/// \param client The client that sent the message.
/// \param message_data A pointer to the message data.
/// \return Void.
void ProcessRegisterClientRequestMessage(
const std::shared_ptr<LocalClientConnection> &client, const uint8_t *message_data);
/// Process client message of GetTask
//
/// \param client The client that sent the message.
/// \return Void.
void ProcessGetTaskMessage(const std::shared_ptr<LocalClientConnection> &client);
/// Process client message of DisconnectClient
//
/// \param client The client that sent the message.
/// \return Void.
void ProcessDisconnectClientMessage(
const std::shared_ptr<LocalClientConnection> &client);
/// Process client message of SubmitTask
//
/// \param message_data A pointer to the message data.
/// \return Void.
void ProcessSubmitTaskMessage(const uint8_t *message_data);
/// Process client message of ReconstructObjects
//
/// \param client The client that sent the message.
/// \param message_data A pointer to the message data.
/// \return Void.
void ProcessReconstructObjectsMessage(
const std::shared_ptr<LocalClientConnection> &client, const uint8_t *message_data);
/// Process client message of WaitRequest
//
/// \param client The client that sent the message.
/// \param message_data A pointer to the message data.
/// \return Void.
void ProcessWaitRequestMessage(const std::shared_ptr<LocalClientConnection> &client,
const uint8_t *message_data);
/// Process client message of PushErrorRequest
//
/// \param message_data A pointer to the message data.
/// \return Void.
void ProcessPushErrorRequestMessage(const uint8_t *message_data);
boost::asio::io_service &io_service_;
ObjectManager &object_manager_;
/// A Plasma object store client. This is used exclusively for creating new