diff --git a/python/ray/tune/schedulers/pbt.py b/python/ray/tune/schedulers/pbt.py index f8f7ce7dc..31259c41d 100644 --- a/python/ray/tune/schedulers/pbt.py +++ b/python/ray/tune/schedulers/pbt.py @@ -268,8 +268,8 @@ class PopulationBasedTraining(FIFOScheduler): "pbt_policy_" + trial_to_clone_id + ".txt") policy = [ trial_name, trial_to_clone_name, - trial.last_result[TRAINING_ITERATION], - trial_to_clone.last_result[TRAINING_ITERATION], + trial.last_result.get(TRAINING_ITERATION, 0), + trial_to_clone.last_result.get(TRAINING_ITERATION, 0), trial_to_clone.config, new_config ] # Log to global file. diff --git a/python/ray/tune/tests/test_trial_scheduler.py b/python/ray/tune/tests/test_trial_scheduler.py index 427ca5363..cd2122668 100644 --- a/python/ray/tune/tests/test_trial_scheduler.py +++ b/python/ray/tune/tests/test_trial_scheduler.py @@ -622,10 +622,15 @@ class PopulationBasedTestingSuite(unittest.TestCase): ray.shutdown() _register_all() # re-register the evicted objects - def basicSetup(self, resample_prob=0.0, explore=None, log_config=False): + def basicSetup(self, + resample_prob=0.0, + explore=None, + perturbation_interval=10, + log_config=False, + step_once=True): pbt = PopulationBasedTraining( time_attr="training_iteration", - perturbation_interval=10, + perturbation_interval=perturbation_interval, resample_probability=resample_prob, quantile_fraction=0.25, hyperparam_mutations={ @@ -646,9 +651,10 @@ class PopulationBasedTestingSuite(unittest.TestCase): }) runner.add_trial(trial) trial.status = Trial.RUNNING - self.assertEqual( - pbt.on_trial_result(runner, trial, result(10, 50 * i)), - TrialScheduler.CONTINUE) + if step_once: + self.assertEqual( + pbt.on_trial_result(runner, trial, result(10, 50 * i)), + TrialScheduler.CONTINUE) pbt.reset_stats() return pbt, runner @@ -959,6 +965,24 @@ class PopulationBasedTestingSuite(unittest.TestCase): self.assertEqual(trials[0].config["id_factor"], 42) self.assertEqual(trials[0].config["float_factor"], 43) + def testFastPerturb(self): + pbt, runner = self.basicSetup( + perturbation_interval=1, step_once=False, log_config=True) + trials = runner.get_trials() + + tmpdir = tempfile.mkdtemp() + for i, trial in enumerate(trials): + trial.local_dir = tmpdir + trial.last_result = {} + pbt.on_trial_result(runner, trials[0], result(1, 10)) + self.assertEqual( + pbt.on_trial_result(runner, trials[2], result(1, 200)), + TrialScheduler.CONTINUE) + self.assertEqual(pbt._num_checkpoints, 1) + + pbt._exploit(runner.trial_executor, trials[1], trials[2]) + shutil.rmtree(tmpdir) + class AsyncHyperBandSuite(unittest.TestCase): def setUp(self):