TaskCancellation (#7669)

* Smol comment

* WIP, not passing ray.init

* Fixed small problem

* wip

* Pseudo interrupt things

* Basic prototype operational

* correct proc title

* Mostly done

* Cleanup

* cleaner raylet error

* Cleaning up a few loose ends

* Fixing Race Conds

* Prelim testing

* Fixing comments and adding second_check for kill

* Working_new_impl

* demo_ready

* Fixing my english

* Fixing a few problems

* Small problems

* Cleaning up

* Response to changes

* Fixing error passing

* Merged to master

* fixing lock

* Cleaning up print statements

* Format

* Fixing Unit test build failure

* mock_worker fix

* java_fix

* Canel

* Switching to Cancel

* Responding to Review

* FixFormatting

* Lease cancellation

* FInal comments?

* Moving exist check to CoreWorker

* Fix Actor Transport Test

* Fixing task manager test

* chaning clock repr

* Fix build

* fix white space

* lint fix

* Updating to medium size

* Fixing Java test compilation issue

* lengthen bad timeouts
This commit is contained in:
ijrsvt
2020-04-25 16:04:52 -07:00
committed by GitHub
parent 9dd3490c38
commit 69ff7e3e35
29 changed files with 731 additions and 38 deletions
+2
View File
@@ -66,6 +66,7 @@ from ray.worker import (
LOCAL_MODE,
SCRIPT_MODE,
WORKER_MODE,
cancel,
connect,
disconnect,
get,
@@ -113,6 +114,7 @@ __all__ = [
"_config",
"_get_runtime_context",
"actor",
"cancel",
"connect",
"disconnect",
"get",
+39 -6
View File
@@ -17,6 +17,8 @@ import logging
import os
import pickle
import sys
import _thread
import setproctitle
from libc.stdint cimport (
int32_t,
@@ -90,6 +92,7 @@ from ray.exceptions import (
RayTaskError,
ObjectStoreFullError,
RayTimeoutError,
RayCancellationError
)
from ray.utils import decode
import gc
@@ -453,13 +456,23 @@ cdef execute_task(
class_name, repr(args), repr(kwargs))
core_worker.set_actor_title(actor_title.encode("utf-8"))
# Execute the task.
with ray.worker._changeproctitle(title, next_title):
with core_worker.profile_event(b"task:execute"):
task_exception = True
outputs = function_executor(*args, **kwargs)
with core_worker.profile_event(b"task:execute"):
task_exception = True
try:
with ray.worker._changeproctitle(title, next_title):
outputs = function_executor(*args, **kwargs)
task_exception = False
if c_return_ids.size() == 1:
outputs = (outputs,)
except KeyboardInterrupt as e:
raise RayCancellationError(
core_worker.get_current_task_id())
if c_return_ids.size() == 1:
outputs = (outputs,)
# Check for a cancellation that was called when the function
# was exiting and was raised after the except block.
if not check_signals().ok():
task_exception = True
raise RayCancellationError(
core_worker.get_current_task_id())
# Store the outputs in the object store.
with core_worker.profile_event(b"task:store_outputs"):
core_worker.store_task_outputs(
@@ -551,6 +564,14 @@ cdef void async_plasma_callback(CObjectID object_id,
event_handler._loop.call_soon_threadsafe(
event_handler._complete_future, obj_id)
cdef c_bool kill_main_task() nogil:
with gil:
if setproctitle.getproctitle() != "ray::IDLE":
_thread.interrupt_main()
return True
return False
cdef CRayStatus check_signals() nogil:
with gil:
try:
@@ -658,6 +679,7 @@ cdef class CoreWorker:
options.ref_counting_enabled = True
options.is_local_mode = local_mode
options.num_workers = 1
options.kill_main = kill_main_task
CCoreWorkerProcess.Initialize(options)
@@ -953,6 +975,17 @@ cdef class CoreWorker:
check_status(CCoreWorkerProcess.GetCoreWorker().KillActor(
c_actor_id, True, no_reconstruction))
def cancel_task(self, ObjectID object_id, c_bool force_kill):
cdef:
CObjectID c_object_id = object_id.native()
CRayStatus status = CRayStatus.OK()
status = CCoreWorkerProcess.GetCoreWorker().CancelTask(
c_object_id, force_kill)
if not status.ok():
raise TypeError(status.message().decode())
def resource_ids(self):
cdef:
ResourceMappingType resource_mapping = (
+17
View File
@@ -16,6 +16,23 @@ class RayConnectionError(RayError):
pass
class RayCancellationError(RayError):
"""Raised when this task is cancelled.
Attributes:
task_id (TaskID): The TaskID of the function that was directly
cancelled.
"""
def __init__(self, task_id=None):
self.task_id = task_id
def __str__(self):
if self.task_id is None:
return "This task or its dependency was cancelled by"
return "Task: " + str(self.task_id) + " was cancelled"
class RayTaskError(RayError):
"""Indicates that a task threw an exception during execution.
+2
View File
@@ -97,6 +97,7 @@ cdef extern from "ray/core_worker/core_worker.h" nogil:
CRayStatus KillActor(
const CActorID &actor_id, c_bool force_kill,
c_bool no_reconstruction)
CRayStatus CancelTask(const CObjectID &object_id, c_bool force_kill)
unique_ptr[CProfileEvent] CreateProfileEvent(
const c_string &event_type)
@@ -214,6 +215,7 @@ cdef extern from "ray/core_worker/core_worker.h" nogil:
c_bool ref_counting_enabled
c_bool is_local_mode
int num_workers
(c_bool() nogil) kill_main
CCoreWorkerOptions()
cdef cppclass CCoreWorkerProcess "ray::CoreWorkerProcess":
+2
View File
@@ -122,6 +122,8 @@ cdef extern from "ray/common/id.h" namespace "ray" nogil:
CTaskID ForNormalTask(CJobID job_id, CTaskID parent_task_id,
int64_t parent_task_counter)
CActorID ActorId() const
cdef cppclass CObjectID" ray::ObjectID"(CBaseID[CObjectID]):
@staticmethod
+3
View File
@@ -220,6 +220,9 @@ cdef class TaskID(BaseID):
def is_nil(self):
return self.data.IsNil()
def actor_id(self):
return ActorID(self.data.ActorId().Binary())
cdef size_t hash(self):
return self.data.Hash()
+3
View File
@@ -12,6 +12,7 @@ from ray.exceptions import (
PlasmaObjectNotAvailable,
RayTaskError,
RayActorError,
RayCancellationError,
RayWorkerError,
UnreconstructableError,
)
@@ -279,6 +280,8 @@ class SerializationContext:
return RayWorkerError()
elif error_type == ErrorType.Value("ACTOR_DIED"):
return RayActorError()
elif error_type == ErrorType.Value("TASK_CANCELLED"):
return RayCancellationError()
elif error_type == ErrorType.Value("OBJECT_UNRECONSTRUCTABLE"):
return UnreconstructableError(ray.ObjectID(object_id.binary()))
else:
+8
View File
@@ -414,3 +414,11 @@ py_test(
tags = ["exclusive"],
deps = ["//:ray_lib"],
)
py_test(
name = "test_cancel",
size = "medium",
srcs = ["test_cancel.py"],
tags = ["exclusive"],
deps = ["//:ray_lib"],
)
+232
View File
@@ -0,0 +1,232 @@
import pytest
import ray
import random
import sys
import time
from ray.exceptions import RayTaskError, RayTimeoutError, \
RayCancellationError, RayWorkerError
from ray.test_utils import SignalActor
def valid_exceptions(use_force):
if use_force:
return (RayTaskError, RayCancellationError, RayWorkerError)
else:
return (RayTaskError, RayCancellationError)
@pytest.mark.parametrize("use_force", [True, False])
def test_cancel_chain(ray_start_regular, use_force):
signaler = SignalActor.remote()
@ray.remote
def wait_for(t):
return ray.get(t[0])
obj1 = wait_for.remote([signaler.wait.remote()])
obj2 = wait_for.remote([obj1])
obj3 = wait_for.remote([obj2])
obj4 = wait_for.remote([obj3])
assert len(ray.wait([obj1], timeout=.1)[0]) == 0
ray.cancel(obj1, use_force)
for ob in [obj1, obj2, obj3, obj4]:
with pytest.raises(valid_exceptions(use_force)):
ray.get(ob)
signaler2 = SignalActor.remote()
obj1 = wait_for.remote([signaler2.wait.remote()])
obj2 = wait_for.remote([obj1])
obj3 = wait_for.remote([obj2])
obj4 = wait_for.remote([obj3])
assert len(ray.wait([obj3], timeout=.1)[0]) == 0
ray.cancel(obj3, use_force)
for ob in [obj3, obj4]:
with pytest.raises(valid_exceptions(use_force)):
ray.get(ob)
with pytest.raises(RayTimeoutError):
ray.get(obj1, timeout=.1)
with pytest.raises(RayTimeoutError):
ray.get(obj2, timeout=.1)
signaler2.send.remote()
ray.get(obj1, timeout=10)
@pytest.mark.parametrize("use_force", [True, False])
def test_cancel_multiple_dependents(ray_start_regular, use_force):
signaler = SignalActor.remote()
@ray.remote
def wait_for(t):
return ray.get(t[0])
head = wait_for.remote([signaler.wait.remote()])
deps = []
for _ in range(3):
deps.append(wait_for.remote([head]))
assert len(ray.wait([head], timeout=.1)[0]) == 0
ray.cancel(head, use_force)
for d in deps:
with pytest.raises(valid_exceptions(use_force)):
ray.get(d)
head2 = wait_for.remote([signaler.wait.remote()])
deps2 = []
for _ in range(3):
deps2.append(wait_for.remote([head]))
for d in deps2:
ray.cancel(d, use_force)
for d in deps2:
with pytest.raises(valid_exceptions(use_force)):
ray.get(d)
signaler.send.remote()
ray.get(head2, timeout=1)
@pytest.mark.parametrize("use_force", [True, False])
def test_single_cpu_cancel(shutdown_only, use_force):
ray.init(num_cpus=1)
signaler = SignalActor.remote()
@ray.remote
def wait_for(t):
return ray.get(t[0])
obj1 = wait_for.remote([signaler.wait.remote()])
obj2 = wait_for.remote([obj1])
obj3 = wait_for.remote([obj2])
indep = wait_for.remote([signaler.wait.remote()])
assert len(ray.wait([obj3], timeout=.1)[0]) == 0
ray.cancel(obj3, use_force)
with pytest.raises(valid_exceptions(use_force)):
ray.get(obj3, 10)
ray.cancel(obj1, use_force)
for d in [obj1, obj2]:
with pytest.raises(valid_exceptions(use_force)):
ray.get(d)
signaler.send.remote()
ray.get(indep)
@pytest.mark.parametrize("use_force", [True, False])
def test_comprehensive(ray_start_regular, use_force):
signaler = SignalActor.remote()
@ray.remote
def wait_for(t):
ray.get(t[0])
return "Result"
@ray.remote
def combine(a, b):
return str(a) + str(b)
a = wait_for.remote([signaler.wait.remote()])
b = wait_for.remote([signaler.wait.remote()])
combo = combine.remote(a, b)
a2 = wait_for.remote([a])
assert len(ray.wait([a, b, a2, combo], timeout=1)[0]) == 0
ray.cancel(a, use_force)
with pytest.raises(valid_exceptions(use_force)):
ray.get(a, 10)
with pytest.raises(valid_exceptions(use_force)):
ray.get(a2, 10)
signaler.send.remote()
with pytest.raises(valid_exceptions(use_force)):
ray.get(combo, 10)
@pytest.mark.parametrize("use_force", [True, False])
def test_stress(shutdown_only, use_force):
ray.init(num_cpus=1)
@ray.remote
def infinite_sleep(y):
if y:
while True:
time.sleep(1 / 10)
first = infinite_sleep.remote(True)
sleep_or_no = [random.randint(0, 1) for _ in range(100)]
tasks = [infinite_sleep.remote(i) for i in sleep_or_no]
cancelled = set()
for t in tasks:
if random.random() > 0.5:
ray.cancel(t, use_force)
cancelled.add(t)
ray.cancel(first, use_force)
cancelled.add(first)
for done in cancelled:
with pytest.raises(valid_exceptions(use_force)):
ray.get(done, 10)
for indx in range(len(tasks)):
t = tasks[indx]
if sleep_or_no[indx]:
ray.cancel(t, use_force)
cancelled.add(t)
if t in cancelled:
with pytest.raises(valid_exceptions(use_force)):
ray.get(t, 10)
else:
ray.get(t)
@pytest.mark.parametrize("use_force", [True, False])
def test_fast(shutdown_only, use_force):
ray.init(num_cpus=2)
@ray.remote
def fast(y):
return y
signaler = SignalActor.remote()
ids = list()
for _ in range(100):
x = fast.remote("a")
ray.cancel(x)
ids.append(x)
@ray.remote
def wait_for(y):
return y
sig = signaler.wait.remote()
for _ in range(5000):
x = wait_for.remote(sig)
ids.append(x)
for idx in range(100, 5100):
if random.random() > 0.95:
ray.cancel(ids[idx])
signaler.send.remote()
for obj_id in ids:
try:
ray.get(obj_id, 10)
except Exception as e:
assert isinstance(e, valid_exceptions(use_force))
if __name__ == "__main__":
sys.exit(pytest.main(["-v", __file__]))
+22 -1
View File
@@ -1661,12 +1661,33 @@ def kill(actor):
if not isinstance(actor, ray.actor.ActorHandle):
raise ValueError("ray.kill() only supported for actors. "
"Got: {}.".format(type(actor)))
worker = ray.worker.global_worker
worker.check_connected()
worker.core_worker.kill_actor(actor._ray_actor_id, False)
def cancel(object_id, force=False):
"""Kill a task forcefully.
This will interrupt any running tasks on the actor, causing them to fail
immediately. Any atexit handlers installed in the actor will still be run.
If this actor is reconstructable, it will be attempted to be reconstructed.
Args:
id (ActorHandle or ObjectID): Handle for the actor to kill or ObjectID
of the task to kill.
"""
worker = ray.worker.global_worker
worker.check_connected()
if not isinstance(object_id, ray.ObjectID):
raise TypeError(
"ray.cancel() only supported for non-actor object IDs. "
"Got: {}.".format(type(object_id)))
return worker.core_worker.cancel_task(object_id, force)
def _mode(worker=global_worker):
"""This is a wrapper around worker.mode.
+3
View File
@@ -278,6 +278,9 @@ RAY_CONFIG(uint32_t, object_store_full_initial_delay_ms, 1000)
/// Duration to wait between retries for failed tasks.
RAY_CONFIG(uint32_t, task_retry_delay_ms, 5000)
/// Duration to wait between retrying to kill a task.
RAY_CONFIG(uint32_t, cancellation_retry_ms, 2000)
/// Whether to enable gcs service.
/// RAY_GCS_SERVICE_ENABLED is an env variable which only set in ci job.
/// If the value of RAY_GCS_SERVICE_ENABLED is false, we will disable gcs service,
+49 -2
View File
@@ -429,7 +429,7 @@ CoreWorker::CoreWorker(const CoreWorkerOptions &options, const WorkerID &worker_
rpc_address_, local_raylet_client_, client_factory, raylet_client_factory,
memory_store_, task_manager_, local_raylet_id,
RayConfig::instance().worker_lease_timeout_milliseconds(),
std::move(actor_create_callback)));
std::move(actor_create_callback), boost::asio::steady_timer(io_service_)));
future_resolver_.reset(new FutureResolver(memory_store_, client_factory));
// Unfortunately the raylet client has to be constructed after the receivers.
if (direct_task_receiver_ != nullptr) {
@@ -590,10 +590,10 @@ const WorkerID &CoreWorker::GetWorkerID() const { return worker_context_.GetWork
void CoreWorker::SetCurrentTaskId(const TaskID &task_id) {
worker_context_.SetCurrentTaskId(task_id);
main_thread_task_id_ = task_id;
bool not_actor_task = false;
{
absl::MutexLock lock(&mutex_);
main_thread_task_id_ = task_id;
not_actor_task = actor_id_.IsNil();
}
if (not_actor_task && task_id.IsNil()) {
@@ -1056,6 +1056,7 @@ TaskID CoreWorker::GetCallerId() const {
if (!actor_id.IsNil()) {
caller_id = TaskID::ForActorCreationTask(actor_id);
} else {
absl::MutexLock lock(&mutex_);
caller_id = main_thread_task_id_;
}
return caller_id;
@@ -1203,6 +1204,25 @@ Status CoreWorker::SubmitActorTask(const ActorID &actor_id, const RayFunction &f
return status;
}
Status CoreWorker::CancelTask(const ObjectID &object_id, bool force_kill) {
ActorHandle *h = nullptr;
if (!object_id.CreatedByTask() ||
GetActorHandle(object_id.TaskId().ActorId(), &h).ok()) {
return Status::Invalid("Actor task cancellation is not supported.");
}
rpc::Address obj_addr;
if (!reference_counter_->GetOwner(object_id, nullptr, &obj_addr) ||
obj_addr.SerializeAsString() != rpc_address_.SerializeAsString()) {
return Status::Invalid("Task is not locally submitted.");
}
auto task_spec = task_manager_->GetTaskSpec(object_id.TaskId());
if (task_spec.has_value() && !task_spec.value().IsActorCreationTask()) {
return direct_task_submitter_->CancelTask(task_spec.value(), force_kill);
}
return Status::OK();
}
Status CoreWorker::KillActor(const ActorID &actor_id, bool force_kill,
bool no_reconstruction) {
ActorHandle *actor_handle = nullptr;
@@ -1735,6 +1755,33 @@ void CoreWorker::HandleWaitForRefRemoved(const rpc::WaitForRefRemovedRequest &re
owner_address, ref_removed_callback);
}
void CoreWorker::HandleCancelTask(const rpc::CancelTaskRequest &request,
rpc::CancelTaskReply *reply,
rpc::SendReplyCallback send_reply_callback) {
absl::MutexLock lock(&mutex_);
TaskID task_id = TaskID::FromBinary(request.intended_task_id());
bool success = main_thread_task_id_ == task_id;
// Try non-force kill
if (success && !request.force_kill()) {
RAY_LOG(INFO) << "Interrupting a running task " << main_thread_task_id_;
success = options_.kill_main();
}
reply->set_attempt_succeeded(success);
send_reply_callback(Status::OK(), nullptr, nullptr);
// Do force kill after reply callback sent
if (success && request.force_kill()) {
RAY_LOG(INFO) << "Force killing a worker running " << main_thread_task_id_;
RAY_IGNORE_EXPR(local_raylet_client_->Disconnect());
if (options_.log_dir != "") {
RayLog::ShutDownRayLog();
}
exit(1);
}
}
void CoreWorker::HandleKillActor(const rpc::KillActorRequest &request,
rpc::KillActorReply *reply,
rpc::SendReplyCallback send_reply_callback) {
+15 -2
View File
@@ -106,6 +106,8 @@ struct CoreWorkerOptions {
std::function<void()> gc_collect;
/// Language worker callback to get the current call stack.
std::function<void(std::string *)> get_lang_stack;
// Function that tries to interrupt the currently running Python thread.
std::function<bool()> kill_main;
/// Whether to enable object ref counting.
bool ref_counting_enabled;
/// Is local mode being used.
@@ -501,7 +503,7 @@ class CoreWorker : public rpc::CoreWorkerServiceHandler {
/// For actors, this is the current actor ID. To make sure that all caller
/// IDs have the same type, we embed the actor ID in a TaskID with the rest
/// of the bytes zeroed out.
TaskID GetCallerId() const;
TaskID GetCallerId() const LOCKS_EXCLUDED(mutex_);
/// Push an error to the relevant driver.
///
@@ -586,6 +588,12 @@ class CoreWorker : public rpc::CoreWorkerServiceHandler {
/// \param[out] Status
Status KillActor(const ActorID &actor_id, bool force_kill, bool no_reconstruction);
/// Stops the task associated with the given Object ID.
///
/// \param[in] object_id of the task to kill (must be a Non-Actor task)
/// \param[in] force_kill Whether to force kill a task by killing the worker.
/// \param[out] Status
Status CancelTask(const ObjectID &object_id, bool force_kill);
/// Decrease the reference count for this actor. Should be called by the
/// language frontend when a reference to the ActorHandle destroyed.
///
@@ -694,6 +702,11 @@ class CoreWorker : public rpc::CoreWorkerServiceHandler {
void HandleKillActor(const rpc::KillActorRequest &request, rpc::KillActorReply *reply,
rpc::SendReplyCallback send_reply_callback) override;
/// Implements gRPC server handler.
void HandleCancelTask(const rpc::CancelTaskRequest &request,
rpc::CancelTaskReply *reply,
rpc::SendReplyCallback send_reply_callback);
/// Implements gRPC server handler.
void HandlePlasmaObjectReady(const rpc::PlasmaObjectReadyRequest &request,
rpc::PlasmaObjectReadyReply *reply,
@@ -893,7 +906,7 @@ class CoreWorker : public rpc::CoreWorkerServiceHandler {
/// The ID of the current task being executed by the main thread. If there
/// are multiple threads, they will have a thread-local task ID stored in the
/// worker context.
TaskID main_thread_task_id_;
TaskID main_thread_task_id_ GUARDED_BY(mutex_);
// Flag indicating whether this worker has been shut down.
bool shutdown_ = false;
@@ -122,6 +122,7 @@ JNIEXPORT void JNICALL Java_io_ray_runtime_RayNativeRuntime_nativeInitialize(
nullptr, // check_signals
nullptr, // gc_collect
nullptr, // get_lang_stack
nullptr, // kill_main
false, // ref_counting_enabled
false, // is_local_mode
static_cast<int>(numWorkersPerProcess), // num_workers
+18
View File
@@ -412,6 +412,15 @@ void TaskManager::RemoveLineageReference(const ObjectID &object_id,
}
}
bool TaskManager::MarkTaskCanceled(const TaskID &task_id) {
absl::MutexLock lock(&mu_);
auto it = submissible_tasks_.find(task_id);
if (it != submissible_tasks_.end()) {
it->second.num_retries_left = 0;
}
return it != submissible_tasks_.end();
}
void TaskManager::MarkPendingTaskFailed(const TaskID &task_id,
const TaskSpecification &spec,
rpc::ErrorType error_type) {
@@ -432,4 +441,13 @@ void TaskManager::MarkPendingTaskFailed(const TaskID &task_id,
}
}
absl::optional<TaskSpecification> TaskManager::GetTaskSpec(const TaskID &task_id) const {
absl::MutexLock lock(&mu_);
auto it = submissible_tasks_.find(task_id);
if (it == submissible_tasks_.end()) {
return absl::optional<TaskSpecification>();
}
return it->second.spec;
}
} // namespace ray
+11
View File
@@ -39,6 +39,8 @@ class TaskFinisherInterface {
const std::vector<ObjectID> &inlined_dependency_ids,
const std::vector<ObjectID> &contained_ids) = 0;
virtual bool MarkTaskCanceled(const TaskID &task_id) = 0;
virtual ~TaskFinisherInterface() {}
};
@@ -129,6 +131,15 @@ class TaskManager : public TaskFinisherInterface, public TaskResubmissionInterfa
void OnTaskDependenciesInlined(const std::vector<ObjectID> &inlined_dependency_ids,
const std::vector<ObjectID> &contained_ids) override;
/// Set number of retries to zero for a task that is being canceled.
///
/// \param[in] task_id to cancel.
/// \return Whether the task was pending and was marked for cancellation.
bool MarkTaskCanceled(const TaskID &task_id);
/// Return the spec for a pending task.
absl::optional<TaskSpecification> GetTaskSpec(const TaskID &task_id) const;
/// Return whether this task can be submitted for execution.
///
/// \param[in] task_id ID of the task to query.
@@ -271,6 +271,7 @@ class CoreWorkerTest : public ::testing::Test {
nullptr, // check_signals
nullptr, // gc_collect
nullptr, // get_lang_stack
nullptr, // kill_main
true, // ref_counting_enabled
false, // is_local_mode
1, // num_workers
@@ -64,6 +64,8 @@ class MockTaskFinisher : public TaskFinisherInterface {
MOCK_METHOD2(OnTaskDependenciesInlined,
void(const std::vector<ObjectID> &, const std::vector<ObjectID> &));
MOCK_METHOD1(MarkTaskCanceled, bool(const TaskID &task_id));
};
TaskSpecification CreateActorTaskHelper(ActorID actor_id, int64_t counter) {
@@ -52,7 +52,15 @@ class MockWorkerClient : public rpc::CoreWorkerClientInterface {
return true;
}
ray::Status CancelTask(
const rpc::CancelTaskRequest &request,
const rpc::ClientCallback<rpc::CancelTaskReply> &callback) override {
kill_requests.push_front(request);
return Status::OK();
}
std::list<rpc::ClientCallback<rpc::PushTaskReply>> callbacks;
std::list<rpc::CancelTaskRequest> kill_requests;
};
class MockTaskFinisher : public TaskFinisherInterface {
@@ -75,6 +83,8 @@ class MockTaskFinisher : public TaskFinisherInterface {
num_contained_ids += contained_ids.size();
}
bool MarkTaskCanceled(const TaskID &task_id) override { return true; }
int num_tasks_complete = 0;
int num_tasks_failed = 0;
int num_inlined_dependencies = 0;
@@ -935,6 +945,105 @@ TEST(DirectTaskTransportTest, TestWorkerLeaseTimeout) {
ASSERT_FALSE(raylet_client->ReplyCancelWorkerLease());
}
TEST(DirectTaskTransportTest, TestKillExecutingTask) {
rpc::Address address;
auto raylet_client = std::make_shared<MockRayletClient>();
auto worker_client = std::make_shared<MockWorkerClient>();
auto store = std::make_shared<CoreWorkerMemoryStore>();
auto factory = [&](const rpc::Address &addr) { return worker_client; };
auto task_finisher = std::make_shared<MockTaskFinisher>();
CoreWorkerDirectTaskSubmitter submitter(address, raylet_client, factory, nullptr, store,
task_finisher, ClientID::Nil(), kLongTimeout);
std::unordered_map<std::string, double> empty_resources;
ray::FunctionDescriptor empty_descriptor =
ray::FunctionDescriptorBuilder::BuildPython("", "", "", "");
TaskSpecification task = BuildTaskSpec(empty_resources, empty_descriptor);
ASSERT_TRUE(submitter.SubmitTask(task).ok());
ASSERT_TRUE(raylet_client->GrantWorkerLease("localhost", 1234, ClientID::Nil()));
// Try force kill, exiting the worker
ASSERT_TRUE(submitter.CancelTask(task, true).ok());
ASSERT_EQ(worker_client->kill_requests.front().intended_task_id(),
task.TaskId().Binary());
ASSERT_TRUE(worker_client->ReplyPushTask(Status::IOError("workerdying"), true));
ASSERT_EQ(worker_client->callbacks.size(), 0);
ASSERT_EQ(raylet_client->num_workers_returned, 0);
ASSERT_EQ(raylet_client->num_workers_disconnected, 0);
ASSERT_EQ(task_finisher->num_tasks_complete, 0);
ASSERT_EQ(task_finisher->num_tasks_failed, 1);
task.GetMutableMessage().set_task_id(
TaskID::ForNormalTask(JobID::Nil(), TaskID::Nil(), 1).Binary());
ASSERT_TRUE(submitter.SubmitTask(task).ok());
ASSERT_TRUE(raylet_client->GrantWorkerLease("localhost", 1234, ClientID::Nil()));
// Try non-force kill, worker returns normally
ASSERT_TRUE(submitter.CancelTask(task, false).ok());
ASSERT_TRUE(worker_client->ReplyPushTask());
ASSERT_EQ(worker_client->kill_requests.front().intended_task_id(),
task.TaskId().Binary());
ASSERT_EQ(worker_client->callbacks.size(), 0);
ASSERT_EQ(raylet_client->num_workers_returned, 1);
ASSERT_EQ(raylet_client->num_workers_disconnected, 0);
ASSERT_EQ(task_finisher->num_tasks_complete, 1);
ASSERT_EQ(task_finisher->num_tasks_failed, 1);
}
TEST(DirectTaskTransportTest, TestKillPendingTask) {
rpc::Address address;
auto raylet_client = std::make_shared<MockRayletClient>();
auto worker_client = std::make_shared<MockWorkerClient>();
auto store = std::make_shared<CoreWorkerMemoryStore>();
auto factory = [&](const rpc::Address &addr) { return worker_client; };
auto task_finisher = std::make_shared<MockTaskFinisher>();
CoreWorkerDirectTaskSubmitter submitter(address, raylet_client, factory, nullptr, store,
task_finisher, ClientID::Nil(), kLongTimeout);
std::unordered_map<std::string, double> empty_resources;
ray::FunctionDescriptor empty_descriptor =
ray::FunctionDescriptorBuilder::BuildPython("", "", "", "");
TaskSpecification task = BuildTaskSpec(empty_resources, empty_descriptor);
ASSERT_TRUE(submitter.SubmitTask(task).ok());
ASSERT_TRUE(submitter.CancelTask(task, true).ok());
ASSERT_EQ(worker_client->kill_requests.size(), 0);
ASSERT_EQ(worker_client->callbacks.size(), 0);
ASSERT_EQ(raylet_client->num_workers_returned, 0);
ASSERT_EQ(raylet_client->num_workers_disconnected, 0);
ASSERT_EQ(task_finisher->num_tasks_complete, 0);
ASSERT_EQ(task_finisher->num_tasks_failed, 1);
ASSERT_EQ(raylet_client->num_leases_canceled, 1);
ASSERT_TRUE(raylet_client->ReplyCancelWorkerLease());
}
TEST(DirectTaskTransportTest, TestKillResolvingTask) {
rpc::Address address;
auto raylet_client = std::make_shared<MockRayletClient>();
auto worker_client = std::make_shared<MockWorkerClient>();
auto store = std::make_shared<CoreWorkerMemoryStore>();
auto factory = [&](const rpc::Address &addr) { return worker_client; };
auto task_finisher = std::make_shared<MockTaskFinisher>();
CoreWorkerDirectTaskSubmitter submitter(address, raylet_client, factory, nullptr, store,
task_finisher, ClientID::Nil(), kLongTimeout);
std::unordered_map<std::string, double> empty_resources;
ray::FunctionDescriptor empty_descriptor =
ray::FunctionDescriptorBuilder::BuildPython("", "", "", "");
TaskSpecification task = BuildTaskSpec(empty_resources, empty_descriptor);
ObjectID obj1 = ObjectID::FromRandom().WithTransportType(TaskTransportType::DIRECT);
task.GetMutableMessage().add_args()->add_object_ids(obj1.Binary());
ASSERT_TRUE(submitter.SubmitTask(task).ok());
ASSERT_EQ(task_finisher->num_inlined_dependencies, 0);
ASSERT_TRUE(submitter.CancelTask(task, true).ok());
auto data = GenerateRandomObject();
ASSERT_TRUE(store->Put(*data, obj1));
ASSERT_EQ(worker_client->kill_requests.size(), 0);
ASSERT_EQ(worker_client->callbacks.size(), 0);
ASSERT_EQ(raylet_client->num_workers_returned, 0);
ASSERT_EQ(raylet_client->num_workers_disconnected, 0);
ASSERT_EQ(task_finisher->num_tasks_complete, 0);
ASSERT_EQ(task_finisher->num_tasks_failed, 1);
}
} // namespace ray
int main(int argc, char **argv) {
+1
View File
@@ -54,6 +54,7 @@ class MockWorker {
nullptr, // check_signals
nullptr, // gc_collect
nullptr, // get_lang_stack
nullptr, // kill_main
true, // ref_counting_enabled
false, // is_local_mode
1, // num_workers
@@ -187,6 +187,31 @@ TEST_F(TaskManagerTest, TestTaskRetry) {
ASSERT_EQ(reference_counter_->NumObjectIDsInScope(), 0);
}
TEST_F(TaskManagerTest, TestTaskKill) {
TaskID caller_id = TaskID::Nil();
rpc::Address caller_address;
ASSERT_EQ(reference_counter_->NumObjectIDsInScope(), 0);
auto spec = CreateTaskHelper(1, {});
ASSERT_FALSE(manager_.IsTaskPending(spec.TaskId()));
int num_retries = 3;
manager_.AddPendingTask(caller_id, caller_address, spec, "", num_retries);
ASSERT_TRUE(manager_.IsTaskPending(spec.TaskId()));
ASSERT_EQ(reference_counter_->NumObjectIDsInScope(), 1);
auto return_id = spec.ReturnId(0, TaskTransportType::DIRECT);
WorkerContext ctx(WorkerType::WORKER, WorkerID::FromRandom(), JobID::FromInt(0));
manager_.MarkTaskCanceled(spec.TaskId());
auto error = rpc::ErrorType::TASK_CANCELLED;
manager_.PendingTaskFailed(spec.TaskId(), error);
ASSERT_FALSE(manager_.IsTaskPending(spec.TaskId()));
std::vector<std::shared_ptr<RayObject>> results;
RAY_CHECK_OK(store_->Get({return_id}, 1, 0, ctx, false, &results));
ASSERT_EQ(results.size(), 1);
rpc::ErrorType stored_error;
ASSERT_TRUE(results[0]->IsException(&stored_error));
ASSERT_EQ(stored_error, error);
}
// Test to make sure that the task spec and dependencies for an object are
// evicted when lineage pinning is disabled in the ReferenceCounter.
TEST_F(TaskManagerTest, TestLineageEvicted) {
@@ -23,17 +23,17 @@ Status CoreWorkerDirectTaskSubmitter::SubmitTask(TaskSpecification task_spec) {
resolver_.ResolveDependencies(task_spec, [this, task_spec]() {
RAY_LOG(DEBUG) << "Task dependencies resolved " << task_spec.TaskId();
if (actor_create_callback_ && task_spec.IsActorCreationTask()) {
// If gcs actor management is enabled, the actor creation task will be sent to gcs
// server directly after the in-memory dependent objects are resolved.
// For more details please see the protocol of actor management based on gcs.
// If gcs actor management is enabled, the actor creation task will be sent to
// gcs server directly after the in-memory dependent objects are resolved. For
// more details please see the protocol of actor management based on gcs.
// https://docs.google.com/document/d/1EAWide-jy05akJp6OMtDn58XOK7bUyruWMia4E-fV28/edit?usp=sharing
auto actor_id = task_spec.ActorCreationId();
auto task_id = task_spec.TaskId();
RAY_LOG(INFO) << "Submitting actor creation task to GCS: " << actor_id;
auto status =
actor_create_callback_(task_spec, [this, actor_id, task_id](Status status) {
// If GCS is failed, GcsRpcClient may receive IOError status but it will not
// trigger this callback, because GcsRpcClient has retry logic at the
// If GCS is failed, GcsRpcClient may receive IOError status but it will
// not trigger this callback, because GcsRpcClient has retry logic at the
// bottom. So if this callback is invoked with an error there must be
// something wrong with the protocol of gcs-based actor management.
// So just check `status.ok()` here.
@@ -46,18 +46,33 @@ Status CoreWorkerDirectTaskSubmitter::SubmitTask(TaskSpecification task_spec) {
return;
}
absl::MutexLock lock(&mu_);
// Note that the dependencies in the task spec are mutated to only contain
// plasma dependencies after ResolveDependencies finishes.
const SchedulingKey scheduling_key(
task_spec.GetSchedulingClass(), task_spec.GetDependencies(),
task_spec.IsActorCreationTask() ? task_spec.ActorCreationId() : ActorID::Nil());
auto it = task_queues_.find(scheduling_key);
if (it == task_queues_.end()) {
it = task_queues_.emplace(scheduling_key, std::deque<TaskSpecification>()).first;
bool keep_executing = true;
{
absl::MutexLock lock(&mu_);
if (cancelled_tasks_.find(task_spec.TaskId()) != cancelled_tasks_.end()) {
cancelled_tasks_.erase(task_spec.TaskId());
keep_executing = false;
}
if (keep_executing) {
// Note that the dependencies in the task spec are mutated to only contain
// plasma dependencies after ResolveDependencies finishes.
const SchedulingKey scheduling_key(
task_spec.GetSchedulingClass(), task_spec.GetDependencies(),
task_spec.IsActorCreationTask() ? task_spec.ActorCreationId()
: ActorID::Nil());
auto it = task_queues_.find(scheduling_key);
if (it == task_queues_.end()) {
it =
task_queues_.emplace(scheduling_key, std::deque<TaskSpecification>()).first;
}
it->second.push_back(task_spec);
RequestNewWorkerIfNeeded(scheduling_key);
}
}
if (!keep_executing) {
task_finisher_->PendingTaskFailed(task_spec.TaskId(),
rpc::ErrorType::TASK_CANCELLED, nullptr);
}
it->second.push_back(task_spec);
RequestNewWorkerIfNeeded(scheduling_key);
});
return Status::OK();
}
@@ -93,8 +108,9 @@ void CoreWorkerDirectTaskSubmitter::OnWorkerIdle(
worker_to_lease_client_.erase(addr);
} else {
auto &client = *client_cache_[addr];
PushNormalTask(addr, client, scheduling_key, queue_entry->second.front(),
assigned_resources);
auto task_spec = queue_entry->second.front();
PushNormalTask(addr, client, scheduling_key, task_spec, assigned_resources);
executing_tasks_.emplace(task_spec.TaskId(), addr);
queue_entry->second.pop_front();
// Delete the queue if it's now empty. Note that the queue cannot already be empty
// because this is the only place tasks are removed from it.
@@ -248,6 +264,10 @@ void CoreWorkerDirectTaskSubmitter::PushNormalTask(
std::move(request),
[this, task_id, is_actor, is_actor_creation, scheduling_key, addr,
assigned_resources](Status status, const rpc::PushTaskReply &reply) {
{
absl::MutexLock lock(&mu_);
executing_tasks_.erase(task_id);
}
if (reply.worker_exiting()) {
// The worker is draining and will shutdown after it is done. Don't return
// it to the Raylet since that will kill it early.
@@ -273,4 +293,78 @@ void CoreWorkerDirectTaskSubmitter::PushNormalTask(
}
}));
}
Status CoreWorkerDirectTaskSubmitter::CancelTask(TaskSpecification task_spec,
bool force_kill) {
RAY_LOG(INFO) << "Killing task: " << task_spec.TaskId();
const SchedulingKey scheduling_key(
task_spec.GetSchedulingClass(), task_spec.GetDependencies(),
task_spec.IsActorCreationTask() ? task_spec.ActorCreationId() : ActorID::Nil());
std::shared_ptr<rpc::CoreWorkerClientInterface> client = nullptr;
{
absl::MutexLock lock(&mu_);
if (cancelled_tasks_.find(task_spec.TaskId()) != cancelled_tasks_.end() ||
!task_finisher_->MarkTaskCanceled(task_spec.TaskId())) {
return Status::OK();
}
auto scheduled_tasks = task_queues_.find(scheduling_key);
// This cancels tasks that have completed dependencies and are awaiting
// a worker lease.
if (scheduled_tasks != task_queues_.end()) {
for (auto spec = scheduled_tasks->second.begin();
spec != scheduled_tasks->second.end(); spec++) {
if (spec->TaskId() == task_spec.TaskId()) {
scheduled_tasks->second.erase(spec);
if (scheduled_tasks->second.empty()) {
task_queues_.erase(scheduling_key);
CancelWorkerLeaseIfNeeded(scheduling_key);
}
task_finisher_->PendingTaskFailed(task_spec.TaskId(),
rpc::ErrorType::TASK_CANCELLED);
return Status::OK();
}
}
}
// This will get removed either when the RPC call to cancel is returned
// or when all dependencies are resolved.
RAY_CHECK(cancelled_tasks_.emplace(task_spec.TaskId()).second);
auto rpc_client = executing_tasks_.find(task_spec.TaskId());
// Looks for an RPC handle for the worker executing the task.
if (rpc_client != executing_tasks_.end() &&
client_cache_.find(rpc_client->second) != client_cache_.end()) {
client = client_cache_.find(rpc_client->second)->second;
}
}
// This case is reached for tasks that have unresolved dependencies.
if (client == nullptr) {
return Status::OK();
}
auto request = rpc::CancelTaskRequest();
request.set_intended_task_id(task_spec.TaskId().Binary());
request.set_force_kill(force_kill);
RAY_UNUSED(client->CancelTask(
request, [this, task_spec, force_kill](const Status &status,
const rpc::CancelTaskReply &reply) {
absl::MutexLock lock(&mu_);
cancelled_tasks_.erase(task_spec.TaskId());
if (status.ok() && !reply.attempt_succeeded()) {
if (cancel_retry_timer_.has_value()) {
if (cancel_retry_timer_->expiry().time_since_epoch() <=
std::chrono::high_resolution_clock::now().time_since_epoch()) {
cancel_retry_timer_->expires_after(boost::asio::chrono::milliseconds(
RayConfig::instance().cancellation_retry_ms()));
}
cancel_retry_timer_->async_wait(boost::bind(
&CoreWorkerDirectTaskSubmitter::CancelTask, this, task_spec, force_kill));
}
}
// Retry is not attempted if !status.ok() because force-kill may kill the worker
// before the reply is sent.
}));
return Status::OK();
}
}; // namespace ray
@@ -56,7 +56,8 @@ class CoreWorkerDirectTaskSubmitter {
std::shared_ptr<TaskFinisherInterface> task_finisher, ClientID local_raylet_id,
int64_t lease_timeout_ms,
std::function<Status(const TaskSpecification &, const gcs::StatusCallback &)>
actor_create_callback = nullptr)
actor_create_callback = nullptr,
absl::optional<boost::asio::steady_timer> cancel_timer = absl::nullopt)
: rpc_address_(rpc_address),
local_lease_client_(lease_client),
client_factory_(client_factory),
@@ -65,13 +66,20 @@ class CoreWorkerDirectTaskSubmitter {
task_finisher_(task_finisher),
lease_timeout_ms_(lease_timeout_ms),
local_raylet_id_(local_raylet_id),
actor_create_callback_(std::move(actor_create_callback)) {}
actor_create_callback_(std::move(actor_create_callback)),
cancel_retry_timer_(std::move(cancel_timer)) {}
/// Schedule a task for direct submission to a worker.
///
/// \param[in] task_spec The task to schedule.
Status SubmitTask(TaskSpecification task_spec);
/// Either remove a pending task or send an RPC to kill a running task
///
/// \param[in] task_spec The task to kill.
/// \param[in] force_kill Whether to kill the worker executing the task.
Status CancelTask(TaskSpecification task_spec, bool force_kill);
private:
/// Schedule more work onto an idle worker or return it back to the raylet if
/// no more tasks are queued for submission. If an error was encountered
@@ -180,6 +188,15 @@ class CoreWorkerDirectTaskSubmitter {
// Invariant: if a queue is in this map, it has at least one task.
absl::flat_hash_map<SchedulingKey, std::deque<TaskSpecification>> task_queues_
GUARDED_BY(mu_);
// Tasks that were cancelled while being resolved.
absl::flat_hash_set<TaskID> cancelled_tasks_ GUARDED_BY(mu_);
// Keeps track of where currently executing tasks are being run.
absl::flat_hash_map<TaskID, rpc::WorkerAddress> executing_tasks_ GUARDED_BY(mu_);
// Retries cancelation requests if they were not successful.
absl::optional<boost::asio::steady_timer> cancel_retry_timer_;
};
}; // namespace ray
+21 -6
View File
@@ -111,7 +111,8 @@ message PushTaskRequest {
// caller, which might happen theoretically when network has issues.
// - For an actor, this is set to the timestamp when the actor is created,
// so it can be used to differentiate which is the new reconstructed actor.
// - For a non-actor task, it's set to the timestamp the task starts execution.
// - For a non-actor task, it's set to the timestamp the task starts
// execution.
int64 caller_version = 7;
}
@@ -121,9 +122,11 @@ message PushTaskReply {
// Set to true if the worker will be exiting.
bool worker_exiting = 2;
// The references that the worker borrowed during the task execution. A
// borrower is a process that is currently using the object ID, in one of 3 ways:
// borrower is a process that is currently using the object ID, in one of 3
// ways:
// 1. Has an ObjectID copy in Python.
// 2. Has submitted a task that depends on the object and that is still pending.
// 2. Has submitted a task that depends on the object and that is still
// pending.
// 3. Owns another object that is in scope and whose value contains the
// ObjectID.
// This list includes the reference counts for any IDs that were passed to
@@ -156,9 +159,7 @@ message GetObjectStatusRequest {
}
message GetObjectStatusReply {
enum ObjectStatus {
CREATED = 0;
}
enum ObjectStatus { CREATED = 0; }
ObjectStatus status = 1;
}
@@ -184,6 +185,18 @@ message KillActorRequest {
message KillActorReply {
}
message CancelTaskRequest {
// ID of task that should be killed.
bytes intended_task_id = 1;
// Whether to kill the worker.
bool force_kill = 2;
}
message CancelTaskReply {
// Whether the requested task is the currently running task.
bool attempt_succeeded = 1;
}
message GetCoreWorkerStatsRequest {
// The ID of the worker this message is intended for.
bytes intended_worker_id = 1;
@@ -281,6 +294,8 @@ service CoreWorkerService {
returns (WaitForObjectEvictionReply);
// Request that the worker shut down without completing outstanding work.
rpc KillActor(KillActorRequest) returns (KillActorReply);
// Request that a worker cancels a task.
rpc CancelTask(CancelTaskRequest) returns (CancelTaskReply);
// Get metrics from core workers.
rpc GetCoreWorkerStats(GetCoreWorkerStatsRequest) returns (GetCoreWorkerStatsReply);
// Wait for a borrower to finish using an object. Sent by the object's owner.
+3
View File
@@ -340,4 +340,7 @@ enum ErrorType {
// exposed to user code; it is only used internally to indicate the result of a direct
// call has been placed in plasma.
OBJECT_IN_PLASMA = 4;
// Indicates that an object has been cancelled.
TASK_CANCELLED = 5;
}
+7
View File
@@ -157,6 +157,11 @@ class CoreWorkerClientInterface {
return Status::NotImplemented("");
}
virtual ray::Status CancelTask(const CancelTaskRequest &request,
const ClientCallback<CancelTaskReply> &callback) {
return Status::NotImplemented("");
}
virtual ray::Status GetCoreWorkerStats(
const GetCoreWorkerStatsRequest &request,
const ClientCallback<GetCoreWorkerStatsReply> &callback) {
@@ -210,6 +215,8 @@ class CoreWorkerClient : public std::enable_shared_from_this<CoreWorkerClient>,
RPC_CLIENT_METHOD(CoreWorkerService, KillActor, grpc_client_, override)
RPC_CLIENT_METHOD(CoreWorkerService, CancelTask, grpc_client_, override)
RPC_CLIENT_METHOD(CoreWorkerService, WaitForObjectEviction, grpc_client_, override)
RPC_CLIENT_METHOD(CoreWorkerService, GetCoreWorkerStats, grpc_client_, override)
+2 -1
View File
@@ -17,7 +17,6 @@
#include "ray/rpc/grpc_server.h"
#include "ray/rpc/server_call.h"
#include "src/ray/protobuf/core_worker.grpc.pb.h"
#include "src/ray/protobuf/core_worker.pb.h"
@@ -36,6 +35,7 @@ namespace rpc {
RPC_SERVICE_HANDLER(CoreWorkerService, WaitForObjectEviction) \
RPC_SERVICE_HANDLER(CoreWorkerService, WaitForRefRemoved) \
RPC_SERVICE_HANDLER(CoreWorkerService, KillActor) \
RPC_SERVICE_HANDLER(CoreWorkerService, CancelTask) \
RPC_SERVICE_HANDLER(CoreWorkerService, GetCoreWorkerStats) \
RPC_SERVICE_HANDLER(CoreWorkerService, LocalGC) \
RPC_SERVICE_HANDLER(CoreWorkerService, PlasmaObjectReady)
@@ -48,6 +48,7 @@ namespace rpc {
DECLARE_VOID_RPC_SERVICE_HANDLER_METHOD(WaitForObjectEviction) \
DECLARE_VOID_RPC_SERVICE_HANDLER_METHOD(WaitForRefRemoved) \
DECLARE_VOID_RPC_SERVICE_HANDLER_METHOD(KillActor) \
DECLARE_VOID_RPC_SERVICE_HANDLER_METHOD(CancelTask) \
DECLARE_VOID_RPC_SERVICE_HANDLER_METHOD(GetCoreWorkerStats) \
DECLARE_VOID_RPC_SERVICE_HANDLER_METHOD(LocalGC) \
DECLARE_VOID_RPC_SERVICE_HANDLER_METHOD(PlasmaObjectReady)
+1
View File
@@ -313,6 +313,7 @@ class StreamingWorker {
nullptr, // check_signals
nullptr, // gc_collect
nullptr, // get_lang_stack
nullptr, // kill_main
true, // ref_counting_enabled
false, // is_local_mode
1, // num_workers
+1
View File
@@ -326,6 +326,7 @@ class StreamingQueueTestBase : public ::testing::TestWithParam<uint64_t> {
nullptr, // check_signals
nullptr, // gc_collect
nullptr, // get_lang_stack
nullptr, // kill_main
true, // ref_counting_enabled
false, // is_local_mode
1, // num_workers