mirror of
https://github.com/wassname/ray.git
synced 2026-06-27 18:06:25 +08:00
TaskCancellation (#7669)
* Smol comment * WIP, not passing ray.init * Fixed small problem * wip * Pseudo interrupt things * Basic prototype operational * correct proc title * Mostly done * Cleanup * cleaner raylet error * Cleaning up a few loose ends * Fixing Race Conds * Prelim testing * Fixing comments and adding second_check for kill * Working_new_impl * demo_ready * Fixing my english * Fixing a few problems * Small problems * Cleaning up * Response to changes * Fixing error passing * Merged to master * fixing lock * Cleaning up print statements * Format * Fixing Unit test build failure * mock_worker fix * java_fix * Canel * Switching to Cancel * Responding to Review * FixFormatting * Lease cancellation * FInal comments? * Moving exist check to CoreWorker * Fix Actor Transport Test * Fixing task manager test * chaning clock repr * Fix build * fix white space * lint fix * Updating to medium size * Fixing Java test compilation issue * lengthen bad timeouts
This commit is contained in:
@@ -66,6 +66,7 @@ from ray.worker import (
|
||||
LOCAL_MODE,
|
||||
SCRIPT_MODE,
|
||||
WORKER_MODE,
|
||||
cancel,
|
||||
connect,
|
||||
disconnect,
|
||||
get,
|
||||
@@ -113,6 +114,7 @@ __all__ = [
|
||||
"_config",
|
||||
"_get_runtime_context",
|
||||
"actor",
|
||||
"cancel",
|
||||
"connect",
|
||||
"disconnect",
|
||||
"get",
|
||||
|
||||
+39
-6
@@ -17,6 +17,8 @@ import logging
|
||||
import os
|
||||
import pickle
|
||||
import sys
|
||||
import _thread
|
||||
import setproctitle
|
||||
|
||||
from libc.stdint cimport (
|
||||
int32_t,
|
||||
@@ -90,6 +92,7 @@ from ray.exceptions import (
|
||||
RayTaskError,
|
||||
ObjectStoreFullError,
|
||||
RayTimeoutError,
|
||||
RayCancellationError
|
||||
)
|
||||
from ray.utils import decode
|
||||
import gc
|
||||
@@ -453,13 +456,23 @@ cdef execute_task(
|
||||
class_name, repr(args), repr(kwargs))
|
||||
core_worker.set_actor_title(actor_title.encode("utf-8"))
|
||||
# Execute the task.
|
||||
with ray.worker._changeproctitle(title, next_title):
|
||||
with core_worker.profile_event(b"task:execute"):
|
||||
task_exception = True
|
||||
outputs = function_executor(*args, **kwargs)
|
||||
with core_worker.profile_event(b"task:execute"):
|
||||
task_exception = True
|
||||
try:
|
||||
with ray.worker._changeproctitle(title, next_title):
|
||||
outputs = function_executor(*args, **kwargs)
|
||||
task_exception = False
|
||||
if c_return_ids.size() == 1:
|
||||
outputs = (outputs,)
|
||||
except KeyboardInterrupt as e:
|
||||
raise RayCancellationError(
|
||||
core_worker.get_current_task_id())
|
||||
if c_return_ids.size() == 1:
|
||||
outputs = (outputs,)
|
||||
# Check for a cancellation that was called when the function
|
||||
# was exiting and was raised after the except block.
|
||||
if not check_signals().ok():
|
||||
task_exception = True
|
||||
raise RayCancellationError(
|
||||
core_worker.get_current_task_id())
|
||||
# Store the outputs in the object store.
|
||||
with core_worker.profile_event(b"task:store_outputs"):
|
||||
core_worker.store_task_outputs(
|
||||
@@ -551,6 +564,14 @@ cdef void async_plasma_callback(CObjectID object_id,
|
||||
event_handler._loop.call_soon_threadsafe(
|
||||
event_handler._complete_future, obj_id)
|
||||
|
||||
cdef c_bool kill_main_task() nogil:
|
||||
with gil:
|
||||
if setproctitle.getproctitle() != "ray::IDLE":
|
||||
_thread.interrupt_main()
|
||||
return True
|
||||
return False
|
||||
|
||||
|
||||
cdef CRayStatus check_signals() nogil:
|
||||
with gil:
|
||||
try:
|
||||
@@ -658,6 +679,7 @@ cdef class CoreWorker:
|
||||
options.ref_counting_enabled = True
|
||||
options.is_local_mode = local_mode
|
||||
options.num_workers = 1
|
||||
options.kill_main = kill_main_task
|
||||
|
||||
CCoreWorkerProcess.Initialize(options)
|
||||
|
||||
@@ -953,6 +975,17 @@ cdef class CoreWorker:
|
||||
check_status(CCoreWorkerProcess.GetCoreWorker().KillActor(
|
||||
c_actor_id, True, no_reconstruction))
|
||||
|
||||
def cancel_task(self, ObjectID object_id, c_bool force_kill):
|
||||
cdef:
|
||||
CObjectID c_object_id = object_id.native()
|
||||
CRayStatus status = CRayStatus.OK()
|
||||
|
||||
status = CCoreWorkerProcess.GetCoreWorker().CancelTask(
|
||||
c_object_id, force_kill)
|
||||
|
||||
if not status.ok():
|
||||
raise TypeError(status.message().decode())
|
||||
|
||||
def resource_ids(self):
|
||||
cdef:
|
||||
ResourceMappingType resource_mapping = (
|
||||
|
||||
@@ -16,6 +16,23 @@ class RayConnectionError(RayError):
|
||||
pass
|
||||
|
||||
|
||||
class RayCancellationError(RayError):
|
||||
"""Raised when this task is cancelled.
|
||||
|
||||
Attributes:
|
||||
task_id (TaskID): The TaskID of the function that was directly
|
||||
cancelled.
|
||||
"""
|
||||
|
||||
def __init__(self, task_id=None):
|
||||
self.task_id = task_id
|
||||
|
||||
def __str__(self):
|
||||
if self.task_id is None:
|
||||
return "This task or its dependency was cancelled by"
|
||||
return "Task: " + str(self.task_id) + " was cancelled"
|
||||
|
||||
|
||||
class RayTaskError(RayError):
|
||||
"""Indicates that a task threw an exception during execution.
|
||||
|
||||
|
||||
@@ -97,6 +97,7 @@ cdef extern from "ray/core_worker/core_worker.h" nogil:
|
||||
CRayStatus KillActor(
|
||||
const CActorID &actor_id, c_bool force_kill,
|
||||
c_bool no_reconstruction)
|
||||
CRayStatus CancelTask(const CObjectID &object_id, c_bool force_kill)
|
||||
|
||||
unique_ptr[CProfileEvent] CreateProfileEvent(
|
||||
const c_string &event_type)
|
||||
@@ -214,6 +215,7 @@ cdef extern from "ray/core_worker/core_worker.h" nogil:
|
||||
c_bool ref_counting_enabled
|
||||
c_bool is_local_mode
|
||||
int num_workers
|
||||
(c_bool() nogil) kill_main
|
||||
CCoreWorkerOptions()
|
||||
|
||||
cdef cppclass CCoreWorkerProcess "ray::CoreWorkerProcess":
|
||||
|
||||
@@ -122,6 +122,8 @@ cdef extern from "ray/common/id.h" namespace "ray" nogil:
|
||||
CTaskID ForNormalTask(CJobID job_id, CTaskID parent_task_id,
|
||||
int64_t parent_task_counter)
|
||||
|
||||
CActorID ActorId() const
|
||||
|
||||
cdef cppclass CObjectID" ray::ObjectID"(CBaseID[CObjectID]):
|
||||
|
||||
@staticmethod
|
||||
|
||||
@@ -220,6 +220,9 @@ cdef class TaskID(BaseID):
|
||||
def is_nil(self):
|
||||
return self.data.IsNil()
|
||||
|
||||
def actor_id(self):
|
||||
return ActorID(self.data.ActorId().Binary())
|
||||
|
||||
cdef size_t hash(self):
|
||||
return self.data.Hash()
|
||||
|
||||
|
||||
@@ -12,6 +12,7 @@ from ray.exceptions import (
|
||||
PlasmaObjectNotAvailable,
|
||||
RayTaskError,
|
||||
RayActorError,
|
||||
RayCancellationError,
|
||||
RayWorkerError,
|
||||
UnreconstructableError,
|
||||
)
|
||||
@@ -279,6 +280,8 @@ class SerializationContext:
|
||||
return RayWorkerError()
|
||||
elif error_type == ErrorType.Value("ACTOR_DIED"):
|
||||
return RayActorError()
|
||||
elif error_type == ErrorType.Value("TASK_CANCELLED"):
|
||||
return RayCancellationError()
|
||||
elif error_type == ErrorType.Value("OBJECT_UNRECONSTRUCTABLE"):
|
||||
return UnreconstructableError(ray.ObjectID(object_id.binary()))
|
||||
else:
|
||||
|
||||
@@ -414,3 +414,11 @@ py_test(
|
||||
tags = ["exclusive"],
|
||||
deps = ["//:ray_lib"],
|
||||
)
|
||||
|
||||
py_test(
|
||||
name = "test_cancel",
|
||||
size = "medium",
|
||||
srcs = ["test_cancel.py"],
|
||||
tags = ["exclusive"],
|
||||
deps = ["//:ray_lib"],
|
||||
)
|
||||
|
||||
@@ -0,0 +1,232 @@
|
||||
import pytest
|
||||
import ray
|
||||
import random
|
||||
import sys
|
||||
import time
|
||||
from ray.exceptions import RayTaskError, RayTimeoutError, \
|
||||
RayCancellationError, RayWorkerError
|
||||
from ray.test_utils import SignalActor
|
||||
|
||||
|
||||
def valid_exceptions(use_force):
|
||||
if use_force:
|
||||
return (RayTaskError, RayCancellationError, RayWorkerError)
|
||||
else:
|
||||
return (RayTaskError, RayCancellationError)
|
||||
|
||||
|
||||
@pytest.mark.parametrize("use_force", [True, False])
|
||||
def test_cancel_chain(ray_start_regular, use_force):
|
||||
signaler = SignalActor.remote()
|
||||
|
||||
@ray.remote
|
||||
def wait_for(t):
|
||||
return ray.get(t[0])
|
||||
|
||||
obj1 = wait_for.remote([signaler.wait.remote()])
|
||||
obj2 = wait_for.remote([obj1])
|
||||
obj3 = wait_for.remote([obj2])
|
||||
obj4 = wait_for.remote([obj3])
|
||||
|
||||
assert len(ray.wait([obj1], timeout=.1)[0]) == 0
|
||||
ray.cancel(obj1, use_force)
|
||||
for ob in [obj1, obj2, obj3, obj4]:
|
||||
with pytest.raises(valid_exceptions(use_force)):
|
||||
ray.get(ob)
|
||||
|
||||
signaler2 = SignalActor.remote()
|
||||
obj1 = wait_for.remote([signaler2.wait.remote()])
|
||||
obj2 = wait_for.remote([obj1])
|
||||
obj3 = wait_for.remote([obj2])
|
||||
obj4 = wait_for.remote([obj3])
|
||||
|
||||
assert len(ray.wait([obj3], timeout=.1)[0]) == 0
|
||||
ray.cancel(obj3, use_force)
|
||||
for ob in [obj3, obj4]:
|
||||
with pytest.raises(valid_exceptions(use_force)):
|
||||
ray.get(ob)
|
||||
|
||||
with pytest.raises(RayTimeoutError):
|
||||
ray.get(obj1, timeout=.1)
|
||||
|
||||
with pytest.raises(RayTimeoutError):
|
||||
ray.get(obj2, timeout=.1)
|
||||
|
||||
signaler2.send.remote()
|
||||
ray.get(obj1, timeout=10)
|
||||
|
||||
|
||||
@pytest.mark.parametrize("use_force", [True, False])
|
||||
def test_cancel_multiple_dependents(ray_start_regular, use_force):
|
||||
signaler = SignalActor.remote()
|
||||
|
||||
@ray.remote
|
||||
def wait_for(t):
|
||||
return ray.get(t[0])
|
||||
|
||||
head = wait_for.remote([signaler.wait.remote()])
|
||||
deps = []
|
||||
for _ in range(3):
|
||||
deps.append(wait_for.remote([head]))
|
||||
|
||||
assert len(ray.wait([head], timeout=.1)[0]) == 0
|
||||
ray.cancel(head, use_force)
|
||||
for d in deps:
|
||||
with pytest.raises(valid_exceptions(use_force)):
|
||||
ray.get(d)
|
||||
|
||||
head2 = wait_for.remote([signaler.wait.remote()])
|
||||
|
||||
deps2 = []
|
||||
for _ in range(3):
|
||||
deps2.append(wait_for.remote([head]))
|
||||
|
||||
for d in deps2:
|
||||
ray.cancel(d, use_force)
|
||||
|
||||
for d in deps2:
|
||||
with pytest.raises(valid_exceptions(use_force)):
|
||||
ray.get(d)
|
||||
|
||||
signaler.send.remote()
|
||||
ray.get(head2, timeout=1)
|
||||
|
||||
|
||||
@pytest.mark.parametrize("use_force", [True, False])
|
||||
def test_single_cpu_cancel(shutdown_only, use_force):
|
||||
ray.init(num_cpus=1)
|
||||
signaler = SignalActor.remote()
|
||||
|
||||
@ray.remote
|
||||
def wait_for(t):
|
||||
return ray.get(t[0])
|
||||
|
||||
obj1 = wait_for.remote([signaler.wait.remote()])
|
||||
obj2 = wait_for.remote([obj1])
|
||||
obj3 = wait_for.remote([obj2])
|
||||
indep = wait_for.remote([signaler.wait.remote()])
|
||||
|
||||
assert len(ray.wait([obj3], timeout=.1)[0]) == 0
|
||||
ray.cancel(obj3, use_force)
|
||||
with pytest.raises(valid_exceptions(use_force)):
|
||||
ray.get(obj3, 10)
|
||||
|
||||
ray.cancel(obj1, use_force)
|
||||
|
||||
for d in [obj1, obj2]:
|
||||
with pytest.raises(valid_exceptions(use_force)):
|
||||
ray.get(d)
|
||||
|
||||
signaler.send.remote()
|
||||
ray.get(indep)
|
||||
|
||||
|
||||
@pytest.mark.parametrize("use_force", [True, False])
|
||||
def test_comprehensive(ray_start_regular, use_force):
|
||||
signaler = SignalActor.remote()
|
||||
|
||||
@ray.remote
|
||||
def wait_for(t):
|
||||
ray.get(t[0])
|
||||
return "Result"
|
||||
|
||||
@ray.remote
|
||||
def combine(a, b):
|
||||
return str(a) + str(b)
|
||||
|
||||
a = wait_for.remote([signaler.wait.remote()])
|
||||
b = wait_for.remote([signaler.wait.remote()])
|
||||
combo = combine.remote(a, b)
|
||||
a2 = wait_for.remote([a])
|
||||
|
||||
assert len(ray.wait([a, b, a2, combo], timeout=1)[0]) == 0
|
||||
|
||||
ray.cancel(a, use_force)
|
||||
with pytest.raises(valid_exceptions(use_force)):
|
||||
ray.get(a, 10)
|
||||
|
||||
with pytest.raises(valid_exceptions(use_force)):
|
||||
ray.get(a2, 10)
|
||||
|
||||
signaler.send.remote()
|
||||
|
||||
with pytest.raises(valid_exceptions(use_force)):
|
||||
ray.get(combo, 10)
|
||||
|
||||
|
||||
@pytest.mark.parametrize("use_force", [True, False])
|
||||
def test_stress(shutdown_only, use_force):
|
||||
ray.init(num_cpus=1)
|
||||
|
||||
@ray.remote
|
||||
def infinite_sleep(y):
|
||||
if y:
|
||||
while True:
|
||||
time.sleep(1 / 10)
|
||||
|
||||
first = infinite_sleep.remote(True)
|
||||
|
||||
sleep_or_no = [random.randint(0, 1) for _ in range(100)]
|
||||
tasks = [infinite_sleep.remote(i) for i in sleep_or_no]
|
||||
cancelled = set()
|
||||
for t in tasks:
|
||||
if random.random() > 0.5:
|
||||
ray.cancel(t, use_force)
|
||||
cancelled.add(t)
|
||||
|
||||
ray.cancel(first, use_force)
|
||||
cancelled.add(first)
|
||||
|
||||
for done in cancelled:
|
||||
with pytest.raises(valid_exceptions(use_force)):
|
||||
ray.get(done, 10)
|
||||
|
||||
for indx in range(len(tasks)):
|
||||
t = tasks[indx]
|
||||
if sleep_or_no[indx]:
|
||||
ray.cancel(t, use_force)
|
||||
cancelled.add(t)
|
||||
if t in cancelled:
|
||||
with pytest.raises(valid_exceptions(use_force)):
|
||||
ray.get(t, 10)
|
||||
else:
|
||||
ray.get(t)
|
||||
|
||||
|
||||
@pytest.mark.parametrize("use_force", [True, False])
|
||||
def test_fast(shutdown_only, use_force):
|
||||
ray.init(num_cpus=2)
|
||||
|
||||
@ray.remote
|
||||
def fast(y):
|
||||
return y
|
||||
|
||||
signaler = SignalActor.remote()
|
||||
ids = list()
|
||||
for _ in range(100):
|
||||
x = fast.remote("a")
|
||||
ray.cancel(x)
|
||||
ids.append(x)
|
||||
|
||||
@ray.remote
|
||||
def wait_for(y):
|
||||
return y
|
||||
|
||||
sig = signaler.wait.remote()
|
||||
for _ in range(5000):
|
||||
x = wait_for.remote(sig)
|
||||
ids.append(x)
|
||||
|
||||
for idx in range(100, 5100):
|
||||
if random.random() > 0.95:
|
||||
ray.cancel(ids[idx])
|
||||
signaler.send.remote()
|
||||
for obj_id in ids:
|
||||
try:
|
||||
ray.get(obj_id, 10)
|
||||
except Exception as e:
|
||||
assert isinstance(e, valid_exceptions(use_force))
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
sys.exit(pytest.main(["-v", __file__]))
|
||||
+22
-1
@@ -1661,12 +1661,33 @@ def kill(actor):
|
||||
if not isinstance(actor, ray.actor.ActorHandle):
|
||||
raise ValueError("ray.kill() only supported for actors. "
|
||||
"Got: {}.".format(type(actor)))
|
||||
|
||||
worker = ray.worker.global_worker
|
||||
worker.check_connected()
|
||||
worker.core_worker.kill_actor(actor._ray_actor_id, False)
|
||||
|
||||
|
||||
def cancel(object_id, force=False):
|
||||
"""Kill a task forcefully.
|
||||
|
||||
This will interrupt any running tasks on the actor, causing them to fail
|
||||
immediately. Any atexit handlers installed in the actor will still be run.
|
||||
|
||||
If this actor is reconstructable, it will be attempted to be reconstructed.
|
||||
|
||||
Args:
|
||||
id (ActorHandle or ObjectID): Handle for the actor to kill or ObjectID
|
||||
of the task to kill.
|
||||
"""
|
||||
worker = ray.worker.global_worker
|
||||
worker.check_connected()
|
||||
|
||||
if not isinstance(object_id, ray.ObjectID):
|
||||
raise TypeError(
|
||||
"ray.cancel() only supported for non-actor object IDs. "
|
||||
"Got: {}.".format(type(object_id)))
|
||||
return worker.core_worker.cancel_task(object_id, force)
|
||||
|
||||
|
||||
def _mode(worker=global_worker):
|
||||
"""This is a wrapper around worker.mode.
|
||||
|
||||
|
||||
@@ -278,6 +278,9 @@ RAY_CONFIG(uint32_t, object_store_full_initial_delay_ms, 1000)
|
||||
/// Duration to wait between retries for failed tasks.
|
||||
RAY_CONFIG(uint32_t, task_retry_delay_ms, 5000)
|
||||
|
||||
/// Duration to wait between retrying to kill a task.
|
||||
RAY_CONFIG(uint32_t, cancellation_retry_ms, 2000)
|
||||
|
||||
/// Whether to enable gcs service.
|
||||
/// RAY_GCS_SERVICE_ENABLED is an env variable which only set in ci job.
|
||||
/// If the value of RAY_GCS_SERVICE_ENABLED is false, we will disable gcs service,
|
||||
|
||||
@@ -429,7 +429,7 @@ CoreWorker::CoreWorker(const CoreWorkerOptions &options, const WorkerID &worker_
|
||||
rpc_address_, local_raylet_client_, client_factory, raylet_client_factory,
|
||||
memory_store_, task_manager_, local_raylet_id,
|
||||
RayConfig::instance().worker_lease_timeout_milliseconds(),
|
||||
std::move(actor_create_callback)));
|
||||
std::move(actor_create_callback), boost::asio::steady_timer(io_service_)));
|
||||
future_resolver_.reset(new FutureResolver(memory_store_, client_factory));
|
||||
// Unfortunately the raylet client has to be constructed after the receivers.
|
||||
if (direct_task_receiver_ != nullptr) {
|
||||
@@ -590,10 +590,10 @@ const WorkerID &CoreWorker::GetWorkerID() const { return worker_context_.GetWork
|
||||
|
||||
void CoreWorker::SetCurrentTaskId(const TaskID &task_id) {
|
||||
worker_context_.SetCurrentTaskId(task_id);
|
||||
main_thread_task_id_ = task_id;
|
||||
bool not_actor_task = false;
|
||||
{
|
||||
absl::MutexLock lock(&mutex_);
|
||||
main_thread_task_id_ = task_id;
|
||||
not_actor_task = actor_id_.IsNil();
|
||||
}
|
||||
if (not_actor_task && task_id.IsNil()) {
|
||||
@@ -1056,6 +1056,7 @@ TaskID CoreWorker::GetCallerId() const {
|
||||
if (!actor_id.IsNil()) {
|
||||
caller_id = TaskID::ForActorCreationTask(actor_id);
|
||||
} else {
|
||||
absl::MutexLock lock(&mutex_);
|
||||
caller_id = main_thread_task_id_;
|
||||
}
|
||||
return caller_id;
|
||||
@@ -1203,6 +1204,25 @@ Status CoreWorker::SubmitActorTask(const ActorID &actor_id, const RayFunction &f
|
||||
return status;
|
||||
}
|
||||
|
||||
Status CoreWorker::CancelTask(const ObjectID &object_id, bool force_kill) {
|
||||
ActorHandle *h = nullptr;
|
||||
if (!object_id.CreatedByTask() ||
|
||||
GetActorHandle(object_id.TaskId().ActorId(), &h).ok()) {
|
||||
return Status::Invalid("Actor task cancellation is not supported.");
|
||||
}
|
||||
rpc::Address obj_addr;
|
||||
if (!reference_counter_->GetOwner(object_id, nullptr, &obj_addr) ||
|
||||
obj_addr.SerializeAsString() != rpc_address_.SerializeAsString()) {
|
||||
return Status::Invalid("Task is not locally submitted.");
|
||||
}
|
||||
|
||||
auto task_spec = task_manager_->GetTaskSpec(object_id.TaskId());
|
||||
if (task_spec.has_value() && !task_spec.value().IsActorCreationTask()) {
|
||||
return direct_task_submitter_->CancelTask(task_spec.value(), force_kill);
|
||||
}
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
Status CoreWorker::KillActor(const ActorID &actor_id, bool force_kill,
|
||||
bool no_reconstruction) {
|
||||
ActorHandle *actor_handle = nullptr;
|
||||
@@ -1735,6 +1755,33 @@ void CoreWorker::HandleWaitForRefRemoved(const rpc::WaitForRefRemovedRequest &re
|
||||
owner_address, ref_removed_callback);
|
||||
}
|
||||
|
||||
void CoreWorker::HandleCancelTask(const rpc::CancelTaskRequest &request,
|
||||
rpc::CancelTaskReply *reply,
|
||||
rpc::SendReplyCallback send_reply_callback) {
|
||||
absl::MutexLock lock(&mutex_);
|
||||
TaskID task_id = TaskID::FromBinary(request.intended_task_id());
|
||||
bool success = main_thread_task_id_ == task_id;
|
||||
|
||||
// Try non-force kill
|
||||
if (success && !request.force_kill()) {
|
||||
RAY_LOG(INFO) << "Interrupting a running task " << main_thread_task_id_;
|
||||
success = options_.kill_main();
|
||||
}
|
||||
|
||||
reply->set_attempt_succeeded(success);
|
||||
send_reply_callback(Status::OK(), nullptr, nullptr);
|
||||
|
||||
// Do force kill after reply callback sent
|
||||
if (success && request.force_kill()) {
|
||||
RAY_LOG(INFO) << "Force killing a worker running " << main_thread_task_id_;
|
||||
RAY_IGNORE_EXPR(local_raylet_client_->Disconnect());
|
||||
if (options_.log_dir != "") {
|
||||
RayLog::ShutDownRayLog();
|
||||
}
|
||||
exit(1);
|
||||
}
|
||||
}
|
||||
|
||||
void CoreWorker::HandleKillActor(const rpc::KillActorRequest &request,
|
||||
rpc::KillActorReply *reply,
|
||||
rpc::SendReplyCallback send_reply_callback) {
|
||||
|
||||
@@ -106,6 +106,8 @@ struct CoreWorkerOptions {
|
||||
std::function<void()> gc_collect;
|
||||
/// Language worker callback to get the current call stack.
|
||||
std::function<void(std::string *)> get_lang_stack;
|
||||
// Function that tries to interrupt the currently running Python thread.
|
||||
std::function<bool()> kill_main;
|
||||
/// Whether to enable object ref counting.
|
||||
bool ref_counting_enabled;
|
||||
/// Is local mode being used.
|
||||
@@ -501,7 +503,7 @@ class CoreWorker : public rpc::CoreWorkerServiceHandler {
|
||||
/// For actors, this is the current actor ID. To make sure that all caller
|
||||
/// IDs have the same type, we embed the actor ID in a TaskID with the rest
|
||||
/// of the bytes zeroed out.
|
||||
TaskID GetCallerId() const;
|
||||
TaskID GetCallerId() const LOCKS_EXCLUDED(mutex_);
|
||||
|
||||
/// Push an error to the relevant driver.
|
||||
///
|
||||
@@ -586,6 +588,12 @@ class CoreWorker : public rpc::CoreWorkerServiceHandler {
|
||||
/// \param[out] Status
|
||||
Status KillActor(const ActorID &actor_id, bool force_kill, bool no_reconstruction);
|
||||
|
||||
/// Stops the task associated with the given Object ID.
|
||||
///
|
||||
/// \param[in] object_id of the task to kill (must be a Non-Actor task)
|
||||
/// \param[in] force_kill Whether to force kill a task by killing the worker.
|
||||
/// \param[out] Status
|
||||
Status CancelTask(const ObjectID &object_id, bool force_kill);
|
||||
/// Decrease the reference count for this actor. Should be called by the
|
||||
/// language frontend when a reference to the ActorHandle destroyed.
|
||||
///
|
||||
@@ -694,6 +702,11 @@ class CoreWorker : public rpc::CoreWorkerServiceHandler {
|
||||
void HandleKillActor(const rpc::KillActorRequest &request, rpc::KillActorReply *reply,
|
||||
rpc::SendReplyCallback send_reply_callback) override;
|
||||
|
||||
/// Implements gRPC server handler.
|
||||
void HandleCancelTask(const rpc::CancelTaskRequest &request,
|
||||
rpc::CancelTaskReply *reply,
|
||||
rpc::SendReplyCallback send_reply_callback);
|
||||
|
||||
/// Implements gRPC server handler.
|
||||
void HandlePlasmaObjectReady(const rpc::PlasmaObjectReadyRequest &request,
|
||||
rpc::PlasmaObjectReadyReply *reply,
|
||||
@@ -893,7 +906,7 @@ class CoreWorker : public rpc::CoreWorkerServiceHandler {
|
||||
/// The ID of the current task being executed by the main thread. If there
|
||||
/// are multiple threads, they will have a thread-local task ID stored in the
|
||||
/// worker context.
|
||||
TaskID main_thread_task_id_;
|
||||
TaskID main_thread_task_id_ GUARDED_BY(mutex_);
|
||||
|
||||
// Flag indicating whether this worker has been shut down.
|
||||
bool shutdown_ = false;
|
||||
|
||||
@@ -122,6 +122,7 @@ JNIEXPORT void JNICALL Java_io_ray_runtime_RayNativeRuntime_nativeInitialize(
|
||||
nullptr, // check_signals
|
||||
nullptr, // gc_collect
|
||||
nullptr, // get_lang_stack
|
||||
nullptr, // kill_main
|
||||
false, // ref_counting_enabled
|
||||
false, // is_local_mode
|
||||
static_cast<int>(numWorkersPerProcess), // num_workers
|
||||
|
||||
@@ -412,6 +412,15 @@ void TaskManager::RemoveLineageReference(const ObjectID &object_id,
|
||||
}
|
||||
}
|
||||
|
||||
bool TaskManager::MarkTaskCanceled(const TaskID &task_id) {
|
||||
absl::MutexLock lock(&mu_);
|
||||
auto it = submissible_tasks_.find(task_id);
|
||||
if (it != submissible_tasks_.end()) {
|
||||
it->second.num_retries_left = 0;
|
||||
}
|
||||
return it != submissible_tasks_.end();
|
||||
}
|
||||
|
||||
void TaskManager::MarkPendingTaskFailed(const TaskID &task_id,
|
||||
const TaskSpecification &spec,
|
||||
rpc::ErrorType error_type) {
|
||||
@@ -432,4 +441,13 @@ void TaskManager::MarkPendingTaskFailed(const TaskID &task_id,
|
||||
}
|
||||
}
|
||||
|
||||
absl::optional<TaskSpecification> TaskManager::GetTaskSpec(const TaskID &task_id) const {
|
||||
absl::MutexLock lock(&mu_);
|
||||
auto it = submissible_tasks_.find(task_id);
|
||||
if (it == submissible_tasks_.end()) {
|
||||
return absl::optional<TaskSpecification>();
|
||||
}
|
||||
return it->second.spec;
|
||||
}
|
||||
|
||||
} // namespace ray
|
||||
|
||||
@@ -39,6 +39,8 @@ class TaskFinisherInterface {
|
||||
const std::vector<ObjectID> &inlined_dependency_ids,
|
||||
const std::vector<ObjectID> &contained_ids) = 0;
|
||||
|
||||
virtual bool MarkTaskCanceled(const TaskID &task_id) = 0;
|
||||
|
||||
virtual ~TaskFinisherInterface() {}
|
||||
};
|
||||
|
||||
@@ -129,6 +131,15 @@ class TaskManager : public TaskFinisherInterface, public TaskResubmissionInterfa
|
||||
void OnTaskDependenciesInlined(const std::vector<ObjectID> &inlined_dependency_ids,
|
||||
const std::vector<ObjectID> &contained_ids) override;
|
||||
|
||||
/// Set number of retries to zero for a task that is being canceled.
|
||||
///
|
||||
/// \param[in] task_id to cancel.
|
||||
/// \return Whether the task was pending and was marked for cancellation.
|
||||
bool MarkTaskCanceled(const TaskID &task_id);
|
||||
|
||||
/// Return the spec for a pending task.
|
||||
absl::optional<TaskSpecification> GetTaskSpec(const TaskID &task_id) const;
|
||||
|
||||
/// Return whether this task can be submitted for execution.
|
||||
///
|
||||
/// \param[in] task_id ID of the task to query.
|
||||
|
||||
@@ -271,6 +271,7 @@ class CoreWorkerTest : public ::testing::Test {
|
||||
nullptr, // check_signals
|
||||
nullptr, // gc_collect
|
||||
nullptr, // get_lang_stack
|
||||
nullptr, // kill_main
|
||||
true, // ref_counting_enabled
|
||||
false, // is_local_mode
|
||||
1, // num_workers
|
||||
|
||||
@@ -64,6 +64,8 @@ class MockTaskFinisher : public TaskFinisherInterface {
|
||||
|
||||
MOCK_METHOD2(OnTaskDependenciesInlined,
|
||||
void(const std::vector<ObjectID> &, const std::vector<ObjectID> &));
|
||||
|
||||
MOCK_METHOD1(MarkTaskCanceled, bool(const TaskID &task_id));
|
||||
};
|
||||
|
||||
TaskSpecification CreateActorTaskHelper(ActorID actor_id, int64_t counter) {
|
||||
|
||||
@@ -52,7 +52,15 @@ class MockWorkerClient : public rpc::CoreWorkerClientInterface {
|
||||
return true;
|
||||
}
|
||||
|
||||
ray::Status CancelTask(
|
||||
const rpc::CancelTaskRequest &request,
|
||||
const rpc::ClientCallback<rpc::CancelTaskReply> &callback) override {
|
||||
kill_requests.push_front(request);
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
std::list<rpc::ClientCallback<rpc::PushTaskReply>> callbacks;
|
||||
std::list<rpc::CancelTaskRequest> kill_requests;
|
||||
};
|
||||
|
||||
class MockTaskFinisher : public TaskFinisherInterface {
|
||||
@@ -75,6 +83,8 @@ class MockTaskFinisher : public TaskFinisherInterface {
|
||||
num_contained_ids += contained_ids.size();
|
||||
}
|
||||
|
||||
bool MarkTaskCanceled(const TaskID &task_id) override { return true; }
|
||||
|
||||
int num_tasks_complete = 0;
|
||||
int num_tasks_failed = 0;
|
||||
int num_inlined_dependencies = 0;
|
||||
@@ -935,6 +945,105 @@ TEST(DirectTaskTransportTest, TestWorkerLeaseTimeout) {
|
||||
ASSERT_FALSE(raylet_client->ReplyCancelWorkerLease());
|
||||
}
|
||||
|
||||
TEST(DirectTaskTransportTest, TestKillExecutingTask) {
|
||||
rpc::Address address;
|
||||
auto raylet_client = std::make_shared<MockRayletClient>();
|
||||
auto worker_client = std::make_shared<MockWorkerClient>();
|
||||
auto store = std::make_shared<CoreWorkerMemoryStore>();
|
||||
auto factory = [&](const rpc::Address &addr) { return worker_client; };
|
||||
auto task_finisher = std::make_shared<MockTaskFinisher>();
|
||||
CoreWorkerDirectTaskSubmitter submitter(address, raylet_client, factory, nullptr, store,
|
||||
task_finisher, ClientID::Nil(), kLongTimeout);
|
||||
std::unordered_map<std::string, double> empty_resources;
|
||||
ray::FunctionDescriptor empty_descriptor =
|
||||
ray::FunctionDescriptorBuilder::BuildPython("", "", "", "");
|
||||
TaskSpecification task = BuildTaskSpec(empty_resources, empty_descriptor);
|
||||
|
||||
ASSERT_TRUE(submitter.SubmitTask(task).ok());
|
||||
ASSERT_TRUE(raylet_client->GrantWorkerLease("localhost", 1234, ClientID::Nil()));
|
||||
|
||||
// Try force kill, exiting the worker
|
||||
ASSERT_TRUE(submitter.CancelTask(task, true).ok());
|
||||
ASSERT_EQ(worker_client->kill_requests.front().intended_task_id(),
|
||||
task.TaskId().Binary());
|
||||
ASSERT_TRUE(worker_client->ReplyPushTask(Status::IOError("workerdying"), true));
|
||||
ASSERT_EQ(worker_client->callbacks.size(), 0);
|
||||
ASSERT_EQ(raylet_client->num_workers_returned, 0);
|
||||
ASSERT_EQ(raylet_client->num_workers_disconnected, 0);
|
||||
ASSERT_EQ(task_finisher->num_tasks_complete, 0);
|
||||
ASSERT_EQ(task_finisher->num_tasks_failed, 1);
|
||||
|
||||
task.GetMutableMessage().set_task_id(
|
||||
TaskID::ForNormalTask(JobID::Nil(), TaskID::Nil(), 1).Binary());
|
||||
ASSERT_TRUE(submitter.SubmitTask(task).ok());
|
||||
ASSERT_TRUE(raylet_client->GrantWorkerLease("localhost", 1234, ClientID::Nil()));
|
||||
|
||||
// Try non-force kill, worker returns normally
|
||||
ASSERT_TRUE(submitter.CancelTask(task, false).ok());
|
||||
ASSERT_TRUE(worker_client->ReplyPushTask());
|
||||
ASSERT_EQ(worker_client->kill_requests.front().intended_task_id(),
|
||||
task.TaskId().Binary());
|
||||
ASSERT_EQ(worker_client->callbacks.size(), 0);
|
||||
ASSERT_EQ(raylet_client->num_workers_returned, 1);
|
||||
ASSERT_EQ(raylet_client->num_workers_disconnected, 0);
|
||||
ASSERT_EQ(task_finisher->num_tasks_complete, 1);
|
||||
ASSERT_EQ(task_finisher->num_tasks_failed, 1);
|
||||
}
|
||||
|
||||
TEST(DirectTaskTransportTest, TestKillPendingTask) {
|
||||
rpc::Address address;
|
||||
auto raylet_client = std::make_shared<MockRayletClient>();
|
||||
auto worker_client = std::make_shared<MockWorkerClient>();
|
||||
auto store = std::make_shared<CoreWorkerMemoryStore>();
|
||||
auto factory = [&](const rpc::Address &addr) { return worker_client; };
|
||||
auto task_finisher = std::make_shared<MockTaskFinisher>();
|
||||
CoreWorkerDirectTaskSubmitter submitter(address, raylet_client, factory, nullptr, store,
|
||||
task_finisher, ClientID::Nil(), kLongTimeout);
|
||||
std::unordered_map<std::string, double> empty_resources;
|
||||
ray::FunctionDescriptor empty_descriptor =
|
||||
ray::FunctionDescriptorBuilder::BuildPython("", "", "", "");
|
||||
TaskSpecification task = BuildTaskSpec(empty_resources, empty_descriptor);
|
||||
|
||||
ASSERT_TRUE(submitter.SubmitTask(task).ok());
|
||||
ASSERT_TRUE(submitter.CancelTask(task, true).ok());
|
||||
ASSERT_EQ(worker_client->kill_requests.size(), 0);
|
||||
ASSERT_EQ(worker_client->callbacks.size(), 0);
|
||||
ASSERT_EQ(raylet_client->num_workers_returned, 0);
|
||||
ASSERT_EQ(raylet_client->num_workers_disconnected, 0);
|
||||
ASSERT_EQ(task_finisher->num_tasks_complete, 0);
|
||||
ASSERT_EQ(task_finisher->num_tasks_failed, 1);
|
||||
ASSERT_EQ(raylet_client->num_leases_canceled, 1);
|
||||
ASSERT_TRUE(raylet_client->ReplyCancelWorkerLease());
|
||||
}
|
||||
|
||||
TEST(DirectTaskTransportTest, TestKillResolvingTask) {
|
||||
rpc::Address address;
|
||||
auto raylet_client = std::make_shared<MockRayletClient>();
|
||||
auto worker_client = std::make_shared<MockWorkerClient>();
|
||||
auto store = std::make_shared<CoreWorkerMemoryStore>();
|
||||
auto factory = [&](const rpc::Address &addr) { return worker_client; };
|
||||
auto task_finisher = std::make_shared<MockTaskFinisher>();
|
||||
CoreWorkerDirectTaskSubmitter submitter(address, raylet_client, factory, nullptr, store,
|
||||
task_finisher, ClientID::Nil(), kLongTimeout);
|
||||
std::unordered_map<std::string, double> empty_resources;
|
||||
ray::FunctionDescriptor empty_descriptor =
|
||||
ray::FunctionDescriptorBuilder::BuildPython("", "", "", "");
|
||||
TaskSpecification task = BuildTaskSpec(empty_resources, empty_descriptor);
|
||||
ObjectID obj1 = ObjectID::FromRandom().WithTransportType(TaskTransportType::DIRECT);
|
||||
task.GetMutableMessage().add_args()->add_object_ids(obj1.Binary());
|
||||
ASSERT_TRUE(submitter.SubmitTask(task).ok());
|
||||
ASSERT_EQ(task_finisher->num_inlined_dependencies, 0);
|
||||
ASSERT_TRUE(submitter.CancelTask(task, true).ok());
|
||||
auto data = GenerateRandomObject();
|
||||
ASSERT_TRUE(store->Put(*data, obj1));
|
||||
ASSERT_EQ(worker_client->kill_requests.size(), 0);
|
||||
ASSERT_EQ(worker_client->callbacks.size(), 0);
|
||||
ASSERT_EQ(raylet_client->num_workers_returned, 0);
|
||||
ASSERT_EQ(raylet_client->num_workers_disconnected, 0);
|
||||
ASSERT_EQ(task_finisher->num_tasks_complete, 0);
|
||||
ASSERT_EQ(task_finisher->num_tasks_failed, 1);
|
||||
}
|
||||
|
||||
} // namespace ray
|
||||
|
||||
int main(int argc, char **argv) {
|
||||
|
||||
@@ -54,6 +54,7 @@ class MockWorker {
|
||||
nullptr, // check_signals
|
||||
nullptr, // gc_collect
|
||||
nullptr, // get_lang_stack
|
||||
nullptr, // kill_main
|
||||
true, // ref_counting_enabled
|
||||
false, // is_local_mode
|
||||
1, // num_workers
|
||||
|
||||
@@ -187,6 +187,31 @@ TEST_F(TaskManagerTest, TestTaskRetry) {
|
||||
ASSERT_EQ(reference_counter_->NumObjectIDsInScope(), 0);
|
||||
}
|
||||
|
||||
TEST_F(TaskManagerTest, TestTaskKill) {
|
||||
TaskID caller_id = TaskID::Nil();
|
||||
rpc::Address caller_address;
|
||||
ASSERT_EQ(reference_counter_->NumObjectIDsInScope(), 0);
|
||||
auto spec = CreateTaskHelper(1, {});
|
||||
ASSERT_FALSE(manager_.IsTaskPending(spec.TaskId()));
|
||||
int num_retries = 3;
|
||||
manager_.AddPendingTask(caller_id, caller_address, spec, "", num_retries);
|
||||
ASSERT_TRUE(manager_.IsTaskPending(spec.TaskId()));
|
||||
ASSERT_EQ(reference_counter_->NumObjectIDsInScope(), 1);
|
||||
auto return_id = spec.ReturnId(0, TaskTransportType::DIRECT);
|
||||
WorkerContext ctx(WorkerType::WORKER, WorkerID::FromRandom(), JobID::FromInt(0));
|
||||
|
||||
manager_.MarkTaskCanceled(spec.TaskId());
|
||||
auto error = rpc::ErrorType::TASK_CANCELLED;
|
||||
manager_.PendingTaskFailed(spec.TaskId(), error);
|
||||
ASSERT_FALSE(manager_.IsTaskPending(spec.TaskId()));
|
||||
std::vector<std::shared_ptr<RayObject>> results;
|
||||
RAY_CHECK_OK(store_->Get({return_id}, 1, 0, ctx, false, &results));
|
||||
ASSERT_EQ(results.size(), 1);
|
||||
rpc::ErrorType stored_error;
|
||||
ASSERT_TRUE(results[0]->IsException(&stored_error));
|
||||
ASSERT_EQ(stored_error, error);
|
||||
}
|
||||
|
||||
// Test to make sure that the task spec and dependencies for an object are
|
||||
// evicted when lineage pinning is disabled in the ReferenceCounter.
|
||||
TEST_F(TaskManagerTest, TestLineageEvicted) {
|
||||
|
||||
@@ -23,17 +23,17 @@ Status CoreWorkerDirectTaskSubmitter::SubmitTask(TaskSpecification task_spec) {
|
||||
resolver_.ResolveDependencies(task_spec, [this, task_spec]() {
|
||||
RAY_LOG(DEBUG) << "Task dependencies resolved " << task_spec.TaskId();
|
||||
if (actor_create_callback_ && task_spec.IsActorCreationTask()) {
|
||||
// If gcs actor management is enabled, the actor creation task will be sent to gcs
|
||||
// server directly after the in-memory dependent objects are resolved.
|
||||
// For more details please see the protocol of actor management based on gcs.
|
||||
// If gcs actor management is enabled, the actor creation task will be sent to
|
||||
// gcs server directly after the in-memory dependent objects are resolved. For
|
||||
// more details please see the protocol of actor management based on gcs.
|
||||
// https://docs.google.com/document/d/1EAWide-jy05akJp6OMtDn58XOK7bUyruWMia4E-fV28/edit?usp=sharing
|
||||
auto actor_id = task_spec.ActorCreationId();
|
||||
auto task_id = task_spec.TaskId();
|
||||
RAY_LOG(INFO) << "Submitting actor creation task to GCS: " << actor_id;
|
||||
auto status =
|
||||
actor_create_callback_(task_spec, [this, actor_id, task_id](Status status) {
|
||||
// If GCS is failed, GcsRpcClient may receive IOError status but it will not
|
||||
// trigger this callback, because GcsRpcClient has retry logic at the
|
||||
// If GCS is failed, GcsRpcClient may receive IOError status but it will
|
||||
// not trigger this callback, because GcsRpcClient has retry logic at the
|
||||
// bottom. So if this callback is invoked with an error there must be
|
||||
// something wrong with the protocol of gcs-based actor management.
|
||||
// So just check `status.ok()` here.
|
||||
@@ -46,18 +46,33 @@ Status CoreWorkerDirectTaskSubmitter::SubmitTask(TaskSpecification task_spec) {
|
||||
return;
|
||||
}
|
||||
|
||||
absl::MutexLock lock(&mu_);
|
||||
// Note that the dependencies in the task spec are mutated to only contain
|
||||
// plasma dependencies after ResolveDependencies finishes.
|
||||
const SchedulingKey scheduling_key(
|
||||
task_spec.GetSchedulingClass(), task_spec.GetDependencies(),
|
||||
task_spec.IsActorCreationTask() ? task_spec.ActorCreationId() : ActorID::Nil());
|
||||
auto it = task_queues_.find(scheduling_key);
|
||||
if (it == task_queues_.end()) {
|
||||
it = task_queues_.emplace(scheduling_key, std::deque<TaskSpecification>()).first;
|
||||
bool keep_executing = true;
|
||||
{
|
||||
absl::MutexLock lock(&mu_);
|
||||
if (cancelled_tasks_.find(task_spec.TaskId()) != cancelled_tasks_.end()) {
|
||||
cancelled_tasks_.erase(task_spec.TaskId());
|
||||
keep_executing = false;
|
||||
}
|
||||
if (keep_executing) {
|
||||
// Note that the dependencies in the task spec are mutated to only contain
|
||||
// plasma dependencies after ResolveDependencies finishes.
|
||||
const SchedulingKey scheduling_key(
|
||||
task_spec.GetSchedulingClass(), task_spec.GetDependencies(),
|
||||
task_spec.IsActorCreationTask() ? task_spec.ActorCreationId()
|
||||
: ActorID::Nil());
|
||||
auto it = task_queues_.find(scheduling_key);
|
||||
if (it == task_queues_.end()) {
|
||||
it =
|
||||
task_queues_.emplace(scheduling_key, std::deque<TaskSpecification>()).first;
|
||||
}
|
||||
it->second.push_back(task_spec);
|
||||
RequestNewWorkerIfNeeded(scheduling_key);
|
||||
}
|
||||
}
|
||||
if (!keep_executing) {
|
||||
task_finisher_->PendingTaskFailed(task_spec.TaskId(),
|
||||
rpc::ErrorType::TASK_CANCELLED, nullptr);
|
||||
}
|
||||
it->second.push_back(task_spec);
|
||||
RequestNewWorkerIfNeeded(scheduling_key);
|
||||
});
|
||||
return Status::OK();
|
||||
}
|
||||
@@ -93,8 +108,9 @@ void CoreWorkerDirectTaskSubmitter::OnWorkerIdle(
|
||||
worker_to_lease_client_.erase(addr);
|
||||
} else {
|
||||
auto &client = *client_cache_[addr];
|
||||
PushNormalTask(addr, client, scheduling_key, queue_entry->second.front(),
|
||||
assigned_resources);
|
||||
auto task_spec = queue_entry->second.front();
|
||||
PushNormalTask(addr, client, scheduling_key, task_spec, assigned_resources);
|
||||
executing_tasks_.emplace(task_spec.TaskId(), addr);
|
||||
queue_entry->second.pop_front();
|
||||
// Delete the queue if it's now empty. Note that the queue cannot already be empty
|
||||
// because this is the only place tasks are removed from it.
|
||||
@@ -248,6 +264,10 @@ void CoreWorkerDirectTaskSubmitter::PushNormalTask(
|
||||
std::move(request),
|
||||
[this, task_id, is_actor, is_actor_creation, scheduling_key, addr,
|
||||
assigned_resources](Status status, const rpc::PushTaskReply &reply) {
|
||||
{
|
||||
absl::MutexLock lock(&mu_);
|
||||
executing_tasks_.erase(task_id);
|
||||
}
|
||||
if (reply.worker_exiting()) {
|
||||
// The worker is draining and will shutdown after it is done. Don't return
|
||||
// it to the Raylet since that will kill it early.
|
||||
@@ -273,4 +293,78 @@ void CoreWorkerDirectTaskSubmitter::PushNormalTask(
|
||||
}
|
||||
}));
|
||||
}
|
||||
|
||||
Status CoreWorkerDirectTaskSubmitter::CancelTask(TaskSpecification task_spec,
|
||||
bool force_kill) {
|
||||
RAY_LOG(INFO) << "Killing task: " << task_spec.TaskId();
|
||||
const SchedulingKey scheduling_key(
|
||||
task_spec.GetSchedulingClass(), task_spec.GetDependencies(),
|
||||
task_spec.IsActorCreationTask() ? task_spec.ActorCreationId() : ActorID::Nil());
|
||||
std::shared_ptr<rpc::CoreWorkerClientInterface> client = nullptr;
|
||||
{
|
||||
absl::MutexLock lock(&mu_);
|
||||
if (cancelled_tasks_.find(task_spec.TaskId()) != cancelled_tasks_.end() ||
|
||||
!task_finisher_->MarkTaskCanceled(task_spec.TaskId())) {
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
auto scheduled_tasks = task_queues_.find(scheduling_key);
|
||||
// This cancels tasks that have completed dependencies and are awaiting
|
||||
// a worker lease.
|
||||
if (scheduled_tasks != task_queues_.end()) {
|
||||
for (auto spec = scheduled_tasks->second.begin();
|
||||
spec != scheduled_tasks->second.end(); spec++) {
|
||||
if (spec->TaskId() == task_spec.TaskId()) {
|
||||
scheduled_tasks->second.erase(spec);
|
||||
|
||||
if (scheduled_tasks->second.empty()) {
|
||||
task_queues_.erase(scheduling_key);
|
||||
CancelWorkerLeaseIfNeeded(scheduling_key);
|
||||
}
|
||||
task_finisher_->PendingTaskFailed(task_spec.TaskId(),
|
||||
rpc::ErrorType::TASK_CANCELLED);
|
||||
return Status::OK();
|
||||
}
|
||||
}
|
||||
}
|
||||
// This will get removed either when the RPC call to cancel is returned
|
||||
// or when all dependencies are resolved.
|
||||
RAY_CHECK(cancelled_tasks_.emplace(task_spec.TaskId()).second);
|
||||
auto rpc_client = executing_tasks_.find(task_spec.TaskId());
|
||||
// Looks for an RPC handle for the worker executing the task.
|
||||
if (rpc_client != executing_tasks_.end() &&
|
||||
client_cache_.find(rpc_client->second) != client_cache_.end()) {
|
||||
client = client_cache_.find(rpc_client->second)->second;
|
||||
}
|
||||
}
|
||||
|
||||
// This case is reached for tasks that have unresolved dependencies.
|
||||
if (client == nullptr) {
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
auto request = rpc::CancelTaskRequest();
|
||||
request.set_intended_task_id(task_spec.TaskId().Binary());
|
||||
request.set_force_kill(force_kill);
|
||||
RAY_UNUSED(client->CancelTask(
|
||||
request, [this, task_spec, force_kill](const Status &status,
|
||||
const rpc::CancelTaskReply &reply) {
|
||||
absl::MutexLock lock(&mu_);
|
||||
cancelled_tasks_.erase(task_spec.TaskId());
|
||||
if (status.ok() && !reply.attempt_succeeded()) {
|
||||
if (cancel_retry_timer_.has_value()) {
|
||||
if (cancel_retry_timer_->expiry().time_since_epoch() <=
|
||||
std::chrono::high_resolution_clock::now().time_since_epoch()) {
|
||||
cancel_retry_timer_->expires_after(boost::asio::chrono::milliseconds(
|
||||
RayConfig::instance().cancellation_retry_ms()));
|
||||
}
|
||||
cancel_retry_timer_->async_wait(boost::bind(
|
||||
&CoreWorkerDirectTaskSubmitter::CancelTask, this, task_spec, force_kill));
|
||||
}
|
||||
}
|
||||
// Retry is not attempted if !status.ok() because force-kill may kill the worker
|
||||
// before the reply is sent.
|
||||
}));
|
||||
return Status::OK();
|
||||
}
|
||||
}; // namespace ray
|
||||
|
||||
@@ -56,7 +56,8 @@ class CoreWorkerDirectTaskSubmitter {
|
||||
std::shared_ptr<TaskFinisherInterface> task_finisher, ClientID local_raylet_id,
|
||||
int64_t lease_timeout_ms,
|
||||
std::function<Status(const TaskSpecification &, const gcs::StatusCallback &)>
|
||||
actor_create_callback = nullptr)
|
||||
actor_create_callback = nullptr,
|
||||
absl::optional<boost::asio::steady_timer> cancel_timer = absl::nullopt)
|
||||
: rpc_address_(rpc_address),
|
||||
local_lease_client_(lease_client),
|
||||
client_factory_(client_factory),
|
||||
@@ -65,13 +66,20 @@ class CoreWorkerDirectTaskSubmitter {
|
||||
task_finisher_(task_finisher),
|
||||
lease_timeout_ms_(lease_timeout_ms),
|
||||
local_raylet_id_(local_raylet_id),
|
||||
actor_create_callback_(std::move(actor_create_callback)) {}
|
||||
actor_create_callback_(std::move(actor_create_callback)),
|
||||
cancel_retry_timer_(std::move(cancel_timer)) {}
|
||||
|
||||
/// Schedule a task for direct submission to a worker.
|
||||
///
|
||||
/// \param[in] task_spec The task to schedule.
|
||||
Status SubmitTask(TaskSpecification task_spec);
|
||||
|
||||
/// Either remove a pending task or send an RPC to kill a running task
|
||||
///
|
||||
/// \param[in] task_spec The task to kill.
|
||||
/// \param[in] force_kill Whether to kill the worker executing the task.
|
||||
Status CancelTask(TaskSpecification task_spec, bool force_kill);
|
||||
|
||||
private:
|
||||
/// Schedule more work onto an idle worker or return it back to the raylet if
|
||||
/// no more tasks are queued for submission. If an error was encountered
|
||||
@@ -180,6 +188,15 @@ class CoreWorkerDirectTaskSubmitter {
|
||||
// Invariant: if a queue is in this map, it has at least one task.
|
||||
absl::flat_hash_map<SchedulingKey, std::deque<TaskSpecification>> task_queues_
|
||||
GUARDED_BY(mu_);
|
||||
|
||||
// Tasks that were cancelled while being resolved.
|
||||
absl::flat_hash_set<TaskID> cancelled_tasks_ GUARDED_BY(mu_);
|
||||
|
||||
// Keeps track of where currently executing tasks are being run.
|
||||
absl::flat_hash_map<TaskID, rpc::WorkerAddress> executing_tasks_ GUARDED_BY(mu_);
|
||||
|
||||
// Retries cancelation requests if they were not successful.
|
||||
absl::optional<boost::asio::steady_timer> cancel_retry_timer_;
|
||||
};
|
||||
|
||||
}; // namespace ray
|
||||
|
||||
@@ -111,7 +111,8 @@ message PushTaskRequest {
|
||||
// caller, which might happen theoretically when network has issues.
|
||||
// - For an actor, this is set to the timestamp when the actor is created,
|
||||
// so it can be used to differentiate which is the new reconstructed actor.
|
||||
// - For a non-actor task, it's set to the timestamp the task starts execution.
|
||||
// - For a non-actor task, it's set to the timestamp the task starts
|
||||
// execution.
|
||||
int64 caller_version = 7;
|
||||
}
|
||||
|
||||
@@ -121,9 +122,11 @@ message PushTaskReply {
|
||||
// Set to true if the worker will be exiting.
|
||||
bool worker_exiting = 2;
|
||||
// The references that the worker borrowed during the task execution. A
|
||||
// borrower is a process that is currently using the object ID, in one of 3 ways:
|
||||
// borrower is a process that is currently using the object ID, in one of 3
|
||||
// ways:
|
||||
// 1. Has an ObjectID copy in Python.
|
||||
// 2. Has submitted a task that depends on the object and that is still pending.
|
||||
// 2. Has submitted a task that depends on the object and that is still
|
||||
// pending.
|
||||
// 3. Owns another object that is in scope and whose value contains the
|
||||
// ObjectID.
|
||||
// This list includes the reference counts for any IDs that were passed to
|
||||
@@ -156,9 +159,7 @@ message GetObjectStatusRequest {
|
||||
}
|
||||
|
||||
message GetObjectStatusReply {
|
||||
enum ObjectStatus {
|
||||
CREATED = 0;
|
||||
}
|
||||
enum ObjectStatus { CREATED = 0; }
|
||||
ObjectStatus status = 1;
|
||||
}
|
||||
|
||||
@@ -184,6 +185,18 @@ message KillActorRequest {
|
||||
message KillActorReply {
|
||||
}
|
||||
|
||||
message CancelTaskRequest {
|
||||
// ID of task that should be killed.
|
||||
bytes intended_task_id = 1;
|
||||
// Whether to kill the worker.
|
||||
bool force_kill = 2;
|
||||
}
|
||||
|
||||
message CancelTaskReply {
|
||||
// Whether the requested task is the currently running task.
|
||||
bool attempt_succeeded = 1;
|
||||
}
|
||||
|
||||
message GetCoreWorkerStatsRequest {
|
||||
// The ID of the worker this message is intended for.
|
||||
bytes intended_worker_id = 1;
|
||||
@@ -281,6 +294,8 @@ service CoreWorkerService {
|
||||
returns (WaitForObjectEvictionReply);
|
||||
// Request that the worker shut down without completing outstanding work.
|
||||
rpc KillActor(KillActorRequest) returns (KillActorReply);
|
||||
// Request that a worker cancels a task.
|
||||
rpc CancelTask(CancelTaskRequest) returns (CancelTaskReply);
|
||||
// Get metrics from core workers.
|
||||
rpc GetCoreWorkerStats(GetCoreWorkerStatsRequest) returns (GetCoreWorkerStatsReply);
|
||||
// Wait for a borrower to finish using an object. Sent by the object's owner.
|
||||
|
||||
@@ -340,4 +340,7 @@ enum ErrorType {
|
||||
// exposed to user code; it is only used internally to indicate the result of a direct
|
||||
// call has been placed in plasma.
|
||||
OBJECT_IN_PLASMA = 4;
|
||||
|
||||
// Indicates that an object has been cancelled.
|
||||
TASK_CANCELLED = 5;
|
||||
}
|
||||
|
||||
@@ -157,6 +157,11 @@ class CoreWorkerClientInterface {
|
||||
return Status::NotImplemented("");
|
||||
}
|
||||
|
||||
virtual ray::Status CancelTask(const CancelTaskRequest &request,
|
||||
const ClientCallback<CancelTaskReply> &callback) {
|
||||
return Status::NotImplemented("");
|
||||
}
|
||||
|
||||
virtual ray::Status GetCoreWorkerStats(
|
||||
const GetCoreWorkerStatsRequest &request,
|
||||
const ClientCallback<GetCoreWorkerStatsReply> &callback) {
|
||||
@@ -210,6 +215,8 @@ class CoreWorkerClient : public std::enable_shared_from_this<CoreWorkerClient>,
|
||||
|
||||
RPC_CLIENT_METHOD(CoreWorkerService, KillActor, grpc_client_, override)
|
||||
|
||||
RPC_CLIENT_METHOD(CoreWorkerService, CancelTask, grpc_client_, override)
|
||||
|
||||
RPC_CLIENT_METHOD(CoreWorkerService, WaitForObjectEviction, grpc_client_, override)
|
||||
|
||||
RPC_CLIENT_METHOD(CoreWorkerService, GetCoreWorkerStats, grpc_client_, override)
|
||||
|
||||
@@ -17,7 +17,6 @@
|
||||
|
||||
#include "ray/rpc/grpc_server.h"
|
||||
#include "ray/rpc/server_call.h"
|
||||
|
||||
#include "src/ray/protobuf/core_worker.grpc.pb.h"
|
||||
#include "src/ray/protobuf/core_worker.pb.h"
|
||||
|
||||
@@ -36,6 +35,7 @@ namespace rpc {
|
||||
RPC_SERVICE_HANDLER(CoreWorkerService, WaitForObjectEviction) \
|
||||
RPC_SERVICE_HANDLER(CoreWorkerService, WaitForRefRemoved) \
|
||||
RPC_SERVICE_HANDLER(CoreWorkerService, KillActor) \
|
||||
RPC_SERVICE_HANDLER(CoreWorkerService, CancelTask) \
|
||||
RPC_SERVICE_HANDLER(CoreWorkerService, GetCoreWorkerStats) \
|
||||
RPC_SERVICE_HANDLER(CoreWorkerService, LocalGC) \
|
||||
RPC_SERVICE_HANDLER(CoreWorkerService, PlasmaObjectReady)
|
||||
@@ -48,6 +48,7 @@ namespace rpc {
|
||||
DECLARE_VOID_RPC_SERVICE_HANDLER_METHOD(WaitForObjectEviction) \
|
||||
DECLARE_VOID_RPC_SERVICE_HANDLER_METHOD(WaitForRefRemoved) \
|
||||
DECLARE_VOID_RPC_SERVICE_HANDLER_METHOD(KillActor) \
|
||||
DECLARE_VOID_RPC_SERVICE_HANDLER_METHOD(CancelTask) \
|
||||
DECLARE_VOID_RPC_SERVICE_HANDLER_METHOD(GetCoreWorkerStats) \
|
||||
DECLARE_VOID_RPC_SERVICE_HANDLER_METHOD(LocalGC) \
|
||||
DECLARE_VOID_RPC_SERVICE_HANDLER_METHOD(PlasmaObjectReady)
|
||||
|
||||
@@ -313,6 +313,7 @@ class StreamingWorker {
|
||||
nullptr, // check_signals
|
||||
nullptr, // gc_collect
|
||||
nullptr, // get_lang_stack
|
||||
nullptr, // kill_main
|
||||
true, // ref_counting_enabled
|
||||
false, // is_local_mode
|
||||
1, // num_workers
|
||||
|
||||
@@ -326,6 +326,7 @@ class StreamingQueueTestBase : public ::testing::TestWithParam<uint64_t> {
|
||||
nullptr, // check_signals
|
||||
nullptr, // gc_collect
|
||||
nullptr, // get_lang_stack
|
||||
nullptr, // kill_main
|
||||
true, // ref_counting_enabled
|
||||
false, // is_local_mode
|
||||
1, // num_workers
|
||||
|
||||
Reference in New Issue
Block a user