From ca6eabc9cb728a829d5305a4dd19216ed925f2e4 Mon Sep 17 00:00:00 2001 From: Richard Liaw Date: Thu, 26 Mar 2020 00:04:09 -0700 Subject: [PATCH] [tune] Fail Fast (#7528) * pytest * init cancel * testing * Update python/ray/tune/tests/test_tune_server.py Co-Authored-By: Richard Liaw * change-test * Apply suggestions from code review * Apply suggestions from code review * finished * set_finished * tune * fix Co-authored-by: ijrsvt --- python/ray/tune/examples/hyperband_example.py | 11 ++++---- python/ray/tune/schedulers/hyperband.py | 7 +++++- python/ray/tune/tests/test_trial_runner_2.py | 25 +++++++++++++++++++ python/ray/tune/trial_runner.py | 18 +++++++++---- python/ray/tune/tune.py | 16 ++++++++---- 5 files changed, 60 insertions(+), 17 deletions(-) diff --git a/python/ray/tune/examples/hyperband_example.py b/python/ray/tune/examples/hyperband_example.py index e1ed629e7..1da764a54 100755 --- a/python/ray/tune/examples/hyperband_example.py +++ b/python/ray/tune/examples/hyperband_example.py @@ -8,7 +8,7 @@ import random import numpy as np import ray -from ray.tune import Trainable, run, Experiment, sample_from +from ray.tune import Trainable, run, sample_from from ray.tune.schedulers import HyperBandScheduler @@ -58,14 +58,13 @@ if __name__ == "__main__": mode="max", max_t=100) - exp = Experiment( + run(MyTrainableClass, name="hyperband_test", - run=MyTrainableClass, num_samples=20, stop={"training_iteration": 1 if args.smoke_test else 99999}, config={ "width": sample_from(lambda spec: 10 + int(90 * random.random())), "height": sample_from(lambda spec: int(100 * random.random())) - }) - - run(exp, scheduler=hyperband) + }, + scheduler=hyperband, + fail_fast=True) diff --git a/python/ray/tune/schedulers/hyperband.py b/python/ray/tune/schedulers/hyperband.py index c55c1c5ba..654cde12d 100644 --- a/python/ray/tune/schedulers/hyperband.py +++ b/python/ray/tune/schedulers/hyperband.py @@ -175,6 +175,11 @@ class HyperBandScheduler(FIFOScheduler): return TrialScheduler.CONTINUE action = self._process_bracket(trial_runner, bracket) + logger.info("{action} for {trial} on {metric}={metric_val}".format( + action=action, + trial=trial, + metric=self._time_attr, + metric_val=result.get(self._time_attr))) return action def _process_bracket(self, trial_runner, bracket): @@ -379,7 +384,7 @@ class Bracket: delta = self._get_result_time(result) - \ self._get_result_time(self._live_trials[trial]) - assert delta >= 0 + assert delta >= 0, (result, self._live_trials[trial]) self._completed_progress += delta self._live_trials[trial] = result diff --git a/python/ray/tune/tests/test_trial_runner_2.py b/python/ray/tune/tests/test_trial_runner_2.py index b72830385..0cf5d434f 100644 --- a/python/ray/tune/tests/test_trial_runner_2.py +++ b/python/ray/tune/tests/test_trial_runner_2.py @@ -191,6 +191,31 @@ class TrialRunnerTest2(unittest.TestCase): self.assertEqual(trials[0].status, Trial.ERROR) self.assertEqual(trials[0].num_failures, 3) + def testFailFast(self): + ray.init(num_cpus=1, num_gpus=1) + runner = TrialRunner(fail_fast=True) + kwargs = { + "resources": Resources(cpu=1, gpu=1), + "checkpoint_freq": 1, + "max_failures": 0, + "config": { + "mock_error": True, + "persistent_error": True, + }, + } + runner.add_trial(Trial("__fake", **kwargs)) + runner.add_trial(Trial("__fake", **kwargs)) + trials = runner.get_trials() + + runner.step() # Start trial + self.assertEqual(trials[0].status, Trial.RUNNING) + runner.step() # Process result, dispatch save + self.assertEqual(trials[0].status, Trial.RUNNING) + runner.step() # Process save + runner.step() # Error + self.assertEqual(trials[0].status, Trial.ERROR) + self.assertRaises(TuneError, lambda: runner.step()) + def testCheckpointing(self): ray.init(num_cpus=1, num_gpus=1) runner = TrialRunner() diff --git a/python/ray/tune/trial_runner.py b/python/ray/tune/trial_runner.py index 1bae88d2e..7a1c950a5 100644 --- a/python/ray/tune/trial_runner.py +++ b/python/ray/tune/trial_runner.py @@ -104,6 +104,7 @@ class TrialRunner: resume (str|False): see `tune.py:run`. sync_to_cloud (func|str): See `tune.py:run`. server_port (int): Port number for launching TuneServer. + fail_fast (bool): Finishes as soon as a trial fails if True. verbose (bool): Flag for verbosity. If False, trial results will not be output. checkpoint_period (int): Trial runner checkpoint periodicity in @@ -124,6 +125,7 @@ class TrialRunner: stopper=None, resume=False, server_port=TuneServer.DEFAULT_PORT, + fail_fast=False, verbose=True, checkpoint_period=10, trial_executor=None): @@ -137,6 +139,8 @@ class TrialRunner: os.environ.get("TRIALRUNNER_WALLTIME_LIMIT", float("inf"))) self._total_time = 0 self._iteration = 0 + self._has_errored = False + self._fail_fast = fail_fast self._verbose = verbose self._server = None @@ -392,12 +396,15 @@ class TrialRunner: return self.trial_executor.has_resources(resources) def _stop_experiment_if_needed(self): - """Stops all trials if the user condition is satisfied.""" - - if self._stopper.stop_all() or self._should_stop_experiment: + """Stops all trials.""" + fail_fast = self._fail_fast and self._has_errored + if (self._stopper.stop_all() or fail_fast + or self._should_stop_experiment): self._search_alg.set_finished() - [self.trial_executor.stop_trial(t) for t in self._trials] - logger.info("All trials stopped due to ``stopper.stop_all``.") + [ + self.trial_executor.stop_trial(t) for t in self._trials + if t.status is not Trial.ERROR + ] def _get_next_trial(self): """Replenishes queue. @@ -571,6 +578,7 @@ class TrialRunner: trial (Trial): Failed trial. error_msg (str): Error message prior to invoking this method. """ + self._has_errored = True if trial.status == Trial.RUNNING: if trial.should_recover(): self._try_recover(trial, error_msg) diff --git a/python/ray/tune/tune.py b/python/ray/tune/tune.py index a96be49ad..4cf1f9d03 100644 --- a/python/ray/tune/tune.py +++ b/python/ray/tune/tune.py @@ -84,6 +84,7 @@ def run(run_or_experiment, global_checkpoint_period=10, export_formats=None, max_failures=0, + fail_fast=False, restore=None, search_alg=None, scheduler=None, @@ -172,6 +173,7 @@ def run(run_or_experiment, Ray will recover from the latest checkpoint if present. Setting to -1 will lead to infinite recovery retries. Setting to 0 will disable retries. Defaults to 3. + fail_fast (bool): Whether to fail upon the first error. restore (str): Path to checkpoint. Only makes sense to set if running 1 trial. Defaults to None. search_alg (SearchAlgorithm): Search Algorithm. Defaults to @@ -270,6 +272,9 @@ def run(run_or_experiment, assert exp.remote_checkpoint_dir, ( "Need `upload_dir` if `sync_to_cloud` given.") + if fail_fast and max_failures != 0: + raise ValueError("max_failures must be 0 if fail_fast=True.") + runner = TrialRunner( search_alg=search_alg or BasicVariantGenerator(), scheduler=scheduler or FIFOScheduler(), @@ -282,6 +287,7 @@ def run(run_or_experiment, launch_web_server=with_server, server_port=server_port, verbose=bool(verbose > 1), + fail_fast=fail_fast, trial_executor=trial_executor) for exp in experiments: @@ -326,16 +332,16 @@ def run(run_or_experiment, wait_for_sync() - errored_trials = [] + incomplete_trials = [] for trial in runner.get_trials(): if trial.status != Trial.TERMINATED: - errored_trials += [trial] + incomplete_trials += [trial] - if errored_trials: + if incomplete_trials: if raise_on_failed_trial: - raise TuneError("Trials did not complete", errored_trials) + raise TuneError("Trials did not complete", incomplete_trials) else: - logger.error("Trials did not complete: %s", errored_trials) + logger.error("Trials did not complete: %s", incomplete_trials) trials = runner.get_trials() if return_trials: