TaskCancellation (#7669)

* Smol comment

* WIP, not passing ray.init

* Fixed small problem

* wip

* Pseudo interrupt things

* Basic prototype operational

* correct proc title

* Mostly done

* Cleanup

* cleaner raylet error

* Cleaning up a few loose ends

* Fixing Race Conds

* Prelim testing

* Fixing comments and adding second_check for kill

* Working_new_impl

* demo_ready

* Fixing my english

* Fixing a few problems

* Small problems

* Cleaning up

* Response to changes

* Fixing error passing

* Merged to master

* fixing lock

* Cleaning up print statements

* Format

* Fixing Unit test build failure

* mock_worker fix

* java_fix

* Canel

* Switching to Cancel

* Responding to Review

* FixFormatting

* Lease cancellation

* FInal comments?

* Moving exist check to CoreWorker

* Fix Actor Transport Test

* Fixing task manager test

* chaning clock repr

* Fix build

* fix white space

* lint fix

* Updating to medium size

* Fixing Java test compilation issue

* lengthen bad timeouts
This commit is contained in:
ijrsvt
2020-04-25 16:04:52 -07:00
committed by GitHub
parent 9dd3490c38
commit 69ff7e3e35
29 changed files with 731 additions and 38 deletions
+2
View File
@@ -66,6 +66,7 @@ from ray.worker import (
LOCAL_MODE,
SCRIPT_MODE,
WORKER_MODE,
cancel,
connect,
disconnect,
get,
@@ -113,6 +114,7 @@ __all__ = [
"_config",
"_get_runtime_context",
"actor",
"cancel",
"connect",
"disconnect",
"get",
+39 -6
View File
@@ -17,6 +17,8 @@ import logging
import os
import pickle
import sys
import _thread
import setproctitle
from libc.stdint cimport (
int32_t,
@@ -90,6 +92,7 @@ from ray.exceptions import (
RayTaskError,
ObjectStoreFullError,
RayTimeoutError,
RayCancellationError
)
from ray.utils import decode
import gc
@@ -453,13 +456,23 @@ cdef execute_task(
class_name, repr(args), repr(kwargs))
core_worker.set_actor_title(actor_title.encode("utf-8"))
# Execute the task.
with ray.worker._changeproctitle(title, next_title):
with core_worker.profile_event(b"task:execute"):
task_exception = True
outputs = function_executor(*args, **kwargs)
with core_worker.profile_event(b"task:execute"):
task_exception = True
try:
with ray.worker._changeproctitle(title, next_title):
outputs = function_executor(*args, **kwargs)
task_exception = False
if c_return_ids.size() == 1:
outputs = (outputs,)
except KeyboardInterrupt as e:
raise RayCancellationError(
core_worker.get_current_task_id())
if c_return_ids.size() == 1:
outputs = (outputs,)
# Check for a cancellation that was called when the function
# was exiting and was raised after the except block.
if not check_signals().ok():
task_exception = True
raise RayCancellationError(
core_worker.get_current_task_id())
# Store the outputs in the object store.
with core_worker.profile_event(b"task:store_outputs"):
core_worker.store_task_outputs(
@@ -551,6 +564,14 @@ cdef void async_plasma_callback(CObjectID object_id,
event_handler._loop.call_soon_threadsafe(
event_handler._complete_future, obj_id)
cdef c_bool kill_main_task() nogil:
with gil:
if setproctitle.getproctitle() != "ray::IDLE":
_thread.interrupt_main()
return True
return False
cdef CRayStatus check_signals() nogil:
with gil:
try:
@@ -658,6 +679,7 @@ cdef class CoreWorker:
options.ref_counting_enabled = True
options.is_local_mode = local_mode
options.num_workers = 1
options.kill_main = kill_main_task
CCoreWorkerProcess.Initialize(options)
@@ -953,6 +975,17 @@ cdef class CoreWorker:
check_status(CCoreWorkerProcess.GetCoreWorker().KillActor(
c_actor_id, True, no_reconstruction))
def cancel_task(self, ObjectID object_id, c_bool force_kill):
cdef:
CObjectID c_object_id = object_id.native()
CRayStatus status = CRayStatus.OK()
status = CCoreWorkerProcess.GetCoreWorker().CancelTask(
c_object_id, force_kill)
if not status.ok():
raise TypeError(status.message().decode())
def resource_ids(self):
cdef:
ResourceMappingType resource_mapping = (
+17
View File
@@ -16,6 +16,23 @@ class RayConnectionError(RayError):
pass
class RayCancellationError(RayError):
"""Raised when this task is cancelled.
Attributes:
task_id (TaskID): The TaskID of the function that was directly
cancelled.
"""
def __init__(self, task_id=None):
self.task_id = task_id
def __str__(self):
if self.task_id is None:
return "This task or its dependency was cancelled by"
return "Task: " + str(self.task_id) + " was cancelled"
class RayTaskError(RayError):
"""Indicates that a task threw an exception during execution.
+2
View File
@@ -97,6 +97,7 @@ cdef extern from "ray/core_worker/core_worker.h" nogil:
CRayStatus KillActor(
const CActorID &actor_id, c_bool force_kill,
c_bool no_reconstruction)
CRayStatus CancelTask(const CObjectID &object_id, c_bool force_kill)
unique_ptr[CProfileEvent] CreateProfileEvent(
const c_string &event_type)
@@ -214,6 +215,7 @@ cdef extern from "ray/core_worker/core_worker.h" nogil:
c_bool ref_counting_enabled
c_bool is_local_mode
int num_workers
(c_bool() nogil) kill_main
CCoreWorkerOptions()
cdef cppclass CCoreWorkerProcess "ray::CoreWorkerProcess":
+2
View File
@@ -122,6 +122,8 @@ cdef extern from "ray/common/id.h" namespace "ray" nogil:
CTaskID ForNormalTask(CJobID job_id, CTaskID parent_task_id,
int64_t parent_task_counter)
CActorID ActorId() const
cdef cppclass CObjectID" ray::ObjectID"(CBaseID[CObjectID]):
@staticmethod
+3
View File
@@ -220,6 +220,9 @@ cdef class TaskID(BaseID):
def is_nil(self):
return self.data.IsNil()
def actor_id(self):
return ActorID(self.data.ActorId().Binary())
cdef size_t hash(self):
return self.data.Hash()
+3
View File
@@ -12,6 +12,7 @@ from ray.exceptions import (
PlasmaObjectNotAvailable,
RayTaskError,
RayActorError,
RayCancellationError,
RayWorkerError,
UnreconstructableError,
)
@@ -279,6 +280,8 @@ class SerializationContext:
return RayWorkerError()
elif error_type == ErrorType.Value("ACTOR_DIED"):
return RayActorError()
elif error_type == ErrorType.Value("TASK_CANCELLED"):
return RayCancellationError()
elif error_type == ErrorType.Value("OBJECT_UNRECONSTRUCTABLE"):
return UnreconstructableError(ray.ObjectID(object_id.binary()))
else:
+8
View File
@@ -414,3 +414,11 @@ py_test(
tags = ["exclusive"],
deps = ["//:ray_lib"],
)
py_test(
name = "test_cancel",
size = "medium",
srcs = ["test_cancel.py"],
tags = ["exclusive"],
deps = ["//:ray_lib"],
)
+232
View File
@@ -0,0 +1,232 @@
import pytest
import ray
import random
import sys
import time
from ray.exceptions import RayTaskError, RayTimeoutError, \
RayCancellationError, RayWorkerError
from ray.test_utils import SignalActor
def valid_exceptions(use_force):
if use_force:
return (RayTaskError, RayCancellationError, RayWorkerError)
else:
return (RayTaskError, RayCancellationError)
@pytest.mark.parametrize("use_force", [True, False])
def test_cancel_chain(ray_start_regular, use_force):
signaler = SignalActor.remote()
@ray.remote
def wait_for(t):
return ray.get(t[0])
obj1 = wait_for.remote([signaler.wait.remote()])
obj2 = wait_for.remote([obj1])
obj3 = wait_for.remote([obj2])
obj4 = wait_for.remote([obj3])
assert len(ray.wait([obj1], timeout=.1)[0]) == 0
ray.cancel(obj1, use_force)
for ob in [obj1, obj2, obj3, obj4]:
with pytest.raises(valid_exceptions(use_force)):
ray.get(ob)
signaler2 = SignalActor.remote()
obj1 = wait_for.remote([signaler2.wait.remote()])
obj2 = wait_for.remote([obj1])
obj3 = wait_for.remote([obj2])
obj4 = wait_for.remote([obj3])
assert len(ray.wait([obj3], timeout=.1)[0]) == 0
ray.cancel(obj3, use_force)
for ob in [obj3, obj4]:
with pytest.raises(valid_exceptions(use_force)):
ray.get(ob)
with pytest.raises(RayTimeoutError):
ray.get(obj1, timeout=.1)
with pytest.raises(RayTimeoutError):
ray.get(obj2, timeout=.1)
signaler2.send.remote()
ray.get(obj1, timeout=10)
@pytest.mark.parametrize("use_force", [True, False])
def test_cancel_multiple_dependents(ray_start_regular, use_force):
signaler = SignalActor.remote()
@ray.remote
def wait_for(t):
return ray.get(t[0])
head = wait_for.remote([signaler.wait.remote()])
deps = []
for _ in range(3):
deps.append(wait_for.remote([head]))
assert len(ray.wait([head], timeout=.1)[0]) == 0
ray.cancel(head, use_force)
for d in deps:
with pytest.raises(valid_exceptions(use_force)):
ray.get(d)
head2 = wait_for.remote([signaler.wait.remote()])
deps2 = []
for _ in range(3):
deps2.append(wait_for.remote([head]))
for d in deps2:
ray.cancel(d, use_force)
for d in deps2:
with pytest.raises(valid_exceptions(use_force)):
ray.get(d)
signaler.send.remote()
ray.get(head2, timeout=1)
@pytest.mark.parametrize("use_force", [True, False])
def test_single_cpu_cancel(shutdown_only, use_force):
ray.init(num_cpus=1)
signaler = SignalActor.remote()
@ray.remote
def wait_for(t):
return ray.get(t[0])
obj1 = wait_for.remote([signaler.wait.remote()])
obj2 = wait_for.remote([obj1])
obj3 = wait_for.remote([obj2])
indep = wait_for.remote([signaler.wait.remote()])
assert len(ray.wait([obj3], timeout=.1)[0]) == 0
ray.cancel(obj3, use_force)
with pytest.raises(valid_exceptions(use_force)):
ray.get(obj3, 10)
ray.cancel(obj1, use_force)
for d in [obj1, obj2]:
with pytest.raises(valid_exceptions(use_force)):
ray.get(d)
signaler.send.remote()
ray.get(indep)
@pytest.mark.parametrize("use_force", [True, False])
def test_comprehensive(ray_start_regular, use_force):
signaler = SignalActor.remote()
@ray.remote
def wait_for(t):
ray.get(t[0])
return "Result"
@ray.remote
def combine(a, b):
return str(a) + str(b)
a = wait_for.remote([signaler.wait.remote()])
b = wait_for.remote([signaler.wait.remote()])
combo = combine.remote(a, b)
a2 = wait_for.remote([a])
assert len(ray.wait([a, b, a2, combo], timeout=1)[0]) == 0
ray.cancel(a, use_force)
with pytest.raises(valid_exceptions(use_force)):
ray.get(a, 10)
with pytest.raises(valid_exceptions(use_force)):
ray.get(a2, 10)
signaler.send.remote()
with pytest.raises(valid_exceptions(use_force)):
ray.get(combo, 10)
@pytest.mark.parametrize("use_force", [True, False])
def test_stress(shutdown_only, use_force):
ray.init(num_cpus=1)
@ray.remote
def infinite_sleep(y):
if y:
while True:
time.sleep(1 / 10)
first = infinite_sleep.remote(True)
sleep_or_no = [random.randint(0, 1) for _ in range(100)]
tasks = [infinite_sleep.remote(i) for i in sleep_or_no]
cancelled = set()
for t in tasks:
if random.random() > 0.5:
ray.cancel(t, use_force)
cancelled.add(t)
ray.cancel(first, use_force)
cancelled.add(first)
for done in cancelled:
with pytest.raises(valid_exceptions(use_force)):
ray.get(done, 10)
for indx in range(len(tasks)):
t = tasks[indx]
if sleep_or_no[indx]:
ray.cancel(t, use_force)
cancelled.add(t)
if t in cancelled:
with pytest.raises(valid_exceptions(use_force)):
ray.get(t, 10)
else:
ray.get(t)
@pytest.mark.parametrize("use_force", [True, False])
def test_fast(shutdown_only, use_force):
ray.init(num_cpus=2)
@ray.remote
def fast(y):
return y
signaler = SignalActor.remote()
ids = list()
for _ in range(100):
x = fast.remote("a")
ray.cancel(x)
ids.append(x)
@ray.remote
def wait_for(y):
return y
sig = signaler.wait.remote()
for _ in range(5000):
x = wait_for.remote(sig)
ids.append(x)
for idx in range(100, 5100):
if random.random() > 0.95:
ray.cancel(ids[idx])
signaler.send.remote()
for obj_id in ids:
try:
ray.get(obj_id, 10)
except Exception as e:
assert isinstance(e, valid_exceptions(use_force))
if __name__ == "__main__":
sys.exit(pytest.main(["-v", __file__]))
+22 -1
View File
@@ -1661,12 +1661,33 @@ def kill(actor):
if not isinstance(actor, ray.actor.ActorHandle):
raise ValueError("ray.kill() only supported for actors. "
"Got: {}.".format(type(actor)))
worker = ray.worker.global_worker
worker.check_connected()
worker.core_worker.kill_actor(actor._ray_actor_id, False)
def cancel(object_id, force=False):
"""Kill a task forcefully.
This will interrupt any running tasks on the actor, causing them to fail
immediately. Any atexit handlers installed in the actor will still be run.
If this actor is reconstructable, it will be attempted to be reconstructed.
Args:
id (ActorHandle or ObjectID): Handle for the actor to kill or ObjectID
of the task to kill.
"""
worker = ray.worker.global_worker
worker.check_connected()
if not isinstance(object_id, ray.ObjectID):
raise TypeError(
"ray.cancel() only supported for non-actor object IDs. "
"Got: {}.".format(type(object_id)))
return worker.core_worker.cancel_task(object_id, force)
def _mode(worker=global_worker):
"""This is a wrapper around worker.mode.