mirror of
https://github.com/wassname/ray.git
synced 2026-07-01 14:36:45 +08:00
[tune] Add a raise_on_failed_trial flag in run_experiments (#2961)
Adds a flag to control raising TuneError if some trial fails in `run_experiments`.
This commit is contained in:
@@ -366,6 +366,26 @@ class TrainableFunctionApiTest(unittest.TestCase):
|
||||
self.assertEqual(trial.status, Trial.TERMINATED)
|
||||
self.assertEqual(trial.last_result[TIMESTEPS_TOTAL], 99)
|
||||
|
||||
def testNoRaiseFlag(self):
|
||||
def train(config, reporter):
|
||||
# Finish this trial without any metric,
|
||||
# which leads to a failed trial
|
||||
return
|
||||
|
||||
register_trainable("f1", train)
|
||||
|
||||
[trial] = run_experiments(
|
||||
{
|
||||
"foo": {
|
||||
"run": "f1",
|
||||
"config": {
|
||||
"script_min_iter_time_s": 0,
|
||||
},
|
||||
}
|
||||
},
|
||||
raise_on_failed_trial=False)
|
||||
self.assertEqual(trial.status, Trial.ERROR)
|
||||
|
||||
def testReportInfinity(self):
|
||||
def train(config, reporter):
|
||||
for i in range(100):
|
||||
|
||||
+10
-3
@@ -39,7 +39,8 @@ def run_experiments(experiments=None,
|
||||
server_port=TuneServer.DEFAULT_PORT,
|
||||
verbose=True,
|
||||
queue_trials=False,
|
||||
trial_executor=None):
|
||||
trial_executor=None,
|
||||
raise_on_failed_trial=True):
|
||||
"""Runs and blocks until all trials finish.
|
||||
|
||||
Args:
|
||||
@@ -59,6 +60,8 @@ def run_experiments(experiments=None,
|
||||
be set to True when running on an autoscaling cluster to enable
|
||||
automatic scale-up.
|
||||
trial_executor (TrialExecutor): Manage the execution of trials.
|
||||
raise_on_failed_trial (bool): Raise TuneError if there exists failed
|
||||
trial (of ERROR state) when the experiments complete.
|
||||
|
||||
Examples:
|
||||
>>> experiment_spec = Experiment("experiment", my_func)
|
||||
@@ -109,13 +112,17 @@ def run_experiments(experiments=None,
|
||||
|
||||
logger.info(runner.debug_string(max_debug=99999))
|
||||
|
||||
wait_for_log_sync()
|
||||
|
||||
errored_trials = []
|
||||
for trial in runner.get_trials():
|
||||
if trial.status != Trial.TERMINATED:
|
||||
errored_trials += [trial]
|
||||
|
||||
if errored_trials:
|
||||
raise TuneError("Trials did not complete", errored_trials)
|
||||
if raise_on_failed_trial:
|
||||
raise TuneError("Trials did not complete", errored_trials)
|
||||
else:
|
||||
logger.error("Trials did not complete: %s", errored_trials)
|
||||
|
||||
wait_for_log_sync()
|
||||
return runner.get_trials()
|
||||
|
||||
Reference in New Issue
Block a user