[tune] PBT perturbing after first iteration (#5097)

This commit is contained in:
Richard Liaw
2019-07-03 17:27:26 -07:00
committed by GitHub
parent 34d054ff19
commit 0dbb6c4911
2 changed files with 31 additions and 7 deletions
+2 -2
View File
@@ -268,8 +268,8 @@ class PopulationBasedTraining(FIFOScheduler):
"pbt_policy_" + trial_to_clone_id + ".txt")
policy = [
trial_name, trial_to_clone_name,
trial.last_result[TRAINING_ITERATION],
trial_to_clone.last_result[TRAINING_ITERATION],
trial.last_result.get(TRAINING_ITERATION, 0),
trial_to_clone.last_result.get(TRAINING_ITERATION, 0),
trial_to_clone.config, new_config
]
# Log to global file.
+29 -5
View File
@@ -622,10 +622,15 @@ class PopulationBasedTestingSuite(unittest.TestCase):
ray.shutdown()
_register_all() # re-register the evicted objects
def basicSetup(self, resample_prob=0.0, explore=None, log_config=False):
def basicSetup(self,
resample_prob=0.0,
explore=None,
perturbation_interval=10,
log_config=False,
step_once=True):
pbt = PopulationBasedTraining(
time_attr="training_iteration",
perturbation_interval=10,
perturbation_interval=perturbation_interval,
resample_probability=resample_prob,
quantile_fraction=0.25,
hyperparam_mutations={
@@ -646,9 +651,10 @@ class PopulationBasedTestingSuite(unittest.TestCase):
})
runner.add_trial(trial)
trial.status = Trial.RUNNING
self.assertEqual(
pbt.on_trial_result(runner, trial, result(10, 50 * i)),
TrialScheduler.CONTINUE)
if step_once:
self.assertEqual(
pbt.on_trial_result(runner, trial, result(10, 50 * i)),
TrialScheduler.CONTINUE)
pbt.reset_stats()
return pbt, runner
@@ -959,6 +965,24 @@ class PopulationBasedTestingSuite(unittest.TestCase):
self.assertEqual(trials[0].config["id_factor"], 42)
self.assertEqual(trials[0].config["float_factor"], 43)
def testFastPerturb(self):
pbt, runner = self.basicSetup(
perturbation_interval=1, step_once=False, log_config=True)
trials = runner.get_trials()
tmpdir = tempfile.mkdtemp()
for i, trial in enumerate(trials):
trial.local_dir = tmpdir
trial.last_result = {}
pbt.on_trial_result(runner, trials[0], result(1, 10))
self.assertEqual(
pbt.on_trial_result(runner, trials[2], result(1, 200)),
TrialScheduler.CONTINUE)
self.assertEqual(pbt._num_checkpoints, 1)
pbt._exploit(runner.trial_executor, trials[1], trials[2])
shutil.rmtree(tmpdir)
class AsyncHyperBandSuite(unittest.TestCase):
def setUp(self):