mirror of
https://github.com/wassname/ray.git
synced 2026-06-28 02:46:49 +08:00
[tune] Add Unit Test for nested PBT + Jenkins (#7324)
This commit is contained in:
@@ -117,6 +117,8 @@ if __name__ == "__main__":
|
||||
args = parser.parse_args()
|
||||
if args.ray_address:
|
||||
ray.init(address=args.ray_address)
|
||||
else:
|
||||
ray.init(num_cpus=2 if args.smoke_test else None)
|
||||
sched = AsyncHyperBandScheduler(
|
||||
time_attr="training_iteration", metric="mean_accuracy")
|
||||
analysis = tune.run(
|
||||
|
||||
@@ -64,7 +64,7 @@ class TrainMNIST(tune.Trainable):
|
||||
|
||||
if __name__ == "__main__":
|
||||
args = parser.parse_args()
|
||||
ray.init(address=args.ray_address)
|
||||
ray.init(address=args.ray_address, num_cpus=6 if args.smoke_test else None)
|
||||
sched = ASHAScheduler(metric="mean_accuracy")
|
||||
analysis = tune.run(
|
||||
TrainMNIST,
|
||||
|
||||
@@ -48,7 +48,7 @@ if __name__ == "__main__":
|
||||
from ray.tune.schedulers import AsyncHyperBandScheduler
|
||||
mnist.load_data() # we do this on the driver because it's not threadsafe
|
||||
|
||||
ray.init()
|
||||
ray.init(num_cpus=2 if args.smoke_test else None)
|
||||
sched = AsyncHyperBandScheduler(
|
||||
time_attr="training_iteration",
|
||||
metric="mean_accuracy",
|
||||
|
||||
@@ -179,9 +179,9 @@ class PopulationBasedTraining(FIFOScheduler):
|
||||
custom_explore_fn=None,
|
||||
log_config=True):
|
||||
for value in hyperparam_mutations.values():
|
||||
if not (isinstance(value, list) or callable(value)):
|
||||
if not (isinstance(value, (list, dict)) or callable(value)):
|
||||
raise TypeError("`hyperparam_mutation` values must be either "
|
||||
"a List or callable.")
|
||||
"a List, Dict, or callable.")
|
||||
|
||||
if not hyperparam_mutations and not custom_explore_fn:
|
||||
raise TuneError(
|
||||
|
||||
@@ -711,28 +711,31 @@ class PopulationBasedTestingSuite(unittest.TestCase):
|
||||
explore=None,
|
||||
perturbation_interval=10,
|
||||
log_config=False,
|
||||
hyperparams=None,
|
||||
hyperparam_mutations=None,
|
||||
step_once=True):
|
||||
hyperparam_mutations = hyperparam_mutations or {
|
||||
"float_factor": lambda: 100.0,
|
||||
"int_factor": lambda: 10,
|
||||
"id_factor": [100]
|
||||
}
|
||||
pbt = PopulationBasedTraining(
|
||||
time_attr="training_iteration",
|
||||
perturbation_interval=perturbation_interval,
|
||||
resample_probability=resample_prob,
|
||||
quantile_fraction=0.25,
|
||||
hyperparam_mutations={
|
||||
"id_factor": [100],
|
||||
"float_factor": lambda: 100.0,
|
||||
"int_factor": lambda: 10,
|
||||
},
|
||||
hyperparam_mutations=hyperparam_mutations,
|
||||
custom_explore_fn=explore,
|
||||
log_config=log_config)
|
||||
runner = _MockTrialRunner(pbt)
|
||||
for i in range(5):
|
||||
trial = _MockTrial(
|
||||
i, {
|
||||
"id_factor": i,
|
||||
"float_factor": 2.0,
|
||||
"const_factor": 3,
|
||||
"int_factor": 10
|
||||
})
|
||||
trial_hyperparams = hyperparams or {
|
||||
"float_factor": 2.0,
|
||||
"const_factor": 3,
|
||||
"int_factor": 10,
|
||||
"id_factor": i
|
||||
}
|
||||
trial = _MockTrial(i, trial_hyperparams)
|
||||
runner.add_trial(trial)
|
||||
trial.status = Trial.RUNNING
|
||||
if step_once:
|
||||
@@ -958,6 +961,37 @@ class PopulationBasedTestingSuite(unittest.TestCase):
|
||||
# Expect call count to be 100 because we call explore 100 times
|
||||
self.assertEqual(custom_explore_fn.call_count, 100)
|
||||
|
||||
def testDictPerturbation(self):
|
||||
pbt, runner = self.basicSetup(
|
||||
resample_prob=1.0,
|
||||
hyperparams={
|
||||
"float_factor": 2.0,
|
||||
"nest": {
|
||||
"nest_float": 3.0
|
||||
},
|
||||
"int_factor": 10,
|
||||
"const_factor": 3
|
||||
},
|
||||
hyperparam_mutations={
|
||||
"float_factor": lambda: 100.0,
|
||||
"nest": {
|
||||
"nest_float": lambda: 101.0
|
||||
},
|
||||
"int_factor": lambda: 10,
|
||||
})
|
||||
trials = runner.get_trials()
|
||||
self.assertEqual(
|
||||
pbt.on_trial_result(runner, trials[0], result(20, -100)),
|
||||
TrialScheduler.CONTINUE)
|
||||
self.assertIn(trials[0].restored_checkpoint, ["trial_3", "trial_4"])
|
||||
self.assertEqual(trials[0].config["float_factor"], 100.0)
|
||||
self.assertIsInstance(trials[0].config["float_factor"], float)
|
||||
self.assertEqual(trials[0].config["int_factor"], 10)
|
||||
self.assertIsInstance(trials[0].config["int_factor"], int)
|
||||
self.assertEqual(trials[0].config["const_factor"], 3)
|
||||
self.assertEqual(trials[0].config["nest"]["nest_float"], 101.0)
|
||||
self.assertIsInstance(trials[0].config["nest"]["nest_float"], float)
|
||||
|
||||
def testYieldsTimeToOtherTrials(self):
|
||||
pbt, runner = self.basicSetup()
|
||||
trials = runner.get_trials()
|
||||
|
||||
Reference in New Issue
Block a user