From e086ddc18f789d3fa9bcd9d3a31fcef1b4e0b7e1 Mon Sep 17 00:00:00 2001 From: Ian Rodney Date: Wed, 18 Nov 2020 15:18:40 -0800 Subject: [PATCH] [core] Add Recursive task cancelation (#11923) --- python/ray/_raylet.pyx | 5 ++- python/ray/includes/libcoreworker.pxd | 3 +- python/ray/tests/test_cancel.py | 32 ++++++++++++++++ python/ray/worker.py | 6 ++- src/ray/core_worker/core_worker.cc | 37 ++++++++++++++++--- src/ray/core_worker/core_worker.h | 10 ++++- src/ray/core_worker/task_manager.cc | 15 ++++++++ src/ray/core_worker/task_manager.h | 3 ++ .../test/direct_task_transport_test.cc | 8 ++-- .../transport/direct_task_transport.cc | 13 ++++--- .../transport/direct_task_transport.h | 5 +-- src/ray/protobuf/core_worker.proto | 4 ++ 12 files changed, 118 insertions(+), 23 deletions(-) diff --git a/python/ray/_raylet.pyx b/python/ray/_raylet.pyx index 65d6ed6ff..31e97798e 100644 --- a/python/ray/_raylet.pyx +++ b/python/ray/_raylet.pyx @@ -1170,13 +1170,14 @@ cdef class CoreWorker: check_status(CCoreWorkerProcess.GetCoreWorker().KillActor( c_actor_id, True, no_restart)) - def cancel_task(self, ObjectRef object_ref, c_bool force_kill): + def cancel_task(self, ObjectRef object_ref, c_bool force_kill, + c_bool recursive): cdef: CObjectID c_object_id = object_ref.native() CRayStatus status = CRayStatus.OK() status = CCoreWorkerProcess.GetCoreWorker().CancelTask( - c_object_id, force_kill) + c_object_id, force_kill, recursive) if not status.ok(): raise TypeError(status.message().decode()) diff --git a/python/ray/includes/libcoreworker.pxd b/python/ray/includes/libcoreworker.pxd index 351385c06..195882938 100644 --- a/python/ray/includes/libcoreworker.pxd +++ b/python/ray/includes/libcoreworker.pxd @@ -110,7 +110,8 @@ cdef extern from "ray/core_worker/core_worker.h" nogil: CRayStatus KillActor( const CActorID &actor_id, c_bool force_kill, c_bool no_restart) - CRayStatus CancelTask(const CObjectID &object_id, c_bool force_kill) + CRayStatus CancelTask(const CObjectID &object_id, c_bool force_kill, + c_bool recursive) unique_ptr[CProfileEvent] CreateProfileEvent( const c_string &event_type) diff --git a/python/ray/tests/test_cancel.py b/python/ray/tests/test_cancel.py index b4e4aa439..99f5227b6 100644 --- a/python/ray/tests/test_cancel.py +++ b/python/ray/tests/test_cancel.py @@ -258,5 +258,37 @@ def test_remote_cancel(ray_start_regular, use_force): ray.get(inner, timeout=10) +@pytest.mark.parametrize("use_force", [True, False]) +def test_recursive_cancel(shutdown_only, use_force): + ray.init(num_cpus=4) + + @ray.remote(num_cpus=1) + def inner(): + while True: + time.sleep(0.1) + + @ray.remote(num_cpus=1) + def outer(): + + x = [inner.remote()] + print(x) + while True: + time.sleep(0.1) + + @ray.remote(num_cpus=4) + def many_resources(): + return 300 + + outer_fut = outer.remote() + many_fut = many_resources.remote() + with pytest.raises(GetTimeoutError): + ray.get(many_fut, timeout=1) + ray.cancel(outer_fut) + with pytest.raises(valid_exceptions(use_force)): + ray.get(outer_fut, timeout=10) + + assert ray.get(many_fut, timeout=30) + + if __name__ == "__main__": sys.exit(pytest.main(["-v", __file__])) diff --git a/python/ray/worker.py b/python/ray/worker.py index d8c767964..adf8279e8 100644 --- a/python/ray/worker.py +++ b/python/ray/worker.py @@ -1575,7 +1575,7 @@ def kill(actor, *, no_restart=True): worker.core_worker.kill_actor(actor._ray_actor_id, no_restart) -def cancel(object_ref, *, force=False): +def cancel(object_ref, *, force=False, recursive=True): """Cancels a task according to the following conditions. If the specified task is pending execution, it will not be executed. If @@ -1595,6 +1595,8 @@ def cancel(object_ref, *, force=False): that should be canceled. force (boolean): Whether to force-kill a running task by killing the worker that is running the task. + recursive (boolean): Whether to try to cancel tasks submitted by the + task specified. Raises: TypeError: This is also raised for actor tasks. """ @@ -1605,7 +1607,7 @@ def cancel(object_ref, *, force=False): raise TypeError( "ray.cancel() only supported for non-actor object refs. " f"Got: {type(object_ref)}.") - return worker.core_worker.cancel_task(object_ref, force) + return worker.core_worker.cancel_task(object_ref, force, recursive) def _mode(worker=global_worker): diff --git a/src/ray/core_worker/core_worker.cc b/src/ray/core_worker/core_worker.cc index 10834429a..a47d5c7c0 100644 --- a/src/ray/core_worker/core_worker.cc +++ b/src/ray/core_worker/core_worker.cc @@ -1518,7 +1518,8 @@ void CoreWorker::SubmitActorTask(const ActorID &actor_id, const RayFunction &fun } } -Status CoreWorker::CancelTask(const ObjectID &object_id, bool force_kill) { +Status CoreWorker::CancelTask(const ObjectID &object_id, bool force_kill, + bool recursive) { if (actor_manager_->CheckActorHandleExists(object_id.TaskId().ActorId())) { return Status::Invalid("Actor task cancellation is not supported."); } @@ -1527,16 +1528,36 @@ Status CoreWorker::CancelTask(const ObjectID &object_id, bool force_kill) { return Status::Invalid("No owner found for object."); } if (obj_addr.SerializeAsString() != rpc_address_.SerializeAsString()) { - return direct_task_submitter_->CancelRemoteTask(object_id, obj_addr, force_kill); + return direct_task_submitter_->CancelRemoteTask(object_id, obj_addr, force_kill, + recursive); } auto task_spec = task_manager_->GetTaskSpec(object_id.TaskId()); if (task_spec.has_value() && !task_spec.value().IsActorCreationTask()) { - return direct_task_submitter_->CancelTask(task_spec.value(), force_kill); + return direct_task_submitter_->CancelTask(task_spec.value(), force_kill, recursive); } return Status::OK(); } +Status CoreWorker::CancelChildren(const TaskID &task_id, bool force_kill) { + bool recursive_success = true; + for (const auto &child_id : task_manager_->GetPendingChildrenTasks(task_id)) { + auto child_spec = task_manager_->GetTaskSpec(child_id); + if (child_spec.has_value()) { + auto result = + direct_task_submitter_->CancelTask(child_spec.value(), force_kill, true); + recursive_success = recursive_success && result.ok(); + } else { + recursive_success = false; + } + } + if (recursive_success) { + return Status::OK(); + } else { + return Status::UnknownError("Recursive task cancelation failed--check warning logs."); + } +} + Status CoreWorker::KillActor(const ActorID &actor_id, bool force_kill, bool no_restart) { if (options_.is_local_mode) { return KillActorLocalMode(actor_id); @@ -2157,8 +2178,8 @@ void CoreWorker::HandleWaitForRefRemoved(const rpc::WaitForRefRemovedRequest &re void CoreWorker::HandleRemoteCancelTask(const rpc::RemoteCancelTaskRequest &request, rpc::RemoteCancelTaskReply *reply, rpc::SendReplyCallback send_reply_callback) { - auto status = - CancelTask(ObjectID::FromBinary(request.remote_object_id()), request.force_kill()); + auto status = CancelTask(ObjectID::FromBinary(request.remote_object_id()), + request.force_kill(), request.recursive()); send_reply_callback(status, nullptr, nullptr); } @@ -2174,6 +2195,12 @@ void CoreWorker::HandleCancelTask(const rpc::CancelTaskRequest &request, RAY_LOG(INFO) << "Interrupting a running task " << main_thread_task_id_; success = options_.kill_main(); } + if (request.recursive()) { + auto recursive_cancel = CancelChildren(task_id, request.force_kill()); + if (recursive_cancel.ok()) { + RAY_LOG(INFO) << "Recursive cancel failed!"; + } + } reply->set_attempt_succeeded(success); send_reply_callback(Status::OK(), nullptr, nullptr); diff --git a/src/ray/core_worker/core_worker.h b/src/ray/core_worker/core_worker.h index 7f624c3bb..92eb88ddb 100644 --- a/src/ray/core_worker/core_worker.h +++ b/src/ray/core_worker/core_worker.h @@ -714,8 +714,10 @@ class CoreWorker : public rpc::CoreWorkerServiceHandler { /// /// \param[in] object_id of the task to kill (must be a Non-Actor task) /// \param[in] force_kill Whether to force kill a task by killing the worker. + /// \param[in] recursive Whether to cancel tasks submitted by the task to cancel. /// \param[out] Status - Status CancelTask(const ObjectID &object_id, bool force_kill); + Status CancelTask(const ObjectID &object_id, bool force_kill, bool recursive); + /// Decrease the reference count for this actor. Should be called by the /// language frontend when a reference to the ActorHandle destroyed. /// @@ -946,6 +948,12 @@ class CoreWorker : public rpc::CoreWorkerServiceHandler { reference_counter_->AddLocalReference(object_id, call_site); } + /// Stops the children tasks from the given TaskID + /// + /// \param[in] task_id of the parent task + /// \param[in] force_kill Whether to force kill a task by killing the worker. + Status CancelChildren(const TaskID &task_id, bool force_kill); + /// /// Private methods related to task execution. Should not be used by driver processes. /// diff --git a/src/ray/core_worker/task_manager.cc b/src/ray/core_worker/task_manager.cc index 2dcaa047c..be9c6858f 100644 --- a/src/ray/core_worker/task_manager.cc +++ b/src/ray/core_worker/task_manager.cc @@ -456,4 +456,19 @@ absl::optional TaskManager::GetTaskSpec(const TaskID &task_id return it->second.spec; } +std::vector TaskManager::GetPendingChildrenTasks( + const TaskID &parent_task_id) const { + std::vector ret_vec; + absl::MutexLock lock(&mu_); + RAY_LOG(ERROR) << " calling get children tasks"; + RAY_LOG(ERROR) << "NUMBER OF PENDING TASKS: " << num_pending_tasks_; + for (auto it : submissible_tasks_) { + RAY_LOG(ERROR) << "Getting tasks!! " << it.second.spec.TaskId(); + if (it.second.pending and it.second.spec.ParentTaskId() == parent_task_id) { + ret_vec.push_back(it.first); + } + } + return ret_vec; +} + } // namespace ray diff --git a/src/ray/core_worker/task_manager.h b/src/ray/core_worker/task_manager.h index 83dadf0ee..88b32522c 100644 --- a/src/ray/core_worker/task_manager.h +++ b/src/ray/core_worker/task_manager.h @@ -140,6 +140,9 @@ class TaskManager : public TaskFinisherInterface, public TaskResubmissionInterfa /// Return the spec for a pending task. absl::optional GetTaskSpec(const TaskID &task_id) const; + /// Return specs for pending children tasks of the given parent task. + std::vector GetPendingChildrenTasks(const TaskID &parent_task_id) const; + /// Return whether this task can be submitted for execution. /// /// \param[in] task_id ID of the task to query. diff --git a/src/ray/core_worker/test/direct_task_transport_test.cc b/src/ray/core_worker/test/direct_task_transport_test.cc index 67ed7de02..958aa4fd6 100644 --- a/src/ray/core_worker/test/direct_task_transport_test.cc +++ b/src/ray/core_worker/test/direct_task_transport_test.cc @@ -1065,7 +1065,7 @@ TEST(DirectTaskTransportTest, TestKillExecutingTask) { ASSERT_TRUE(raylet_client->GrantWorkerLease("localhost", 1234, NodeID::Nil())); // Try force kill, exiting the worker - ASSERT_TRUE(submitter.CancelTask(task, true).ok()); + ASSERT_TRUE(submitter.CancelTask(task, true, false).ok()); ASSERT_EQ(worker_client->kill_requests.front().intended_task_id(), task.TaskId().Binary()); ASSERT_TRUE(worker_client->ReplyPushTask(Status::IOError("workerdying"), true)); @@ -1081,7 +1081,7 @@ TEST(DirectTaskTransportTest, TestKillExecutingTask) { ASSERT_TRUE(raylet_client->GrantWorkerLease("localhost", 1234, NodeID::Nil())); // Try non-force kill, worker returns normally - ASSERT_TRUE(submitter.CancelTask(task, false).ok()); + ASSERT_TRUE(submitter.CancelTask(task, false, false).ok()); ASSERT_TRUE(worker_client->ReplyPushTask()); ASSERT_EQ(worker_client->kill_requests.front().intended_task_id(), task.TaskId().Binary()); @@ -1114,7 +1114,7 @@ TEST(DirectTaskTransportTest, TestKillPendingTask) { TaskSpecification task = BuildTaskSpec(empty_resources, empty_descriptor); ASSERT_TRUE(submitter.SubmitTask(task).ok()); - ASSERT_TRUE(submitter.CancelTask(task, true).ok()); + ASSERT_TRUE(submitter.CancelTask(task, true, false).ok()); ASSERT_EQ(worker_client->kill_requests.size(), 0); ASSERT_EQ(worker_client->callbacks.size(), 0); ASSERT_EQ(raylet_client->num_workers_returned, 0); @@ -1152,7 +1152,7 @@ TEST(DirectTaskTransportTest, TestKillResolvingTask) { task.GetMutableMessage().add_args()->mutable_object_ref()->set_object_id(obj1.Binary()); ASSERT_TRUE(submitter.SubmitTask(task).ok()); ASSERT_EQ(task_finisher->num_inlined_dependencies, 0); - ASSERT_TRUE(submitter.CancelTask(task, true).ok()); + ASSERT_TRUE(submitter.CancelTask(task, true, false).ok()); auto data = GenerateRandomObject(); ASSERT_TRUE(store->Put(*data, obj1)); ASSERT_EQ(worker_client->kill_requests.size(), 0); diff --git a/src/ray/core_worker/transport/direct_task_transport.cc b/src/ray/core_worker/transport/direct_task_transport.cc index 7aa635f21..1d4cd40ef 100644 --- a/src/ray/core_worker/transport/direct_task_transport.cc +++ b/src/ray/core_worker/transport/direct_task_transport.cc @@ -408,7 +408,7 @@ void CoreWorkerDirectTaskSubmitter::PushNormalTask( } Status CoreWorkerDirectTaskSubmitter::CancelTask(TaskSpecification task_spec, - bool force_kill) { + bool force_kill, bool recursive) { RAY_LOG(INFO) << "Killing task: " << task_spec.TaskId(); const SchedulingKey scheduling_key( task_spec.GetSchedulingClass(), task_spec.GetDependencyIds(), @@ -470,8 +470,9 @@ Status CoreWorkerDirectTaskSubmitter::CancelTask(TaskSpecification task_spec, auto request = rpc::CancelTaskRequest(); request.set_intended_task_id(task_spec.TaskId().Binary()); request.set_force_kill(force_kill); + request.set_recursive(recursive); client->CancelTask( - request, [this, task_spec, scheduling_key, force_kill]( + request, [this, task_spec, scheduling_key, force_kill, recursive]( const Status &status, const rpc::CancelTaskReply &reply) { absl::MutexLock lock(&mu_); cancelled_tasks_.erase(task_spec.TaskId()); @@ -483,8 +484,9 @@ Status CoreWorkerDirectTaskSubmitter::CancelTask(TaskSpecification task_spec, cancel_retry_timer_->expires_after(boost::asio::chrono::milliseconds( RayConfig::instance().cancellation_retry_ms())); } - cancel_retry_timer_->async_wait(boost::bind( - &CoreWorkerDirectTaskSubmitter::CancelTask, this, task_spec, force_kill)); + cancel_retry_timer_->async_wait( + boost::bind(&CoreWorkerDirectTaskSubmitter::CancelTask, this, task_spec, + force_kill, recursive)); } } // Retry is not attempted if !status.ok() because force-kill may kill the worker @@ -495,7 +497,7 @@ Status CoreWorkerDirectTaskSubmitter::CancelTask(TaskSpecification task_spec, Status CoreWorkerDirectTaskSubmitter::CancelRemoteTask(const ObjectID &object_id, const rpc::Address &worker_addr, - bool force_kill) { + bool force_kill, bool recursive) { auto maybe_client = client_cache_->GetByID(rpc::WorkerAddress(worker_addr).worker_id); if (!maybe_client.has_value()) { @@ -504,6 +506,7 @@ Status CoreWorkerDirectTaskSubmitter::CancelRemoteTask(const ObjectID &object_id auto client = maybe_client.value(); auto request = rpc::RemoteCancelTaskRequest(); request.set_force_kill(force_kill); + request.set_recursive(recursive); request.set_remote_object_id(object_id.Binary()); client->RemoteCancelTask(request, nullptr); return Status::OK(); diff --git a/src/ray/core_worker/transport/direct_task_transport.h b/src/ray/core_worker/transport/direct_task_transport.h index 244e6b5e2..19a2a7080 100644 --- a/src/ray/core_worker/transport/direct_task_transport.h +++ b/src/ray/core_worker/transport/direct_task_transport.h @@ -81,11 +81,10 @@ class CoreWorkerDirectTaskSubmitter { /// /// \param[in] task_spec The task to kill. /// \param[in] force_kill Whether to kill the worker executing the task. - Status CancelTask(TaskSpecification task_spec, bool force_kill); + Status CancelTask(TaskSpecification task_spec, bool force_kill, bool recursive); Status CancelRemoteTask(const ObjectID &object_id, const rpc::Address &worker_addr, - bool force_kill); - + bool force_kill, bool recursive); /// Check that the scheduling_key_entries_ hashmap is empty by calling the private /// CheckNoSchedulingKeyEntries function after acquiring the lock. bool CheckNoSchedulingKeyEntriesPublic() { diff --git a/src/ray/protobuf/core_worker.proto b/src/ray/protobuf/core_worker.proto index 5a5c751e3..bb61612eb 100644 --- a/src/ray/protobuf/core_worker.proto +++ b/src/ray/protobuf/core_worker.proto @@ -205,6 +205,8 @@ message CancelTaskRequest { bytes intended_task_id = 1; // Whether to kill the worker. bool force_kill = 2; + // Whether to recursively cancel tasks. + bool recursive = 3; } message CancelTaskReply { @@ -217,6 +219,8 @@ message RemoteCancelTaskRequest { bytes remote_object_id = 1; // Whether to kill the worker. bool force_kill = 2; + // Whether to recursively cancel tasks. + bool recursive = 3; } message RemoteCancelTaskReply {