From 753ba76141bc8c426702e720f69ebfa0738b8bd6 Mon Sep 17 00:00:00 2001 From: Zhijun Fu <37800433+zhijunfu@users.noreply.github.com> Date: Fri, 7 Sep 2018 16:11:32 +0800 Subject: [PATCH] [Issue 2809][xray] Cleanup on driver detach (#2826) This change addresses issue #2809. Test #2797 has been enabled for raylet and can pass. The following should happen when a driver exits (either gracefully or ungracefully). #2797 should be enabled and pass. Any actors created by the driver that are still running should be killed. Any workers running tasks for the driver should be killed. Any tasks for the driver in any node_manager queues should be removed. Any future tasks received by a node manager for the driver should be ignored. The driver death notification should only be received once. --- src/ray/raylet/node_manager.cc | 116 ++++++++++++++++------ src/ray/raylet/node_manager.h | 28 +++--- src/ray/raylet/scheduling_queue.cc | 56 +++++++++++ src/ray/raylet/scheduling_queue.h | 12 +++ src/ray/raylet/task_dependency_manager.cc | 28 ++++++ src/ray/raylet/task_dependency_manager.h | 6 ++ src/ray/raylet/worker.cc | 11 ++ src/ray/raylet/worker.h | 8 ++ src/ray/raylet/worker_pool.cc | 17 ++++ src/ray/raylet/worker_pool.h | 7 ++ test/multi_node_test.py | 6 +- 11 files changed, 246 insertions(+), 49 deletions(-) diff --git a/src/ray/raylet/node_manager.cc b/src/ray/raylet/node_manager.cc index d850cbc09..7240fb1d7 100644 --- a/src/ray/raylet/node_manager.cc +++ b/src/ray/raylet/node_manager.cc @@ -210,14 +210,47 @@ ray::Status NodeManager::RegisterGcs() { return ray::Status::OK(); } +void NodeManager::KillWorker(std::shared_ptr worker) { + // If we're just cleaning up a single worker, allow it some time to clean + // up its state before force killing. The client socket will be closed + // and the worker struct will be freed after the timeout. + kill(worker->Pid(), SIGTERM); + + auto retry_timer = std::make_shared(io_service_); + auto retry_duration = boost::posix_time::milliseconds( + RayConfig::instance().kill_worker_timeout_milliseconds()); + retry_timer->expires_from_now(retry_duration); + retry_timer->async_wait([retry_timer, worker](const boost::system::error_code &error) { + RAY_LOG(DEBUG) << "Send SIGKILL to worker, pid=" << worker->Pid(); + // Force kill worker. + kill(worker->Pid(), SIGKILL); + }); +} + void NodeManager::HandleDriverTableUpdate( const ClientID &id, const std::vector &driver_data) { for (const auto &entry : driver_data) { RAY_LOG(DEBUG) << "HandleDriverTableUpdate " << UniqueID::from_binary(entry.driver_id) << " " << entry.is_dead; if (entry.is_dead) { - // TODO: Implement cleanup on driver death. For reference, - // see handle_driver_removed_callback in local_scheduler.cc + auto driver_id = UniqueID::from_binary(entry.driver_id); + auto workers = worker_pool_.GetWorkersRunningTasksForDriver(driver_id); + + // Kill all the workers. The actual cleanup for these workers is done + // later when we receive the DisconnectClient message from them. + for (const auto &worker : workers) { + // Mark the worker as dead so further messages from it are ignored + // (except DisconnectClient). + worker->MarkDead(); + // Then kill the worker process. + KillWorker(worker); + } + + // Remove all tasks for this driver from the scheduling queues, mark + // the results for these tasks as not required, cancel any attempts + // at reconstruction. Note that at this time the workers are likely + // alive because of the delay in killing workers. + CleanUpTasksForDeadDriver(driver_id); } } } @@ -439,32 +472,10 @@ void NodeManager::HandleActorCreation(const ActorID &actor_id, } } -void NodeManager::GetActorTasksFromList(const ActorID &actor_id, - const std::list &tasks, - std::unordered_set &tasks_to_remove) { - for (auto const &task : tasks) { - auto const &spec = task.GetTaskSpecification(); - if (actor_id == spec.ActorId()) { - tasks_to_remove.insert(spec.TaskId()); - } - } -} - void NodeManager::CleanUpTasksForDeadActor(const ActorID &actor_id) { - // TODO(rkn): The code below should be cleaned up when we improve the - // SchedulingQueue API. - std::unordered_set tasks_to_remove; - - // (See design_docs/task_states.rst for the state transition diagram.) - GetActorTasksFromList(actor_id, local_queues_.GetMethodsWaitingForActorCreation(), - tasks_to_remove); - GetActorTasksFromList(actor_id, local_queues_.GetWaitingTasks(), tasks_to_remove); - GetActorTasksFromList(actor_id, local_queues_.GetPlaceableTasks(), tasks_to_remove); - GetActorTasksFromList(actor_id, local_queues_.GetReadyTasks(), tasks_to_remove); - GetActorTasksFromList(actor_id, local_queues_.GetRunningTasks(), tasks_to_remove); - GetActorTasksFromList(actor_id, local_queues_.GetBlockedTasks(), tasks_to_remove); - + auto tasks_to_remove = local_queues_.GetTaskIdsForActor(actor_id); auto removed_tasks = local_queues_.RemoveTasks(tasks_to_remove); + for (auto const &task : removed_tasks) { const TaskSpecification &spec = task.GetTaskSpecification(); TreatTaskAsFailed(spec); @@ -472,6 +483,13 @@ void NodeManager::CleanUpTasksForDeadActor(const ActorID &actor_id) { } } +void NodeManager::CleanUpTasksForDeadDriver(const DriverID &driver_id) { + auto tasks_to_remove = local_queues_.GetTaskIdsForDriver(driver_id); + local_queues_.RemoveTasks(tasks_to_remove); + + task_dependency_manager_.RemoveTasksAndRelatedObjects(tasks_to_remove); +} + void NodeManager::ProcessNewClient(LocalClientConnection &client) { // The new client is a worker, so begin listening for messages. client.ProcessMessages(); @@ -506,6 +524,18 @@ void NodeManager::ProcessClientMessage( const uint8_t *message_data) { RAY_LOG(DEBUG) << "Message of type " << message_type; + auto registered_worker = worker_pool_.GetRegisteredWorker(client); + if (registered_worker && registered_worker->IsDead()) { + // For a worker that is marked as dead (because the driver has died already), + // all the messages are ignored except DisconnectClient. + if (static_cast(message_type) != + protocol::MessageType::DisconnectClient) { + // Listen for more messages. + client->ProcessMessages(); + return; + } + } + switch (static_cast(message_type)) { case protocol::MessageType::RegisterClientRequest: { auto message = flatbuffers::GetRoot(message_data); @@ -517,11 +547,15 @@ void NodeManager::ProcessClientMessage( worker_pool_.RegisterWorker(std::move(worker)); DispatchTasks(); } else { - // Register the new driver. - JobID job_id = from_flatbuf(*message->driver_id()); - worker->AssignTaskId(job_id); + // Register the new driver. Note that here the driver_id in RegisterClientRequest + // message is actually the ID of the driver task, while client_id represents the + // real driver ID, which can associate all the tasks/actors for a given driver, + // which is set to the worker ID. + const JobID driver_task_id = from_flatbuf(*message->driver_id()); + worker->AssignTaskId(driver_task_id); + worker->AssignDriverId(from_flatbuf(*message->client_id())); worker_pool_.RegisterDriver(std::move(worker)); - local_queues_.AddDriverTaskId(job_id); + local_queues_.AddDriverTaskId(driver_task_id); } } break; case protocol::MessageType::GetTask: { @@ -551,7 +585,10 @@ void NodeManager::ProcessClientMessage( // an error to the driver. // (See design_docs/task_states.rst for the state transition diagram.) const TaskID &task_id = worker->GetAssignedTaskId(); - if (!task_id.is_nil()) { + if (!task_id.is_nil() && !worker->IsDead()) { + // If the worker was killed intentionally, e.g., when the driver that created + // the task that this worker is currently executing exits, the task for this + // worker has already been removed from queue, so the following are skipped. auto const &running_tasks = local_queues_.GetRunningTasks(); // TODO(rkn): This is too heavyweight just to get the task's driver ID. auto const it = std::find_if( @@ -562,6 +599,7 @@ void NodeManager::ProcessClientMessage( RAY_CHECK(it != running_tasks.end()); const TaskSpecification &spec = it->GetTaskSpecification(); const JobID job_id = spec.DriverId(); + // TODO(rkn): Define this constant somewhere else. std::string type = "worker_died"; std::ostringstream error_message; @@ -606,6 +644,9 @@ void NodeManager::ProcessClientMessage( cluster_resource_map_[client_id].Release(lifetime_resources.ToResourceSet()); worker->ResetLifetimeResourceIds(); + RAY_LOG(DEBUG) << "Worker (pid=" << worker->Pid() << ") is disconnected. " + << "driver_id: " << worker->GetAssignedDriverId(); + // Since some resources may have been released, we can try to dispatch more tasks. DispatchTasks(); } else { @@ -618,6 +659,9 @@ void NodeManager::ProcessClientMessage( RAY_CHECK(!driver_id.is_nil()); local_queues_.RemoveDriverTaskId(driver_id); worker_pool_.DisconnectDriver(driver); + + RAY_LOG(DEBUG) << "Driver (pid=" << driver->Pid() << ") is disconnected. " + << "driver_id: " << driver->GetAssignedDriverId(); } return; } break; @@ -1151,6 +1195,7 @@ void NodeManager::AssignTask(Task &task) { if (status.ok()) { // We successfully assigned the task to the worker. worker->AssignTaskId(spec.TaskId()); + worker->AssignDriverId(spec.DriverId()); // If the task was an actor task, then record this execution to guarantee // consistency in the case of reconstruction. if (spec.IsActorTask()) { @@ -1220,7 +1265,9 @@ void NodeManager::FinishAssignedTask(Worker &worker) { actor_notification->driver_id = JobID::nil().binary(); actor_notification->node_manager_id = gcs_client_->client_table().GetLocalClientId().binary(); - RAY_LOG(DEBUG) << "Publishing actor creation: " << actor_id; + auto driver_id = task.GetTaskSpecification().DriverId(); + RAY_LOG(DEBUG) << "Publishing actor creation: " << actor_id + << " driver_id: " << driver_id; RAY_CHECK_OK(gcs_client_->actor_table().Append(JobID::nil(), actor_id, actor_notification, nullptr)); @@ -1251,6 +1298,11 @@ void NodeManager::FinishAssignedTask(Worker &worker) { // Unset the worker's assigned task. worker.AssignTaskId(TaskID::nil()); + // Unset the worker's assigned driver Id if this is not an actor. + if (!task.GetTaskSpecification().IsActorCreationTask() && + !task.GetTaskSpecification().IsActorTask()) { + worker.AssignDriverId(DriverID::nil()); + } } void NodeManager::HandleTaskReconstruction(const TaskID &task_id) { diff --git a/src/ray/raylet/node_manager.h b/src/ray/raylet/node_manager.h index 76d95c293..07e5877de 100644 --- a/src/ray/raylet/node_manager.h +++ b/src/ray/raylet/node_manager.h @@ -192,6 +192,12 @@ class NodeManager { /// \return Void. void HandleWorkerUnblocked(std::shared_ptr worker); + /// Kill a worker. + /// + /// \param worker The worker to kill. + /// \return Void. + void KillWorker(std::shared_ptr worker); + /// Methods for actor scheduling. /// Handler for the creation of an actor, possibly on a remote node. /// @@ -201,21 +207,6 @@ class NodeManager { void HandleActorCreation(const ActorID &actor_id, const std::vector &data); - /// TODO(rkn): This should probably be removed when we improve the - /// SchedulingQueue API. This is a helper function for - /// CleanUpTasksForDeadActor. - /// - /// This essentially loops over all of the tasks in the provided list and - /// finds The IDs of the tasks that belong to the given actor. - /// - /// \param actor_id The actor to get the tasks for. - /// \param tasks A list of tasks to extract from. - /// \param tasks_to_remove The task IDs of the extracted tasks are inserted in - /// this vector. - /// \return Void. - void GetActorTasksFromList(const ActorID &actor_id, const std::list &tasks, - std::unordered_set &tasks_to_remove); - /// When an actor dies, loop over all of the queued tasks for that actor and /// treat them as failed. /// @@ -223,6 +214,13 @@ class NodeManager { /// \return Void. void CleanUpTasksForDeadActor(const ActorID &actor_id); + /// When a driver dies, loop over all of the queued tasks for that driver and + /// treat them as failed. + /// + /// \param driver_id The driver that died. + /// \return Void. + void CleanUpTasksForDeadDriver(const DriverID &driver_id); + /// Handle an object becoming local. This updates any local accounting, but /// does not write to any global accounting in the GCS. /// diff --git a/src/ray/raylet/scheduling_queue.cc b/src/ray/raylet/scheduling_queue.cc index 290857611..7943fc2e3 100644 --- a/src/ray/raylet/scheduling_queue.cc +++ b/src/ray/raylet/scheduling_queue.cc @@ -39,6 +39,32 @@ inline void FilterStateFromQueue(const ray::raylet::SchedulingQueue::TaskQueue & } } +// Helper function to get tasks for a driver from a given state. +inline void GetDriverTasksFromQueue(const ray::raylet::SchedulingQueue::TaskQueue &queue, + const ray::DriverID &driver_id, + std::unordered_set &task_ids) { + const auto &tasks = queue.GetTasks(); + for (const auto &task : tasks) { + auto const &spec = task.GetTaskSpecification(); + if (driver_id == spec.DriverId()) { + task_ids.insert(spec.TaskId()); + } + } +} + +// Helper function to get tasks for an actor from a given state. +inline void GetActorTasksFromQueue(const ray::raylet::SchedulingQueue::TaskQueue &queue, + const ray::ActorID &actor_id, + std::unordered_set &task_ids) { + const auto &tasks = queue.GetTasks(); + for (const auto &task : tasks) { + auto const &spec = task.GetTaskSpecification(); + if (actor_id == spec.ActorId()) { + task_ids.insert(spec.TaskId()); + } + } +} + } // namespace namespace ray { @@ -285,6 +311,36 @@ void SchedulingQueue::QueueBlockedTasks(const std::vector &tasks) { QueueTasks(blocked_tasks_, tasks); } +std::unordered_set SchedulingQueue::GetTaskIdsForDriver( + const DriverID &driver_id) const { + std::unordered_set task_ids; + + GetDriverTasksFromQueue(methods_waiting_for_actor_creation_, driver_id, task_ids); + GetDriverTasksFromQueue(waiting_tasks_, driver_id, task_ids); + GetDriverTasksFromQueue(placeable_tasks_, driver_id, task_ids); + GetDriverTasksFromQueue(ready_tasks_, driver_id, task_ids); + GetDriverTasksFromQueue(running_tasks_, driver_id, task_ids); + GetDriverTasksFromQueue(blocked_tasks_, driver_id, task_ids); + GetDriverTasksFromQueue(infeasible_tasks_, driver_id, task_ids); + + return task_ids; +} + +std::unordered_set SchedulingQueue::GetTaskIdsForActor( + const ActorID &actor_id) const { + std::unordered_set task_ids; + + GetActorTasksFromQueue(methods_waiting_for_actor_creation_, actor_id, task_ids); + GetActorTasksFromQueue(waiting_tasks_, actor_id, task_ids); + GetActorTasksFromQueue(placeable_tasks_, actor_id, task_ids); + GetActorTasksFromQueue(ready_tasks_, actor_id, task_ids); + GetActorTasksFromQueue(running_tasks_, actor_id, task_ids); + GetActorTasksFromQueue(blocked_tasks_, actor_id, task_ids); + GetActorTasksFromQueue(infeasible_tasks_, actor_id, task_ids); + + return task_ids; +} + void SchedulingQueue::AddDriverTaskId(const TaskID &driver_id) { auto inserted = driver_task_ids_.insert(driver_id); RAY_CHECK(inserted.second); diff --git a/src/ray/raylet/scheduling_queue.h b/src/ray/raylet/scheduling_queue.h index 6b9635586..d8ecf6ae9 100644 --- a/src/ray/raylet/scheduling_queue.h +++ b/src/ray/raylet/scheduling_queue.h @@ -174,6 +174,18 @@ class SchedulingQueue { /// \param filter_state The task state to filter out. void FilterState(std::unordered_set &task_ids, TaskState filter_state) const; + /// \brief Get all the task IDs for a driver. + /// + /// \param driver_id All the tasks that have the given driver_id are returned. + /// \return All the tasks that have the given driver ID. + std::unordered_set GetTaskIdsForDriver(const DriverID &driver_id) const; + + /// \brief Get all the task IDs for an actor. + /// + /// \param actor_id All the tasks that have the given actor_id are returned. + /// \return All the tasks that have the given actor ID. + std::unordered_set GetTaskIdsForActor(const ActorID &actor_id) const; + /// \brief Return all resource demand associated with the ready queue. /// /// \return Aggregate resource demand from ready tasks. diff --git a/src/ray/raylet/task_dependency_manager.cc b/src/ray/raylet/task_dependency_manager.cc index 247a83fcb..df4546efa 100644 --- a/src/ray/raylet/task_dependency_manager.cc +++ b/src/ray/raylet/task_dependency_manager.cc @@ -174,6 +174,7 @@ void TaskDependencyManager::UnsubscribeDependencies(const TaskID &task_id) { // Remove the task from the table of subscribed tasks. auto it = task_dependencies_.find(task_id); RAY_CHECK(it != task_dependencies_.end()); + const TaskDependencies task_entry = std::move(it->second); task_dependencies_.erase(it); @@ -297,6 +298,33 @@ void TaskDependencyManager::TaskCanceled(const TaskID &task_id) { } } +void TaskDependencyManager::RemoveTasksAndRelatedObjects( + const std::unordered_set &task_ids) { + if (task_ids.empty()) { + return; + } + + for (auto it = task_ids.begin(); it != task_ids.end(); it++) { + task_dependencies_.erase(*it); + required_tasks_.erase(*it); + pending_tasks_.erase(*it); + } + + // TODO: the size of required_objects_ could be large, consider to add + // an index if this turns out to be a perf problem. + for (auto it = required_objects_.begin(); it != required_objects_.end();) { + const auto object_id = *it; + TaskID creating_task_id = ComputeTaskId(object_id); + if (task_ids.find(creating_task_id) != task_ids.end()) { + object_manager_.CancelPull(object_id); + reconstruction_policy_.Cancel(object_id); + it = required_objects_.erase(it); + } else { + it++; + } + } +} + } // namespace raylet } // namespace ray diff --git a/src/ray/raylet/task_dependency_manager.h b/src/ray/raylet/task_dependency_manager.h index ea795ad43..84c47bd16 100644 --- a/src/ray/raylet/task_dependency_manager.h +++ b/src/ray/raylet/task_dependency_manager.h @@ -105,6 +105,12 @@ class TaskDependencyManager { /// \return Return a vector of TaskIDs for tasks registered as pending. std::vector GetPendingTasks() const; + /// Remove all of the tasks specified, and all the objects created by + /// these tasks from task dependency manager. + /// + /// \param task_ids The collection of task IDs. + void RemoveTasksAndRelatedObjects(const std::unordered_set &task_ids); + private: using ObjectDependencyMap = std::unordered_map>; diff --git a/src/ray/raylet/worker.cc b/src/ray/raylet/worker.cc index b7a70297f..4cbe17104 100644 --- a/src/ray/raylet/worker.cc +++ b/src/ray/raylet/worker.cc @@ -18,8 +18,13 @@ Worker::Worker(pid_t pid, const Language &language, connection_(connection), assigned_task_id_(TaskID::nil()), actor_id_(ActorID::nil()), + dead_(false), blocked_(false) {} +void Worker::MarkDead() { dead_ = true; } + +bool Worker::IsDead() const { return dead_; } + void Worker::MarkBlocked() { blocked_ = true; } void Worker::MarkUnblocked() { blocked_ = false; } @@ -34,6 +39,12 @@ void Worker::AssignTaskId(const TaskID &task_id) { assigned_task_id_ = task_id; const TaskID &Worker::GetAssignedTaskId() const { return assigned_task_id_; } +void Worker::AssignDriverId(const DriverID &driver_id) { + assigned_driver_id_ = driver_id; +} + +const DriverID &Worker::GetAssignedDriverId() const { return assigned_driver_id_; } + void Worker::AssignActorId(const ActorID &actor_id) { RAY_CHECK(actor_id_.is_nil()) << "A worker that is already an actor cannot be assigned an actor ID again."; diff --git a/src/ray/raylet/worker.h b/src/ray/raylet/worker.h index fe200a2ca..c6ec7bac8 100644 --- a/src/ray/raylet/worker.h +++ b/src/ray/raylet/worker.h @@ -21,6 +21,8 @@ class Worker { std::shared_ptr connection); /// A destructor responsible for freeing all worker state. ~Worker() {} + void MarkDead(); + bool IsDead() const; void MarkBlocked(); void MarkUnblocked(); bool IsBlocked() const; @@ -29,6 +31,8 @@ class Worker { Language GetLanguage() const; void AssignTaskId(const TaskID &task_id); const TaskID &GetAssignedTaskId() const; + void AssignDriverId(const DriverID &driver_id); + const DriverID &GetAssignedDriverId() const; void AssignActorId(const ActorID &actor_id); const ActorID &GetActorId() const; /// Return the worker's connection. @@ -53,8 +57,12 @@ class Worker { std::shared_ptr connection_; /// The worker's currently assigned task. TaskID assigned_task_id_; + /// Driver ID for the worker's current assigned task. + DriverID assigned_driver_id_; /// The worker's actor ID. If this is nil, then the worker is not an actor. ActorID actor_id_; + /// Whether the worker is dead. + bool dead_; /// Whether the worker is blocked. Workers become blocked in a `ray.get`, if /// they require a data dependency while executing a task. bool blocked_; diff --git a/src/ray/raylet/worker_pool.cc b/src/ray/raylet/worker_pool.cc index b31318302..e06743b5c 100644 --- a/src/ray/raylet/worker_pool.cc +++ b/src/ray/raylet/worker_pool.cc @@ -232,6 +232,23 @@ inline WorkerPool::State &WorkerPool::GetStateForLanguage(const Language &langua return state->second; } +std::vector> WorkerPool::GetWorkersRunningTasksForDriver( + const DriverID &driver_id) const { + std::vector> workers; + + for (const auto &entry : states_by_lang_) { + for (const auto &worker : entry.second.registered_workers) { + RAY_LOG(DEBUG) << "worker: pid : " << worker->Pid() + << " driver_id: " << worker->GetAssignedDriverId(); + if (worker->GetAssignedDriverId() == driver_id) { + workers.push_back(worker); + } + } + } + + return workers; +} + } // namespace raylet } // namespace ray diff --git a/src/ray/raylet/worker_pool.h b/src/ray/raylet/worker_pool.h index 368f0e4f3..528cc917d 100644 --- a/src/ray/raylet/worker_pool.h +++ b/src/ray/raylet/worker_pool.h @@ -111,6 +111,13 @@ class WorkerPool { /// \return The total count of all workers (actor and non-actor) in the pool. uint32_t Size(const Language &language) const; + /// Get all the workers which are running tasks for a given driver. + /// + /// \param driver_id The driver ID. + /// \return A list containing all the workers which are running tasks for the driver. + std::vector> GetWorkersRunningTasksForDriver( + const DriverID &driver_id) const; + protected: /// A map from the pids of starting worker processes /// to the number of their unregistered workers. diff --git a/test/multi_node_test.py b/test/multi_node_test.py index c26aa8d11..657c03710 100644 --- a/test/multi_node_test.py +++ b/test/multi_node_test.py @@ -239,7 +239,9 @@ def ray_start_head_with_resources(): subprocess.Popen(["ray", "stop"]).wait() -@pytest.mark.skip(reason="This test does not work yet.") +@pytest.mark.skipif( + os.environ.get("RAY_USE_XRAY") != "1", + reason="This test only works with xray.") def test_drivers_release_resources(ray_start_head_with_resources): redis_address = ray_start_head_with_resources @@ -278,7 +280,7 @@ print("success") driver_script2 = (driver_script1 + "import sys\nsys.stdout.flush()\ntime.sleep(10 ** 6)\n") - def wait_for_success_output(process_handle, timeout=100): + def wait_for_success_output(process_handle, timeout=10): # Wait until the process prints "success" and then return. start_time = time.time() while time.time() - start_time < timeout: