diff --git a/python/ray/tune/test/trial_runner_test.py b/python/ray/tune/test/trial_runner_test.py index f3926e71f..1e4c0509d 100644 --- a/python/ray/tune/test/trial_runner_test.py +++ b/python/ray/tune/test/trial_runner_test.py @@ -366,6 +366,26 @@ class TrainableFunctionApiTest(unittest.TestCase): self.assertEqual(trial.status, Trial.TERMINATED) self.assertEqual(trial.last_result[TIMESTEPS_TOTAL], 99) + def testNoRaiseFlag(self): + def train(config, reporter): + # Finish this trial without any metric, + # which leads to a failed trial + return + + register_trainable("f1", train) + + [trial] = run_experiments( + { + "foo": { + "run": "f1", + "config": { + "script_min_iter_time_s": 0, + }, + } + }, + raise_on_failed_trial=False) + self.assertEqual(trial.status, Trial.ERROR) + def testReportInfinity(self): def train(config, reporter): for i in range(100): diff --git a/python/ray/tune/tune.py b/python/ray/tune/tune.py index e84304624..335660ecb 100644 --- a/python/ray/tune/tune.py +++ b/python/ray/tune/tune.py @@ -39,7 +39,8 @@ def run_experiments(experiments=None, server_port=TuneServer.DEFAULT_PORT, verbose=True, queue_trials=False, - trial_executor=None): + trial_executor=None, + raise_on_failed_trial=True): """Runs and blocks until all trials finish. Args: @@ -59,6 +60,8 @@ def run_experiments(experiments=None, be set to True when running on an autoscaling cluster to enable automatic scale-up. trial_executor (TrialExecutor): Manage the execution of trials. + raise_on_failed_trial (bool): Raise TuneError if there exists failed + trial (of ERROR state) when the experiments complete. Examples: >>> experiment_spec = Experiment("experiment", my_func) @@ -109,13 +112,17 @@ def run_experiments(experiments=None, logger.info(runner.debug_string(max_debug=99999)) + wait_for_log_sync() + errored_trials = [] for trial in runner.get_trials(): if trial.status != Trial.TERMINATED: errored_trials += [trial] if errored_trials: - raise TuneError("Trials did not complete", errored_trials) + if raise_on_failed_trial: + raise TuneError("Trials did not complete", errored_trials) + else: + logger.error("Trials did not complete: %s", errored_trials) - wait_for_log_sync() return runner.get_trials()