mirror of
https://github.com/wassname/ray.git
synced 2026-07-03 16:24:23 +08:00
[core] Add Recursive task cancelation (#11923)
This commit is contained in:
@@ -1170,13 +1170,14 @@ cdef class CoreWorker:
|
||||
check_status(CCoreWorkerProcess.GetCoreWorker().KillActor(
|
||||
c_actor_id, True, no_restart))
|
||||
|
||||
def cancel_task(self, ObjectRef object_ref, c_bool force_kill):
|
||||
def cancel_task(self, ObjectRef object_ref, c_bool force_kill,
|
||||
c_bool recursive):
|
||||
cdef:
|
||||
CObjectID c_object_id = object_ref.native()
|
||||
CRayStatus status = CRayStatus.OK()
|
||||
|
||||
status = CCoreWorkerProcess.GetCoreWorker().CancelTask(
|
||||
c_object_id, force_kill)
|
||||
c_object_id, force_kill, recursive)
|
||||
|
||||
if not status.ok():
|
||||
raise TypeError(status.message().decode())
|
||||
|
||||
@@ -110,7 +110,8 @@ cdef extern from "ray/core_worker/core_worker.h" nogil:
|
||||
CRayStatus KillActor(
|
||||
const CActorID &actor_id, c_bool force_kill,
|
||||
c_bool no_restart)
|
||||
CRayStatus CancelTask(const CObjectID &object_id, c_bool force_kill)
|
||||
CRayStatus CancelTask(const CObjectID &object_id, c_bool force_kill,
|
||||
c_bool recursive)
|
||||
|
||||
unique_ptr[CProfileEvent] CreateProfileEvent(
|
||||
const c_string &event_type)
|
||||
|
||||
@@ -258,5 +258,37 @@ def test_remote_cancel(ray_start_regular, use_force):
|
||||
ray.get(inner, timeout=10)
|
||||
|
||||
|
||||
@pytest.mark.parametrize("use_force", [True, False])
|
||||
def test_recursive_cancel(shutdown_only, use_force):
|
||||
ray.init(num_cpus=4)
|
||||
|
||||
@ray.remote(num_cpus=1)
|
||||
def inner():
|
||||
while True:
|
||||
time.sleep(0.1)
|
||||
|
||||
@ray.remote(num_cpus=1)
|
||||
def outer():
|
||||
|
||||
x = [inner.remote()]
|
||||
print(x)
|
||||
while True:
|
||||
time.sleep(0.1)
|
||||
|
||||
@ray.remote(num_cpus=4)
|
||||
def many_resources():
|
||||
return 300
|
||||
|
||||
outer_fut = outer.remote()
|
||||
many_fut = many_resources.remote()
|
||||
with pytest.raises(GetTimeoutError):
|
||||
ray.get(many_fut, timeout=1)
|
||||
ray.cancel(outer_fut)
|
||||
with pytest.raises(valid_exceptions(use_force)):
|
||||
ray.get(outer_fut, timeout=10)
|
||||
|
||||
assert ray.get(many_fut, timeout=30)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
sys.exit(pytest.main(["-v", __file__]))
|
||||
|
||||
@@ -1575,7 +1575,7 @@ def kill(actor, *, no_restart=True):
|
||||
worker.core_worker.kill_actor(actor._ray_actor_id, no_restart)
|
||||
|
||||
|
||||
def cancel(object_ref, *, force=False):
|
||||
def cancel(object_ref, *, force=False, recursive=True):
|
||||
"""Cancels a task according to the following conditions.
|
||||
|
||||
If the specified task is pending execution, it will not be executed. If
|
||||
@@ -1595,6 +1595,8 @@ def cancel(object_ref, *, force=False):
|
||||
that should be canceled.
|
||||
force (boolean): Whether to force-kill a running task by killing
|
||||
the worker that is running the task.
|
||||
recursive (boolean): Whether to try to cancel tasks submitted by the
|
||||
task specified.
|
||||
Raises:
|
||||
TypeError: This is also raised for actor tasks.
|
||||
"""
|
||||
@@ -1605,7 +1607,7 @@ def cancel(object_ref, *, force=False):
|
||||
raise TypeError(
|
||||
"ray.cancel() only supported for non-actor object refs. "
|
||||
f"Got: {type(object_ref)}.")
|
||||
return worker.core_worker.cancel_task(object_ref, force)
|
||||
return worker.core_worker.cancel_task(object_ref, force, recursive)
|
||||
|
||||
|
||||
def _mode(worker=global_worker):
|
||||
|
||||
@@ -1518,7 +1518,8 @@ void CoreWorker::SubmitActorTask(const ActorID &actor_id, const RayFunction &fun
|
||||
}
|
||||
}
|
||||
|
||||
Status CoreWorker::CancelTask(const ObjectID &object_id, bool force_kill) {
|
||||
Status CoreWorker::CancelTask(const ObjectID &object_id, bool force_kill,
|
||||
bool recursive) {
|
||||
if (actor_manager_->CheckActorHandleExists(object_id.TaskId().ActorId())) {
|
||||
return Status::Invalid("Actor task cancellation is not supported.");
|
||||
}
|
||||
@@ -1527,16 +1528,36 @@ Status CoreWorker::CancelTask(const ObjectID &object_id, bool force_kill) {
|
||||
return Status::Invalid("No owner found for object.");
|
||||
}
|
||||
if (obj_addr.SerializeAsString() != rpc_address_.SerializeAsString()) {
|
||||
return direct_task_submitter_->CancelRemoteTask(object_id, obj_addr, force_kill);
|
||||
return direct_task_submitter_->CancelRemoteTask(object_id, obj_addr, force_kill,
|
||||
recursive);
|
||||
}
|
||||
|
||||
auto task_spec = task_manager_->GetTaskSpec(object_id.TaskId());
|
||||
if (task_spec.has_value() && !task_spec.value().IsActorCreationTask()) {
|
||||
return direct_task_submitter_->CancelTask(task_spec.value(), force_kill);
|
||||
return direct_task_submitter_->CancelTask(task_spec.value(), force_kill, recursive);
|
||||
}
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
Status CoreWorker::CancelChildren(const TaskID &task_id, bool force_kill) {
|
||||
bool recursive_success = true;
|
||||
for (const auto &child_id : task_manager_->GetPendingChildrenTasks(task_id)) {
|
||||
auto child_spec = task_manager_->GetTaskSpec(child_id);
|
||||
if (child_spec.has_value()) {
|
||||
auto result =
|
||||
direct_task_submitter_->CancelTask(child_spec.value(), force_kill, true);
|
||||
recursive_success = recursive_success && result.ok();
|
||||
} else {
|
||||
recursive_success = false;
|
||||
}
|
||||
}
|
||||
if (recursive_success) {
|
||||
return Status::OK();
|
||||
} else {
|
||||
return Status::UnknownError("Recursive task cancelation failed--check warning logs.");
|
||||
}
|
||||
}
|
||||
|
||||
Status CoreWorker::KillActor(const ActorID &actor_id, bool force_kill, bool no_restart) {
|
||||
if (options_.is_local_mode) {
|
||||
return KillActorLocalMode(actor_id);
|
||||
@@ -2157,8 +2178,8 @@ void CoreWorker::HandleWaitForRefRemoved(const rpc::WaitForRefRemovedRequest &re
|
||||
void CoreWorker::HandleRemoteCancelTask(const rpc::RemoteCancelTaskRequest &request,
|
||||
rpc::RemoteCancelTaskReply *reply,
|
||||
rpc::SendReplyCallback send_reply_callback) {
|
||||
auto status =
|
||||
CancelTask(ObjectID::FromBinary(request.remote_object_id()), request.force_kill());
|
||||
auto status = CancelTask(ObjectID::FromBinary(request.remote_object_id()),
|
||||
request.force_kill(), request.recursive());
|
||||
send_reply_callback(status, nullptr, nullptr);
|
||||
}
|
||||
|
||||
@@ -2174,6 +2195,12 @@ void CoreWorker::HandleCancelTask(const rpc::CancelTaskRequest &request,
|
||||
RAY_LOG(INFO) << "Interrupting a running task " << main_thread_task_id_;
|
||||
success = options_.kill_main();
|
||||
}
|
||||
if (request.recursive()) {
|
||||
auto recursive_cancel = CancelChildren(task_id, request.force_kill());
|
||||
if (recursive_cancel.ok()) {
|
||||
RAY_LOG(INFO) << "Recursive cancel failed!";
|
||||
}
|
||||
}
|
||||
|
||||
reply->set_attempt_succeeded(success);
|
||||
send_reply_callback(Status::OK(), nullptr, nullptr);
|
||||
|
||||
@@ -714,8 +714,10 @@ class CoreWorker : public rpc::CoreWorkerServiceHandler {
|
||||
///
|
||||
/// \param[in] object_id of the task to kill (must be a Non-Actor task)
|
||||
/// \param[in] force_kill Whether to force kill a task by killing the worker.
|
||||
/// \param[in] recursive Whether to cancel tasks submitted by the task to cancel.
|
||||
/// \param[out] Status
|
||||
Status CancelTask(const ObjectID &object_id, bool force_kill);
|
||||
Status CancelTask(const ObjectID &object_id, bool force_kill, bool recursive);
|
||||
|
||||
/// Decrease the reference count for this actor. Should be called by the
|
||||
/// language frontend when a reference to the ActorHandle destroyed.
|
||||
///
|
||||
@@ -946,6 +948,12 @@ class CoreWorker : public rpc::CoreWorkerServiceHandler {
|
||||
reference_counter_->AddLocalReference(object_id, call_site);
|
||||
}
|
||||
|
||||
/// Stops the children tasks from the given TaskID
|
||||
///
|
||||
/// \param[in] task_id of the parent task
|
||||
/// \param[in] force_kill Whether to force kill a task by killing the worker.
|
||||
Status CancelChildren(const TaskID &task_id, bool force_kill);
|
||||
|
||||
///
|
||||
/// Private methods related to task execution. Should not be used by driver processes.
|
||||
///
|
||||
|
||||
@@ -456,4 +456,19 @@ absl::optional<TaskSpecification> TaskManager::GetTaskSpec(const TaskID &task_id
|
||||
return it->second.spec;
|
||||
}
|
||||
|
||||
std::vector<TaskID> TaskManager::GetPendingChildrenTasks(
|
||||
const TaskID &parent_task_id) const {
|
||||
std::vector<TaskID> ret_vec;
|
||||
absl::MutexLock lock(&mu_);
|
||||
RAY_LOG(ERROR) << " calling get children tasks";
|
||||
RAY_LOG(ERROR) << "NUMBER OF PENDING TASKS: " << num_pending_tasks_;
|
||||
for (auto it : submissible_tasks_) {
|
||||
RAY_LOG(ERROR) << "Getting tasks!! " << it.second.spec.TaskId();
|
||||
if (it.second.pending and it.second.spec.ParentTaskId() == parent_task_id) {
|
||||
ret_vec.push_back(it.first);
|
||||
}
|
||||
}
|
||||
return ret_vec;
|
||||
}
|
||||
|
||||
} // namespace ray
|
||||
|
||||
@@ -140,6 +140,9 @@ class TaskManager : public TaskFinisherInterface, public TaskResubmissionInterfa
|
||||
/// Return the spec for a pending task.
|
||||
absl::optional<TaskSpecification> GetTaskSpec(const TaskID &task_id) const;
|
||||
|
||||
/// Return specs for pending children tasks of the given parent task.
|
||||
std::vector<TaskID> GetPendingChildrenTasks(const TaskID &parent_task_id) const;
|
||||
|
||||
/// Return whether this task can be submitted for execution.
|
||||
///
|
||||
/// \param[in] task_id ID of the task to query.
|
||||
|
||||
@@ -1065,7 +1065,7 @@ TEST(DirectTaskTransportTest, TestKillExecutingTask) {
|
||||
ASSERT_TRUE(raylet_client->GrantWorkerLease("localhost", 1234, NodeID::Nil()));
|
||||
|
||||
// Try force kill, exiting the worker
|
||||
ASSERT_TRUE(submitter.CancelTask(task, true).ok());
|
||||
ASSERT_TRUE(submitter.CancelTask(task, true, false).ok());
|
||||
ASSERT_EQ(worker_client->kill_requests.front().intended_task_id(),
|
||||
task.TaskId().Binary());
|
||||
ASSERT_TRUE(worker_client->ReplyPushTask(Status::IOError("workerdying"), true));
|
||||
@@ -1081,7 +1081,7 @@ TEST(DirectTaskTransportTest, TestKillExecutingTask) {
|
||||
ASSERT_TRUE(raylet_client->GrantWorkerLease("localhost", 1234, NodeID::Nil()));
|
||||
|
||||
// Try non-force kill, worker returns normally
|
||||
ASSERT_TRUE(submitter.CancelTask(task, false).ok());
|
||||
ASSERT_TRUE(submitter.CancelTask(task, false, false).ok());
|
||||
ASSERT_TRUE(worker_client->ReplyPushTask());
|
||||
ASSERT_EQ(worker_client->kill_requests.front().intended_task_id(),
|
||||
task.TaskId().Binary());
|
||||
@@ -1114,7 +1114,7 @@ TEST(DirectTaskTransportTest, TestKillPendingTask) {
|
||||
TaskSpecification task = BuildTaskSpec(empty_resources, empty_descriptor);
|
||||
|
||||
ASSERT_TRUE(submitter.SubmitTask(task).ok());
|
||||
ASSERT_TRUE(submitter.CancelTask(task, true).ok());
|
||||
ASSERT_TRUE(submitter.CancelTask(task, true, false).ok());
|
||||
ASSERT_EQ(worker_client->kill_requests.size(), 0);
|
||||
ASSERT_EQ(worker_client->callbacks.size(), 0);
|
||||
ASSERT_EQ(raylet_client->num_workers_returned, 0);
|
||||
@@ -1152,7 +1152,7 @@ TEST(DirectTaskTransportTest, TestKillResolvingTask) {
|
||||
task.GetMutableMessage().add_args()->mutable_object_ref()->set_object_id(obj1.Binary());
|
||||
ASSERT_TRUE(submitter.SubmitTask(task).ok());
|
||||
ASSERT_EQ(task_finisher->num_inlined_dependencies, 0);
|
||||
ASSERT_TRUE(submitter.CancelTask(task, true).ok());
|
||||
ASSERT_TRUE(submitter.CancelTask(task, true, false).ok());
|
||||
auto data = GenerateRandomObject();
|
||||
ASSERT_TRUE(store->Put(*data, obj1));
|
||||
ASSERT_EQ(worker_client->kill_requests.size(), 0);
|
||||
|
||||
@@ -408,7 +408,7 @@ void CoreWorkerDirectTaskSubmitter::PushNormalTask(
|
||||
}
|
||||
|
||||
Status CoreWorkerDirectTaskSubmitter::CancelTask(TaskSpecification task_spec,
|
||||
bool force_kill) {
|
||||
bool force_kill, bool recursive) {
|
||||
RAY_LOG(INFO) << "Killing task: " << task_spec.TaskId();
|
||||
const SchedulingKey scheduling_key(
|
||||
task_spec.GetSchedulingClass(), task_spec.GetDependencyIds(),
|
||||
@@ -470,8 +470,9 @@ Status CoreWorkerDirectTaskSubmitter::CancelTask(TaskSpecification task_spec,
|
||||
auto request = rpc::CancelTaskRequest();
|
||||
request.set_intended_task_id(task_spec.TaskId().Binary());
|
||||
request.set_force_kill(force_kill);
|
||||
request.set_recursive(recursive);
|
||||
client->CancelTask(
|
||||
request, [this, task_spec, scheduling_key, force_kill](
|
||||
request, [this, task_spec, scheduling_key, force_kill, recursive](
|
||||
const Status &status, const rpc::CancelTaskReply &reply) {
|
||||
absl::MutexLock lock(&mu_);
|
||||
cancelled_tasks_.erase(task_spec.TaskId());
|
||||
@@ -483,8 +484,9 @@ Status CoreWorkerDirectTaskSubmitter::CancelTask(TaskSpecification task_spec,
|
||||
cancel_retry_timer_->expires_after(boost::asio::chrono::milliseconds(
|
||||
RayConfig::instance().cancellation_retry_ms()));
|
||||
}
|
||||
cancel_retry_timer_->async_wait(boost::bind(
|
||||
&CoreWorkerDirectTaskSubmitter::CancelTask, this, task_spec, force_kill));
|
||||
cancel_retry_timer_->async_wait(
|
||||
boost::bind(&CoreWorkerDirectTaskSubmitter::CancelTask, this, task_spec,
|
||||
force_kill, recursive));
|
||||
}
|
||||
}
|
||||
// Retry is not attempted if !status.ok() because force-kill may kill the worker
|
||||
@@ -495,7 +497,7 @@ Status CoreWorkerDirectTaskSubmitter::CancelTask(TaskSpecification task_spec,
|
||||
|
||||
Status CoreWorkerDirectTaskSubmitter::CancelRemoteTask(const ObjectID &object_id,
|
||||
const rpc::Address &worker_addr,
|
||||
bool force_kill) {
|
||||
bool force_kill, bool recursive) {
|
||||
auto maybe_client = client_cache_->GetByID(rpc::WorkerAddress(worker_addr).worker_id);
|
||||
|
||||
if (!maybe_client.has_value()) {
|
||||
@@ -504,6 +506,7 @@ Status CoreWorkerDirectTaskSubmitter::CancelRemoteTask(const ObjectID &object_id
|
||||
auto client = maybe_client.value();
|
||||
auto request = rpc::RemoteCancelTaskRequest();
|
||||
request.set_force_kill(force_kill);
|
||||
request.set_recursive(recursive);
|
||||
request.set_remote_object_id(object_id.Binary());
|
||||
client->RemoteCancelTask(request, nullptr);
|
||||
return Status::OK();
|
||||
|
||||
@@ -81,11 +81,10 @@ class CoreWorkerDirectTaskSubmitter {
|
||||
///
|
||||
/// \param[in] task_spec The task to kill.
|
||||
/// \param[in] force_kill Whether to kill the worker executing the task.
|
||||
Status CancelTask(TaskSpecification task_spec, bool force_kill);
|
||||
Status CancelTask(TaskSpecification task_spec, bool force_kill, bool recursive);
|
||||
|
||||
Status CancelRemoteTask(const ObjectID &object_id, const rpc::Address &worker_addr,
|
||||
bool force_kill);
|
||||
|
||||
bool force_kill, bool recursive);
|
||||
/// Check that the scheduling_key_entries_ hashmap is empty by calling the private
|
||||
/// CheckNoSchedulingKeyEntries function after acquiring the lock.
|
||||
bool CheckNoSchedulingKeyEntriesPublic() {
|
||||
|
||||
@@ -205,6 +205,8 @@ message CancelTaskRequest {
|
||||
bytes intended_task_id = 1;
|
||||
// Whether to kill the worker.
|
||||
bool force_kill = 2;
|
||||
// Whether to recursively cancel tasks.
|
||||
bool recursive = 3;
|
||||
}
|
||||
|
||||
message CancelTaskReply {
|
||||
@@ -217,6 +219,8 @@ message RemoteCancelTaskRequest {
|
||||
bytes remote_object_id = 1;
|
||||
// Whether to kill the worker.
|
||||
bool force_kill = 2;
|
||||
// Whether to recursively cancel tasks.
|
||||
bool recursive = 3;
|
||||
}
|
||||
|
||||
message RemoteCancelTaskReply {
|
||||
|
||||
Reference in New Issue
Block a user