diff --git a/src/ray/raylet/node_manager.cc b/src/ray/raylet/node_manager.cc index d10d6d390..9ec2399c8 100644 --- a/src/ray/raylet/node_manager.cc +++ b/src/ray/raylet/node_manager.cc @@ -538,227 +538,34 @@ void NodeManager::ProcessClientMessage( switch (static_cast(message_type)) { case protocol::MessageType::RegisterClientRequest: { - auto message = flatbuffers::GetRoot(message_data); - client->SetClientID(from_flatbuf(*message->client_id())); - auto worker = - std::make_shared(message->worker_pid(), message->language(), client); - if (message->is_worker()) { - // Register the new worker. - worker_pool_.RegisterWorker(std::move(worker)); - DispatchTasks(); - } else { - // Register the new driver. Note that here the driver_id in RegisterClientRequest - // message is actually the ID of the driver task, while client_id represents the - // real driver ID, which can associate all the tasks/actors for a given driver, - // which is set to the worker ID. - const JobID driver_task_id = from_flatbuf(*message->driver_id()); - worker->AssignTaskId(driver_task_id); - worker->AssignDriverId(from_flatbuf(*message->client_id())); - worker_pool_.RegisterDriver(std::move(worker)); - local_queues_.AddDriverTaskId(driver_task_id); - } + ProcessRegisterClientRequestMessage(client, message_data); } break; case protocol::MessageType::GetTask: { - std::shared_ptr worker = worker_pool_.GetRegisteredWorker(client); - RAY_CHECK(worker); - // If the worker was assigned a task, mark it as finished. - if (!worker->GetAssignedTaskId().is_nil()) { - FinishAssignedTask(*worker); - } - // Return the worker to the idle pool. - worker_pool_.PushWorker(std::move(worker)); - // Local resource availability changed: invoke scheduling policy for local node. - const ClientID &local_client_id = gcs_client_->client_table().GetLocalClientId(); - cluster_resource_map_[local_client_id].SetLoadResources( - local_queues_.GetResourceLoad()); - // Call task dispatch to assign work to the new worker. - DispatchTasks(); - + ProcessGetTaskMessage(client); } break; case protocol::MessageType::DisconnectClient: { - // Remove the dead worker from the pool and stop listening for messages. - const std::shared_ptr worker = worker_pool_.GetRegisteredWorker(client); - - if (worker) { - // The client is a worker. Handle the case where the worker is killed - // while executing a task. Clean up the assigned task's resources, push - // an error to the driver. - // (See design_docs/task_states.rst for the state transition diagram.) - const TaskID &task_id = worker->GetAssignedTaskId(); - if (!task_id.is_nil() && !worker->IsDead()) { - // If the worker was killed intentionally, e.g., when the driver that created - // the task that this worker is currently executing exits, the task for this - // worker has already been removed from queue, so the following are skipped. - auto const &running_tasks = local_queues_.GetRunningTasks(); - // TODO(rkn): This is too heavyweight just to get the task's driver ID. - auto const it = std::find_if( - running_tasks.begin(), running_tasks.end(), [task_id](const Task &task) { - return task.GetTaskSpecification().TaskId() == task_id; - }); - RAY_CHECK(running_tasks.size() != 0); - RAY_CHECK(it != running_tasks.end()); - const TaskSpecification &spec = it->GetTaskSpecification(); - const JobID job_id = spec.DriverId(); - - // TODO(rkn): Define this constant somewhere else. - std::string type = "worker_died"; - std::ostringstream error_message; - error_message << "A worker died or was killed while executing task " << task_id - << "."; - RAY_CHECK_OK(gcs_client_->error_table().PushErrorToDriver( - job_id, type, error_message.str(), current_time_ms())); - - // Handle the task failure in order to raise an exception in the - // application. - TreatTaskAsFailed(spec); - task_dependency_manager_.TaskCanceled(spec.TaskId()); - local_queues_.RemoveTask(spec.TaskId()); - } - - worker_pool_.DisconnectWorker(worker); - - // If the worker was an actor, add it to the list of dead actors. - const ActorID actor_id = worker->GetActorId(); - if (!actor_id.is_nil()) { - // TODO(rkn): Consider broadcasting a message to all of the other - // node managers so that they can mark the actor as dead. - RAY_LOG(DEBUG) << "The actor with ID " << actor_id << " died."; - auto actor_entry = actor_registry_.find(actor_id); - RAY_CHECK(actor_entry != actor_registry_.end()); - actor_entry->second.MarkDead(); - // For dead actors, if there are remaining tasks for this actor, we - // should handle them. - CleanUpTasksForDeadActor(actor_id); - } - - const ClientID &client_id = gcs_client_->client_table().GetLocalClientId(); - - // Return the resources that were being used by this worker. - auto const &task_resources = worker->GetTaskResourceIds(); - local_available_resources_.Release(task_resources); - cluster_resource_map_[client_id].Release(task_resources.ToResourceSet()); - worker->ResetTaskResourceIds(); - - auto const &lifetime_resources = worker->GetLifetimeResourceIds(); - local_available_resources_.Release(lifetime_resources); - cluster_resource_map_[client_id].Release(lifetime_resources.ToResourceSet()); - worker->ResetLifetimeResourceIds(); - - RAY_LOG(DEBUG) << "Worker (pid=" << worker->Pid() << ") is disconnected. " - << "driver_id: " << worker->GetAssignedDriverId(); - - // Since some resources may have been released, we can try to dispatch more tasks. - DispatchTasks(); - } else { - // The client is a driver. - RAY_CHECK_OK(gcs_client_->driver_table().AppendDriverData(client->GetClientID(), - /*is_dead=*/true)); - const std::shared_ptr driver = worker_pool_.GetRegisteredDriver(client); - RAY_CHECK(driver); - auto driver_id = driver->GetAssignedTaskId(); - RAY_CHECK(!driver_id.is_nil()); - local_queues_.RemoveDriverTaskId(driver_id); - worker_pool_.DisconnectDriver(driver); - - RAY_LOG(DEBUG) << "Driver (pid=" << driver->Pid() << ") is disconnected. " - << "driver_id: " << driver->GetAssignedDriverId(); - } + ProcessDisconnectClientMessage(client); + // We don't need to receive future messages from this client, + // because it's already disconnected. return; } break; case protocol::MessageType::SubmitTask: { - // Read the task submitted by the client. - auto message = flatbuffers::GetRoot(message_data); - TaskExecutionSpecification task_execution_spec( - from_flatbuf(*message->execution_dependencies())); - TaskSpecification task_spec(*message->task_spec()); - Task task(task_execution_spec, task_spec); - // Submit the task to the local scheduler. Since the task was submitted - // locally, there is no uncommitted lineage. - SubmitTask(task, Lineage()); + ProcessSubmitTaskMessage(message_data); } break; case protocol::MessageType::ReconstructObjects: { - auto message = flatbuffers::GetRoot(message_data); - std::vector required_object_ids; - for (size_t i = 0; i < message->object_ids()->size(); ++i) { - ObjectID object_id = from_flatbuf(*message->object_ids()->Get(i)); - if (!task_dependency_manager_.CheckObjectLocal(object_id)) { - if (message->fetch_only()) { - // If only a fetch is required, then do not subscribe to the - // dependencies to the task dependency manager. - RAY_CHECK_OK(object_manager_.Pull(object_id)); - } else { - // If reconstruction is also required, then add any missing objects - // to the list to subscribe to in the task dependency manager. These - // objects will be pulled from remote node managers and reconstructed - // if necessary. - required_object_ids.push_back(object_id); - } - } - } - - if (!required_object_ids.empty()) { - HandleClientBlocked(client, required_object_ids); - } + ProcessReconstructObjectsMessage(client, message_data); } break; case protocol::MessageType::NotifyUnblocked: { HandleClientUnblocked(client); } break; case protocol::MessageType::WaitRequest: { - // Read the data. - auto message = flatbuffers::GetRoot(message_data); - std::vector object_ids = from_flatbuf(*message->object_ids()); - int64_t wait_ms = message->timeout(); - uint64_t num_required_objects = static_cast(message->num_ready_objects()); - bool wait_local = message->wait_local(); - - std::vector required_object_ids; - for (auto const &object_id : object_ids) { - if (!task_dependency_manager_.CheckObjectLocal(object_id)) { - // Add any missing objects to the list to subscribe to in the task - // dependency manager. These objects will be pulled from remote node - // managers and reconstructed if necessary. - required_object_ids.push_back(object_id); - } - } - - bool client_blocked = !required_object_ids.empty(); - if (client_blocked) { - HandleClientBlocked(client, required_object_ids); - } - - ray::Status status = object_manager_.Wait( - object_ids, wait_ms, num_required_objects, wait_local, - [this, client_blocked, client](std::vector found, - std::vector remaining) { - // Write the data. - flatbuffers::FlatBufferBuilder fbb; - flatbuffers::Offset wait_reply = protocol::CreateWaitReply( - fbb, to_flatbuf(fbb, found), to_flatbuf(fbb, remaining)); - fbb.Finish(wait_reply); - RAY_CHECK_OK( - client->WriteMessage(static_cast(protocol::MessageType::WaitReply), - fbb.GetSize(), fbb.GetBufferPointer())); - // The client is unblocked now because the wait call has returned. - if (client_blocked) { - HandleClientUnblocked(client); - } - }); - RAY_CHECK_OK(status); + ProcessWaitRequestMessage(client, message_data); } break; case protocol::MessageType::PushErrorRequest: { - auto message = flatbuffers::GetRoot(message_data); - - JobID job_id = from_flatbuf(*message->job_id()); - auto const &type = string_from_flatbuf(*message->type()); - auto const &error_message = string_from_flatbuf(*message->error_message()); - double timestamp = message->timestamp(); - - RAY_CHECK_OK(gcs_client_->error_table().PushErrorToDriver(job_id, type, error_message, - timestamp)); + ProcessPushErrorRequestMessage(message_data); } break; case protocol::MessageType::PushProfileEventsRequest: { auto message = flatbuffers::GetRoot(message_data); - RAY_CHECK_OK(gcs_client_->profile_table().AddProfileEventBatch(*message)); } break; case protocol::MessageType::FreeObjectsInObjectStoreRequest: { @@ -775,6 +582,232 @@ void NodeManager::ProcessClientMessage( client->ProcessMessages(); } +void NodeManager::ProcessRegisterClientRequestMessage( + const std::shared_ptr &client, const uint8_t *message_data) { + auto message = flatbuffers::GetRoot(message_data); + client->SetClientID(from_flatbuf(*message->client_id())); + auto worker = + std::make_shared(message->worker_pid(), message->language(), client); + if (message->is_worker()) { + // Register the new worker. + worker_pool_.RegisterWorker(std::move(worker)); + DispatchTasks(); + } else { + // Register the new driver. Note that here the driver_id in RegisterClientRequest + // message is actually the ID of the driver task, while client_id represents the + // real driver ID, which can associate all the tasks/actors for a given driver, + // which is set to the worker ID. + const JobID driver_task_id = from_flatbuf(*message->driver_id()); + worker->AssignTaskId(driver_task_id); + worker->AssignDriverId(from_flatbuf(*message->client_id())); + worker_pool_.RegisterDriver(std::move(worker)); + local_queues_.AddDriverTaskId(driver_task_id); + } +} + +void NodeManager::ProcessGetTaskMessage( + const std::shared_ptr &client) { + std::shared_ptr worker = worker_pool_.GetRegisteredWorker(client); + RAY_CHECK(worker); + // If the worker was assigned a task, mark it as finished. + if (!worker->GetAssignedTaskId().is_nil()) { + FinishAssignedTask(*worker); + } + // Return the worker to the idle pool. + worker_pool_.PushWorker(std::move(worker)); + // Local resource availability changed: invoke scheduling policy for local node. + const ClientID &local_client_id = gcs_client_->client_table().GetLocalClientId(); + cluster_resource_map_[local_client_id].SetLoadResources( + local_queues_.GetResourceLoad()); + // Call task dispatch to assign work to the new worker. + DispatchTasks(); +} + +void NodeManager::ProcessDisconnectClientMessage( + const std::shared_ptr &client) { + // Remove the dead worker from the pool and stop listening for messages. + const std::shared_ptr worker = worker_pool_.GetRegisteredWorker(client); + + if (worker) { + // The client is a worker. Handle the case where the worker is killed + // while executing a task. Clean up the assigned task's resources, push + // an error to the driver. + // (See design_docs/task_states.rst for the state transition diagram.) + const TaskID &task_id = worker->GetAssignedTaskId(); + if (!task_id.is_nil() && !worker->IsDead()) { + // If the worker was killed intentionally, e.g., when the driver that created + // the task that this worker is currently executing exits, the task for this + // worker has already been removed from queue, so the following are skipped. + auto const &running_tasks = local_queues_.GetRunningTasks(); + // TODO(rkn): This is too heavyweight just to get the task's driver ID. + auto const it = std::find_if( + running_tasks.begin(), running_tasks.end(), [task_id](const Task &task) { + return task.GetTaskSpecification().TaskId() == task_id; + }); + RAY_CHECK(running_tasks.size() != 0); + RAY_CHECK(it != running_tasks.end()); + const TaskSpecification &spec = it->GetTaskSpecification(); + const JobID job_id = spec.DriverId(); + + // TODO(rkn): Define this constant somewhere else. + std::string type = "worker_died"; + std::ostringstream error_message; + error_message << "A worker died or was killed while executing task " << task_id + << "."; + RAY_CHECK_OK(gcs_client_->error_table().PushErrorToDriver( + job_id, type, error_message.str(), current_time_ms())); + + // Handle the task failure in order to raise an exception in the + // application. + TreatTaskAsFailed(spec); + task_dependency_manager_.TaskCanceled(spec.TaskId()); + local_queues_.RemoveTask(spec.TaskId()); + } + + worker_pool_.DisconnectWorker(worker); + + // If the worker was an actor, add it to the list of dead actors. + const ActorID actor_id = worker->GetActorId(); + if (!actor_id.is_nil()) { + // TODO(rkn): Consider broadcasting a message to all of the other + // node managers so that they can mark the actor as dead. + RAY_LOG(DEBUG) << "The actor with ID " << actor_id << " died."; + auto actor_entry = actor_registry_.find(actor_id); + RAY_CHECK(actor_entry != actor_registry_.end()); + actor_entry->second.MarkDead(); + // For dead actors, if there are remaining tasks for this actor, we + // should handle them. + CleanUpTasksForDeadActor(actor_id); + } + + const ClientID &client_id = gcs_client_->client_table().GetLocalClientId(); + + // Return the resources that were being used by this worker. + auto const &task_resources = worker->GetTaskResourceIds(); + local_available_resources_.Release(task_resources); + cluster_resource_map_[client_id].Release(task_resources.ToResourceSet()); + worker->ResetTaskResourceIds(); + + auto const &lifetime_resources = worker->GetLifetimeResourceIds(); + local_available_resources_.Release(lifetime_resources); + cluster_resource_map_[client_id].Release(lifetime_resources.ToResourceSet()); + worker->ResetLifetimeResourceIds(); + + RAY_LOG(DEBUG) << "Worker (pid=" << worker->Pid() << ") is disconnected. " + << "driver_id: " << worker->GetAssignedDriverId(); + + // Since some resources may have been released, we can try to dispatch more tasks. + DispatchTasks(); + } else { + // The client is a driver. + RAY_CHECK_OK(gcs_client_->driver_table().AppendDriverData(client->GetClientID(), + /*is_dead=*/true)); + const std::shared_ptr driver = worker_pool_.GetRegisteredDriver(client); + RAY_CHECK(driver); + auto driver_id = driver->GetAssignedTaskId(); + RAY_CHECK(!driver_id.is_nil()); + local_queues_.RemoveDriverTaskId(driver_id); + worker_pool_.DisconnectDriver(driver); + + RAY_LOG(DEBUG) << "Driver (pid=" << driver->Pid() << ") is disconnected. " + << "driver_id: " << driver->GetAssignedDriverId(); + } +} + +void NodeManager::ProcessSubmitTaskMessage(const uint8_t *message_data) { + // Read the task submitted by the client. + auto message = flatbuffers::GetRoot(message_data); + TaskExecutionSpecification task_execution_spec( + from_flatbuf(*message->execution_dependencies())); + TaskSpecification task_spec(*message->task_spec()); + Task task(task_execution_spec, task_spec); + // Submit the task to the local scheduler. Since the task was submitted + // locally, there is no uncommitted lineage. + SubmitTask(task, Lineage()); +} + +void NodeManager::ProcessReconstructObjectsMessage( + const std::shared_ptr &client, const uint8_t *message_data) { + auto message = flatbuffers::GetRoot(message_data); + std::vector required_object_ids; + for (size_t i = 0; i < message->object_ids()->size(); ++i) { + ObjectID object_id = from_flatbuf(*message->object_ids()->Get(i)); + if (!task_dependency_manager_.CheckObjectLocal(object_id)) { + if (message->fetch_only()) { + // If only a fetch is required, then do not subscribe to the + // dependencies to the task dependency manager. + RAY_CHECK_OK(object_manager_.Pull(object_id)); + } else { + // If reconstruction is also required, then add any missing objects + // to the list to subscribe to in the task dependency manager. These + // objects will be pulled from remote node managers and reconstructed + // if necessary. + required_object_ids.push_back(object_id); + } + } + } + + if (!required_object_ids.empty()) { + HandleClientBlocked(client, required_object_ids); + } +} + +void NodeManager::ProcessWaitRequestMessage( + const std::shared_ptr &client, const uint8_t *message_data) { + // Read the data. + auto message = flatbuffers::GetRoot(message_data); + std::vector object_ids = from_flatbuf(*message->object_ids()); + int64_t wait_ms = message->timeout(); + uint64_t num_required_objects = static_cast(message->num_ready_objects()); + bool wait_local = message->wait_local(); + + std::vector required_object_ids; + for (auto const &object_id : object_ids) { + if (!task_dependency_manager_.CheckObjectLocal(object_id)) { + // Add any missing objects to the list to subscribe to in the task + // dependency manager. These objects will be pulled from remote node + // managers and reconstructed if necessary. + required_object_ids.push_back(object_id); + } + } + + bool client_blocked = !required_object_ids.empty(); + if (client_blocked) { + HandleClientBlocked(client, required_object_ids); + } + + ray::Status status = object_manager_.Wait( + object_ids, wait_ms, num_required_objects, wait_local, + [this, client_blocked, client](std::vector found, + std::vector remaining) { + // Write the data. + flatbuffers::FlatBufferBuilder fbb; + flatbuffers::Offset wait_reply = protocol::CreateWaitReply( + fbb, to_flatbuf(fbb, found), to_flatbuf(fbb, remaining)); + fbb.Finish(wait_reply); + RAY_CHECK_OK( + client->WriteMessage(static_cast(protocol::MessageType::WaitReply), + fbb.GetSize(), fbb.GetBufferPointer())); + // The client is unblocked now because the wait call has returned. + if (client_blocked) { + HandleClientUnblocked(client); + } + }); + RAY_CHECK_OK(status); +} + +void NodeManager::ProcessPushErrorRequestMessage(const uint8_t *message_data) { + auto message = flatbuffers::GetRoot(message_data); + + JobID job_id = from_flatbuf(*message->job_id()); + auto const &type = string_from_flatbuf(*message->type()); + auto const &error_message = string_from_flatbuf(*message->error_message()); + double timestamp = message->timestamp(); + + RAY_CHECK_OK(gcs_client_->error_table().PushErrorToDriver(job_id, type, error_message, + timestamp)); +} + void NodeManager::ProcessNewNodeManager(TcpClientConnection &node_manager_client) { node_manager_client.ProcessMessages(); } diff --git a/src/ray/raylet/node_manager.h b/src/ray/raylet/node_manager.h index 7de408736..d8025a6c8 100644 --- a/src/ray/raylet/node_manager.h +++ b/src/ray/raylet/node_manager.h @@ -59,10 +59,10 @@ class NodeManager { /// /// \param client The client that sent the message. /// \param message_type The message type (e.g., a flatbuffer enum). - /// \param message A pointer to the message data. + /// \param message_data A pointer to the message data. /// \return Void. void ProcessClientMessage(const std::shared_ptr &client, - int64_t message_type, const uint8_t *message); + int64_t message_type, const uint8_t *message_data); /// Handle a new node manager connection. /// @@ -266,6 +266,55 @@ class NodeManager { /// \return True if the invariants are satisfied and false otherwise. bool CheckDependencyManagerInvariant() const; + /// Process client message of RegisterClientRequest + // + /// \param client The client that sent the message. + /// \param message_data A pointer to the message data. + /// \return Void. + void ProcessRegisterClientRequestMessage( + const std::shared_ptr &client, const uint8_t *message_data); + + /// Process client message of GetTask + // + /// \param client The client that sent the message. + /// \return Void. + void ProcessGetTaskMessage(const std::shared_ptr &client); + + /// Process client message of DisconnectClient + // + /// \param client The client that sent the message. + /// \return Void. + void ProcessDisconnectClientMessage( + const std::shared_ptr &client); + + /// Process client message of SubmitTask + // + /// \param message_data A pointer to the message data. + /// \return Void. + void ProcessSubmitTaskMessage(const uint8_t *message_data); + + /// Process client message of ReconstructObjects + // + /// \param client The client that sent the message. + /// \param message_data A pointer to the message data. + /// \return Void. + void ProcessReconstructObjectsMessage( + const std::shared_ptr &client, const uint8_t *message_data); + + /// Process client message of WaitRequest + // + /// \param client The client that sent the message. + /// \param message_data A pointer to the message data. + /// \return Void. + void ProcessWaitRequestMessage(const std::shared_ptr &client, + const uint8_t *message_data); + + /// Process client message of PushErrorRequest + // + /// \param message_data A pointer to the message data. + /// \return Void. + void ProcessPushErrorRequestMessage(const uint8_t *message_data); + boost::asio::io_service &io_service_; ObjectManager &object_manager_; /// A Plasma object store client. This is used exclusively for creating new