diff --git a/python/ray/tune/tests/test_trial_runner_callbacks.py b/python/ray/tune/tests/test_trial_runner_callbacks.py index 13deb6e7b..05cf507d2 100644 --- a/python/ray/tune/tests/test_trial_runner_callbacks.py +++ b/python/ray/tune/tests/test_trial_runner_callbacks.py @@ -89,6 +89,7 @@ class TrialRunnerCallbacks(unittest.TestCase): for t in trials: self.trial_runner.add_trial(t) + self.executor.next_trial = trials[0] self.trial_runner.step() # Trial 1 has been started @@ -103,6 +104,7 @@ class TrialRunnerCallbacks(unittest.TestCase): "trial_complete", "trial_fail" ])) + self.executor.next_trial = trials[1] self.trial_runner.step() # Iteration not increased yet @@ -120,6 +122,7 @@ class TrialRunnerCallbacks(unittest.TestCase): {TRAINING_ITERATION: 0}) # Let the first trial save a checkpoint + self.executor.next_trial = trials[0] trials[0].saving_to = cp self.trial_runner.step() self.assertEqual(self.callback.state["trial_save"]["iteration"], 2)