mirror of
https://github.com/wassname/ray.git
synced 2026-06-30 23:46:50 +08:00
TaskCancellation (#7669)
* Smol comment * WIP, not passing ray.init * Fixed small problem * wip * Pseudo interrupt things * Basic prototype operational * correct proc title * Mostly done * Cleanup * cleaner raylet error * Cleaning up a few loose ends * Fixing Race Conds * Prelim testing * Fixing comments and adding second_check for kill * Working_new_impl * demo_ready * Fixing my english * Fixing a few problems * Small problems * Cleaning up * Response to changes * Fixing error passing * Merged to master * fixing lock * Cleaning up print statements * Format * Fixing Unit test build failure * mock_worker fix * java_fix * Canel * Switching to Cancel * Responding to Review * FixFormatting * Lease cancellation * FInal comments? * Moving exist check to CoreWorker * Fix Actor Transport Test * Fixing task manager test * chaning clock repr * Fix build * fix white space * lint fix * Updating to medium size * Fixing Java test compilation issue * lengthen bad timeouts
This commit is contained in:
@@ -66,6 +66,7 @@ from ray.worker import (
|
||||
LOCAL_MODE,
|
||||
SCRIPT_MODE,
|
||||
WORKER_MODE,
|
||||
cancel,
|
||||
connect,
|
||||
disconnect,
|
||||
get,
|
||||
@@ -113,6 +114,7 @@ __all__ = [
|
||||
"_config",
|
||||
"_get_runtime_context",
|
||||
"actor",
|
||||
"cancel",
|
||||
"connect",
|
||||
"disconnect",
|
||||
"get",
|
||||
|
||||
+39
-6
@@ -17,6 +17,8 @@ import logging
|
||||
import os
|
||||
import pickle
|
||||
import sys
|
||||
import _thread
|
||||
import setproctitle
|
||||
|
||||
from libc.stdint cimport (
|
||||
int32_t,
|
||||
@@ -90,6 +92,7 @@ from ray.exceptions import (
|
||||
RayTaskError,
|
||||
ObjectStoreFullError,
|
||||
RayTimeoutError,
|
||||
RayCancellationError
|
||||
)
|
||||
from ray.utils import decode
|
||||
import gc
|
||||
@@ -453,13 +456,23 @@ cdef execute_task(
|
||||
class_name, repr(args), repr(kwargs))
|
||||
core_worker.set_actor_title(actor_title.encode("utf-8"))
|
||||
# Execute the task.
|
||||
with ray.worker._changeproctitle(title, next_title):
|
||||
with core_worker.profile_event(b"task:execute"):
|
||||
task_exception = True
|
||||
outputs = function_executor(*args, **kwargs)
|
||||
with core_worker.profile_event(b"task:execute"):
|
||||
task_exception = True
|
||||
try:
|
||||
with ray.worker._changeproctitle(title, next_title):
|
||||
outputs = function_executor(*args, **kwargs)
|
||||
task_exception = False
|
||||
if c_return_ids.size() == 1:
|
||||
outputs = (outputs,)
|
||||
except KeyboardInterrupt as e:
|
||||
raise RayCancellationError(
|
||||
core_worker.get_current_task_id())
|
||||
if c_return_ids.size() == 1:
|
||||
outputs = (outputs,)
|
||||
# Check for a cancellation that was called when the function
|
||||
# was exiting and was raised after the except block.
|
||||
if not check_signals().ok():
|
||||
task_exception = True
|
||||
raise RayCancellationError(
|
||||
core_worker.get_current_task_id())
|
||||
# Store the outputs in the object store.
|
||||
with core_worker.profile_event(b"task:store_outputs"):
|
||||
core_worker.store_task_outputs(
|
||||
@@ -551,6 +564,14 @@ cdef void async_plasma_callback(CObjectID object_id,
|
||||
event_handler._loop.call_soon_threadsafe(
|
||||
event_handler._complete_future, obj_id)
|
||||
|
||||
cdef c_bool kill_main_task() nogil:
|
||||
with gil:
|
||||
if setproctitle.getproctitle() != "ray::IDLE":
|
||||
_thread.interrupt_main()
|
||||
return True
|
||||
return False
|
||||
|
||||
|
||||
cdef CRayStatus check_signals() nogil:
|
||||
with gil:
|
||||
try:
|
||||
@@ -658,6 +679,7 @@ cdef class CoreWorker:
|
||||
options.ref_counting_enabled = True
|
||||
options.is_local_mode = local_mode
|
||||
options.num_workers = 1
|
||||
options.kill_main = kill_main_task
|
||||
|
||||
CCoreWorkerProcess.Initialize(options)
|
||||
|
||||
@@ -953,6 +975,17 @@ cdef class CoreWorker:
|
||||
check_status(CCoreWorkerProcess.GetCoreWorker().KillActor(
|
||||
c_actor_id, True, no_reconstruction))
|
||||
|
||||
def cancel_task(self, ObjectID object_id, c_bool force_kill):
|
||||
cdef:
|
||||
CObjectID c_object_id = object_id.native()
|
||||
CRayStatus status = CRayStatus.OK()
|
||||
|
||||
status = CCoreWorkerProcess.GetCoreWorker().CancelTask(
|
||||
c_object_id, force_kill)
|
||||
|
||||
if not status.ok():
|
||||
raise TypeError(status.message().decode())
|
||||
|
||||
def resource_ids(self):
|
||||
cdef:
|
||||
ResourceMappingType resource_mapping = (
|
||||
|
||||
@@ -16,6 +16,23 @@ class RayConnectionError(RayError):
|
||||
pass
|
||||
|
||||
|
||||
class RayCancellationError(RayError):
|
||||
"""Raised when this task is cancelled.
|
||||
|
||||
Attributes:
|
||||
task_id (TaskID): The TaskID of the function that was directly
|
||||
cancelled.
|
||||
"""
|
||||
|
||||
def __init__(self, task_id=None):
|
||||
self.task_id = task_id
|
||||
|
||||
def __str__(self):
|
||||
if self.task_id is None:
|
||||
return "This task or its dependency was cancelled by"
|
||||
return "Task: " + str(self.task_id) + " was cancelled"
|
||||
|
||||
|
||||
class RayTaskError(RayError):
|
||||
"""Indicates that a task threw an exception during execution.
|
||||
|
||||
|
||||
@@ -97,6 +97,7 @@ cdef extern from "ray/core_worker/core_worker.h" nogil:
|
||||
CRayStatus KillActor(
|
||||
const CActorID &actor_id, c_bool force_kill,
|
||||
c_bool no_reconstruction)
|
||||
CRayStatus CancelTask(const CObjectID &object_id, c_bool force_kill)
|
||||
|
||||
unique_ptr[CProfileEvent] CreateProfileEvent(
|
||||
const c_string &event_type)
|
||||
@@ -214,6 +215,7 @@ cdef extern from "ray/core_worker/core_worker.h" nogil:
|
||||
c_bool ref_counting_enabled
|
||||
c_bool is_local_mode
|
||||
int num_workers
|
||||
(c_bool() nogil) kill_main
|
||||
CCoreWorkerOptions()
|
||||
|
||||
cdef cppclass CCoreWorkerProcess "ray::CoreWorkerProcess":
|
||||
|
||||
@@ -122,6 +122,8 @@ cdef extern from "ray/common/id.h" namespace "ray" nogil:
|
||||
CTaskID ForNormalTask(CJobID job_id, CTaskID parent_task_id,
|
||||
int64_t parent_task_counter)
|
||||
|
||||
CActorID ActorId() const
|
||||
|
||||
cdef cppclass CObjectID" ray::ObjectID"(CBaseID[CObjectID]):
|
||||
|
||||
@staticmethod
|
||||
|
||||
@@ -220,6 +220,9 @@ cdef class TaskID(BaseID):
|
||||
def is_nil(self):
|
||||
return self.data.IsNil()
|
||||
|
||||
def actor_id(self):
|
||||
return ActorID(self.data.ActorId().Binary())
|
||||
|
||||
cdef size_t hash(self):
|
||||
return self.data.Hash()
|
||||
|
||||
|
||||
@@ -12,6 +12,7 @@ from ray.exceptions import (
|
||||
PlasmaObjectNotAvailable,
|
||||
RayTaskError,
|
||||
RayActorError,
|
||||
RayCancellationError,
|
||||
RayWorkerError,
|
||||
UnreconstructableError,
|
||||
)
|
||||
@@ -279,6 +280,8 @@ class SerializationContext:
|
||||
return RayWorkerError()
|
||||
elif error_type == ErrorType.Value("ACTOR_DIED"):
|
||||
return RayActorError()
|
||||
elif error_type == ErrorType.Value("TASK_CANCELLED"):
|
||||
return RayCancellationError()
|
||||
elif error_type == ErrorType.Value("OBJECT_UNRECONSTRUCTABLE"):
|
||||
return UnreconstructableError(ray.ObjectID(object_id.binary()))
|
||||
else:
|
||||
|
||||
@@ -414,3 +414,11 @@ py_test(
|
||||
tags = ["exclusive"],
|
||||
deps = ["//:ray_lib"],
|
||||
)
|
||||
|
||||
py_test(
|
||||
name = "test_cancel",
|
||||
size = "medium",
|
||||
srcs = ["test_cancel.py"],
|
||||
tags = ["exclusive"],
|
||||
deps = ["//:ray_lib"],
|
||||
)
|
||||
|
||||
@@ -0,0 +1,232 @@
|
||||
import pytest
|
||||
import ray
|
||||
import random
|
||||
import sys
|
||||
import time
|
||||
from ray.exceptions import RayTaskError, RayTimeoutError, \
|
||||
RayCancellationError, RayWorkerError
|
||||
from ray.test_utils import SignalActor
|
||||
|
||||
|
||||
def valid_exceptions(use_force):
|
||||
if use_force:
|
||||
return (RayTaskError, RayCancellationError, RayWorkerError)
|
||||
else:
|
||||
return (RayTaskError, RayCancellationError)
|
||||
|
||||
|
||||
@pytest.mark.parametrize("use_force", [True, False])
|
||||
def test_cancel_chain(ray_start_regular, use_force):
|
||||
signaler = SignalActor.remote()
|
||||
|
||||
@ray.remote
|
||||
def wait_for(t):
|
||||
return ray.get(t[0])
|
||||
|
||||
obj1 = wait_for.remote([signaler.wait.remote()])
|
||||
obj2 = wait_for.remote([obj1])
|
||||
obj3 = wait_for.remote([obj2])
|
||||
obj4 = wait_for.remote([obj3])
|
||||
|
||||
assert len(ray.wait([obj1], timeout=.1)[0]) == 0
|
||||
ray.cancel(obj1, use_force)
|
||||
for ob in [obj1, obj2, obj3, obj4]:
|
||||
with pytest.raises(valid_exceptions(use_force)):
|
||||
ray.get(ob)
|
||||
|
||||
signaler2 = SignalActor.remote()
|
||||
obj1 = wait_for.remote([signaler2.wait.remote()])
|
||||
obj2 = wait_for.remote([obj1])
|
||||
obj3 = wait_for.remote([obj2])
|
||||
obj4 = wait_for.remote([obj3])
|
||||
|
||||
assert len(ray.wait([obj3], timeout=.1)[0]) == 0
|
||||
ray.cancel(obj3, use_force)
|
||||
for ob in [obj3, obj4]:
|
||||
with pytest.raises(valid_exceptions(use_force)):
|
||||
ray.get(ob)
|
||||
|
||||
with pytest.raises(RayTimeoutError):
|
||||
ray.get(obj1, timeout=.1)
|
||||
|
||||
with pytest.raises(RayTimeoutError):
|
||||
ray.get(obj2, timeout=.1)
|
||||
|
||||
signaler2.send.remote()
|
||||
ray.get(obj1, timeout=10)
|
||||
|
||||
|
||||
@pytest.mark.parametrize("use_force", [True, False])
|
||||
def test_cancel_multiple_dependents(ray_start_regular, use_force):
|
||||
signaler = SignalActor.remote()
|
||||
|
||||
@ray.remote
|
||||
def wait_for(t):
|
||||
return ray.get(t[0])
|
||||
|
||||
head = wait_for.remote([signaler.wait.remote()])
|
||||
deps = []
|
||||
for _ in range(3):
|
||||
deps.append(wait_for.remote([head]))
|
||||
|
||||
assert len(ray.wait([head], timeout=.1)[0]) == 0
|
||||
ray.cancel(head, use_force)
|
||||
for d in deps:
|
||||
with pytest.raises(valid_exceptions(use_force)):
|
||||
ray.get(d)
|
||||
|
||||
head2 = wait_for.remote([signaler.wait.remote()])
|
||||
|
||||
deps2 = []
|
||||
for _ in range(3):
|
||||
deps2.append(wait_for.remote([head]))
|
||||
|
||||
for d in deps2:
|
||||
ray.cancel(d, use_force)
|
||||
|
||||
for d in deps2:
|
||||
with pytest.raises(valid_exceptions(use_force)):
|
||||
ray.get(d)
|
||||
|
||||
signaler.send.remote()
|
||||
ray.get(head2, timeout=1)
|
||||
|
||||
|
||||
@pytest.mark.parametrize("use_force", [True, False])
|
||||
def test_single_cpu_cancel(shutdown_only, use_force):
|
||||
ray.init(num_cpus=1)
|
||||
signaler = SignalActor.remote()
|
||||
|
||||
@ray.remote
|
||||
def wait_for(t):
|
||||
return ray.get(t[0])
|
||||
|
||||
obj1 = wait_for.remote([signaler.wait.remote()])
|
||||
obj2 = wait_for.remote([obj1])
|
||||
obj3 = wait_for.remote([obj2])
|
||||
indep = wait_for.remote([signaler.wait.remote()])
|
||||
|
||||
assert len(ray.wait([obj3], timeout=.1)[0]) == 0
|
||||
ray.cancel(obj3, use_force)
|
||||
with pytest.raises(valid_exceptions(use_force)):
|
||||
ray.get(obj3, 10)
|
||||
|
||||
ray.cancel(obj1, use_force)
|
||||
|
||||
for d in [obj1, obj2]:
|
||||
with pytest.raises(valid_exceptions(use_force)):
|
||||
ray.get(d)
|
||||
|
||||
signaler.send.remote()
|
||||
ray.get(indep)
|
||||
|
||||
|
||||
@pytest.mark.parametrize("use_force", [True, False])
|
||||
def test_comprehensive(ray_start_regular, use_force):
|
||||
signaler = SignalActor.remote()
|
||||
|
||||
@ray.remote
|
||||
def wait_for(t):
|
||||
ray.get(t[0])
|
||||
return "Result"
|
||||
|
||||
@ray.remote
|
||||
def combine(a, b):
|
||||
return str(a) + str(b)
|
||||
|
||||
a = wait_for.remote([signaler.wait.remote()])
|
||||
b = wait_for.remote([signaler.wait.remote()])
|
||||
combo = combine.remote(a, b)
|
||||
a2 = wait_for.remote([a])
|
||||
|
||||
assert len(ray.wait([a, b, a2, combo], timeout=1)[0]) == 0
|
||||
|
||||
ray.cancel(a, use_force)
|
||||
with pytest.raises(valid_exceptions(use_force)):
|
||||
ray.get(a, 10)
|
||||
|
||||
with pytest.raises(valid_exceptions(use_force)):
|
||||
ray.get(a2, 10)
|
||||
|
||||
signaler.send.remote()
|
||||
|
||||
with pytest.raises(valid_exceptions(use_force)):
|
||||
ray.get(combo, 10)
|
||||
|
||||
|
||||
@pytest.mark.parametrize("use_force", [True, False])
|
||||
def test_stress(shutdown_only, use_force):
|
||||
ray.init(num_cpus=1)
|
||||
|
||||
@ray.remote
|
||||
def infinite_sleep(y):
|
||||
if y:
|
||||
while True:
|
||||
time.sleep(1 / 10)
|
||||
|
||||
first = infinite_sleep.remote(True)
|
||||
|
||||
sleep_or_no = [random.randint(0, 1) for _ in range(100)]
|
||||
tasks = [infinite_sleep.remote(i) for i in sleep_or_no]
|
||||
cancelled = set()
|
||||
for t in tasks:
|
||||
if random.random() > 0.5:
|
||||
ray.cancel(t, use_force)
|
||||
cancelled.add(t)
|
||||
|
||||
ray.cancel(first, use_force)
|
||||
cancelled.add(first)
|
||||
|
||||
for done in cancelled:
|
||||
with pytest.raises(valid_exceptions(use_force)):
|
||||
ray.get(done, 10)
|
||||
|
||||
for indx in range(len(tasks)):
|
||||
t = tasks[indx]
|
||||
if sleep_or_no[indx]:
|
||||
ray.cancel(t, use_force)
|
||||
cancelled.add(t)
|
||||
if t in cancelled:
|
||||
with pytest.raises(valid_exceptions(use_force)):
|
||||
ray.get(t, 10)
|
||||
else:
|
||||
ray.get(t)
|
||||
|
||||
|
||||
@pytest.mark.parametrize("use_force", [True, False])
|
||||
def test_fast(shutdown_only, use_force):
|
||||
ray.init(num_cpus=2)
|
||||
|
||||
@ray.remote
|
||||
def fast(y):
|
||||
return y
|
||||
|
||||
signaler = SignalActor.remote()
|
||||
ids = list()
|
||||
for _ in range(100):
|
||||
x = fast.remote("a")
|
||||
ray.cancel(x)
|
||||
ids.append(x)
|
||||
|
||||
@ray.remote
|
||||
def wait_for(y):
|
||||
return y
|
||||
|
||||
sig = signaler.wait.remote()
|
||||
for _ in range(5000):
|
||||
x = wait_for.remote(sig)
|
||||
ids.append(x)
|
||||
|
||||
for idx in range(100, 5100):
|
||||
if random.random() > 0.95:
|
||||
ray.cancel(ids[idx])
|
||||
signaler.send.remote()
|
||||
for obj_id in ids:
|
||||
try:
|
||||
ray.get(obj_id, 10)
|
||||
except Exception as e:
|
||||
assert isinstance(e, valid_exceptions(use_force))
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
sys.exit(pytest.main(["-v", __file__]))
|
||||
+22
-1
@@ -1661,12 +1661,33 @@ def kill(actor):
|
||||
if not isinstance(actor, ray.actor.ActorHandle):
|
||||
raise ValueError("ray.kill() only supported for actors. "
|
||||
"Got: {}.".format(type(actor)))
|
||||
|
||||
worker = ray.worker.global_worker
|
||||
worker.check_connected()
|
||||
worker.core_worker.kill_actor(actor._ray_actor_id, False)
|
||||
|
||||
|
||||
def cancel(object_id, force=False):
|
||||
"""Kill a task forcefully.
|
||||
|
||||
This will interrupt any running tasks on the actor, causing them to fail
|
||||
immediately. Any atexit handlers installed in the actor will still be run.
|
||||
|
||||
If this actor is reconstructable, it will be attempted to be reconstructed.
|
||||
|
||||
Args:
|
||||
id (ActorHandle or ObjectID): Handle for the actor to kill or ObjectID
|
||||
of the task to kill.
|
||||
"""
|
||||
worker = ray.worker.global_worker
|
||||
worker.check_connected()
|
||||
|
||||
if not isinstance(object_id, ray.ObjectID):
|
||||
raise TypeError(
|
||||
"ray.cancel() only supported for non-actor object IDs. "
|
||||
"Got: {}.".format(type(object_id)))
|
||||
return worker.core_worker.cancel_task(object_id, force)
|
||||
|
||||
|
||||
def _mode(worker=global_worker):
|
||||
"""This is a wrapper around worker.mode.
|
||||
|
||||
|
||||
Reference in New Issue
Block a user