[tune] Raise Error when overstepping (#3235)

This commit is contained in:
Richard Liaw
2018-11-07 14:27:09 -08:00
committed by GitHub
parent 29e3362905
commit cf9e838326
2 changed files with 68 additions and 15 deletions
+62 -8
View File
@@ -970,6 +970,30 @@ class TrialRunnerTest(unittest.TestCase):
self.assertEqual(trials[0].status, Trial.RUNNING)
self.assertEqual(trials[1].status, Trial.RUNNING)
def testMultiStepRun2(self):
"""Checks that runner.step throws when overstepping."""
ray.init(num_cpus=1)
runner = TrialRunner(BasicVariantGenerator())
kwargs = {
"stopping_criterion": {
"training_iteration": 2
},
"resources": Resources(cpu=1, gpu=0),
}
trials = [Trial("__fake", **kwargs)]
for t in trials:
runner.add_trial(t)
runner.step()
self.assertEqual(trials[0].status, Trial.RUNNING)
runner.step()
self.assertEqual(trials[0].status, Trial.RUNNING)
runner.step()
self.assertEqual(trials[0].status, Trial.TERMINATED)
self.assertRaises(TuneError, runner.step)
def testErrorHandling(self):
ray.init(num_cpus=4, num_gpus=2)
runner = TrialRunner(BasicVariantGenerator())
@@ -992,6 +1016,12 @@ class TrialRunnerTest(unittest.TestCase):
self.assertEqual(trials[0].status, Trial.ERROR)
self.assertEqual(trials[1].status, Trial.RUNNING)
def testThrowOnOverstep(self):
ray.init(num_cpus=1, num_gpus=1)
runner = TrialRunner(BasicVariantGenerator())
runner.step()
self.assertRaises(TuneError, runner.step)
def testFailureRecoveryDisabled(self):
ray.init(num_cpus=1, num_gpus=1)
runner = TrialRunner(BasicVariantGenerator())
@@ -1390,17 +1420,30 @@ class TrialRunnerTest(unittest.TestCase):
self.assertTrue(runner.is_finished())
def testSearchAlgFinishes(self):
"""SearchAlg changing state in `next_trials` does not crash."""
"""Empty SearchAlg changing state in `next_trials` does not crash."""
class FinishFastAlg(SuggestionAlgorithm):
def next_trials(self):
self._finished = True
return []
_index = 0
ray.init(num_cpus=4, num_gpus=2)
def next_trials(self):
trials = []
self._index += 1
for trial in self._trial_generator:
trials += [trial]
break
if self._index > 4:
self._finished = True
return trials
def _suggest(self, trial_id):
return {}
ray.init(num_cpus=2)
experiment_spec = {
"run": "__fake",
"num_samples": 3,
"num_samples": 2,
"stop": {
"training_iteration": 1
}
@@ -1410,9 +1453,20 @@ class TrialRunnerTest(unittest.TestCase):
searcher.add_configurations(experiments)
runner = TrialRunner(search_alg=searcher)
runner.step() # This should not fail
self.assertFalse(runner.is_finished())
runner.step() # This launches a new run
runner.step() # This launches a 2nd run
self.assertFalse(searcher.is_finished())
self.assertFalse(runner.is_finished())
runner.step() # This kills the first run
self.assertFalse(searcher.is_finished())
self.assertFalse(runner.is_finished())
runner.step() # This kills the 2nd run
self.assertFalse(searcher.is_finished())
self.assertFalse(runner.is_finished())
runner.step() # this converts self._finished to True
self.assertTrue(searcher.is_finished())
self.assertTrue(runner.is_finished())
self.assertRaises(TuneError, runner.step)
if __name__ == "__main__":
+6 -7
View File
@@ -108,16 +108,14 @@ class TrialRunner(object):
Callers should typically run this method repeatedly in a loop. They
may inspect or modify the runner's state in between calls to step().
"""
if self.is_finished():
raise TuneError("Called step when all trials finished?")
self.trial_executor.on_step_begin()
next_trial = self._get_next_trial()
if next_trial is not None:
self.trial_executor.start_trial(next_trial)
elif self.trial_executor.get_running_trials():
self._process_events()
elif self.is_finished():
# We check `is_finished` again here because the experiment
# may have finished while getting the next trial.
pass
else:
for trial in self._trials:
if trial.status == Trial.PENDING:
@@ -137,7 +135,6 @@ class TrialRunner(object):
raise TuneError(
"There are paused trials, but no more pending "
"trials with sufficient resources.")
raise TuneError("Called step when all trials finished?")
if self._server:
self._process_requests()
@@ -306,13 +303,15 @@ class TrialRunner(object):
Args:
blocking (bool): Blocks until either a trial is available
or the Runner finishes (i.e., timeout or search algorithm
finishes).
or is_finished (timeout or search algorithm finishes).
timeout (int): Seconds before blocking times out.
"""
trials = self._search_alg.next_trials()
if blocking and not trials:
start = time.time()
# Checking `is_finished` instead of _search_alg.is_finished
# is fine because blocking only occurs if all trials are
# finished and search_algorithm is not yet finished
while (not trials and not self.is_finished()
and time.time() - start < timeout):
logger.info("Blocking for next trial...")