mirror of
https://github.com/wassname/ray.git
synced 2026-06-29 02:47:10 +08:00
[tune] Fail Fast (#7528)
* pytest * init cancel * testing * Update python/ray/tune/tests/test_tune_server.py Co-Authored-By: Richard Liaw <rliaw@berkeley.edu> * change-test * Apply suggestions from code review * Apply suggestions from code review * finished * set_finished * tune * fix Co-authored-by: ijrsvt <ian.rodney@gmail.com>
This commit is contained in:
@@ -8,7 +8,7 @@ import random
|
||||
import numpy as np
|
||||
|
||||
import ray
|
||||
from ray.tune import Trainable, run, Experiment, sample_from
|
||||
from ray.tune import Trainable, run, sample_from
|
||||
from ray.tune.schedulers import HyperBandScheduler
|
||||
|
||||
|
||||
@@ -58,14 +58,13 @@ if __name__ == "__main__":
|
||||
mode="max",
|
||||
max_t=100)
|
||||
|
||||
exp = Experiment(
|
||||
run(MyTrainableClass,
|
||||
name="hyperband_test",
|
||||
run=MyTrainableClass,
|
||||
num_samples=20,
|
||||
stop={"training_iteration": 1 if args.smoke_test else 99999},
|
||||
config={
|
||||
"width": sample_from(lambda spec: 10 + int(90 * random.random())),
|
||||
"height": sample_from(lambda spec: int(100 * random.random()))
|
||||
})
|
||||
|
||||
run(exp, scheduler=hyperband)
|
||||
},
|
||||
scheduler=hyperband,
|
||||
fail_fast=True)
|
||||
|
||||
@@ -175,6 +175,11 @@ class HyperBandScheduler(FIFOScheduler):
|
||||
return TrialScheduler.CONTINUE
|
||||
|
||||
action = self._process_bracket(trial_runner, bracket)
|
||||
logger.info("{action} for {trial} on {metric}={metric_val}".format(
|
||||
action=action,
|
||||
trial=trial,
|
||||
metric=self._time_attr,
|
||||
metric_val=result.get(self._time_attr)))
|
||||
return action
|
||||
|
||||
def _process_bracket(self, trial_runner, bracket):
|
||||
@@ -379,7 +384,7 @@ class Bracket:
|
||||
|
||||
delta = self._get_result_time(result) - \
|
||||
self._get_result_time(self._live_trials[trial])
|
||||
assert delta >= 0
|
||||
assert delta >= 0, (result, self._live_trials[trial])
|
||||
self._completed_progress += delta
|
||||
self._live_trials[trial] = result
|
||||
|
||||
|
||||
@@ -191,6 +191,31 @@ class TrialRunnerTest2(unittest.TestCase):
|
||||
self.assertEqual(trials[0].status, Trial.ERROR)
|
||||
self.assertEqual(trials[0].num_failures, 3)
|
||||
|
||||
def testFailFast(self):
|
||||
ray.init(num_cpus=1, num_gpus=1)
|
||||
runner = TrialRunner(fail_fast=True)
|
||||
kwargs = {
|
||||
"resources": Resources(cpu=1, gpu=1),
|
||||
"checkpoint_freq": 1,
|
||||
"max_failures": 0,
|
||||
"config": {
|
||||
"mock_error": True,
|
||||
"persistent_error": True,
|
||||
},
|
||||
}
|
||||
runner.add_trial(Trial("__fake", **kwargs))
|
||||
runner.add_trial(Trial("__fake", **kwargs))
|
||||
trials = runner.get_trials()
|
||||
|
||||
runner.step() # Start trial
|
||||
self.assertEqual(trials[0].status, Trial.RUNNING)
|
||||
runner.step() # Process result, dispatch save
|
||||
self.assertEqual(trials[0].status, Trial.RUNNING)
|
||||
runner.step() # Process save
|
||||
runner.step() # Error
|
||||
self.assertEqual(trials[0].status, Trial.ERROR)
|
||||
self.assertRaises(TuneError, lambda: runner.step())
|
||||
|
||||
def testCheckpointing(self):
|
||||
ray.init(num_cpus=1, num_gpus=1)
|
||||
runner = TrialRunner()
|
||||
|
||||
@@ -104,6 +104,7 @@ class TrialRunner:
|
||||
resume (str|False): see `tune.py:run`.
|
||||
sync_to_cloud (func|str): See `tune.py:run`.
|
||||
server_port (int): Port number for launching TuneServer.
|
||||
fail_fast (bool): Finishes as soon as a trial fails if True.
|
||||
verbose (bool): Flag for verbosity. If False, trial results
|
||||
will not be output.
|
||||
checkpoint_period (int): Trial runner checkpoint periodicity in
|
||||
@@ -124,6 +125,7 @@ class TrialRunner:
|
||||
stopper=None,
|
||||
resume=False,
|
||||
server_port=TuneServer.DEFAULT_PORT,
|
||||
fail_fast=False,
|
||||
verbose=True,
|
||||
checkpoint_period=10,
|
||||
trial_executor=None):
|
||||
@@ -137,6 +139,8 @@ class TrialRunner:
|
||||
os.environ.get("TRIALRUNNER_WALLTIME_LIMIT", float("inf")))
|
||||
self._total_time = 0
|
||||
self._iteration = 0
|
||||
self._has_errored = False
|
||||
self._fail_fast = fail_fast
|
||||
self._verbose = verbose
|
||||
|
||||
self._server = None
|
||||
@@ -392,12 +396,15 @@ class TrialRunner:
|
||||
return self.trial_executor.has_resources(resources)
|
||||
|
||||
def _stop_experiment_if_needed(self):
|
||||
"""Stops all trials if the user condition is satisfied."""
|
||||
|
||||
if self._stopper.stop_all() or self._should_stop_experiment:
|
||||
"""Stops all trials."""
|
||||
fail_fast = self._fail_fast and self._has_errored
|
||||
if (self._stopper.stop_all() or fail_fast
|
||||
or self._should_stop_experiment):
|
||||
self._search_alg.set_finished()
|
||||
[self.trial_executor.stop_trial(t) for t in self._trials]
|
||||
logger.info("All trials stopped due to ``stopper.stop_all``.")
|
||||
[
|
||||
self.trial_executor.stop_trial(t) for t in self._trials
|
||||
if t.status is not Trial.ERROR
|
||||
]
|
||||
|
||||
def _get_next_trial(self):
|
||||
"""Replenishes queue.
|
||||
@@ -571,6 +578,7 @@ class TrialRunner:
|
||||
trial (Trial): Failed trial.
|
||||
error_msg (str): Error message prior to invoking this method.
|
||||
"""
|
||||
self._has_errored = True
|
||||
if trial.status == Trial.RUNNING:
|
||||
if trial.should_recover():
|
||||
self._try_recover(trial, error_msg)
|
||||
|
||||
+11
-5
@@ -84,6 +84,7 @@ def run(run_or_experiment,
|
||||
global_checkpoint_period=10,
|
||||
export_formats=None,
|
||||
max_failures=0,
|
||||
fail_fast=False,
|
||||
restore=None,
|
||||
search_alg=None,
|
||||
scheduler=None,
|
||||
@@ -172,6 +173,7 @@ def run(run_or_experiment,
|
||||
Ray will recover from the latest checkpoint if present.
|
||||
Setting to -1 will lead to infinite recovery retries.
|
||||
Setting to 0 will disable retries. Defaults to 3.
|
||||
fail_fast (bool): Whether to fail upon the first error.
|
||||
restore (str): Path to checkpoint. Only makes sense to set if
|
||||
running 1 trial. Defaults to None.
|
||||
search_alg (SearchAlgorithm): Search Algorithm. Defaults to
|
||||
@@ -270,6 +272,9 @@ def run(run_or_experiment,
|
||||
assert exp.remote_checkpoint_dir, (
|
||||
"Need `upload_dir` if `sync_to_cloud` given.")
|
||||
|
||||
if fail_fast and max_failures != 0:
|
||||
raise ValueError("max_failures must be 0 if fail_fast=True.")
|
||||
|
||||
runner = TrialRunner(
|
||||
search_alg=search_alg or BasicVariantGenerator(),
|
||||
scheduler=scheduler or FIFOScheduler(),
|
||||
@@ -282,6 +287,7 @@ def run(run_or_experiment,
|
||||
launch_web_server=with_server,
|
||||
server_port=server_port,
|
||||
verbose=bool(verbose > 1),
|
||||
fail_fast=fail_fast,
|
||||
trial_executor=trial_executor)
|
||||
|
||||
for exp in experiments:
|
||||
@@ -326,16 +332,16 @@ def run(run_or_experiment,
|
||||
|
||||
wait_for_sync()
|
||||
|
||||
errored_trials = []
|
||||
incomplete_trials = []
|
||||
for trial in runner.get_trials():
|
||||
if trial.status != Trial.TERMINATED:
|
||||
errored_trials += [trial]
|
||||
incomplete_trials += [trial]
|
||||
|
||||
if errored_trials:
|
||||
if incomplete_trials:
|
||||
if raise_on_failed_trial:
|
||||
raise TuneError("Trials did not complete", errored_trials)
|
||||
raise TuneError("Trials did not complete", incomplete_trials)
|
||||
else:
|
||||
logger.error("Trials did not complete: %s", errored_trials)
|
||||
logger.error("Trials did not complete: %s", incomplete_trials)
|
||||
|
||||
trials = runner.get_trials()
|
||||
if return_trials:
|
||||
|
||||
Reference in New Issue
Block a user