diff --git a/python/ray/tune/test/trial_runner_test.py b/python/ray/tune/test/trial_runner_test.py index f846467f0..79c9b9d28 100644 --- a/python/ray/tune/test/trial_runner_test.py +++ b/python/ray/tune/test/trial_runner_test.py @@ -970,6 +970,30 @@ class TrialRunnerTest(unittest.TestCase): self.assertEqual(trials[0].status, Trial.RUNNING) self.assertEqual(trials[1].status, Trial.RUNNING) + def testMultiStepRun2(self): + """Checks that runner.step throws when overstepping.""" + ray.init(num_cpus=1) + runner = TrialRunner(BasicVariantGenerator()) + kwargs = { + "stopping_criterion": { + "training_iteration": 2 + }, + "resources": Resources(cpu=1, gpu=0), + } + trials = [Trial("__fake", **kwargs)] + for t in trials: + runner.add_trial(t) + + runner.step() + self.assertEqual(trials[0].status, Trial.RUNNING) + + runner.step() + self.assertEqual(trials[0].status, Trial.RUNNING) + + runner.step() + self.assertEqual(trials[0].status, Trial.TERMINATED) + self.assertRaises(TuneError, runner.step) + def testErrorHandling(self): ray.init(num_cpus=4, num_gpus=2) runner = TrialRunner(BasicVariantGenerator()) @@ -992,6 +1016,12 @@ class TrialRunnerTest(unittest.TestCase): self.assertEqual(trials[0].status, Trial.ERROR) self.assertEqual(trials[1].status, Trial.RUNNING) + def testThrowOnOverstep(self): + ray.init(num_cpus=1, num_gpus=1) + runner = TrialRunner(BasicVariantGenerator()) + runner.step() + self.assertRaises(TuneError, runner.step) + def testFailureRecoveryDisabled(self): ray.init(num_cpus=1, num_gpus=1) runner = TrialRunner(BasicVariantGenerator()) @@ -1390,17 +1420,30 @@ class TrialRunnerTest(unittest.TestCase): self.assertTrue(runner.is_finished()) def testSearchAlgFinishes(self): - """SearchAlg changing state in `next_trials` does not crash.""" + """Empty SearchAlg changing state in `next_trials` does not crash.""" class FinishFastAlg(SuggestionAlgorithm): - def next_trials(self): - self._finished = True - return [] + _index = 0 - ray.init(num_cpus=4, num_gpus=2) + def next_trials(self): + trials = [] + self._index += 1 + + for trial in self._trial_generator: + trials += [trial] + break + + if self._index > 4: + self._finished = True + return trials + + def _suggest(self, trial_id): + return {} + + ray.init(num_cpus=2) experiment_spec = { "run": "__fake", - "num_samples": 3, + "num_samples": 2, "stop": { "training_iteration": 1 } @@ -1410,9 +1453,20 @@ class TrialRunnerTest(unittest.TestCase): searcher.add_configurations(experiments) runner = TrialRunner(search_alg=searcher) - runner.step() # This should not fail + self.assertFalse(runner.is_finished()) + runner.step() # This launches a new run + runner.step() # This launches a 2nd run + self.assertFalse(searcher.is_finished()) + self.assertFalse(runner.is_finished()) + runner.step() # This kills the first run + self.assertFalse(searcher.is_finished()) + self.assertFalse(runner.is_finished()) + runner.step() # This kills the 2nd run + self.assertFalse(searcher.is_finished()) + self.assertFalse(runner.is_finished()) + runner.step() # this converts self._finished to True self.assertTrue(searcher.is_finished()) - self.assertTrue(runner.is_finished()) + self.assertRaises(TuneError, runner.step) if __name__ == "__main__": diff --git a/python/ray/tune/trial_runner.py b/python/ray/tune/trial_runner.py index 8e3eb8612..bb31345c8 100644 --- a/python/ray/tune/trial_runner.py +++ b/python/ray/tune/trial_runner.py @@ -108,16 +108,14 @@ class TrialRunner(object): Callers should typically run this method repeatedly in a loop. They may inspect or modify the runner's state in between calls to step(). """ + if self.is_finished(): + raise TuneError("Called step when all trials finished?") self.trial_executor.on_step_begin() next_trial = self._get_next_trial() if next_trial is not None: self.trial_executor.start_trial(next_trial) elif self.trial_executor.get_running_trials(): self._process_events() - elif self.is_finished(): - # We check `is_finished` again here because the experiment - # may have finished while getting the next trial. - pass else: for trial in self._trials: if trial.status == Trial.PENDING: @@ -137,7 +135,6 @@ class TrialRunner(object): raise TuneError( "There are paused trials, but no more pending " "trials with sufficient resources.") - raise TuneError("Called step when all trials finished?") if self._server: self._process_requests() @@ -306,13 +303,15 @@ class TrialRunner(object): Args: blocking (bool): Blocks until either a trial is available - or the Runner finishes (i.e., timeout or search algorithm - finishes). + or is_finished (timeout or search algorithm finishes). timeout (int): Seconds before blocking times out. """ trials = self._search_alg.next_trials() if blocking and not trials: start = time.time() + # Checking `is_finished` instead of _search_alg.is_finished + # is fine because blocking only occurs if all trials are + # finished and search_algorithm is not yet finished while (not trials and not self.is_finished() and time.time() - start < timeout): logger.info("Blocking for next trial...")