mirror of
https://github.com/wassname/ray.git
synced 2026-06-28 19:33:06 +08:00
[tune] PBT perturbing after first iteration (#5097)
This commit is contained in:
@@ -268,8 +268,8 @@ class PopulationBasedTraining(FIFOScheduler):
|
||||
"pbt_policy_" + trial_to_clone_id + ".txt")
|
||||
policy = [
|
||||
trial_name, trial_to_clone_name,
|
||||
trial.last_result[TRAINING_ITERATION],
|
||||
trial_to_clone.last_result[TRAINING_ITERATION],
|
||||
trial.last_result.get(TRAINING_ITERATION, 0),
|
||||
trial_to_clone.last_result.get(TRAINING_ITERATION, 0),
|
||||
trial_to_clone.config, new_config
|
||||
]
|
||||
# Log to global file.
|
||||
|
||||
@@ -622,10 +622,15 @@ class PopulationBasedTestingSuite(unittest.TestCase):
|
||||
ray.shutdown()
|
||||
_register_all() # re-register the evicted objects
|
||||
|
||||
def basicSetup(self, resample_prob=0.0, explore=None, log_config=False):
|
||||
def basicSetup(self,
|
||||
resample_prob=0.0,
|
||||
explore=None,
|
||||
perturbation_interval=10,
|
||||
log_config=False,
|
||||
step_once=True):
|
||||
pbt = PopulationBasedTraining(
|
||||
time_attr="training_iteration",
|
||||
perturbation_interval=10,
|
||||
perturbation_interval=perturbation_interval,
|
||||
resample_probability=resample_prob,
|
||||
quantile_fraction=0.25,
|
||||
hyperparam_mutations={
|
||||
@@ -646,9 +651,10 @@ class PopulationBasedTestingSuite(unittest.TestCase):
|
||||
})
|
||||
runner.add_trial(trial)
|
||||
trial.status = Trial.RUNNING
|
||||
self.assertEqual(
|
||||
pbt.on_trial_result(runner, trial, result(10, 50 * i)),
|
||||
TrialScheduler.CONTINUE)
|
||||
if step_once:
|
||||
self.assertEqual(
|
||||
pbt.on_trial_result(runner, trial, result(10, 50 * i)),
|
||||
TrialScheduler.CONTINUE)
|
||||
pbt.reset_stats()
|
||||
return pbt, runner
|
||||
|
||||
@@ -959,6 +965,24 @@ class PopulationBasedTestingSuite(unittest.TestCase):
|
||||
self.assertEqual(trials[0].config["id_factor"], 42)
|
||||
self.assertEqual(trials[0].config["float_factor"], 43)
|
||||
|
||||
def testFastPerturb(self):
|
||||
pbt, runner = self.basicSetup(
|
||||
perturbation_interval=1, step_once=False, log_config=True)
|
||||
trials = runner.get_trials()
|
||||
|
||||
tmpdir = tempfile.mkdtemp()
|
||||
for i, trial in enumerate(trials):
|
||||
trial.local_dir = tmpdir
|
||||
trial.last_result = {}
|
||||
pbt.on_trial_result(runner, trials[0], result(1, 10))
|
||||
self.assertEqual(
|
||||
pbt.on_trial_result(runner, trials[2], result(1, 200)),
|
||||
TrialScheduler.CONTINUE)
|
||||
self.assertEqual(pbt._num_checkpoints, 1)
|
||||
|
||||
pbt._exploit(runner.trial_executor, trials[1], trials[2])
|
||||
shutil.rmtree(tmpdir)
|
||||
|
||||
|
||||
class AsyncHyperBandSuite(unittest.TestCase):
|
||||
def setUp(self):
|
||||
|
||||
Reference in New Issue
Block a user