mirror of
https://github.com/wassname/ray.git
synced 2026-06-29 19:32:38 +08:00
[tune] Raise Error when overstepping (#3235)
This commit is contained in:
@@ -970,6 +970,30 @@ class TrialRunnerTest(unittest.TestCase):
|
||||
self.assertEqual(trials[0].status, Trial.RUNNING)
|
||||
self.assertEqual(trials[1].status, Trial.RUNNING)
|
||||
|
||||
def testMultiStepRun2(self):
|
||||
"""Checks that runner.step throws when overstepping."""
|
||||
ray.init(num_cpus=1)
|
||||
runner = TrialRunner(BasicVariantGenerator())
|
||||
kwargs = {
|
||||
"stopping_criterion": {
|
||||
"training_iteration": 2
|
||||
},
|
||||
"resources": Resources(cpu=1, gpu=0),
|
||||
}
|
||||
trials = [Trial("__fake", **kwargs)]
|
||||
for t in trials:
|
||||
runner.add_trial(t)
|
||||
|
||||
runner.step()
|
||||
self.assertEqual(trials[0].status, Trial.RUNNING)
|
||||
|
||||
runner.step()
|
||||
self.assertEqual(trials[0].status, Trial.RUNNING)
|
||||
|
||||
runner.step()
|
||||
self.assertEqual(trials[0].status, Trial.TERMINATED)
|
||||
self.assertRaises(TuneError, runner.step)
|
||||
|
||||
def testErrorHandling(self):
|
||||
ray.init(num_cpus=4, num_gpus=2)
|
||||
runner = TrialRunner(BasicVariantGenerator())
|
||||
@@ -992,6 +1016,12 @@ class TrialRunnerTest(unittest.TestCase):
|
||||
self.assertEqual(trials[0].status, Trial.ERROR)
|
||||
self.assertEqual(trials[1].status, Trial.RUNNING)
|
||||
|
||||
def testThrowOnOverstep(self):
|
||||
ray.init(num_cpus=1, num_gpus=1)
|
||||
runner = TrialRunner(BasicVariantGenerator())
|
||||
runner.step()
|
||||
self.assertRaises(TuneError, runner.step)
|
||||
|
||||
def testFailureRecoveryDisabled(self):
|
||||
ray.init(num_cpus=1, num_gpus=1)
|
||||
runner = TrialRunner(BasicVariantGenerator())
|
||||
@@ -1390,17 +1420,30 @@ class TrialRunnerTest(unittest.TestCase):
|
||||
self.assertTrue(runner.is_finished())
|
||||
|
||||
def testSearchAlgFinishes(self):
|
||||
"""SearchAlg changing state in `next_trials` does not crash."""
|
||||
"""Empty SearchAlg changing state in `next_trials` does not crash."""
|
||||
|
||||
class FinishFastAlg(SuggestionAlgorithm):
|
||||
def next_trials(self):
|
||||
self._finished = True
|
||||
return []
|
||||
_index = 0
|
||||
|
||||
ray.init(num_cpus=4, num_gpus=2)
|
||||
def next_trials(self):
|
||||
trials = []
|
||||
self._index += 1
|
||||
|
||||
for trial in self._trial_generator:
|
||||
trials += [trial]
|
||||
break
|
||||
|
||||
if self._index > 4:
|
||||
self._finished = True
|
||||
return trials
|
||||
|
||||
def _suggest(self, trial_id):
|
||||
return {}
|
||||
|
||||
ray.init(num_cpus=2)
|
||||
experiment_spec = {
|
||||
"run": "__fake",
|
||||
"num_samples": 3,
|
||||
"num_samples": 2,
|
||||
"stop": {
|
||||
"training_iteration": 1
|
||||
}
|
||||
@@ -1410,9 +1453,20 @@ class TrialRunnerTest(unittest.TestCase):
|
||||
searcher.add_configurations(experiments)
|
||||
|
||||
runner = TrialRunner(search_alg=searcher)
|
||||
runner.step() # This should not fail
|
||||
self.assertFalse(runner.is_finished())
|
||||
runner.step() # This launches a new run
|
||||
runner.step() # This launches a 2nd run
|
||||
self.assertFalse(searcher.is_finished())
|
||||
self.assertFalse(runner.is_finished())
|
||||
runner.step() # This kills the first run
|
||||
self.assertFalse(searcher.is_finished())
|
||||
self.assertFalse(runner.is_finished())
|
||||
runner.step() # This kills the 2nd run
|
||||
self.assertFalse(searcher.is_finished())
|
||||
self.assertFalse(runner.is_finished())
|
||||
runner.step() # this converts self._finished to True
|
||||
self.assertTrue(searcher.is_finished())
|
||||
self.assertTrue(runner.is_finished())
|
||||
self.assertRaises(TuneError, runner.step)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
|
||||
@@ -108,16 +108,14 @@ class TrialRunner(object):
|
||||
Callers should typically run this method repeatedly in a loop. They
|
||||
may inspect or modify the runner's state in between calls to step().
|
||||
"""
|
||||
if self.is_finished():
|
||||
raise TuneError("Called step when all trials finished?")
|
||||
self.trial_executor.on_step_begin()
|
||||
next_trial = self._get_next_trial()
|
||||
if next_trial is not None:
|
||||
self.trial_executor.start_trial(next_trial)
|
||||
elif self.trial_executor.get_running_trials():
|
||||
self._process_events()
|
||||
elif self.is_finished():
|
||||
# We check `is_finished` again here because the experiment
|
||||
# may have finished while getting the next trial.
|
||||
pass
|
||||
else:
|
||||
for trial in self._trials:
|
||||
if trial.status == Trial.PENDING:
|
||||
@@ -137,7 +135,6 @@ class TrialRunner(object):
|
||||
raise TuneError(
|
||||
"There are paused trials, but no more pending "
|
||||
"trials with sufficient resources.")
|
||||
raise TuneError("Called step when all trials finished?")
|
||||
|
||||
if self._server:
|
||||
self._process_requests()
|
||||
@@ -306,13 +303,15 @@ class TrialRunner(object):
|
||||
|
||||
Args:
|
||||
blocking (bool): Blocks until either a trial is available
|
||||
or the Runner finishes (i.e., timeout or search algorithm
|
||||
finishes).
|
||||
or is_finished (timeout or search algorithm finishes).
|
||||
timeout (int): Seconds before blocking times out.
|
||||
"""
|
||||
trials = self._search_alg.next_trials()
|
||||
if blocking and not trials:
|
||||
start = time.time()
|
||||
# Checking `is_finished` instead of _search_alg.is_finished
|
||||
# is fine because blocking only occurs if all trials are
|
||||
# finished and search_algorithm is not yet finished
|
||||
while (not trials and not self.is_finished()
|
||||
and time.time() - start < timeout):
|
||||
logger.info("Blocking for next trial...")
|
||||
|
||||
Reference in New Issue
Block a user