mirror of
https://github.com/wassname/ray.git
synced 2026-06-27 21:38:18 +08:00
[tune] Fix SearchAlg finishing early (#3081)
* Fix trial search alg finishing early * Fix lint * fix lint * nit fix
This commit is contained in:
@@ -20,7 +20,8 @@ from ray.tune.experiment import Experiment
|
||||
from ray.tune.trial import Trial, Resources
|
||||
from ray.tune.trial_runner import TrialRunner
|
||||
from ray.tune.suggest import grid_search, BasicVariantGenerator
|
||||
from ray.tune.suggest.suggestion import _MockSuggestionAlgorithm
|
||||
from ray.tune.suggest.suggestion import (_MockSuggestionAlgorithm,
|
||||
SuggestionAlgorithm)
|
||||
from ray.tune.suggest.variant_generator import RecursiveDependencyError
|
||||
|
||||
|
||||
@@ -1385,6 +1386,31 @@ class TrialRunnerTest(unittest.TestCase):
|
||||
self.assertTrue(searcher.is_finished())
|
||||
self.assertTrue(runner.is_finished())
|
||||
|
||||
def testSearchAlgFinishes(self):
|
||||
"""SearchAlg changing state in `next_trials` does not crash."""
|
||||
|
||||
class FinishFastAlg(SuggestionAlgorithm):
|
||||
def next_trials(self):
|
||||
self._finished = True
|
||||
return []
|
||||
|
||||
ray.init(num_cpus=4, num_gpus=2)
|
||||
experiment_spec = {
|
||||
"run": "__fake",
|
||||
"num_samples": 3,
|
||||
"stop": {
|
||||
"training_iteration": 1
|
||||
}
|
||||
}
|
||||
searcher = FinishFastAlg()
|
||||
experiments = [Experiment.from_json("test", experiment_spec)]
|
||||
searcher.add_configurations(experiments)
|
||||
|
||||
runner = TrialRunner(search_alg=searcher)
|
||||
runner.step() # This should not fail
|
||||
self.assertTrue(searcher.is_finished())
|
||||
self.assertTrue(runner.is_finished())
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
unittest.main(verbosity=2)
|
||||
|
||||
@@ -114,6 +114,10 @@ class TrialRunner(object):
|
||||
self.trial_executor.start_trial(next_trial)
|
||||
elif self.trial_executor.get_running_trials():
|
||||
self._process_events()
|
||||
elif self.is_finished():
|
||||
# We check `is_finished` again here because the experiment
|
||||
# may have finished while getting the next trial.
|
||||
pass
|
||||
else:
|
||||
for trial in self._trials:
|
||||
if trial.status == Trial.PENDING:
|
||||
|
||||
Reference in New Issue
Block a user