From 69ff7e3e3596454782a4765802d24a81bec4d273 Mon Sep 17 00:00:00 2001 From: ijrsvt Date: Sat, 25 Apr 2020 16:04:52 -0700 Subject: [PATCH] TaskCancellation (#7669) * Smol comment * WIP, not passing ray.init * Fixed small problem * wip * Pseudo interrupt things * Basic prototype operational * correct proc title * Mostly done * Cleanup * cleaner raylet error * Cleaning up a few loose ends * Fixing Race Conds * Prelim testing * Fixing comments and adding second_check for kill * Working_new_impl * demo_ready * Fixing my english * Fixing a few problems * Small problems * Cleaning up * Response to changes * Fixing error passing * Merged to master * fixing lock * Cleaning up print statements * Format * Fixing Unit test build failure * mock_worker fix * java_fix * Canel * Switching to Cancel * Responding to Review * FixFormatting * Lease cancellation * FInal comments? * Moving exist check to CoreWorker * Fix Actor Transport Test * Fixing task manager test * chaning clock repr * Fix build * fix white space * lint fix * Updating to medium size * Fixing Java test compilation issue * lengthen bad timeouts --- python/ray/__init__.py | 2 + python/ray/_raylet.pyx | 45 +++- python/ray/exceptions.py | 17 ++ python/ray/includes/libcoreworker.pxd | 2 + python/ray/includes/unique_ids.pxd | 2 + python/ray/includes/unique_ids.pxi | 3 + python/ray/serialization.py | 3 + python/ray/tests/BUILD | 8 + python/ray/tests/test_cancel.py | 232 ++++++++++++++++++ python/ray/worker.py | 23 +- src/ray/common/ray_config_def.h | 3 + src/ray/core_worker/core_worker.cc | 51 +++- src/ray/core_worker/core_worker.h | 17 +- .../java/io_ray_runtime_RayNativeRuntime.cc | 1 + src/ray/core_worker/task_manager.cc | 18 ++ src/ray/core_worker/task_manager.h | 11 + src/ray/core_worker/test/core_worker_test.cc | 1 + .../test/direct_actor_transport_test.cc | 2 + .../test/direct_task_transport_test.cc | 109 ++++++++ src/ray/core_worker/test/mock_worker.cc | 1 + src/ray/core_worker/test/task_manager_test.cc | 25 ++ .../transport/direct_task_transport.cc | 130 ++++++++-- .../transport/direct_task_transport.h | 21 +- src/ray/protobuf/core_worker.proto | 27 +- src/ray/protobuf/gcs.proto | 3 + src/ray/rpc/worker/core_worker_client.h | 7 + src/ray/rpc/worker/core_worker_server.h | 3 +- streaming/src/test/mock_actor.cc | 1 + streaming/src/test/queue_tests_base.h | 1 + 29 files changed, 731 insertions(+), 38 deletions(-) create mode 100644 python/ray/tests/test_cancel.py diff --git a/python/ray/__init__.py b/python/ray/__init__.py index 3a2a69f82..2c9033030 100644 --- a/python/ray/__init__.py +++ b/python/ray/__init__.py @@ -66,6 +66,7 @@ from ray.worker import ( LOCAL_MODE, SCRIPT_MODE, WORKER_MODE, + cancel, connect, disconnect, get, @@ -113,6 +114,7 @@ __all__ = [ "_config", "_get_runtime_context", "actor", + "cancel", "connect", "disconnect", "get", diff --git a/python/ray/_raylet.pyx b/python/ray/_raylet.pyx index 6b90cdf75..085cf4784 100644 --- a/python/ray/_raylet.pyx +++ b/python/ray/_raylet.pyx @@ -17,6 +17,8 @@ import logging import os import pickle import sys +import _thread +import setproctitle from libc.stdint cimport ( int32_t, @@ -90,6 +92,7 @@ from ray.exceptions import ( RayTaskError, ObjectStoreFullError, RayTimeoutError, + RayCancellationError ) from ray.utils import decode import gc @@ -453,13 +456,23 @@ cdef execute_task( class_name, repr(args), repr(kwargs)) core_worker.set_actor_title(actor_title.encode("utf-8")) # Execute the task. - with ray.worker._changeproctitle(title, next_title): - with core_worker.profile_event(b"task:execute"): - task_exception = True - outputs = function_executor(*args, **kwargs) + with core_worker.profile_event(b"task:execute"): + task_exception = True + try: + with ray.worker._changeproctitle(title, next_title): + outputs = function_executor(*args, **kwargs) task_exception = False - if c_return_ids.size() == 1: - outputs = (outputs,) + except KeyboardInterrupt as e: + raise RayCancellationError( + core_worker.get_current_task_id()) + if c_return_ids.size() == 1: + outputs = (outputs,) + # Check for a cancellation that was called when the function + # was exiting and was raised after the except block. + if not check_signals().ok(): + task_exception = True + raise RayCancellationError( + core_worker.get_current_task_id()) # Store the outputs in the object store. with core_worker.profile_event(b"task:store_outputs"): core_worker.store_task_outputs( @@ -551,6 +564,14 @@ cdef void async_plasma_callback(CObjectID object_id, event_handler._loop.call_soon_threadsafe( event_handler._complete_future, obj_id) +cdef c_bool kill_main_task() nogil: + with gil: + if setproctitle.getproctitle() != "ray::IDLE": + _thread.interrupt_main() + return True + return False + + cdef CRayStatus check_signals() nogil: with gil: try: @@ -658,6 +679,7 @@ cdef class CoreWorker: options.ref_counting_enabled = True options.is_local_mode = local_mode options.num_workers = 1 + options.kill_main = kill_main_task CCoreWorkerProcess.Initialize(options) @@ -953,6 +975,17 @@ cdef class CoreWorker: check_status(CCoreWorkerProcess.GetCoreWorker().KillActor( c_actor_id, True, no_reconstruction)) + def cancel_task(self, ObjectID object_id, c_bool force_kill): + cdef: + CObjectID c_object_id = object_id.native() + CRayStatus status = CRayStatus.OK() + + status = CCoreWorkerProcess.GetCoreWorker().CancelTask( + c_object_id, force_kill) + + if not status.ok(): + raise TypeError(status.message().decode()) + def resource_ids(self): cdef: ResourceMappingType resource_mapping = ( diff --git a/python/ray/exceptions.py b/python/ray/exceptions.py index be013ad7c..8adc46044 100644 --- a/python/ray/exceptions.py +++ b/python/ray/exceptions.py @@ -16,6 +16,23 @@ class RayConnectionError(RayError): pass +class RayCancellationError(RayError): + """Raised when this task is cancelled. + + Attributes: + task_id (TaskID): The TaskID of the function that was directly + cancelled. + """ + + def __init__(self, task_id=None): + self.task_id = task_id + + def __str__(self): + if self.task_id is None: + return "This task or its dependency was cancelled by" + return "Task: " + str(self.task_id) + " was cancelled" + + class RayTaskError(RayError): """Indicates that a task threw an exception during execution. diff --git a/python/ray/includes/libcoreworker.pxd b/python/ray/includes/libcoreworker.pxd index e4759095a..6524139e0 100644 --- a/python/ray/includes/libcoreworker.pxd +++ b/python/ray/includes/libcoreworker.pxd @@ -97,6 +97,7 @@ cdef extern from "ray/core_worker/core_worker.h" nogil: CRayStatus KillActor( const CActorID &actor_id, c_bool force_kill, c_bool no_reconstruction) + CRayStatus CancelTask(const CObjectID &object_id, c_bool force_kill) unique_ptr[CProfileEvent] CreateProfileEvent( const c_string &event_type) @@ -214,6 +215,7 @@ cdef extern from "ray/core_worker/core_worker.h" nogil: c_bool ref_counting_enabled c_bool is_local_mode int num_workers + (c_bool() nogil) kill_main CCoreWorkerOptions() cdef cppclass CCoreWorkerProcess "ray::CoreWorkerProcess": diff --git a/python/ray/includes/unique_ids.pxd b/python/ray/includes/unique_ids.pxd index 81bb0a806..512d4dd2d 100644 --- a/python/ray/includes/unique_ids.pxd +++ b/python/ray/includes/unique_ids.pxd @@ -122,6 +122,8 @@ cdef extern from "ray/common/id.h" namespace "ray" nogil: CTaskID ForNormalTask(CJobID job_id, CTaskID parent_task_id, int64_t parent_task_counter) + CActorID ActorId() const + cdef cppclass CObjectID" ray::ObjectID"(CBaseID[CObjectID]): @staticmethod diff --git a/python/ray/includes/unique_ids.pxi b/python/ray/includes/unique_ids.pxi index fc87b49dd..d1d268ca5 100644 --- a/python/ray/includes/unique_ids.pxi +++ b/python/ray/includes/unique_ids.pxi @@ -220,6 +220,9 @@ cdef class TaskID(BaseID): def is_nil(self): return self.data.IsNil() + def actor_id(self): + return ActorID(self.data.ActorId().Binary()) + cdef size_t hash(self): return self.data.Hash() diff --git a/python/ray/serialization.py b/python/ray/serialization.py index 9b6589910..a3052f13d 100644 --- a/python/ray/serialization.py +++ b/python/ray/serialization.py @@ -12,6 +12,7 @@ from ray.exceptions import ( PlasmaObjectNotAvailable, RayTaskError, RayActorError, + RayCancellationError, RayWorkerError, UnreconstructableError, ) @@ -279,6 +280,8 @@ class SerializationContext: return RayWorkerError() elif error_type == ErrorType.Value("ACTOR_DIED"): return RayActorError() + elif error_type == ErrorType.Value("TASK_CANCELLED"): + return RayCancellationError() elif error_type == ErrorType.Value("OBJECT_UNRECONSTRUCTABLE"): return UnreconstructableError(ray.ObjectID(object_id.binary())) else: diff --git a/python/ray/tests/BUILD b/python/ray/tests/BUILD index b1fd2ddb2..f2fefb044 100644 --- a/python/ray/tests/BUILD +++ b/python/ray/tests/BUILD @@ -414,3 +414,11 @@ py_test( tags = ["exclusive"], deps = ["//:ray_lib"], ) + +py_test( + name = "test_cancel", + size = "medium", + srcs = ["test_cancel.py"], + tags = ["exclusive"], + deps = ["//:ray_lib"], +) diff --git a/python/ray/tests/test_cancel.py b/python/ray/tests/test_cancel.py new file mode 100644 index 000000000..adec360a3 --- /dev/null +++ b/python/ray/tests/test_cancel.py @@ -0,0 +1,232 @@ +import pytest +import ray +import random +import sys +import time +from ray.exceptions import RayTaskError, RayTimeoutError, \ + RayCancellationError, RayWorkerError +from ray.test_utils import SignalActor + + +def valid_exceptions(use_force): + if use_force: + return (RayTaskError, RayCancellationError, RayWorkerError) + else: + return (RayTaskError, RayCancellationError) + + +@pytest.mark.parametrize("use_force", [True, False]) +def test_cancel_chain(ray_start_regular, use_force): + signaler = SignalActor.remote() + + @ray.remote + def wait_for(t): + return ray.get(t[0]) + + obj1 = wait_for.remote([signaler.wait.remote()]) + obj2 = wait_for.remote([obj1]) + obj3 = wait_for.remote([obj2]) + obj4 = wait_for.remote([obj3]) + + assert len(ray.wait([obj1], timeout=.1)[0]) == 0 + ray.cancel(obj1, use_force) + for ob in [obj1, obj2, obj3, obj4]: + with pytest.raises(valid_exceptions(use_force)): + ray.get(ob) + + signaler2 = SignalActor.remote() + obj1 = wait_for.remote([signaler2.wait.remote()]) + obj2 = wait_for.remote([obj1]) + obj3 = wait_for.remote([obj2]) + obj4 = wait_for.remote([obj3]) + + assert len(ray.wait([obj3], timeout=.1)[0]) == 0 + ray.cancel(obj3, use_force) + for ob in [obj3, obj4]: + with pytest.raises(valid_exceptions(use_force)): + ray.get(ob) + + with pytest.raises(RayTimeoutError): + ray.get(obj1, timeout=.1) + + with pytest.raises(RayTimeoutError): + ray.get(obj2, timeout=.1) + + signaler2.send.remote() + ray.get(obj1, timeout=10) + + +@pytest.mark.parametrize("use_force", [True, False]) +def test_cancel_multiple_dependents(ray_start_regular, use_force): + signaler = SignalActor.remote() + + @ray.remote + def wait_for(t): + return ray.get(t[0]) + + head = wait_for.remote([signaler.wait.remote()]) + deps = [] + for _ in range(3): + deps.append(wait_for.remote([head])) + + assert len(ray.wait([head], timeout=.1)[0]) == 0 + ray.cancel(head, use_force) + for d in deps: + with pytest.raises(valid_exceptions(use_force)): + ray.get(d) + + head2 = wait_for.remote([signaler.wait.remote()]) + + deps2 = [] + for _ in range(3): + deps2.append(wait_for.remote([head])) + + for d in deps2: + ray.cancel(d, use_force) + + for d in deps2: + with pytest.raises(valid_exceptions(use_force)): + ray.get(d) + + signaler.send.remote() + ray.get(head2, timeout=1) + + +@pytest.mark.parametrize("use_force", [True, False]) +def test_single_cpu_cancel(shutdown_only, use_force): + ray.init(num_cpus=1) + signaler = SignalActor.remote() + + @ray.remote + def wait_for(t): + return ray.get(t[0]) + + obj1 = wait_for.remote([signaler.wait.remote()]) + obj2 = wait_for.remote([obj1]) + obj3 = wait_for.remote([obj2]) + indep = wait_for.remote([signaler.wait.remote()]) + + assert len(ray.wait([obj3], timeout=.1)[0]) == 0 + ray.cancel(obj3, use_force) + with pytest.raises(valid_exceptions(use_force)): + ray.get(obj3, 10) + + ray.cancel(obj1, use_force) + + for d in [obj1, obj2]: + with pytest.raises(valid_exceptions(use_force)): + ray.get(d) + + signaler.send.remote() + ray.get(indep) + + +@pytest.mark.parametrize("use_force", [True, False]) +def test_comprehensive(ray_start_regular, use_force): + signaler = SignalActor.remote() + + @ray.remote + def wait_for(t): + ray.get(t[0]) + return "Result" + + @ray.remote + def combine(a, b): + return str(a) + str(b) + + a = wait_for.remote([signaler.wait.remote()]) + b = wait_for.remote([signaler.wait.remote()]) + combo = combine.remote(a, b) + a2 = wait_for.remote([a]) + + assert len(ray.wait([a, b, a2, combo], timeout=1)[0]) == 0 + + ray.cancel(a, use_force) + with pytest.raises(valid_exceptions(use_force)): + ray.get(a, 10) + + with pytest.raises(valid_exceptions(use_force)): + ray.get(a2, 10) + + signaler.send.remote() + + with pytest.raises(valid_exceptions(use_force)): + ray.get(combo, 10) + + +@pytest.mark.parametrize("use_force", [True, False]) +def test_stress(shutdown_only, use_force): + ray.init(num_cpus=1) + + @ray.remote + def infinite_sleep(y): + if y: + while True: + time.sleep(1 / 10) + + first = infinite_sleep.remote(True) + + sleep_or_no = [random.randint(0, 1) for _ in range(100)] + tasks = [infinite_sleep.remote(i) for i in sleep_or_no] + cancelled = set() + for t in tasks: + if random.random() > 0.5: + ray.cancel(t, use_force) + cancelled.add(t) + + ray.cancel(first, use_force) + cancelled.add(first) + + for done in cancelled: + with pytest.raises(valid_exceptions(use_force)): + ray.get(done, 10) + + for indx in range(len(tasks)): + t = tasks[indx] + if sleep_or_no[indx]: + ray.cancel(t, use_force) + cancelled.add(t) + if t in cancelled: + with pytest.raises(valid_exceptions(use_force)): + ray.get(t, 10) + else: + ray.get(t) + + +@pytest.mark.parametrize("use_force", [True, False]) +def test_fast(shutdown_only, use_force): + ray.init(num_cpus=2) + + @ray.remote + def fast(y): + return y + + signaler = SignalActor.remote() + ids = list() + for _ in range(100): + x = fast.remote("a") + ray.cancel(x) + ids.append(x) + + @ray.remote + def wait_for(y): + return y + + sig = signaler.wait.remote() + for _ in range(5000): + x = wait_for.remote(sig) + ids.append(x) + + for idx in range(100, 5100): + if random.random() > 0.95: + ray.cancel(ids[idx]) + signaler.send.remote() + for obj_id in ids: + try: + ray.get(obj_id, 10) + except Exception as e: + assert isinstance(e, valid_exceptions(use_force)) + + +if __name__ == "__main__": + sys.exit(pytest.main(["-v", __file__])) diff --git a/python/ray/worker.py b/python/ray/worker.py index ee86ce887..4c4ecb6f0 100644 --- a/python/ray/worker.py +++ b/python/ray/worker.py @@ -1661,12 +1661,33 @@ def kill(actor): if not isinstance(actor, ray.actor.ActorHandle): raise ValueError("ray.kill() only supported for actors. " "Got: {}.".format(type(actor))) - worker = ray.worker.global_worker worker.check_connected() worker.core_worker.kill_actor(actor._ray_actor_id, False) +def cancel(object_id, force=False): + """Kill a task forcefully. + + This will interrupt any running tasks on the actor, causing them to fail + immediately. Any atexit handlers installed in the actor will still be run. + + If this actor is reconstructable, it will be attempted to be reconstructed. + + Args: + id (ActorHandle or ObjectID): Handle for the actor to kill or ObjectID + of the task to kill. + """ + worker = ray.worker.global_worker + worker.check_connected() + + if not isinstance(object_id, ray.ObjectID): + raise TypeError( + "ray.cancel() only supported for non-actor object IDs. " + "Got: {}.".format(type(object_id))) + return worker.core_worker.cancel_task(object_id, force) + + def _mode(worker=global_worker): """This is a wrapper around worker.mode. diff --git a/src/ray/common/ray_config_def.h b/src/ray/common/ray_config_def.h index d9adc019d..0d6791306 100644 --- a/src/ray/common/ray_config_def.h +++ b/src/ray/common/ray_config_def.h @@ -278,6 +278,9 @@ RAY_CONFIG(uint32_t, object_store_full_initial_delay_ms, 1000) /// Duration to wait between retries for failed tasks. RAY_CONFIG(uint32_t, task_retry_delay_ms, 5000) +/// Duration to wait between retrying to kill a task. +RAY_CONFIG(uint32_t, cancellation_retry_ms, 2000) + /// Whether to enable gcs service. /// RAY_GCS_SERVICE_ENABLED is an env variable which only set in ci job. /// If the value of RAY_GCS_SERVICE_ENABLED is false, we will disable gcs service, diff --git a/src/ray/core_worker/core_worker.cc b/src/ray/core_worker/core_worker.cc index 3c4042a0b..9757c5df4 100644 --- a/src/ray/core_worker/core_worker.cc +++ b/src/ray/core_worker/core_worker.cc @@ -429,7 +429,7 @@ CoreWorker::CoreWorker(const CoreWorkerOptions &options, const WorkerID &worker_ rpc_address_, local_raylet_client_, client_factory, raylet_client_factory, memory_store_, task_manager_, local_raylet_id, RayConfig::instance().worker_lease_timeout_milliseconds(), - std::move(actor_create_callback))); + std::move(actor_create_callback), boost::asio::steady_timer(io_service_))); future_resolver_.reset(new FutureResolver(memory_store_, client_factory)); // Unfortunately the raylet client has to be constructed after the receivers. if (direct_task_receiver_ != nullptr) { @@ -590,10 +590,10 @@ const WorkerID &CoreWorker::GetWorkerID() const { return worker_context_.GetWork void CoreWorker::SetCurrentTaskId(const TaskID &task_id) { worker_context_.SetCurrentTaskId(task_id); - main_thread_task_id_ = task_id; bool not_actor_task = false; { absl::MutexLock lock(&mutex_); + main_thread_task_id_ = task_id; not_actor_task = actor_id_.IsNil(); } if (not_actor_task && task_id.IsNil()) { @@ -1056,6 +1056,7 @@ TaskID CoreWorker::GetCallerId() const { if (!actor_id.IsNil()) { caller_id = TaskID::ForActorCreationTask(actor_id); } else { + absl::MutexLock lock(&mutex_); caller_id = main_thread_task_id_; } return caller_id; @@ -1203,6 +1204,25 @@ Status CoreWorker::SubmitActorTask(const ActorID &actor_id, const RayFunction &f return status; } +Status CoreWorker::CancelTask(const ObjectID &object_id, bool force_kill) { + ActorHandle *h = nullptr; + if (!object_id.CreatedByTask() || + GetActorHandle(object_id.TaskId().ActorId(), &h).ok()) { + return Status::Invalid("Actor task cancellation is not supported."); + } + rpc::Address obj_addr; + if (!reference_counter_->GetOwner(object_id, nullptr, &obj_addr) || + obj_addr.SerializeAsString() != rpc_address_.SerializeAsString()) { + return Status::Invalid("Task is not locally submitted."); + } + + auto task_spec = task_manager_->GetTaskSpec(object_id.TaskId()); + if (task_spec.has_value() && !task_spec.value().IsActorCreationTask()) { + return direct_task_submitter_->CancelTask(task_spec.value(), force_kill); + } + return Status::OK(); +} + Status CoreWorker::KillActor(const ActorID &actor_id, bool force_kill, bool no_reconstruction) { ActorHandle *actor_handle = nullptr; @@ -1735,6 +1755,33 @@ void CoreWorker::HandleWaitForRefRemoved(const rpc::WaitForRefRemovedRequest &re owner_address, ref_removed_callback); } +void CoreWorker::HandleCancelTask(const rpc::CancelTaskRequest &request, + rpc::CancelTaskReply *reply, + rpc::SendReplyCallback send_reply_callback) { + absl::MutexLock lock(&mutex_); + TaskID task_id = TaskID::FromBinary(request.intended_task_id()); + bool success = main_thread_task_id_ == task_id; + + // Try non-force kill + if (success && !request.force_kill()) { + RAY_LOG(INFO) << "Interrupting a running task " << main_thread_task_id_; + success = options_.kill_main(); + } + + reply->set_attempt_succeeded(success); + send_reply_callback(Status::OK(), nullptr, nullptr); + + // Do force kill after reply callback sent + if (success && request.force_kill()) { + RAY_LOG(INFO) << "Force killing a worker running " << main_thread_task_id_; + RAY_IGNORE_EXPR(local_raylet_client_->Disconnect()); + if (options_.log_dir != "") { + RayLog::ShutDownRayLog(); + } + exit(1); + } +} + void CoreWorker::HandleKillActor(const rpc::KillActorRequest &request, rpc::KillActorReply *reply, rpc::SendReplyCallback send_reply_callback) { diff --git a/src/ray/core_worker/core_worker.h b/src/ray/core_worker/core_worker.h index 538f6f60e..94368d02a 100644 --- a/src/ray/core_worker/core_worker.h +++ b/src/ray/core_worker/core_worker.h @@ -106,6 +106,8 @@ struct CoreWorkerOptions { std::function gc_collect; /// Language worker callback to get the current call stack. std::function get_lang_stack; + // Function that tries to interrupt the currently running Python thread. + std::function kill_main; /// Whether to enable object ref counting. bool ref_counting_enabled; /// Is local mode being used. @@ -501,7 +503,7 @@ class CoreWorker : public rpc::CoreWorkerServiceHandler { /// For actors, this is the current actor ID. To make sure that all caller /// IDs have the same type, we embed the actor ID in a TaskID with the rest /// of the bytes zeroed out. - TaskID GetCallerId() const; + TaskID GetCallerId() const LOCKS_EXCLUDED(mutex_); /// Push an error to the relevant driver. /// @@ -586,6 +588,12 @@ class CoreWorker : public rpc::CoreWorkerServiceHandler { /// \param[out] Status Status KillActor(const ActorID &actor_id, bool force_kill, bool no_reconstruction); + /// Stops the task associated with the given Object ID. + /// + /// \param[in] object_id of the task to kill (must be a Non-Actor task) + /// \param[in] force_kill Whether to force kill a task by killing the worker. + /// \param[out] Status + Status CancelTask(const ObjectID &object_id, bool force_kill); /// Decrease the reference count for this actor. Should be called by the /// language frontend when a reference to the ActorHandle destroyed. /// @@ -694,6 +702,11 @@ class CoreWorker : public rpc::CoreWorkerServiceHandler { void HandleKillActor(const rpc::KillActorRequest &request, rpc::KillActorReply *reply, rpc::SendReplyCallback send_reply_callback) override; + /// Implements gRPC server handler. + void HandleCancelTask(const rpc::CancelTaskRequest &request, + rpc::CancelTaskReply *reply, + rpc::SendReplyCallback send_reply_callback); + /// Implements gRPC server handler. void HandlePlasmaObjectReady(const rpc::PlasmaObjectReadyRequest &request, rpc::PlasmaObjectReadyReply *reply, @@ -893,7 +906,7 @@ class CoreWorker : public rpc::CoreWorkerServiceHandler { /// The ID of the current task being executed by the main thread. If there /// are multiple threads, they will have a thread-local task ID stored in the /// worker context. - TaskID main_thread_task_id_; + TaskID main_thread_task_id_ GUARDED_BY(mutex_); // Flag indicating whether this worker has been shut down. bool shutdown_ = false; diff --git a/src/ray/core_worker/lib/java/io_ray_runtime_RayNativeRuntime.cc b/src/ray/core_worker/lib/java/io_ray_runtime_RayNativeRuntime.cc index e441f0d18..0dbc10ae8 100644 --- a/src/ray/core_worker/lib/java/io_ray_runtime_RayNativeRuntime.cc +++ b/src/ray/core_worker/lib/java/io_ray_runtime_RayNativeRuntime.cc @@ -122,6 +122,7 @@ JNIEXPORT void JNICALL Java_io_ray_runtime_RayNativeRuntime_nativeInitialize( nullptr, // check_signals nullptr, // gc_collect nullptr, // get_lang_stack + nullptr, // kill_main false, // ref_counting_enabled false, // is_local_mode static_cast(numWorkersPerProcess), // num_workers diff --git a/src/ray/core_worker/task_manager.cc b/src/ray/core_worker/task_manager.cc index 9361d7917..4e69e0b43 100644 --- a/src/ray/core_worker/task_manager.cc +++ b/src/ray/core_worker/task_manager.cc @@ -412,6 +412,15 @@ void TaskManager::RemoveLineageReference(const ObjectID &object_id, } } +bool TaskManager::MarkTaskCanceled(const TaskID &task_id) { + absl::MutexLock lock(&mu_); + auto it = submissible_tasks_.find(task_id); + if (it != submissible_tasks_.end()) { + it->second.num_retries_left = 0; + } + return it != submissible_tasks_.end(); +} + void TaskManager::MarkPendingTaskFailed(const TaskID &task_id, const TaskSpecification &spec, rpc::ErrorType error_type) { @@ -432,4 +441,13 @@ void TaskManager::MarkPendingTaskFailed(const TaskID &task_id, } } +absl::optional TaskManager::GetTaskSpec(const TaskID &task_id) const { + absl::MutexLock lock(&mu_); + auto it = submissible_tasks_.find(task_id); + if (it == submissible_tasks_.end()) { + return absl::optional(); + } + return it->second.spec; +} + } // namespace ray diff --git a/src/ray/core_worker/task_manager.h b/src/ray/core_worker/task_manager.h index fea5e18c3..495f8b64a 100644 --- a/src/ray/core_worker/task_manager.h +++ b/src/ray/core_worker/task_manager.h @@ -39,6 +39,8 @@ class TaskFinisherInterface { const std::vector &inlined_dependency_ids, const std::vector &contained_ids) = 0; + virtual bool MarkTaskCanceled(const TaskID &task_id) = 0; + virtual ~TaskFinisherInterface() {} }; @@ -129,6 +131,15 @@ class TaskManager : public TaskFinisherInterface, public TaskResubmissionInterfa void OnTaskDependenciesInlined(const std::vector &inlined_dependency_ids, const std::vector &contained_ids) override; + /// Set number of retries to zero for a task that is being canceled. + /// + /// \param[in] task_id to cancel. + /// \return Whether the task was pending and was marked for cancellation. + bool MarkTaskCanceled(const TaskID &task_id); + + /// Return the spec for a pending task. + absl::optional GetTaskSpec(const TaskID &task_id) const; + /// Return whether this task can be submitted for execution. /// /// \param[in] task_id ID of the task to query. diff --git a/src/ray/core_worker/test/core_worker_test.cc b/src/ray/core_worker/test/core_worker_test.cc index d4c167d63..0d95aabe7 100644 --- a/src/ray/core_worker/test/core_worker_test.cc +++ b/src/ray/core_worker/test/core_worker_test.cc @@ -271,6 +271,7 @@ class CoreWorkerTest : public ::testing::Test { nullptr, // check_signals nullptr, // gc_collect nullptr, // get_lang_stack + nullptr, // kill_main true, // ref_counting_enabled false, // is_local_mode 1, // num_workers diff --git a/src/ray/core_worker/test/direct_actor_transport_test.cc b/src/ray/core_worker/test/direct_actor_transport_test.cc index 50031c74a..80a9065e2 100644 --- a/src/ray/core_worker/test/direct_actor_transport_test.cc +++ b/src/ray/core_worker/test/direct_actor_transport_test.cc @@ -64,6 +64,8 @@ class MockTaskFinisher : public TaskFinisherInterface { MOCK_METHOD2(OnTaskDependenciesInlined, void(const std::vector &, const std::vector &)); + + MOCK_METHOD1(MarkTaskCanceled, bool(const TaskID &task_id)); }; TaskSpecification CreateActorTaskHelper(ActorID actor_id, int64_t counter) { diff --git a/src/ray/core_worker/test/direct_task_transport_test.cc b/src/ray/core_worker/test/direct_task_transport_test.cc index 8146630d0..433841a4c 100644 --- a/src/ray/core_worker/test/direct_task_transport_test.cc +++ b/src/ray/core_worker/test/direct_task_transport_test.cc @@ -52,7 +52,15 @@ class MockWorkerClient : public rpc::CoreWorkerClientInterface { return true; } + ray::Status CancelTask( + const rpc::CancelTaskRequest &request, + const rpc::ClientCallback &callback) override { + kill_requests.push_front(request); + return Status::OK(); + } + std::list> callbacks; + std::list kill_requests; }; class MockTaskFinisher : public TaskFinisherInterface { @@ -75,6 +83,8 @@ class MockTaskFinisher : public TaskFinisherInterface { num_contained_ids += contained_ids.size(); } + bool MarkTaskCanceled(const TaskID &task_id) override { return true; } + int num_tasks_complete = 0; int num_tasks_failed = 0; int num_inlined_dependencies = 0; @@ -935,6 +945,105 @@ TEST(DirectTaskTransportTest, TestWorkerLeaseTimeout) { ASSERT_FALSE(raylet_client->ReplyCancelWorkerLease()); } +TEST(DirectTaskTransportTest, TestKillExecutingTask) { + rpc::Address address; + auto raylet_client = std::make_shared(); + auto worker_client = std::make_shared(); + auto store = std::make_shared(); + auto factory = [&](const rpc::Address &addr) { return worker_client; }; + auto task_finisher = std::make_shared(); + CoreWorkerDirectTaskSubmitter submitter(address, raylet_client, factory, nullptr, store, + task_finisher, ClientID::Nil(), kLongTimeout); + std::unordered_map empty_resources; + ray::FunctionDescriptor empty_descriptor = + ray::FunctionDescriptorBuilder::BuildPython("", "", "", ""); + TaskSpecification task = BuildTaskSpec(empty_resources, empty_descriptor); + + ASSERT_TRUE(submitter.SubmitTask(task).ok()); + ASSERT_TRUE(raylet_client->GrantWorkerLease("localhost", 1234, ClientID::Nil())); + + // Try force kill, exiting the worker + ASSERT_TRUE(submitter.CancelTask(task, true).ok()); + ASSERT_EQ(worker_client->kill_requests.front().intended_task_id(), + task.TaskId().Binary()); + ASSERT_TRUE(worker_client->ReplyPushTask(Status::IOError("workerdying"), true)); + ASSERT_EQ(worker_client->callbacks.size(), 0); + ASSERT_EQ(raylet_client->num_workers_returned, 0); + ASSERT_EQ(raylet_client->num_workers_disconnected, 0); + ASSERT_EQ(task_finisher->num_tasks_complete, 0); + ASSERT_EQ(task_finisher->num_tasks_failed, 1); + + task.GetMutableMessage().set_task_id( + TaskID::ForNormalTask(JobID::Nil(), TaskID::Nil(), 1).Binary()); + ASSERT_TRUE(submitter.SubmitTask(task).ok()); + ASSERT_TRUE(raylet_client->GrantWorkerLease("localhost", 1234, ClientID::Nil())); + + // Try non-force kill, worker returns normally + ASSERT_TRUE(submitter.CancelTask(task, false).ok()); + ASSERT_TRUE(worker_client->ReplyPushTask()); + ASSERT_EQ(worker_client->kill_requests.front().intended_task_id(), + task.TaskId().Binary()); + ASSERT_EQ(worker_client->callbacks.size(), 0); + ASSERT_EQ(raylet_client->num_workers_returned, 1); + ASSERT_EQ(raylet_client->num_workers_disconnected, 0); + ASSERT_EQ(task_finisher->num_tasks_complete, 1); + ASSERT_EQ(task_finisher->num_tasks_failed, 1); +} + +TEST(DirectTaskTransportTest, TestKillPendingTask) { + rpc::Address address; + auto raylet_client = std::make_shared(); + auto worker_client = std::make_shared(); + auto store = std::make_shared(); + auto factory = [&](const rpc::Address &addr) { return worker_client; }; + auto task_finisher = std::make_shared(); + CoreWorkerDirectTaskSubmitter submitter(address, raylet_client, factory, nullptr, store, + task_finisher, ClientID::Nil(), kLongTimeout); + std::unordered_map empty_resources; + ray::FunctionDescriptor empty_descriptor = + ray::FunctionDescriptorBuilder::BuildPython("", "", "", ""); + TaskSpecification task = BuildTaskSpec(empty_resources, empty_descriptor); + + ASSERT_TRUE(submitter.SubmitTask(task).ok()); + ASSERT_TRUE(submitter.CancelTask(task, true).ok()); + ASSERT_EQ(worker_client->kill_requests.size(), 0); + ASSERT_EQ(worker_client->callbacks.size(), 0); + ASSERT_EQ(raylet_client->num_workers_returned, 0); + ASSERT_EQ(raylet_client->num_workers_disconnected, 0); + ASSERT_EQ(task_finisher->num_tasks_complete, 0); + ASSERT_EQ(task_finisher->num_tasks_failed, 1); + ASSERT_EQ(raylet_client->num_leases_canceled, 1); + ASSERT_TRUE(raylet_client->ReplyCancelWorkerLease()); +} + +TEST(DirectTaskTransportTest, TestKillResolvingTask) { + rpc::Address address; + auto raylet_client = std::make_shared(); + auto worker_client = std::make_shared(); + auto store = std::make_shared(); + auto factory = [&](const rpc::Address &addr) { return worker_client; }; + auto task_finisher = std::make_shared(); + CoreWorkerDirectTaskSubmitter submitter(address, raylet_client, factory, nullptr, store, + task_finisher, ClientID::Nil(), kLongTimeout); + std::unordered_map empty_resources; + ray::FunctionDescriptor empty_descriptor = + ray::FunctionDescriptorBuilder::BuildPython("", "", "", ""); + TaskSpecification task = BuildTaskSpec(empty_resources, empty_descriptor); + ObjectID obj1 = ObjectID::FromRandom().WithTransportType(TaskTransportType::DIRECT); + task.GetMutableMessage().add_args()->add_object_ids(obj1.Binary()); + ASSERT_TRUE(submitter.SubmitTask(task).ok()); + ASSERT_EQ(task_finisher->num_inlined_dependencies, 0); + ASSERT_TRUE(submitter.CancelTask(task, true).ok()); + auto data = GenerateRandomObject(); + ASSERT_TRUE(store->Put(*data, obj1)); + ASSERT_EQ(worker_client->kill_requests.size(), 0); + ASSERT_EQ(worker_client->callbacks.size(), 0); + ASSERT_EQ(raylet_client->num_workers_returned, 0); + ASSERT_EQ(raylet_client->num_workers_disconnected, 0); + ASSERT_EQ(task_finisher->num_tasks_complete, 0); + ASSERT_EQ(task_finisher->num_tasks_failed, 1); +} + } // namespace ray int main(int argc, char **argv) { diff --git a/src/ray/core_worker/test/mock_worker.cc b/src/ray/core_worker/test/mock_worker.cc index cdb82187f..71f6fd040 100644 --- a/src/ray/core_worker/test/mock_worker.cc +++ b/src/ray/core_worker/test/mock_worker.cc @@ -54,6 +54,7 @@ class MockWorker { nullptr, // check_signals nullptr, // gc_collect nullptr, // get_lang_stack + nullptr, // kill_main true, // ref_counting_enabled false, // is_local_mode 1, // num_workers diff --git a/src/ray/core_worker/test/task_manager_test.cc b/src/ray/core_worker/test/task_manager_test.cc index 9e9e0d48f..dbeb121a9 100644 --- a/src/ray/core_worker/test/task_manager_test.cc +++ b/src/ray/core_worker/test/task_manager_test.cc @@ -187,6 +187,31 @@ TEST_F(TaskManagerTest, TestTaskRetry) { ASSERT_EQ(reference_counter_->NumObjectIDsInScope(), 0); } +TEST_F(TaskManagerTest, TestTaskKill) { + TaskID caller_id = TaskID::Nil(); + rpc::Address caller_address; + ASSERT_EQ(reference_counter_->NumObjectIDsInScope(), 0); + auto spec = CreateTaskHelper(1, {}); + ASSERT_FALSE(manager_.IsTaskPending(spec.TaskId())); + int num_retries = 3; + manager_.AddPendingTask(caller_id, caller_address, spec, "", num_retries); + ASSERT_TRUE(manager_.IsTaskPending(spec.TaskId())); + ASSERT_EQ(reference_counter_->NumObjectIDsInScope(), 1); + auto return_id = spec.ReturnId(0, TaskTransportType::DIRECT); + WorkerContext ctx(WorkerType::WORKER, WorkerID::FromRandom(), JobID::FromInt(0)); + + manager_.MarkTaskCanceled(spec.TaskId()); + auto error = rpc::ErrorType::TASK_CANCELLED; + manager_.PendingTaskFailed(spec.TaskId(), error); + ASSERT_FALSE(manager_.IsTaskPending(spec.TaskId())); + std::vector> results; + RAY_CHECK_OK(store_->Get({return_id}, 1, 0, ctx, false, &results)); + ASSERT_EQ(results.size(), 1); + rpc::ErrorType stored_error; + ASSERT_TRUE(results[0]->IsException(&stored_error)); + ASSERT_EQ(stored_error, error); +} + // Test to make sure that the task spec and dependencies for an object are // evicted when lineage pinning is disabled in the ReferenceCounter. TEST_F(TaskManagerTest, TestLineageEvicted) { diff --git a/src/ray/core_worker/transport/direct_task_transport.cc b/src/ray/core_worker/transport/direct_task_transport.cc index b5b98b6e4..a9abdeb1e 100644 --- a/src/ray/core_worker/transport/direct_task_transport.cc +++ b/src/ray/core_worker/transport/direct_task_transport.cc @@ -23,17 +23,17 @@ Status CoreWorkerDirectTaskSubmitter::SubmitTask(TaskSpecification task_spec) { resolver_.ResolveDependencies(task_spec, [this, task_spec]() { RAY_LOG(DEBUG) << "Task dependencies resolved " << task_spec.TaskId(); if (actor_create_callback_ && task_spec.IsActorCreationTask()) { - // If gcs actor management is enabled, the actor creation task will be sent to gcs - // server directly after the in-memory dependent objects are resolved. - // For more details please see the protocol of actor management based on gcs. + // If gcs actor management is enabled, the actor creation task will be sent to + // gcs server directly after the in-memory dependent objects are resolved. For + // more details please see the protocol of actor management based on gcs. // https://docs.google.com/document/d/1EAWide-jy05akJp6OMtDn58XOK7bUyruWMia4E-fV28/edit?usp=sharing auto actor_id = task_spec.ActorCreationId(); auto task_id = task_spec.TaskId(); RAY_LOG(INFO) << "Submitting actor creation task to GCS: " << actor_id; auto status = actor_create_callback_(task_spec, [this, actor_id, task_id](Status status) { - // If GCS is failed, GcsRpcClient may receive IOError status but it will not - // trigger this callback, because GcsRpcClient has retry logic at the + // If GCS is failed, GcsRpcClient may receive IOError status but it will + // not trigger this callback, because GcsRpcClient has retry logic at the // bottom. So if this callback is invoked with an error there must be // something wrong with the protocol of gcs-based actor management. // So just check `status.ok()` here. @@ -46,18 +46,33 @@ Status CoreWorkerDirectTaskSubmitter::SubmitTask(TaskSpecification task_spec) { return; } - absl::MutexLock lock(&mu_); - // Note that the dependencies in the task spec are mutated to only contain - // plasma dependencies after ResolveDependencies finishes. - const SchedulingKey scheduling_key( - task_spec.GetSchedulingClass(), task_spec.GetDependencies(), - task_spec.IsActorCreationTask() ? task_spec.ActorCreationId() : ActorID::Nil()); - auto it = task_queues_.find(scheduling_key); - if (it == task_queues_.end()) { - it = task_queues_.emplace(scheduling_key, std::deque()).first; + bool keep_executing = true; + { + absl::MutexLock lock(&mu_); + if (cancelled_tasks_.find(task_spec.TaskId()) != cancelled_tasks_.end()) { + cancelled_tasks_.erase(task_spec.TaskId()); + keep_executing = false; + } + if (keep_executing) { + // Note that the dependencies in the task spec are mutated to only contain + // plasma dependencies after ResolveDependencies finishes. + const SchedulingKey scheduling_key( + task_spec.GetSchedulingClass(), task_spec.GetDependencies(), + task_spec.IsActorCreationTask() ? task_spec.ActorCreationId() + : ActorID::Nil()); + auto it = task_queues_.find(scheduling_key); + if (it == task_queues_.end()) { + it = + task_queues_.emplace(scheduling_key, std::deque()).first; + } + it->second.push_back(task_spec); + RequestNewWorkerIfNeeded(scheduling_key); + } + } + if (!keep_executing) { + task_finisher_->PendingTaskFailed(task_spec.TaskId(), + rpc::ErrorType::TASK_CANCELLED, nullptr); } - it->second.push_back(task_spec); - RequestNewWorkerIfNeeded(scheduling_key); }); return Status::OK(); } @@ -93,8 +108,9 @@ void CoreWorkerDirectTaskSubmitter::OnWorkerIdle( worker_to_lease_client_.erase(addr); } else { auto &client = *client_cache_[addr]; - PushNormalTask(addr, client, scheduling_key, queue_entry->second.front(), - assigned_resources); + auto task_spec = queue_entry->second.front(); + PushNormalTask(addr, client, scheduling_key, task_spec, assigned_resources); + executing_tasks_.emplace(task_spec.TaskId(), addr); queue_entry->second.pop_front(); // Delete the queue if it's now empty. Note that the queue cannot already be empty // because this is the only place tasks are removed from it. @@ -248,6 +264,10 @@ void CoreWorkerDirectTaskSubmitter::PushNormalTask( std::move(request), [this, task_id, is_actor, is_actor_creation, scheduling_key, addr, assigned_resources](Status status, const rpc::PushTaskReply &reply) { + { + absl::MutexLock lock(&mu_); + executing_tasks_.erase(task_id); + } if (reply.worker_exiting()) { // The worker is draining and will shutdown after it is done. Don't return // it to the Raylet since that will kill it early. @@ -273,4 +293,78 @@ void CoreWorkerDirectTaskSubmitter::PushNormalTask( } })); } + +Status CoreWorkerDirectTaskSubmitter::CancelTask(TaskSpecification task_spec, + bool force_kill) { + RAY_LOG(INFO) << "Killing task: " << task_spec.TaskId(); + const SchedulingKey scheduling_key( + task_spec.GetSchedulingClass(), task_spec.GetDependencies(), + task_spec.IsActorCreationTask() ? task_spec.ActorCreationId() : ActorID::Nil()); + std::shared_ptr client = nullptr; + { + absl::MutexLock lock(&mu_); + if (cancelled_tasks_.find(task_spec.TaskId()) != cancelled_tasks_.end() || + !task_finisher_->MarkTaskCanceled(task_spec.TaskId())) { + return Status::OK(); + } + + auto scheduled_tasks = task_queues_.find(scheduling_key); + // This cancels tasks that have completed dependencies and are awaiting + // a worker lease. + if (scheduled_tasks != task_queues_.end()) { + for (auto spec = scheduled_tasks->second.begin(); + spec != scheduled_tasks->second.end(); spec++) { + if (spec->TaskId() == task_spec.TaskId()) { + scheduled_tasks->second.erase(spec); + + if (scheduled_tasks->second.empty()) { + task_queues_.erase(scheduling_key); + CancelWorkerLeaseIfNeeded(scheduling_key); + } + task_finisher_->PendingTaskFailed(task_spec.TaskId(), + rpc::ErrorType::TASK_CANCELLED); + return Status::OK(); + } + } + } + // This will get removed either when the RPC call to cancel is returned + // or when all dependencies are resolved. + RAY_CHECK(cancelled_tasks_.emplace(task_spec.TaskId()).second); + auto rpc_client = executing_tasks_.find(task_spec.TaskId()); + // Looks for an RPC handle for the worker executing the task. + if (rpc_client != executing_tasks_.end() && + client_cache_.find(rpc_client->second) != client_cache_.end()) { + client = client_cache_.find(rpc_client->second)->second; + } + } + + // This case is reached for tasks that have unresolved dependencies. + if (client == nullptr) { + return Status::OK(); + } + + auto request = rpc::CancelTaskRequest(); + request.set_intended_task_id(task_spec.TaskId().Binary()); + request.set_force_kill(force_kill); + RAY_UNUSED(client->CancelTask( + request, [this, task_spec, force_kill](const Status &status, + const rpc::CancelTaskReply &reply) { + absl::MutexLock lock(&mu_); + cancelled_tasks_.erase(task_spec.TaskId()); + if (status.ok() && !reply.attempt_succeeded()) { + if (cancel_retry_timer_.has_value()) { + if (cancel_retry_timer_->expiry().time_since_epoch() <= + std::chrono::high_resolution_clock::now().time_since_epoch()) { + cancel_retry_timer_->expires_after(boost::asio::chrono::milliseconds( + RayConfig::instance().cancellation_retry_ms())); + } + cancel_retry_timer_->async_wait(boost::bind( + &CoreWorkerDirectTaskSubmitter::CancelTask, this, task_spec, force_kill)); + } + } + // Retry is not attempted if !status.ok() because force-kill may kill the worker + // before the reply is sent. + })); + return Status::OK(); +} }; // namespace ray diff --git a/src/ray/core_worker/transport/direct_task_transport.h b/src/ray/core_worker/transport/direct_task_transport.h index 6f4ec27df..2a90477ef 100644 --- a/src/ray/core_worker/transport/direct_task_transport.h +++ b/src/ray/core_worker/transport/direct_task_transport.h @@ -56,7 +56,8 @@ class CoreWorkerDirectTaskSubmitter { std::shared_ptr task_finisher, ClientID local_raylet_id, int64_t lease_timeout_ms, std::function - actor_create_callback = nullptr) + actor_create_callback = nullptr, + absl::optional cancel_timer = absl::nullopt) : rpc_address_(rpc_address), local_lease_client_(lease_client), client_factory_(client_factory), @@ -65,13 +66,20 @@ class CoreWorkerDirectTaskSubmitter { task_finisher_(task_finisher), lease_timeout_ms_(lease_timeout_ms), local_raylet_id_(local_raylet_id), - actor_create_callback_(std::move(actor_create_callback)) {} + actor_create_callback_(std::move(actor_create_callback)), + cancel_retry_timer_(std::move(cancel_timer)) {} /// Schedule a task for direct submission to a worker. /// /// \param[in] task_spec The task to schedule. Status SubmitTask(TaskSpecification task_spec); + /// Either remove a pending task or send an RPC to kill a running task + /// + /// \param[in] task_spec The task to kill. + /// \param[in] force_kill Whether to kill the worker executing the task. + Status CancelTask(TaskSpecification task_spec, bool force_kill); + private: /// Schedule more work onto an idle worker or return it back to the raylet if /// no more tasks are queued for submission. If an error was encountered @@ -180,6 +188,15 @@ class CoreWorkerDirectTaskSubmitter { // Invariant: if a queue is in this map, it has at least one task. absl::flat_hash_map> task_queues_ GUARDED_BY(mu_); + + // Tasks that were cancelled while being resolved. + absl::flat_hash_set cancelled_tasks_ GUARDED_BY(mu_); + + // Keeps track of where currently executing tasks are being run. + absl::flat_hash_map executing_tasks_ GUARDED_BY(mu_); + + // Retries cancelation requests if they were not successful. + absl::optional cancel_retry_timer_; }; }; // namespace ray diff --git a/src/ray/protobuf/core_worker.proto b/src/ray/protobuf/core_worker.proto index 8c927b5c9..501830015 100644 --- a/src/ray/protobuf/core_worker.proto +++ b/src/ray/protobuf/core_worker.proto @@ -111,7 +111,8 @@ message PushTaskRequest { // caller, which might happen theoretically when network has issues. // - For an actor, this is set to the timestamp when the actor is created, // so it can be used to differentiate which is the new reconstructed actor. - // - For a non-actor task, it's set to the timestamp the task starts execution. + // - For a non-actor task, it's set to the timestamp the task starts + // execution. int64 caller_version = 7; } @@ -121,9 +122,11 @@ message PushTaskReply { // Set to true if the worker will be exiting. bool worker_exiting = 2; // The references that the worker borrowed during the task execution. A - // borrower is a process that is currently using the object ID, in one of 3 ways: + // borrower is a process that is currently using the object ID, in one of 3 + // ways: // 1. Has an ObjectID copy in Python. - // 2. Has submitted a task that depends on the object and that is still pending. + // 2. Has submitted a task that depends on the object and that is still + // pending. // 3. Owns another object that is in scope and whose value contains the // ObjectID. // This list includes the reference counts for any IDs that were passed to @@ -156,9 +159,7 @@ message GetObjectStatusRequest { } message GetObjectStatusReply { - enum ObjectStatus { - CREATED = 0; - } + enum ObjectStatus { CREATED = 0; } ObjectStatus status = 1; } @@ -184,6 +185,18 @@ message KillActorRequest { message KillActorReply { } +message CancelTaskRequest { + // ID of task that should be killed. + bytes intended_task_id = 1; + // Whether to kill the worker. + bool force_kill = 2; +} + +message CancelTaskReply { + // Whether the requested task is the currently running task. + bool attempt_succeeded = 1; +} + message GetCoreWorkerStatsRequest { // The ID of the worker this message is intended for. bytes intended_worker_id = 1; @@ -281,6 +294,8 @@ service CoreWorkerService { returns (WaitForObjectEvictionReply); // Request that the worker shut down without completing outstanding work. rpc KillActor(KillActorRequest) returns (KillActorReply); + // Request that a worker cancels a task. + rpc CancelTask(CancelTaskRequest) returns (CancelTaskReply); // Get metrics from core workers. rpc GetCoreWorkerStats(GetCoreWorkerStatsRequest) returns (GetCoreWorkerStatsReply); // Wait for a borrower to finish using an object. Sent by the object's owner. diff --git a/src/ray/protobuf/gcs.proto b/src/ray/protobuf/gcs.proto index 3829a74e9..89c93311c 100644 --- a/src/ray/protobuf/gcs.proto +++ b/src/ray/protobuf/gcs.proto @@ -340,4 +340,7 @@ enum ErrorType { // exposed to user code; it is only used internally to indicate the result of a direct // call has been placed in plasma. OBJECT_IN_PLASMA = 4; + + // Indicates that an object has been cancelled. + TASK_CANCELLED = 5; } diff --git a/src/ray/rpc/worker/core_worker_client.h b/src/ray/rpc/worker/core_worker_client.h index 6f532bdf8..db4ce5ee4 100644 --- a/src/ray/rpc/worker/core_worker_client.h +++ b/src/ray/rpc/worker/core_worker_client.h @@ -157,6 +157,11 @@ class CoreWorkerClientInterface { return Status::NotImplemented(""); } + virtual ray::Status CancelTask(const CancelTaskRequest &request, + const ClientCallback &callback) { + return Status::NotImplemented(""); + } + virtual ray::Status GetCoreWorkerStats( const GetCoreWorkerStatsRequest &request, const ClientCallback &callback) { @@ -210,6 +215,8 @@ class CoreWorkerClient : public std::enable_shared_from_this, RPC_CLIENT_METHOD(CoreWorkerService, KillActor, grpc_client_, override) + RPC_CLIENT_METHOD(CoreWorkerService, CancelTask, grpc_client_, override) + RPC_CLIENT_METHOD(CoreWorkerService, WaitForObjectEviction, grpc_client_, override) RPC_CLIENT_METHOD(CoreWorkerService, GetCoreWorkerStats, grpc_client_, override) diff --git a/src/ray/rpc/worker/core_worker_server.h b/src/ray/rpc/worker/core_worker_server.h index 7ea42f8d0..152c971cb 100644 --- a/src/ray/rpc/worker/core_worker_server.h +++ b/src/ray/rpc/worker/core_worker_server.h @@ -17,7 +17,6 @@ #include "ray/rpc/grpc_server.h" #include "ray/rpc/server_call.h" - #include "src/ray/protobuf/core_worker.grpc.pb.h" #include "src/ray/protobuf/core_worker.pb.h" @@ -36,6 +35,7 @@ namespace rpc { RPC_SERVICE_HANDLER(CoreWorkerService, WaitForObjectEviction) \ RPC_SERVICE_HANDLER(CoreWorkerService, WaitForRefRemoved) \ RPC_SERVICE_HANDLER(CoreWorkerService, KillActor) \ + RPC_SERVICE_HANDLER(CoreWorkerService, CancelTask) \ RPC_SERVICE_HANDLER(CoreWorkerService, GetCoreWorkerStats) \ RPC_SERVICE_HANDLER(CoreWorkerService, LocalGC) \ RPC_SERVICE_HANDLER(CoreWorkerService, PlasmaObjectReady) @@ -48,6 +48,7 @@ namespace rpc { DECLARE_VOID_RPC_SERVICE_HANDLER_METHOD(WaitForObjectEviction) \ DECLARE_VOID_RPC_SERVICE_HANDLER_METHOD(WaitForRefRemoved) \ DECLARE_VOID_RPC_SERVICE_HANDLER_METHOD(KillActor) \ + DECLARE_VOID_RPC_SERVICE_HANDLER_METHOD(CancelTask) \ DECLARE_VOID_RPC_SERVICE_HANDLER_METHOD(GetCoreWorkerStats) \ DECLARE_VOID_RPC_SERVICE_HANDLER_METHOD(LocalGC) \ DECLARE_VOID_RPC_SERVICE_HANDLER_METHOD(PlasmaObjectReady) diff --git a/streaming/src/test/mock_actor.cc b/streaming/src/test/mock_actor.cc index 266a4897b..db482da48 100644 --- a/streaming/src/test/mock_actor.cc +++ b/streaming/src/test/mock_actor.cc @@ -313,6 +313,7 @@ class StreamingWorker { nullptr, // check_signals nullptr, // gc_collect nullptr, // get_lang_stack + nullptr, // kill_main true, // ref_counting_enabled false, // is_local_mode 1, // num_workers diff --git a/streaming/src/test/queue_tests_base.h b/streaming/src/test/queue_tests_base.h index 8e7ef168e..dd03f1c1a 100644 --- a/streaming/src/test/queue_tests_base.h +++ b/streaming/src/test/queue_tests_base.h @@ -326,6 +326,7 @@ class StreamingQueueTestBase : public ::testing::TestWithParam { nullptr, // check_signals nullptr, // gc_collect nullptr, // get_lang_stack + nullptr, // kill_main true, // ref_counting_enabled false, // is_local_mode 1, // num_workers