[Issue 2809][xray] Cleanup on driver detach (#2826)

This change addresses issue #2809. Test #2797 has been enabled for raylet and can pass.

The following should happen when a driver exits (either gracefully or ungracefully).

#2797 should be enabled and pass.
Any actors created by the driver that are still running should be killed.
Any workers running tasks for the driver should be killed.
Any tasks for the driver in any node_manager queues should be removed.
Any future tasks received by a node manager for the driver should be ignored.
The driver death notification should only be received once.
This commit is contained in:
Zhijun Fu
2018-09-07 16:11:32 +08:00
committed by Hao Chen
parent 3f6ed537a4
commit 753ba76141
11 changed files with 246 additions and 49 deletions
+84 -32
View File
@@ -210,14 +210,47 @@ ray::Status NodeManager::RegisterGcs() {
return ray::Status::OK();
}
void NodeManager::KillWorker(std::shared_ptr<Worker> worker) {
// If we're just cleaning up a single worker, allow it some time to clean
// up its state before force killing. The client socket will be closed
// and the worker struct will be freed after the timeout.
kill(worker->Pid(), SIGTERM);
auto retry_timer = std::make_shared<boost::asio::deadline_timer>(io_service_);
auto retry_duration = boost::posix_time::milliseconds(
RayConfig::instance().kill_worker_timeout_milliseconds());
retry_timer->expires_from_now(retry_duration);
retry_timer->async_wait([retry_timer, worker](const boost::system::error_code &error) {
RAY_LOG(DEBUG) << "Send SIGKILL to worker, pid=" << worker->Pid();
// Force kill worker.
kill(worker->Pid(), SIGKILL);
});
}
void NodeManager::HandleDriverTableUpdate(
const ClientID &id, const std::vector<DriverTableDataT> &driver_data) {
for (const auto &entry : driver_data) {
RAY_LOG(DEBUG) << "HandleDriverTableUpdate " << UniqueID::from_binary(entry.driver_id)
<< " " << entry.is_dead;
if (entry.is_dead) {
// TODO: Implement cleanup on driver death. For reference,
// see handle_driver_removed_callback in local_scheduler.cc
auto driver_id = UniqueID::from_binary(entry.driver_id);
auto workers = worker_pool_.GetWorkersRunningTasksForDriver(driver_id);
// Kill all the workers. The actual cleanup for these workers is done
// later when we receive the DisconnectClient message from them.
for (const auto &worker : workers) {
// Mark the worker as dead so further messages from it are ignored
// (except DisconnectClient).
worker->MarkDead();
// Then kill the worker process.
KillWorker(worker);
}
// Remove all tasks for this driver from the scheduling queues, mark
// the results for these tasks as not required, cancel any attempts
// at reconstruction. Note that at this time the workers are likely
// alive because of the delay in killing workers.
CleanUpTasksForDeadDriver(driver_id);
}
}
}
@@ -439,32 +472,10 @@ void NodeManager::HandleActorCreation(const ActorID &actor_id,
}
}
void NodeManager::GetActorTasksFromList(const ActorID &actor_id,
const std::list<Task> &tasks,
std::unordered_set<TaskID> &tasks_to_remove) {
for (auto const &task : tasks) {
auto const &spec = task.GetTaskSpecification();
if (actor_id == spec.ActorId()) {
tasks_to_remove.insert(spec.TaskId());
}
}
}
void NodeManager::CleanUpTasksForDeadActor(const ActorID &actor_id) {
// TODO(rkn): The code below should be cleaned up when we improve the
// SchedulingQueue API.
std::unordered_set<TaskID> tasks_to_remove;
// (See design_docs/task_states.rst for the state transition diagram.)
GetActorTasksFromList(actor_id, local_queues_.GetMethodsWaitingForActorCreation(),
tasks_to_remove);
GetActorTasksFromList(actor_id, local_queues_.GetWaitingTasks(), tasks_to_remove);
GetActorTasksFromList(actor_id, local_queues_.GetPlaceableTasks(), tasks_to_remove);
GetActorTasksFromList(actor_id, local_queues_.GetReadyTasks(), tasks_to_remove);
GetActorTasksFromList(actor_id, local_queues_.GetRunningTasks(), tasks_to_remove);
GetActorTasksFromList(actor_id, local_queues_.GetBlockedTasks(), tasks_to_remove);
auto tasks_to_remove = local_queues_.GetTaskIdsForActor(actor_id);
auto removed_tasks = local_queues_.RemoveTasks(tasks_to_remove);
for (auto const &task : removed_tasks) {
const TaskSpecification &spec = task.GetTaskSpecification();
TreatTaskAsFailed(spec);
@@ -472,6 +483,13 @@ void NodeManager::CleanUpTasksForDeadActor(const ActorID &actor_id) {
}
}
void NodeManager::CleanUpTasksForDeadDriver(const DriverID &driver_id) {
auto tasks_to_remove = local_queues_.GetTaskIdsForDriver(driver_id);
local_queues_.RemoveTasks(tasks_to_remove);
task_dependency_manager_.RemoveTasksAndRelatedObjects(tasks_to_remove);
}
void NodeManager::ProcessNewClient(LocalClientConnection &client) {
// The new client is a worker, so begin listening for messages.
client.ProcessMessages();
@@ -506,6 +524,18 @@ void NodeManager::ProcessClientMessage(
const uint8_t *message_data) {
RAY_LOG(DEBUG) << "Message of type " << message_type;
auto registered_worker = worker_pool_.GetRegisteredWorker(client);
if (registered_worker && registered_worker->IsDead()) {
// For a worker that is marked as dead (because the driver has died already),
// all the messages are ignored except DisconnectClient.
if (static_cast<protocol::MessageType>(message_type) !=
protocol::MessageType::DisconnectClient) {
// Listen for more messages.
client->ProcessMessages();
return;
}
}
switch (static_cast<protocol::MessageType>(message_type)) {
case protocol::MessageType::RegisterClientRequest: {
auto message = flatbuffers::GetRoot<protocol::RegisterClientRequest>(message_data);
@@ -517,11 +547,15 @@ void NodeManager::ProcessClientMessage(
worker_pool_.RegisterWorker(std::move(worker));
DispatchTasks();
} else {
// Register the new driver.
JobID job_id = from_flatbuf(*message->driver_id());
worker->AssignTaskId(job_id);
// Register the new driver. Note that here the driver_id in RegisterClientRequest
// message is actually the ID of the driver task, while client_id represents the
// real driver ID, which can associate all the tasks/actors for a given driver,
// which is set to the worker ID.
const JobID driver_task_id = from_flatbuf(*message->driver_id());
worker->AssignTaskId(driver_task_id);
worker->AssignDriverId(from_flatbuf(*message->client_id()));
worker_pool_.RegisterDriver(std::move(worker));
local_queues_.AddDriverTaskId(job_id);
local_queues_.AddDriverTaskId(driver_task_id);
}
} break;
case protocol::MessageType::GetTask: {
@@ -551,7 +585,10 @@ void NodeManager::ProcessClientMessage(
// an error to the driver.
// (See design_docs/task_states.rst for the state transition diagram.)
const TaskID &task_id = worker->GetAssignedTaskId();
if (!task_id.is_nil()) {
if (!task_id.is_nil() && !worker->IsDead()) {
// If the worker was killed intentionally, e.g., when the driver that created
// the task that this worker is currently executing exits, the task for this
// worker has already been removed from queue, so the following are skipped.
auto const &running_tasks = local_queues_.GetRunningTasks();
// TODO(rkn): This is too heavyweight just to get the task's driver ID.
auto const it = std::find_if(
@@ -562,6 +599,7 @@ void NodeManager::ProcessClientMessage(
RAY_CHECK(it != running_tasks.end());
const TaskSpecification &spec = it->GetTaskSpecification();
const JobID job_id = spec.DriverId();
// TODO(rkn): Define this constant somewhere else.
std::string type = "worker_died";
std::ostringstream error_message;
@@ -606,6 +644,9 @@ void NodeManager::ProcessClientMessage(
cluster_resource_map_[client_id].Release(lifetime_resources.ToResourceSet());
worker->ResetLifetimeResourceIds();
RAY_LOG(DEBUG) << "Worker (pid=" << worker->Pid() << ") is disconnected. "
<< "driver_id: " << worker->GetAssignedDriverId();
// Since some resources may have been released, we can try to dispatch more tasks.
DispatchTasks();
} else {
@@ -618,6 +659,9 @@ void NodeManager::ProcessClientMessage(
RAY_CHECK(!driver_id.is_nil());
local_queues_.RemoveDriverTaskId(driver_id);
worker_pool_.DisconnectDriver(driver);
RAY_LOG(DEBUG) << "Driver (pid=" << driver->Pid() << ") is disconnected. "
<< "driver_id: " << driver->GetAssignedDriverId();
}
return;
} break;
@@ -1151,6 +1195,7 @@ void NodeManager::AssignTask(Task &task) {
if (status.ok()) {
// We successfully assigned the task to the worker.
worker->AssignTaskId(spec.TaskId());
worker->AssignDriverId(spec.DriverId());
// If the task was an actor task, then record this execution to guarantee
// consistency in the case of reconstruction.
if (spec.IsActorTask()) {
@@ -1220,7 +1265,9 @@ void NodeManager::FinishAssignedTask(Worker &worker) {
actor_notification->driver_id = JobID::nil().binary();
actor_notification->node_manager_id =
gcs_client_->client_table().GetLocalClientId().binary();
RAY_LOG(DEBUG) << "Publishing actor creation: " << actor_id;
auto driver_id = task.GetTaskSpecification().DriverId();
RAY_LOG(DEBUG) << "Publishing actor creation: " << actor_id
<< " driver_id: " << driver_id;
RAY_CHECK_OK(gcs_client_->actor_table().Append(JobID::nil(), actor_id,
actor_notification, nullptr));
@@ -1251,6 +1298,11 @@ void NodeManager::FinishAssignedTask(Worker &worker) {
// Unset the worker's assigned task.
worker.AssignTaskId(TaskID::nil());
// Unset the worker's assigned driver Id if this is not an actor.
if (!task.GetTaskSpecification().IsActorCreationTask() &&
!task.GetTaskSpecification().IsActorTask()) {
worker.AssignDriverId(DriverID::nil());
}
}
void NodeManager::HandleTaskReconstruction(const TaskID &task_id) {
+13 -15
View File
@@ -192,6 +192,12 @@ class NodeManager {
/// \return Void.
void HandleWorkerUnblocked(std::shared_ptr<Worker> worker);
/// Kill a worker.
///
/// \param worker The worker to kill.
/// \return Void.
void KillWorker(std::shared_ptr<Worker> worker);
/// Methods for actor scheduling.
/// Handler for the creation of an actor, possibly on a remote node.
///
@@ -201,21 +207,6 @@ class NodeManager {
void HandleActorCreation(const ActorID &actor_id,
const std::vector<ActorTableDataT> &data);
/// TODO(rkn): This should probably be removed when we improve the
/// SchedulingQueue API. This is a helper function for
/// CleanUpTasksForDeadActor.
///
/// This essentially loops over all of the tasks in the provided list and
/// finds The IDs of the tasks that belong to the given actor.
///
/// \param actor_id The actor to get the tasks for.
/// \param tasks A list of tasks to extract from.
/// \param tasks_to_remove The task IDs of the extracted tasks are inserted in
/// this vector.
/// \return Void.
void GetActorTasksFromList(const ActorID &actor_id, const std::list<Task> &tasks,
std::unordered_set<TaskID> &tasks_to_remove);
/// When an actor dies, loop over all of the queued tasks for that actor and
/// treat them as failed.
///
@@ -223,6 +214,13 @@ class NodeManager {
/// \return Void.
void CleanUpTasksForDeadActor(const ActorID &actor_id);
/// When a driver dies, loop over all of the queued tasks for that driver and
/// treat them as failed.
///
/// \param driver_id The driver that died.
/// \return Void.
void CleanUpTasksForDeadDriver(const DriverID &driver_id);
/// Handle an object becoming local. This updates any local accounting, but
/// does not write to any global accounting in the GCS.
///
+56
View File
@@ -39,6 +39,32 @@ inline void FilterStateFromQueue(const ray::raylet::SchedulingQueue::TaskQueue &
}
}
// Helper function to get tasks for a driver from a given state.
inline void GetDriverTasksFromQueue(const ray::raylet::SchedulingQueue::TaskQueue &queue,
const ray::DriverID &driver_id,
std::unordered_set<ray::TaskID> &task_ids) {
const auto &tasks = queue.GetTasks();
for (const auto &task : tasks) {
auto const &spec = task.GetTaskSpecification();
if (driver_id == spec.DriverId()) {
task_ids.insert(spec.TaskId());
}
}
}
// Helper function to get tasks for an actor from a given state.
inline void GetActorTasksFromQueue(const ray::raylet::SchedulingQueue::TaskQueue &queue,
const ray::ActorID &actor_id,
std::unordered_set<ray::TaskID> &task_ids) {
const auto &tasks = queue.GetTasks();
for (const auto &task : tasks) {
auto const &spec = task.GetTaskSpecification();
if (actor_id == spec.ActorId()) {
task_ids.insert(spec.TaskId());
}
}
}
} // namespace
namespace ray {
@@ -285,6 +311,36 @@ void SchedulingQueue::QueueBlockedTasks(const std::vector<Task> &tasks) {
QueueTasks(blocked_tasks_, tasks);
}
std::unordered_set<TaskID> SchedulingQueue::GetTaskIdsForDriver(
const DriverID &driver_id) const {
std::unordered_set<TaskID> task_ids;
GetDriverTasksFromQueue(methods_waiting_for_actor_creation_, driver_id, task_ids);
GetDriverTasksFromQueue(waiting_tasks_, driver_id, task_ids);
GetDriverTasksFromQueue(placeable_tasks_, driver_id, task_ids);
GetDriverTasksFromQueue(ready_tasks_, driver_id, task_ids);
GetDriverTasksFromQueue(running_tasks_, driver_id, task_ids);
GetDriverTasksFromQueue(blocked_tasks_, driver_id, task_ids);
GetDriverTasksFromQueue(infeasible_tasks_, driver_id, task_ids);
return task_ids;
}
std::unordered_set<TaskID> SchedulingQueue::GetTaskIdsForActor(
const ActorID &actor_id) const {
std::unordered_set<TaskID> task_ids;
GetActorTasksFromQueue(methods_waiting_for_actor_creation_, actor_id, task_ids);
GetActorTasksFromQueue(waiting_tasks_, actor_id, task_ids);
GetActorTasksFromQueue(placeable_tasks_, actor_id, task_ids);
GetActorTasksFromQueue(ready_tasks_, actor_id, task_ids);
GetActorTasksFromQueue(running_tasks_, actor_id, task_ids);
GetActorTasksFromQueue(blocked_tasks_, actor_id, task_ids);
GetActorTasksFromQueue(infeasible_tasks_, actor_id, task_ids);
return task_ids;
}
void SchedulingQueue::AddDriverTaskId(const TaskID &driver_id) {
auto inserted = driver_task_ids_.insert(driver_id);
RAY_CHECK(inserted.second);
+12
View File
@@ -174,6 +174,18 @@ class SchedulingQueue {
/// \param filter_state The task state to filter out.
void FilterState(std::unordered_set<TaskID> &task_ids, TaskState filter_state) const;
/// \brief Get all the task IDs for a driver.
///
/// \param driver_id All the tasks that have the given driver_id are returned.
/// \return All the tasks that have the given driver ID.
std::unordered_set<TaskID> GetTaskIdsForDriver(const DriverID &driver_id) const;
/// \brief Get all the task IDs for an actor.
///
/// \param actor_id All the tasks that have the given actor_id are returned.
/// \return All the tasks that have the given actor ID.
std::unordered_set<TaskID> GetTaskIdsForActor(const ActorID &actor_id) const;
/// \brief Return all resource demand associated with the ready queue.
///
/// \return Aggregate resource demand from ready tasks.
+28
View File
@@ -174,6 +174,7 @@ void TaskDependencyManager::UnsubscribeDependencies(const TaskID &task_id) {
// Remove the task from the table of subscribed tasks.
auto it = task_dependencies_.find(task_id);
RAY_CHECK(it != task_dependencies_.end());
const TaskDependencies task_entry = std::move(it->second);
task_dependencies_.erase(it);
@@ -297,6 +298,33 @@ void TaskDependencyManager::TaskCanceled(const TaskID &task_id) {
}
}
void TaskDependencyManager::RemoveTasksAndRelatedObjects(
const std::unordered_set<TaskID> &task_ids) {
if (task_ids.empty()) {
return;
}
for (auto it = task_ids.begin(); it != task_ids.end(); it++) {
task_dependencies_.erase(*it);
required_tasks_.erase(*it);
pending_tasks_.erase(*it);
}
// TODO: the size of required_objects_ could be large, consider to add
// an index if this turns out to be a perf problem.
for (auto it = required_objects_.begin(); it != required_objects_.end();) {
const auto object_id = *it;
TaskID creating_task_id = ComputeTaskId(object_id);
if (task_ids.find(creating_task_id) != task_ids.end()) {
object_manager_.CancelPull(object_id);
reconstruction_policy_.Cancel(object_id);
it = required_objects_.erase(it);
} else {
it++;
}
}
}
} // namespace raylet
} // namespace ray
+6
View File
@@ -105,6 +105,12 @@ class TaskDependencyManager {
/// \return Return a vector of TaskIDs for tasks registered as pending.
std::vector<TaskID> GetPendingTasks() const;
/// Remove all of the tasks specified, and all the objects created by
/// these tasks from task dependency manager.
///
/// \param task_ids The collection of task IDs.
void RemoveTasksAndRelatedObjects(const std::unordered_set<TaskID> &task_ids);
private:
using ObjectDependencyMap = std::unordered_map<ray::ObjectID, std::vector<ray::TaskID>>;
+11
View File
@@ -18,8 +18,13 @@ Worker::Worker(pid_t pid, const Language &language,
connection_(connection),
assigned_task_id_(TaskID::nil()),
actor_id_(ActorID::nil()),
dead_(false),
blocked_(false) {}
void Worker::MarkDead() { dead_ = true; }
bool Worker::IsDead() const { return dead_; }
void Worker::MarkBlocked() { blocked_ = true; }
void Worker::MarkUnblocked() { blocked_ = false; }
@@ -34,6 +39,12 @@ void Worker::AssignTaskId(const TaskID &task_id) { assigned_task_id_ = task_id;
const TaskID &Worker::GetAssignedTaskId() const { return assigned_task_id_; }
void Worker::AssignDriverId(const DriverID &driver_id) {
assigned_driver_id_ = driver_id;
}
const DriverID &Worker::GetAssignedDriverId() const { return assigned_driver_id_; }
void Worker::AssignActorId(const ActorID &actor_id) {
RAY_CHECK(actor_id_.is_nil())
<< "A worker that is already an actor cannot be assigned an actor ID again.";
+8
View File
@@ -21,6 +21,8 @@ class Worker {
std::shared_ptr<LocalClientConnection> connection);
/// A destructor responsible for freeing all worker state.
~Worker() {}
void MarkDead();
bool IsDead() const;
void MarkBlocked();
void MarkUnblocked();
bool IsBlocked() const;
@@ -29,6 +31,8 @@ class Worker {
Language GetLanguage() const;
void AssignTaskId(const TaskID &task_id);
const TaskID &GetAssignedTaskId() const;
void AssignDriverId(const DriverID &driver_id);
const DriverID &GetAssignedDriverId() const;
void AssignActorId(const ActorID &actor_id);
const ActorID &GetActorId() const;
/// Return the worker's connection.
@@ -53,8 +57,12 @@ class Worker {
std::shared_ptr<LocalClientConnection> connection_;
/// The worker's currently assigned task.
TaskID assigned_task_id_;
/// Driver ID for the worker's current assigned task.
DriverID assigned_driver_id_;
/// The worker's actor ID. If this is nil, then the worker is not an actor.
ActorID actor_id_;
/// Whether the worker is dead.
bool dead_;
/// Whether the worker is blocked. Workers become blocked in a `ray.get`, if
/// they require a data dependency while executing a task.
bool blocked_;
+17
View File
@@ -232,6 +232,23 @@ inline WorkerPool::State &WorkerPool::GetStateForLanguage(const Language &langua
return state->second;
}
std::vector<std::shared_ptr<Worker>> WorkerPool::GetWorkersRunningTasksForDriver(
const DriverID &driver_id) const {
std::vector<std::shared_ptr<Worker>> workers;
for (const auto &entry : states_by_lang_) {
for (const auto &worker : entry.second.registered_workers) {
RAY_LOG(DEBUG) << "worker: pid : " << worker->Pid()
<< " driver_id: " << worker->GetAssignedDriverId();
if (worker->GetAssignedDriverId() == driver_id) {
workers.push_back(worker);
}
}
}
return workers;
}
} // namespace raylet
} // namespace ray
+7
View File
@@ -111,6 +111,13 @@ class WorkerPool {
/// \return The total count of all workers (actor and non-actor) in the pool.
uint32_t Size(const Language &language) const;
/// Get all the workers which are running tasks for a given driver.
///
/// \param driver_id The driver ID.
/// \return A list containing all the workers which are running tasks for the driver.
std::vector<std::shared_ptr<Worker>> GetWorkersRunningTasksForDriver(
const DriverID &driver_id) const;
protected:
/// A map from the pids of starting worker processes
/// to the number of their unregistered workers.
+4 -2
View File
@@ -239,7 +239,9 @@ def ray_start_head_with_resources():
subprocess.Popen(["ray", "stop"]).wait()
@pytest.mark.skip(reason="This test does not work yet.")
@pytest.mark.skipif(
os.environ.get("RAY_USE_XRAY") != "1",
reason="This test only works with xray.")
def test_drivers_release_resources(ray_start_head_with_resources):
redis_address = ray_start_head_with_resources
@@ -278,7 +280,7 @@ print("success")
driver_script2 = (driver_script1 +
"import sys\nsys.stdout.flush()\ntime.sleep(10 ** 6)\n")
def wait_for_success_output(process_handle, timeout=100):
def wait_for_success_output(process_handle, timeout=10):
# Wait until the process prints "success" and then return.
start_time = time.time()
while time.time() - start_time < timeout: