[tune] Add a raise_on_failed_trial flag in run_experiments (#2961)

Adds a flag to control raising TuneError if some trial fails in `run_experiments`.
This commit is contained in:
old-bear
2018-09-30 02:29:46 +08:00
committed by Richard Liaw
parent a879302355
commit b3f0dcf20b
2 changed files with 30 additions and 3 deletions
+20
View File
@@ -366,6 +366,26 @@ class TrainableFunctionApiTest(unittest.TestCase):
self.assertEqual(trial.status, Trial.TERMINATED)
self.assertEqual(trial.last_result[TIMESTEPS_TOTAL], 99)
def testNoRaiseFlag(self):
def train(config, reporter):
# Finish this trial without any metric,
# which leads to a failed trial
return
register_trainable("f1", train)
[trial] = run_experiments(
{
"foo": {
"run": "f1",
"config": {
"script_min_iter_time_s": 0,
},
}
},
raise_on_failed_trial=False)
self.assertEqual(trial.status, Trial.ERROR)
def testReportInfinity(self):
def train(config, reporter):
for i in range(100):
+10 -3
View File
@@ -39,7 +39,8 @@ def run_experiments(experiments=None,
server_port=TuneServer.DEFAULT_PORT,
verbose=True,
queue_trials=False,
trial_executor=None):
trial_executor=None,
raise_on_failed_trial=True):
"""Runs and blocks until all trials finish.
Args:
@@ -59,6 +60,8 @@ def run_experiments(experiments=None,
be set to True when running on an autoscaling cluster to enable
automatic scale-up.
trial_executor (TrialExecutor): Manage the execution of trials.
raise_on_failed_trial (bool): Raise TuneError if there exists failed
trial (of ERROR state) when the experiments complete.
Examples:
>>> experiment_spec = Experiment("experiment", my_func)
@@ -109,13 +112,17 @@ def run_experiments(experiments=None,
logger.info(runner.debug_string(max_debug=99999))
wait_for_log_sync()
errored_trials = []
for trial in runner.get_trials():
if trial.status != Trial.TERMINATED:
errored_trials += [trial]
if errored_trials:
raise TuneError("Trials did not complete", errored_trials)
if raise_on_failed_trial:
raise TuneError("Trials did not complete", errored_trials)
else:
logger.error("Trials did not complete: %s", errored_trials)
wait_for_log_sync()
return runner.get_trials()