From 79310452e7db0f0f1077233dac27148ee7b0545c Mon Sep 17 00:00:00 2001 From: Gabriele Oliaro Date: Wed, 3 Feb 2021 13:20:12 -0500 Subject: [PATCH] Enabling the cancellation of non-actor tasks in a worker's queue 2 (#13244) * wrote code to enable cancellation of queued non-actor tasks * minor changes * bug fixes * added comments * rev1 * linting * making ActorSchedulingQueue::CancelTaskIfFound raise a fatal error * bug fix * added two unit tests * linting * iterating through pending_normal_tasks starting from end * fixup! iterating through pending_normal_tasks starting from end * fixup! fixup! iterating through pending_normal_tasks starting from end * post merge fixes * added debugging instructions, pulled Accept() out of guarded loop * removed debugging instructions, linting * first commit * lint * lint * added hack to avoid race condition in test stress * moved hack * fix test cancel * removed hack (hopefully no longer needed) * Revert "removed hack (hopefully no longer needed)" This reverts commit 99d0e7c91539f290700f50aaaed805dcde04a5ee. * added sleep in mock_worker.cc * sleep function fixup to work on windows * sleep in test_fast both for force=true and force=false * linting Co-authored-by: Ian --- python/ray/tests/test_cancel.py | 9 ++- src/ray/core_worker/core_worker.cc | 15 ++++- src/ray/core_worker/test/core_worker_test.cc | 42 ++++++++++++++ src/ray/core_worker/test/mock_worker.cc | 11 ++++ .../core_worker/test/scheduling_queue_test.cc | 27 +++++++-- .../transport/direct_actor_transport.cc | 10 +++- .../transport/direct_actor_transport.h | 56 ++++++++++++++++--- 7 files changed, 151 insertions(+), 19 deletions(-) diff --git a/python/ray/tests/test_cancel.py b/python/ray/tests/test_cancel.py index 11b4dfbd4..aefff09fa 100644 --- a/python/ray/tests/test_cancel.py +++ b/python/ray/tests/test_cancel.py @@ -175,6 +175,8 @@ def test_stress(shutdown_only, use_force): sleep_or_no = [random.randint(0, 1) for _ in range(100)] tasks = [infinite_sleep.remote(i) for i in sleep_or_no] cancelled = set() + + # Randomly kill queued tasks (infinitely sleeping or not). for t in tasks: if random.random() > 0.5: ray.cancel(t, force=use_force) @@ -186,10 +188,13 @@ def test_stress(shutdown_only, use_force): for done in cancelled: with pytest.raises(valid_exceptions(use_force)): ray.get(done, timeout=120) + + # Kill all infinitely sleeping tasks (queued or not). for indx, t in enumerate(tasks): if sleep_or_no[indx]: ray.cancel(t, force=use_force) cancelled.add(t) + for indx, t in enumerate(tasks): if t in cancelled: with pytest.raises(valid_exceptions(use_force)): ray.get(t, timeout=120) @@ -213,8 +218,8 @@ def test_fast(shutdown_only, use_force): # between a worker receiving a task and the worker executing # that task (specifically the python execution), Cancellation # can fail. - if not use_force: - time.sleep(0.1) + + time.sleep(0.1) ray.cancel(x, force=use_force) ids.append(x) diff --git a/src/ray/core_worker/core_worker.cc b/src/ray/core_worker/core_worker.cc index 1961406d8..b56f18cf0 100644 --- a/src/ray/core_worker/core_worker.cc +++ b/src/ray/core_worker/core_worker.cc @@ -760,6 +760,7 @@ void CoreWorker::InternalHeartbeat(const boost::system::error_code &error) { } absl::MutexLock lock(&mutex_); + while (!to_resubmit_.empty() && current_time_ms() > to_resubmit_.front().first) { auto &spec = to_resubmit_.front().second; if (spec.IsActorTask()) { @@ -2266,12 +2267,17 @@ void CoreWorker::HandleCancelTask(const rpc::CancelTaskRequest &request, rpc::SendReplyCallback send_reply_callback) { absl::MutexLock lock(&mutex_); TaskID task_id = TaskID::FromBinary(request.intended_task_id()); - bool success = main_thread_task_id_ == task_id; + bool requested_task_running = main_thread_task_id_ == task_id; + bool success = requested_task_running; // Try non-force kill - if (success && !request.force_kill()) { + if (requested_task_running && !request.force_kill()) { RAY_LOG(INFO) << "Interrupting a running task " << main_thread_task_id_; success = options_.kill_main(); + } else if (!requested_task_running) { + // If the task is not currently running, check if it is in the worker's queue of + // normal tasks, and remove it if found. + success = direct_task_receiver_->CancelQueuedNormalTask(task_id); } if (request.recursive()) { auto recursive_cancel = CancelChildren(task_id, request.force_kill()); @@ -2280,11 +2286,14 @@ void CoreWorker::HandleCancelTask(const rpc::CancelTaskRequest &request, } } + // TODO: fix race condition to avoid using this hack + requested_task_running = main_thread_task_id_ == task_id; + reply->set_attempt_succeeded(success); send_reply_callback(Status::OK(), nullptr, nullptr); // Do force kill after reply callback sent - if (success && request.force_kill()) { + if (requested_task_running && request.force_kill()) { RAY_LOG(INFO) << "Force killing a worker running " << main_thread_task_id_; Disconnect(); if (options_.enable_logging) { diff --git a/src/ray/core_worker/test/core_worker_test.cc b/src/ray/core_worker/test/core_worker_test.cc index 82ea82617..cf1bab624 100644 --- a/src/ray/core_worker/test/core_worker_test.cc +++ b/src/ray/core_worker/test/core_worker_test.cc @@ -841,6 +841,48 @@ TEST_F(SingleNodeTest, TestNormalTaskLocal) { TestNormalTask(resources); } +TEST_F(SingleNodeTest, TestCancelTasks) { + auto &driver = CoreWorkerProcess::GetCoreWorker(); + + // Create two functions, each implementing a while(true) loop. + RayFunction func1(ray::Language::PYTHON, ray::FunctionDescriptorBuilder::BuildPython( + "WhileTrueLoop", "", "", "")); + RayFunction func2(ray::Language::PYTHON, ray::FunctionDescriptorBuilder::BuildPython( + "WhileTrueLoop", "", "", "")); + // Return IDs for the two functions that implement while(true) loops. + std::vector return_ids1; + std::vector return_ids2; + + // Create default args and options needed to submit the tasks that encapsulate func1 and + // func2. + std::vector> args; + TaskOptions options; + + // Submit func1. The function should start looping forever. + driver.SubmitTask(func1, args, options, &return_ids1, /*max_retries=*/0, + std::make_pair(PlacementGroupID::Nil(), -1), true, + /*debugger_breakpoint=*/""); + ASSERT_EQ(return_ids1.size(), 1); + + // Submit func2. The function should be queued at the worker indefinitely. + driver.SubmitTask(func2, args, options, &return_ids2, /*max_retries=*/0, + std::make_pair(PlacementGroupID::Nil(), -1), true, + /*debugger_breakpoint=*/""); + ASSERT_EQ(return_ids2.size(), 1); + + // Cancel func2 by removing it from the worker's queue + RAY_CHECK_OK(driver.CancelTask(return_ids2[0], true, false)); + + // Cancel func1, which is currently running. + RAY_CHECK_OK(driver.CancelTask(return_ids1[0], true, false)); + + // TestNormalTask will get stuck unless both func1 and func2 have been cancelled. Thus, + // if TestNormalTask succeeds, we know that func2 must have been removed from the + // worker's queue. + std::unordered_map resources; + TestNormalTask(resources); +} + TEST_F(TwoNodeTest, TestNormalTaskCrossNodes) { std::unordered_map resources; resources.emplace("resource1", 1); diff --git a/src/ray/core_worker/test/mock_worker.cc b/src/ray/core_worker/test/mock_worker.cc index 4439519bb..03a78a198 100644 --- a/src/ray/core_worker/test/mock_worker.cc +++ b/src/ray/core_worker/test/mock_worker.cc @@ -79,6 +79,8 @@ class MockWorker { } else if ("MergeInputArgsAsOutput" == typed_descriptor->ModuleName()) { // Merge input args and write the merged content to each of return ids return MergeInputArgsAsOutput(args, return_ids, results); + } else if ("WhileTrueLoop" == typed_descriptor->ModuleName()) { + return WhileTrueLoop(args, return_ids, results); } else { return Status::TypeError("Unknown function descriptor: " + typed_descriptor->ModuleName()); @@ -128,6 +130,15 @@ class MockWorker { return Status::OK(); } + Status WhileTrueLoop(const std::vector> &args, + const std::vector &return_ids, + std::vector> *results) { + while (1) { + std::this_thread::sleep_for(std::chrono::milliseconds(100)); + } + return Status::OK(); + } + int64_t prev_seq_no_ = 0; }; diff --git a/src/ray/core_worker/test/scheduling_queue_test.cc b/src/ray/core_worker/test/scheduling_queue_test.cc index 8c8e60fd5..6854c1810 100644 --- a/src/ray/core_worker/test/scheduling_queue_test.cc +++ b/src/ray/core_worker/test/scheduling_queue_test.cc @@ -66,9 +66,9 @@ TEST(SchedulingQueueTest, TestWaitForObjects) { auto fn_ok = [&n_ok]() { n_ok++; }; auto fn_rej = [&n_rej]() { n_rej++; }; queue.Add(0, -1, fn_ok, fn_rej); - queue.Add(1, -1, fn_ok, fn_rej, ObjectIdsToRefs({obj1})); - queue.Add(2, -1, fn_ok, fn_rej, ObjectIdsToRefs({obj2})); - queue.Add(3, -1, fn_ok, fn_rej, ObjectIdsToRefs({obj3})); + queue.Add(1, -1, fn_ok, fn_rej, TaskID::Nil(), ObjectIdsToRefs({obj1})); + queue.Add(2, -1, fn_ok, fn_rej, TaskID::Nil(), ObjectIdsToRefs({obj2})); + queue.Add(3, -1, fn_ok, fn_rej, TaskID::Nil(), ObjectIdsToRefs({obj3})); ASSERT_EQ(n_ok, 1); waiter.Complete(0); @@ -92,7 +92,7 @@ TEST(SchedulingQueueTest, TestWaitForObjectsNotSubjectToSeqTimeout) { auto fn_ok = [&n_ok]() { n_ok++; }; auto fn_rej = [&n_rej]() { n_rej++; }; queue.Add(0, -1, fn_ok, fn_rej); - queue.Add(1, -1, fn_ok, fn_rej, ObjectIdsToRefs({obj1})); + queue.Add(1, -1, fn_ok, fn_rej, TaskID::Nil(), ObjectIdsToRefs({obj1})); ASSERT_EQ(n_ok, 1); io_service.run(); ASSERT_EQ(n_rej, 0); @@ -158,6 +158,25 @@ TEST(SchedulingQueueTest, TestSkipAlreadyProcessedByClient) { ASSERT_EQ(n_rej, 2); } +TEST(SchedulingQueueTest, TestCancelQueuedTask) { + NormalSchedulingQueue *queue = new NormalSchedulingQueue(); + ASSERT_TRUE(queue->TaskQueueEmpty()); + int n_ok = 0; + int n_rej = 0; + auto fn_ok = [&n_ok]() { n_ok++; }; + auto fn_rej = [&n_rej]() { n_rej++; }; + queue->Add(-1, -1, fn_ok, fn_rej); + queue->Add(-1, -1, fn_ok, fn_rej); + queue->Add(-1, -1, fn_ok, fn_rej); + queue->Add(-1, -1, fn_ok, fn_rej); + queue->Add(-1, -1, fn_ok, fn_rej); + ASSERT_TRUE(queue->CancelTaskIfFound(TaskID::Nil())); + ASSERT_FALSE(queue->TaskQueueEmpty()); + queue->ScheduleRequests(); + ASSERT_EQ(n_ok, 4); + ASSERT_EQ(n_rej, 0); +} + } // namespace ray int main(int argc, char **argv) { diff --git a/src/ray/core_worker/transport/direct_actor_transport.cc b/src/ray/core_worker/transport/direct_actor_transport.cc index e266b0d94..bac80af4f 100644 --- a/src/ray/core_worker/transport/direct_actor_transport.cc +++ b/src/ray/core_worker/transport/direct_actor_transport.cc @@ -482,12 +482,12 @@ void CoreWorkerDirectTaskReceiver::HandleTask( // TODO(swang): Remove this with legacy raylet code. dependencies.pop_back(); it->second->Add(request.sequence_number(), request.client_processed_up_to(), - accept_callback, reject_callback, dependencies); + accept_callback, reject_callback, task_spec.TaskId(), dependencies); } else { // Add the normal task's callbacks to the non-actor scheduling queue. normal_scheduling_queue_->Add(request.sequence_number(), request.client_processed_up_to(), accept_callback, - reject_callback, dependencies); + reject_callback, task_spec.TaskId(), dependencies); } } @@ -501,4 +501,10 @@ void CoreWorkerDirectTaskReceiver::RunNormalTasksFromQueue() { normal_scheduling_queue_->ScheduleRequests(); } +bool CoreWorkerDirectTaskReceiver::CancelQueuedNormalTask(TaskID task_id) { + // Look up the task to be canceled in the queue of normal tasks. If it is found and + // removed successfully, return true. + return normal_scheduling_queue_->CancelTaskIfFound(task_id); +} + } // namespace ray diff --git a/src/ray/core_worker/transport/direct_actor_transport.h b/src/ray/core_worker/transport/direct_actor_transport.h index ab28dc85a..cbd0a82fc 100644 --- a/src/ray/core_worker/transport/direct_actor_transport.h +++ b/src/ray/core_worker/transport/direct_actor_transport.h @@ -254,19 +254,23 @@ class InboundRequest { public: InboundRequest(){}; InboundRequest(std::function accept_callback, - std::function reject_callback, bool has_dependencies) + std::function reject_callback, TaskID task_id, + bool has_dependencies) : accept_callback_(accept_callback), reject_callback_(reject_callback), + task_id(task_id), has_pending_dependencies_(has_dependencies) {} void Accept() { accept_callback_(); } void Cancel() { reject_callback_(); } bool CanExecute() const { return !has_pending_dependencies_; } + ray::TaskID TaskID() const { return task_id; } void MarkDependenciesSatisfied() { has_pending_dependencies_ = false; } private: std::function accept_callback_; std::function reject_callback_; + ray::TaskID task_id; bool has_pending_dependencies_; }; @@ -346,10 +350,11 @@ class SchedulingQueue { public: virtual void Add(int64_t seq_no, int64_t client_processed_up_to, std::function accept_request, - std::function reject_request, + std::function reject_request, TaskID task_id = TaskID::Nil(), const std::vector &dependencies = {}) = 0; virtual void ScheduleRequests() = 0; virtual bool TaskQueueEmpty() const = 0; + virtual bool CancelTaskIfFound(TaskID task_id) = 0; virtual ~SchedulingQueue(){}; }; @@ -371,6 +376,7 @@ class ActorSchedulingQueue : public SchedulingQueue { /// Add a new actor task's callbacks to the worker queue. void Add(int64_t seq_no, int64_t client_processed_up_to, std::function accept_request, std::function reject_request, + TaskID task_id = TaskID::Nil(), const std::vector &dependencies = {}) { // A seq_no of -1 means no ordering constraint. Actor tasks must be executed in order. RAY_CHECK(seq_no != -1); @@ -383,7 +389,7 @@ class ActorSchedulingQueue : public SchedulingQueue { } RAY_LOG(DEBUG) << "Enqueue " << seq_no << " cur seqno " << next_seq_no_; pending_actor_tasks_[seq_no] = - InboundRequest(accept_request, reject_request, dependencies.size() > 0); + InboundRequest(accept_request, reject_request, task_id, dependencies.size() > 0); if (dependencies.size() > 0) { waiter_.Wait(dependencies, [seq_no, this]() { RAY_CHECK(boost::this_thread::get_id() == main_thread_id_); @@ -397,6 +403,15 @@ class ActorSchedulingQueue : public SchedulingQueue { ScheduleRequests(); } + // We don't allow the cancellation of actor tasks, so invoking CancelTaskIfFound results + // in a fatal error. + bool CancelTaskIfFound(TaskID task_id) { + RAY_CHECK(false) << "Cannot cancel actor tasks"; + // The return instruction will never be executed, but we need to include it + // nonetheless because this is a non-void function. + return false; + } + /// Schedules as many requests as possible in sequence. void ScheduleRequests() { // Only call SetMaxActorConcurrency to configure threadpool size when the @@ -520,22 +535,45 @@ class NormalSchedulingQueue : public SchedulingQueue { /// Add a new task's callbacks to the worker queue. void Add(int64_t seq_no, int64_t client_processed_up_to, std::function accept_request, std::function reject_request, + TaskID task_id = TaskID::Nil(), const std::vector &dependencies = {}) { absl::MutexLock lock(&mu_); // Normal tasks should not have ordering constraints. RAY_CHECK(seq_no == -1); // Create a InboundRequest object for the new task, and add it to the queue. pending_normal_tasks_.push_back( - InboundRequest(accept_request, reject_request, dependencies.size() > 0)); + InboundRequest(accept_request, reject_request, task_id, dependencies.size() > 0)); + } + + // Search for an InboundRequest associated with the task that we are trying to cancel. + // If found, remove the InboundRequest from the queue and return true. Otherwise, return + // false. + bool CancelTaskIfFound(TaskID task_id) { + absl::MutexLock lock(&mu_); + for (std::deque::reverse_iterator it = pending_normal_tasks_.rbegin(); + it != pending_normal_tasks_.rend(); ++it) { + if (it->TaskID() == task_id) { + pending_normal_tasks_.erase(std::next(it).base()); + return true; + } + } + return false; } /// Schedules as many requests as possible in sequence. void ScheduleRequests() { - absl::MutexLock lock(&mu_); - while (!pending_normal_tasks_.empty()) { - auto &head = pending_normal_tasks_.front(); + while (true) { + InboundRequest head; + { + absl::MutexLock lock(&mu_); + if (!pending_normal_tasks_.empty()) { + head = pending_normal_tasks_.front(); + pending_normal_tasks_.pop_front(); + } else { + return; + } + } head.Accept(); - pending_normal_tasks_.pop_front(); } } @@ -583,6 +621,8 @@ class CoreWorkerDirectTaskReceiver { /// Pop tasks from the queue and execute them sequentially void RunNormalTasksFromQueue(); + bool CancelQueuedNormalTask(TaskID task_id); + private: // Worker context. WorkerContext &worker_context_;