[tune] catch SIGINT signal and trigger experiment checkpoint (#13767)

* [tune] catch SIGINT signal and trigger experiment checkpoint

* Apply suggestions from code review

* Fix user guide docs

* Update doc/source/tune/user-guide.rst
This commit is contained in:
Kai Fricke
2021-02-02 14:52:09 +01:00
committed by GitHub
parent b9c15a2551
commit d29fcfb45c
3 changed files with 151 additions and 3 deletions
@@ -1,8 +1,10 @@
# coding: utf-8
import signal
from collections import Counter
import os
import shutil
import tempfile
import time
import unittest
import skopt
import numpy as np
@@ -87,6 +89,66 @@ class TuneRestoreTest(unittest.TestCase):
self.assertTrue(os.path.isfile(self.checkpoint_path))
class TuneInterruptionTest(unittest.TestCase):
def testExperimentInterrupted(self):
import multiprocessing
trainer_semaphore = multiprocessing.Semaphore()
driver_semaphore = multiprocessing.Semaphore()
class SteppingCallback(Callback):
def on_step_end(self, iteration, trials, **info):
driver_semaphore.release() # Driver should continue
trainer_semaphore.acquire() # Wait until released
def _run(local_dir):
def _train(config):
for i in range(7):
tune.report(val=i)
tune.run(
_train,
local_dir=local_dir,
name="interrupt",
callbacks=[SteppingCallback()])
local_dir = tempfile.mkdtemp()
process = multiprocessing.Process(target=_run, args=(local_dir, ))
process.daemon = False
process.start()
exp_dir = os.path.join(local_dir, "interrupt")
# Skip first five steps
for i in range(5):
driver_semaphore.acquire() # Wait for callback
trainer_semaphore.release() # Continue training
driver_semaphore.acquire()
experiment_state_file = None
for file in os.listdir(exp_dir):
if file.startswith("experiment_state"):
experiment_state_file = os.path.join(exp_dir, file)
break
self.assertTrue(experiment_state_file)
last_mtime = os.path.getmtime(experiment_state_file)
# Now send kill signal
os.kill(process.pid, signal.SIGINT)
# Release trainer. It should handle the signal and try to
# checkpoint the experiment
trainer_semaphore.release()
time.sleep(2) # Wait for checkpoint
new_mtime = os.path.getmtime(experiment_state_file)
self.assertNotEqual(last_mtime, new_mtime)
shutil.rmtree(local_dir)
class TuneFailResumeGridTest(unittest.TestCase):
class FailureInjectorCallback(Callback):
"""Adds random failure injection to the TrialExecutor."""
+30 -3
View File
@@ -3,6 +3,8 @@ from typing import Any, Callable, Dict, Mapping, Optional, Sequence, Type, \
import datetime
import logging
import os
import signal
import sys
import time
@@ -112,6 +114,10 @@ def run(
) -> ExperimentAnalysis:
"""Executes training.
When a SIGINT signal is received (e.g. through Ctrl+C), the tuning run
will gracefully shut down and checkpoint the latest experiment state.
Sending SIGINT again (or SIGKILL/SIGTERM instead) will skip this step.
Examples:
.. code-block:: python
@@ -265,7 +271,6 @@ def run(
`LoggerCallback` and `SyncerCallback` callbacks are automatically
added.
Returns:
ExperimentAnalysis: Object for experiment analysis.
@@ -427,8 +432,24 @@ def run(
"`Trainable.default_resource_request` if using the "
"Trainable API.")
original_handler = signal.getsignal(signal.SIGINT)
state = {signal.SIGINT: False}
def sigint_handler(sig, frame):
logger.warning(
"SIGINT received (e.g. via Ctrl+C), ending Ray Tune run. "
"This will try to checkpoint the experiment state one last time. "
"Press CTRL+C one more time (or send SIGINT/SIGKILL/SIGTERM) "
"to skip. ")
state[signal.SIGINT] = True
# Restore original signal handler to react to future SIGINT signals
signal.signal(signal.SIGINT, original_handler)
if not int(os.getenv("TUNE_DISABLE_SIGINT_HANDLER", "0")):
signal.signal(signal.SIGINT, sigint_handler)
tune_start = time.time()
while not runner.is_finished():
while not runner.is_finished() and not state[signal.SIGINT]:
runner.step()
if has_verbosity(Verbosity.V1_EXPERIMENT):
_report_progress(runner, progress_reporter)
@@ -451,7 +472,7 @@ def run(
incomplete_trials += [trial]
if incomplete_trials:
if raise_on_failed_trial:
if raise_on_failed_trial and not state[signal.SIGINT]:
raise TuneError("Trials did not complete", incomplete_trials)
else:
logger.error("Trials did not complete: %s", incomplete_trials)
@@ -461,6 +482,12 @@ def run(
logger.info(f"Total run time: {all_taken:.2f} seconds "
f"({tune_taken:.2f} seconds for the tuning loop).")
if state[signal.SIGINT]:
logger.warning(
"Experiment has been interrupted, but the most recent state was "
"saved. You can continue running this experiment by passing "
"`resume=True` to `tune.run()`")
trials = runner.get_trials()
return ExperimentAnalysis(
runner.checkpoint_file,