[tune] Fail Fast (#7528)

* pytest

* init cancel

* testing

* Update python/ray/tune/tests/test_tune_server.py

Co-Authored-By: Richard Liaw <rliaw@berkeley.edu>

* change-test

* Apply suggestions from code review

* Apply suggestions from code review

* finished

* set_finished

* tune

* fix

Co-authored-by: ijrsvt <ian.rodney@gmail.com>
This commit is contained in:
Richard Liaw
2020-03-26 00:04:09 -07:00
committed by GitHub
parent 3d0a8662b3
commit ca6eabc9cb
5 changed files with 60 additions and 17 deletions
@@ -8,7 +8,7 @@ import random
import numpy as np
import ray
from ray.tune import Trainable, run, Experiment, sample_from
from ray.tune import Trainable, run, sample_from
from ray.tune.schedulers import HyperBandScheduler
@@ -58,14 +58,13 @@ if __name__ == "__main__":
mode="max",
max_t=100)
exp = Experiment(
run(MyTrainableClass,
name="hyperband_test",
run=MyTrainableClass,
num_samples=20,
stop={"training_iteration": 1 if args.smoke_test else 99999},
config={
"width": sample_from(lambda spec: 10 + int(90 * random.random())),
"height": sample_from(lambda spec: int(100 * random.random()))
})
run(exp, scheduler=hyperband)
},
scheduler=hyperband,
fail_fast=True)
+6 -1
View File
@@ -175,6 +175,11 @@ class HyperBandScheduler(FIFOScheduler):
return TrialScheduler.CONTINUE
action = self._process_bracket(trial_runner, bracket)
logger.info("{action} for {trial} on {metric}={metric_val}".format(
action=action,
trial=trial,
metric=self._time_attr,
metric_val=result.get(self._time_attr)))
return action
def _process_bracket(self, trial_runner, bracket):
@@ -379,7 +384,7 @@ class Bracket:
delta = self._get_result_time(result) - \
self._get_result_time(self._live_trials[trial])
assert delta >= 0
assert delta >= 0, (result, self._live_trials[trial])
self._completed_progress += delta
self._live_trials[trial] = result
@@ -191,6 +191,31 @@ class TrialRunnerTest2(unittest.TestCase):
self.assertEqual(trials[0].status, Trial.ERROR)
self.assertEqual(trials[0].num_failures, 3)
def testFailFast(self):
ray.init(num_cpus=1, num_gpus=1)
runner = TrialRunner(fail_fast=True)
kwargs = {
"resources": Resources(cpu=1, gpu=1),
"checkpoint_freq": 1,
"max_failures": 0,
"config": {
"mock_error": True,
"persistent_error": True,
},
}
runner.add_trial(Trial("__fake", **kwargs))
runner.add_trial(Trial("__fake", **kwargs))
trials = runner.get_trials()
runner.step() # Start trial
self.assertEqual(trials[0].status, Trial.RUNNING)
runner.step() # Process result, dispatch save
self.assertEqual(trials[0].status, Trial.RUNNING)
runner.step() # Process save
runner.step() # Error
self.assertEqual(trials[0].status, Trial.ERROR)
self.assertRaises(TuneError, lambda: runner.step())
def testCheckpointing(self):
ray.init(num_cpus=1, num_gpus=1)
runner = TrialRunner()
+13 -5
View File
@@ -104,6 +104,7 @@ class TrialRunner:
resume (str|False): see `tune.py:run`.
sync_to_cloud (func|str): See `tune.py:run`.
server_port (int): Port number for launching TuneServer.
fail_fast (bool): Finishes as soon as a trial fails if True.
verbose (bool): Flag for verbosity. If False, trial results
will not be output.
checkpoint_period (int): Trial runner checkpoint periodicity in
@@ -124,6 +125,7 @@ class TrialRunner:
stopper=None,
resume=False,
server_port=TuneServer.DEFAULT_PORT,
fail_fast=False,
verbose=True,
checkpoint_period=10,
trial_executor=None):
@@ -137,6 +139,8 @@ class TrialRunner:
os.environ.get("TRIALRUNNER_WALLTIME_LIMIT", float("inf")))
self._total_time = 0
self._iteration = 0
self._has_errored = False
self._fail_fast = fail_fast
self._verbose = verbose
self._server = None
@@ -392,12 +396,15 @@ class TrialRunner:
return self.trial_executor.has_resources(resources)
def _stop_experiment_if_needed(self):
"""Stops all trials if the user condition is satisfied."""
if self._stopper.stop_all() or self._should_stop_experiment:
"""Stops all trials."""
fail_fast = self._fail_fast and self._has_errored
if (self._stopper.stop_all() or fail_fast
or self._should_stop_experiment):
self._search_alg.set_finished()
[self.trial_executor.stop_trial(t) for t in self._trials]
logger.info("All trials stopped due to ``stopper.stop_all``.")
[
self.trial_executor.stop_trial(t) for t in self._trials
if t.status is not Trial.ERROR
]
def _get_next_trial(self):
"""Replenishes queue.
@@ -571,6 +578,7 @@ class TrialRunner:
trial (Trial): Failed trial.
error_msg (str): Error message prior to invoking this method.
"""
self._has_errored = True
if trial.status == Trial.RUNNING:
if trial.should_recover():
self._try_recover(trial, error_msg)
+11 -5
View File
@@ -84,6 +84,7 @@ def run(run_or_experiment,
global_checkpoint_period=10,
export_formats=None,
max_failures=0,
fail_fast=False,
restore=None,
search_alg=None,
scheduler=None,
@@ -172,6 +173,7 @@ def run(run_or_experiment,
Ray will recover from the latest checkpoint if present.
Setting to -1 will lead to infinite recovery retries.
Setting to 0 will disable retries. Defaults to 3.
fail_fast (bool): Whether to fail upon the first error.
restore (str): Path to checkpoint. Only makes sense to set if
running 1 trial. Defaults to None.
search_alg (SearchAlgorithm): Search Algorithm. Defaults to
@@ -270,6 +272,9 @@ def run(run_or_experiment,
assert exp.remote_checkpoint_dir, (
"Need `upload_dir` if `sync_to_cloud` given.")
if fail_fast and max_failures != 0:
raise ValueError("max_failures must be 0 if fail_fast=True.")
runner = TrialRunner(
search_alg=search_alg or BasicVariantGenerator(),
scheduler=scheduler or FIFOScheduler(),
@@ -282,6 +287,7 @@ def run(run_or_experiment,
launch_web_server=with_server,
server_port=server_port,
verbose=bool(verbose > 1),
fail_fast=fail_fast,
trial_executor=trial_executor)
for exp in experiments:
@@ -326,16 +332,16 @@ def run(run_or_experiment,
wait_for_sync()
errored_trials = []
incomplete_trials = []
for trial in runner.get_trials():
if trial.status != Trial.TERMINATED:
errored_trials += [trial]
incomplete_trials += [trial]
if errored_trials:
if incomplete_trials:
if raise_on_failed_trial:
raise TuneError("Trials did not complete", errored_trials)
raise TuneError("Trials did not complete", incomplete_trials)
else:
logger.error("Trials did not complete: %s", errored_trials)
logger.error("Trials did not complete: %s", incomplete_trials)
trials = runner.get_trials()
if return_trials: