[tune] Add Unit Test for nested PBT + Jenkins (#7324)

This commit is contained in:
Richard Liaw
2020-02-27 18:17:11 -08:00
committed by GitHub
parent 8730996682
commit 3fc162f93c
5 changed files with 52 additions and 16 deletions
@@ -117,6 +117,8 @@ if __name__ == "__main__":
args = parser.parse_args()
if args.ray_address:
ray.init(address=args.ray_address)
else:
ray.init(num_cpus=2 if args.smoke_test else None)
sched = AsyncHyperBandScheduler(
time_attr="training_iteration", metric="mean_accuracy")
analysis = tune.run(
@@ -64,7 +64,7 @@ class TrainMNIST(tune.Trainable):
if __name__ == "__main__":
args = parser.parse_args()
ray.init(address=args.ray_address)
ray.init(address=args.ray_address, num_cpus=6 if args.smoke_test else None)
sched = ASHAScheduler(metric="mean_accuracy")
analysis = tune.run(
TrainMNIST,
+1 -1
View File
@@ -48,7 +48,7 @@ if __name__ == "__main__":
from ray.tune.schedulers import AsyncHyperBandScheduler
mnist.load_data() # we do this on the driver because it's not threadsafe
ray.init()
ray.init(num_cpus=2 if args.smoke_test else None)
sched = AsyncHyperBandScheduler(
time_attr="training_iteration",
metric="mean_accuracy",
+2 -2
View File
@@ -179,9 +179,9 @@ class PopulationBasedTraining(FIFOScheduler):
custom_explore_fn=None,
log_config=True):
for value in hyperparam_mutations.values():
if not (isinstance(value, list) or callable(value)):
if not (isinstance(value, (list, dict)) or callable(value)):
raise TypeError("`hyperparam_mutation` values must be either "
"a List or callable.")
"a List, Dict, or callable.")
if not hyperparam_mutations and not custom_explore_fn:
raise TuneError(
+46 -12
View File
@@ -711,28 +711,31 @@ class PopulationBasedTestingSuite(unittest.TestCase):
explore=None,
perturbation_interval=10,
log_config=False,
hyperparams=None,
hyperparam_mutations=None,
step_once=True):
hyperparam_mutations = hyperparam_mutations or {
"float_factor": lambda: 100.0,
"int_factor": lambda: 10,
"id_factor": [100]
}
pbt = PopulationBasedTraining(
time_attr="training_iteration",
perturbation_interval=perturbation_interval,
resample_probability=resample_prob,
quantile_fraction=0.25,
hyperparam_mutations={
"id_factor": [100],
"float_factor": lambda: 100.0,
"int_factor": lambda: 10,
},
hyperparam_mutations=hyperparam_mutations,
custom_explore_fn=explore,
log_config=log_config)
runner = _MockTrialRunner(pbt)
for i in range(5):
trial = _MockTrial(
i, {
"id_factor": i,
"float_factor": 2.0,
"const_factor": 3,
"int_factor": 10
})
trial_hyperparams = hyperparams or {
"float_factor": 2.0,
"const_factor": 3,
"int_factor": 10,
"id_factor": i
}
trial = _MockTrial(i, trial_hyperparams)
runner.add_trial(trial)
trial.status = Trial.RUNNING
if step_once:
@@ -958,6 +961,37 @@ class PopulationBasedTestingSuite(unittest.TestCase):
# Expect call count to be 100 because we call explore 100 times
self.assertEqual(custom_explore_fn.call_count, 100)
def testDictPerturbation(self):
pbt, runner = self.basicSetup(
resample_prob=1.0,
hyperparams={
"float_factor": 2.0,
"nest": {
"nest_float": 3.0
},
"int_factor": 10,
"const_factor": 3
},
hyperparam_mutations={
"float_factor": lambda: 100.0,
"nest": {
"nest_float": lambda: 101.0
},
"int_factor": lambda: 10,
})
trials = runner.get_trials()
self.assertEqual(
pbt.on_trial_result(runner, trials[0], result(20, -100)),
TrialScheduler.CONTINUE)
self.assertIn(trials[0].restored_checkpoint, ["trial_3", "trial_4"])
self.assertEqual(trials[0].config["float_factor"], 100.0)
self.assertIsInstance(trials[0].config["float_factor"], float)
self.assertEqual(trials[0].config["int_factor"], 10)
self.assertIsInstance(trials[0].config["int_factor"], int)
self.assertEqual(trials[0].config["const_factor"], 3)
self.assertEqual(trials[0].config["nest"]["nest_float"], 101.0)
self.assertIsInstance(trials[0].config["nest"]["nest_float"], float)
def testYieldsTimeToOtherTrials(self):
pbt, runner = self.basicSetup()
trials = runner.get_trials()