mirror of
https://github.com/wassname/ray.git
synced 2026-06-27 20:38:19 +08:00
[tune] catch SIGINT signal and trigger experiment checkpoint (#13767)
* [tune] catch SIGINT signal and trigger experiment checkpoint * Apply suggestions from code review * Fix user guide docs * Update doc/source/tune/user-guide.rst
This commit is contained in:
@@ -1,8 +1,10 @@
|
||||
# coding: utf-8
|
||||
import signal
|
||||
from collections import Counter
|
||||
import os
|
||||
import shutil
|
||||
import tempfile
|
||||
import time
|
||||
import unittest
|
||||
import skopt
|
||||
import numpy as np
|
||||
@@ -87,6 +89,66 @@ class TuneRestoreTest(unittest.TestCase):
|
||||
self.assertTrue(os.path.isfile(self.checkpoint_path))
|
||||
|
||||
|
||||
class TuneInterruptionTest(unittest.TestCase):
|
||||
def testExperimentInterrupted(self):
|
||||
import multiprocessing
|
||||
|
||||
trainer_semaphore = multiprocessing.Semaphore()
|
||||
driver_semaphore = multiprocessing.Semaphore()
|
||||
|
||||
class SteppingCallback(Callback):
|
||||
def on_step_end(self, iteration, trials, **info):
|
||||
driver_semaphore.release() # Driver should continue
|
||||
trainer_semaphore.acquire() # Wait until released
|
||||
|
||||
def _run(local_dir):
|
||||
def _train(config):
|
||||
for i in range(7):
|
||||
tune.report(val=i)
|
||||
|
||||
tune.run(
|
||||
_train,
|
||||
local_dir=local_dir,
|
||||
name="interrupt",
|
||||
callbacks=[SteppingCallback()])
|
||||
|
||||
local_dir = tempfile.mkdtemp()
|
||||
process = multiprocessing.Process(target=_run, args=(local_dir, ))
|
||||
process.daemon = False
|
||||
process.start()
|
||||
|
||||
exp_dir = os.path.join(local_dir, "interrupt")
|
||||
|
||||
# Skip first five steps
|
||||
for i in range(5):
|
||||
driver_semaphore.acquire() # Wait for callback
|
||||
trainer_semaphore.release() # Continue training
|
||||
|
||||
driver_semaphore.acquire()
|
||||
|
||||
experiment_state_file = None
|
||||
for file in os.listdir(exp_dir):
|
||||
if file.startswith("experiment_state"):
|
||||
experiment_state_file = os.path.join(exp_dir, file)
|
||||
break
|
||||
|
||||
self.assertTrue(experiment_state_file)
|
||||
last_mtime = os.path.getmtime(experiment_state_file)
|
||||
|
||||
# Now send kill signal
|
||||
os.kill(process.pid, signal.SIGINT)
|
||||
# Release trainer. It should handle the signal and try to
|
||||
# checkpoint the experiment
|
||||
trainer_semaphore.release()
|
||||
|
||||
time.sleep(2) # Wait for checkpoint
|
||||
new_mtime = os.path.getmtime(experiment_state_file)
|
||||
|
||||
self.assertNotEqual(last_mtime, new_mtime)
|
||||
|
||||
shutil.rmtree(local_dir)
|
||||
|
||||
|
||||
class TuneFailResumeGridTest(unittest.TestCase):
|
||||
class FailureInjectorCallback(Callback):
|
||||
"""Adds random failure injection to the TrialExecutor."""
|
||||
|
||||
+30
-3
@@ -3,6 +3,8 @@ from typing import Any, Callable, Dict, Mapping, Optional, Sequence, Type, \
|
||||
|
||||
import datetime
|
||||
import logging
|
||||
import os
|
||||
import signal
|
||||
import sys
|
||||
import time
|
||||
|
||||
@@ -112,6 +114,10 @@ def run(
|
||||
) -> ExperimentAnalysis:
|
||||
"""Executes training.
|
||||
|
||||
When a SIGINT signal is received (e.g. through Ctrl+C), the tuning run
|
||||
will gracefully shut down and checkpoint the latest experiment state.
|
||||
Sending SIGINT again (or SIGKILL/SIGTERM instead) will skip this step.
|
||||
|
||||
Examples:
|
||||
|
||||
.. code-block:: python
|
||||
@@ -265,7 +271,6 @@ def run(
|
||||
`LoggerCallback` and `SyncerCallback` callbacks are automatically
|
||||
added.
|
||||
|
||||
|
||||
Returns:
|
||||
ExperimentAnalysis: Object for experiment analysis.
|
||||
|
||||
@@ -427,8 +432,24 @@ def run(
|
||||
"`Trainable.default_resource_request` if using the "
|
||||
"Trainable API.")
|
||||
|
||||
original_handler = signal.getsignal(signal.SIGINT)
|
||||
state = {signal.SIGINT: False}
|
||||
|
||||
def sigint_handler(sig, frame):
|
||||
logger.warning(
|
||||
"SIGINT received (e.g. via Ctrl+C), ending Ray Tune run. "
|
||||
"This will try to checkpoint the experiment state one last time. "
|
||||
"Press CTRL+C one more time (or send SIGINT/SIGKILL/SIGTERM) "
|
||||
"to skip. ")
|
||||
state[signal.SIGINT] = True
|
||||
# Restore original signal handler to react to future SIGINT signals
|
||||
signal.signal(signal.SIGINT, original_handler)
|
||||
|
||||
if not int(os.getenv("TUNE_DISABLE_SIGINT_HANDLER", "0")):
|
||||
signal.signal(signal.SIGINT, sigint_handler)
|
||||
|
||||
tune_start = time.time()
|
||||
while not runner.is_finished():
|
||||
while not runner.is_finished() and not state[signal.SIGINT]:
|
||||
runner.step()
|
||||
if has_verbosity(Verbosity.V1_EXPERIMENT):
|
||||
_report_progress(runner, progress_reporter)
|
||||
@@ -451,7 +472,7 @@ def run(
|
||||
incomplete_trials += [trial]
|
||||
|
||||
if incomplete_trials:
|
||||
if raise_on_failed_trial:
|
||||
if raise_on_failed_trial and not state[signal.SIGINT]:
|
||||
raise TuneError("Trials did not complete", incomplete_trials)
|
||||
else:
|
||||
logger.error("Trials did not complete: %s", incomplete_trials)
|
||||
@@ -461,6 +482,12 @@ def run(
|
||||
logger.info(f"Total run time: {all_taken:.2f} seconds "
|
||||
f"({tune_taken:.2f} seconds for the tuning loop).")
|
||||
|
||||
if state[signal.SIGINT]:
|
||||
logger.warning(
|
||||
"Experiment has been interrupted, but the most recent state was "
|
||||
"saved. You can continue running this experiment by passing "
|
||||
"`resume=True` to `tune.run()`")
|
||||
|
||||
trials = runner.get_trials()
|
||||
return ExperimentAnalysis(
|
||||
runner.checkpoint_file,
|
||||
|
||||
Reference in New Issue
Block a user