[tune] catch SIGINT signal and trigger experiment checkpoint (#13767)

* [tune] catch SIGINT signal and trigger experiment checkpoint

* Apply suggestions from code review

* Fix user guide docs

* Update doc/source/tune/user-guide.rst
This commit is contained in:
Kai Fricke
2021-02-02 14:52:09 +01:00
committed by GitHub
parent b9c15a2551
commit d29fcfb45c
3 changed files with 151 additions and 3 deletions
+59
View File
@@ -261,6 +261,7 @@ You can restore a single trial checkpoint by using ``tune.run(restore=<checkpoin
config={"env": "CartPole-v0"},
)
Distributed Checkpointing
~~~~~~~~~~~~~~~~~~~~~~~~~
@@ -277,6 +278,60 @@ disable cross-node syncing:
tune.run(func, sync_config=sync_config)
Stopping and resuming a tuning run
----------------------------------
Ray Tune periodically checkpoints the experiment state so that it can be
restarted when it fails or stops. The checkpointing period is
dynamically adjusted so that at least 95% of the time is used for handling
training results and scheduling.
If you send a SIGINT signal to the process running ``tune.run()`` (which is
usually what happens when you press Ctrl+C in the console), Ray Tune shuts
down training gracefully and saves a final experiment-level checkpoint. You
can then call ``tune.run()`` with ``resume=True`` to continue this run in
the future:
.. code-block:: python
:emphasize-lines: 14
tune.run(
train,
# ...
name="my_experiment"
)
# This is interrupted e.g. by sending a SIGINT signal
# Next time, continue the run like so:
tune.run(
train,
# ...
name="my_experiment",
resume=True
)
You will have to pass a ``name`` if you are using ``resume=True`` so that
Ray Tune can detect the experiment folder (which is usually stored at e.g.
``~/ray_results/my_experiment``). If you forgot to pass a name in the first
call, you can still pass the name when you resume the run. Please note that
in this case it is likely that your experiment name has a date suffix, so if you
ran ``tune.run(my_trainable)``, the ``name`` might look like something like this:
``my_trainable_2021-01-29_10-16-44``.
You can see which name you need to pass by taking a look at the results table
of your original tuning run:
.. code-block::
:emphasize-lines: 5
== Status ==
Memory usage on this node: 11.0/16.0 GiB
Using FIFO scheduling algorithm.
Resources requested: 1/16 CPUs, 0/0 GPUs, 0.0/4.69 GiB heap, 0.0/1.61 GiB objects
Result logdir: /Users/ray/ray_results/my_trainable_2021-01-29_10-16-44
Number of trials: 1/1 (1 RUNNING)
Handling Large Datasets
-----------------------
@@ -682,6 +737,10 @@ These are the environment variables Ray Tune currently considers:
or a search algorithm, Tune will error
if the metric was not reported in the result. Setting this environment variable
to ``1`` will disable this check.
* **TUNE_DISABLE_SIGINT_HANDLER**: Ray Tune catches SIGINT signals (e.g. sent by
Ctrl+C) to gracefully shutdown and do a final checkpoint. Setting this variable
to ``1`` will disable signal handling and stop execution right away. Defaults to
``0``.
* **TUNE_FUNCTION_THREAD_TIMEOUT_S**: Time in seconds the function API waits
for threads to finish after instructing them to complete. Defaults to ``2``.
* **TUNE_GLOBAL_CHECKPOINT_S**: Time in seconds that limits how often Tune's
@@ -1,8 +1,10 @@
# coding: utf-8
import signal
from collections import Counter
import os
import shutil
import tempfile
import time
import unittest
import skopt
import numpy as np
@@ -87,6 +89,66 @@ class TuneRestoreTest(unittest.TestCase):
self.assertTrue(os.path.isfile(self.checkpoint_path))
class TuneInterruptionTest(unittest.TestCase):
def testExperimentInterrupted(self):
import multiprocessing
trainer_semaphore = multiprocessing.Semaphore()
driver_semaphore = multiprocessing.Semaphore()
class SteppingCallback(Callback):
def on_step_end(self, iteration, trials, **info):
driver_semaphore.release() # Driver should continue
trainer_semaphore.acquire() # Wait until released
def _run(local_dir):
def _train(config):
for i in range(7):
tune.report(val=i)
tune.run(
_train,
local_dir=local_dir,
name="interrupt",
callbacks=[SteppingCallback()])
local_dir = tempfile.mkdtemp()
process = multiprocessing.Process(target=_run, args=(local_dir, ))
process.daemon = False
process.start()
exp_dir = os.path.join(local_dir, "interrupt")
# Skip first five steps
for i in range(5):
driver_semaphore.acquire() # Wait for callback
trainer_semaphore.release() # Continue training
driver_semaphore.acquire()
experiment_state_file = None
for file in os.listdir(exp_dir):
if file.startswith("experiment_state"):
experiment_state_file = os.path.join(exp_dir, file)
break
self.assertTrue(experiment_state_file)
last_mtime = os.path.getmtime(experiment_state_file)
# Now send kill signal
os.kill(process.pid, signal.SIGINT)
# Release trainer. It should handle the signal and try to
# checkpoint the experiment
trainer_semaphore.release()
time.sleep(2) # Wait for checkpoint
new_mtime = os.path.getmtime(experiment_state_file)
self.assertNotEqual(last_mtime, new_mtime)
shutil.rmtree(local_dir)
class TuneFailResumeGridTest(unittest.TestCase):
class FailureInjectorCallback(Callback):
"""Adds random failure injection to the TrialExecutor."""
+30 -3
View File
@@ -3,6 +3,8 @@ from typing import Any, Callable, Dict, Mapping, Optional, Sequence, Type, \
import datetime
import logging
import os
import signal
import sys
import time
@@ -112,6 +114,10 @@ def run(
) -> ExperimentAnalysis:
"""Executes training.
When a SIGINT signal is received (e.g. through Ctrl+C), the tuning run
will gracefully shut down and checkpoint the latest experiment state.
Sending SIGINT again (or SIGKILL/SIGTERM instead) will skip this step.
Examples:
.. code-block:: python
@@ -265,7 +271,6 @@ def run(
`LoggerCallback` and `SyncerCallback` callbacks are automatically
added.
Returns:
ExperimentAnalysis: Object for experiment analysis.
@@ -427,8 +432,24 @@ def run(
"`Trainable.default_resource_request` if using the "
"Trainable API.")
original_handler = signal.getsignal(signal.SIGINT)
state = {signal.SIGINT: False}
def sigint_handler(sig, frame):
logger.warning(
"SIGINT received (e.g. via Ctrl+C), ending Ray Tune run. "
"This will try to checkpoint the experiment state one last time. "
"Press CTRL+C one more time (or send SIGINT/SIGKILL/SIGTERM) "
"to skip. ")
state[signal.SIGINT] = True
# Restore original signal handler to react to future SIGINT signals
signal.signal(signal.SIGINT, original_handler)
if not int(os.getenv("TUNE_DISABLE_SIGINT_HANDLER", "0")):
signal.signal(signal.SIGINT, sigint_handler)
tune_start = time.time()
while not runner.is_finished():
while not runner.is_finished() and not state[signal.SIGINT]:
runner.step()
if has_verbosity(Verbosity.V1_EXPERIMENT):
_report_progress(runner, progress_reporter)
@@ -451,7 +472,7 @@ def run(
incomplete_trials += [trial]
if incomplete_trials:
if raise_on_failed_trial:
if raise_on_failed_trial and not state[signal.SIGINT]:
raise TuneError("Trials did not complete", incomplete_trials)
else:
logger.error("Trials did not complete: %s", incomplete_trials)
@@ -461,6 +482,12 @@ def run(
logger.info(f"Total run time: {all_taken:.2f} seconds "
f"({tune_taken:.2f} seconds for the tuning loop).")
if state[signal.SIGINT]:
logger.warning(
"Experiment has been interrupted, but the most recent state was "
"saved. You can continue running this experiment by passing "
"`resume=True` to `tune.run()`")
trials = runner.get_trials()
return ExperimentAnalysis(
runner.checkpoint_file,