diff --git a/doc/source/tune-usage.rst b/doc/source/tune-usage.rst index b52241f03..0748b1cac 100644 --- a/doc/source/tune-usage.rst +++ b/doc/source/tune-usage.rst @@ -223,7 +223,7 @@ For TensorFlow model training, this would look something like this `(full tensor .. code-block:: python class MyClass(Trainable): - def _setup(self): + def _setup(self, config): self.saver = tf.train.Saver() self.sess = ... self.iteration = 0 diff --git a/python/ray/tune/test/trial_runner_test.py b/python/ray/tune/test/trial_runner_test.py index 65b8fbe36..450e96136 100644 --- a/python/ray/tune/test/trial_runner_test.py +++ b/python/ray/tune/test/trial_runner_test.py @@ -20,7 +20,8 @@ from ray.tune.experiment import Experiment from ray.tune.trial import Trial, Resources from ray.tune.trial_runner import TrialRunner from ray.tune.suggest import grid_search, BasicVariantGenerator -from ray.tune.suggest.suggestion import _MockSuggestionAlgorithm +from ray.tune.suggest.suggestion import (_MockSuggestionAlgorithm, + SuggestionAlgorithm) from ray.tune.suggest.variant_generator import RecursiveDependencyError @@ -1385,6 +1386,31 @@ class TrialRunnerTest(unittest.TestCase): self.assertTrue(searcher.is_finished()) self.assertTrue(runner.is_finished()) + def testSearchAlgFinishes(self): + """SearchAlg changing state in `next_trials` does not crash.""" + + class FinishFastAlg(SuggestionAlgorithm): + def next_trials(self): + self._finished = True + return [] + + ray.init(num_cpus=4, num_gpus=2) + experiment_spec = { + "run": "__fake", + "num_samples": 3, + "stop": { + "training_iteration": 1 + } + } + searcher = FinishFastAlg() + experiments = [Experiment.from_json("test", experiment_spec)] + searcher.add_configurations(experiments) + + runner = TrialRunner(search_alg=searcher) + runner.step() # This should not fail + self.assertTrue(searcher.is_finished()) + self.assertTrue(runner.is_finished()) + if __name__ == "__main__": unittest.main(verbosity=2) diff --git a/python/ray/tune/trial_runner.py b/python/ray/tune/trial_runner.py index 6423d6a95..8e3eb8612 100644 --- a/python/ray/tune/trial_runner.py +++ b/python/ray/tune/trial_runner.py @@ -114,6 +114,10 @@ class TrialRunner(object): self.trial_executor.start_trial(next_trial) elif self.trial_executor.get_running_trials(): self._process_events() + elif self.is_finished(): + # We check `is_finished` again here because the experiment + # may have finished while getting the next trial. + pass else: for trial in self._trials: if trial.status == Trial.PENDING: