diff --git a/python/ray/tune/examples/mnist_pytorch.py b/python/ray/tune/examples/mnist_pytorch.py index b6858cea2..f05db73e5 100644 --- a/python/ray/tune/examples/mnist_pytorch.py +++ b/python/ray/tune/examples/mnist_pytorch.py @@ -117,6 +117,8 @@ if __name__ == "__main__": args = parser.parse_args() if args.ray_address: ray.init(address=args.ray_address) + else: + ray.init(num_cpus=2 if args.smoke_test else None) sched = AsyncHyperBandScheduler( time_attr="training_iteration", metric="mean_accuracy") analysis = tune.run( diff --git a/python/ray/tune/examples/mnist_pytorch_trainable.py b/python/ray/tune/examples/mnist_pytorch_trainable.py index f006848f0..43f51cc7b 100644 --- a/python/ray/tune/examples/mnist_pytorch_trainable.py +++ b/python/ray/tune/examples/mnist_pytorch_trainable.py @@ -64,7 +64,7 @@ class TrainMNIST(tune.Trainable): if __name__ == "__main__": args = parser.parse_args() - ray.init(address=args.ray_address) + ray.init(address=args.ray_address, num_cpus=6 if args.smoke_test else None) sched = ASHAScheduler(metric="mean_accuracy") analysis = tune.run( TrainMNIST, diff --git a/python/ray/tune/examples/tune_mnist_keras.py b/python/ray/tune/examples/tune_mnist_keras.py index c2ac39486..1d043c74f 100644 --- a/python/ray/tune/examples/tune_mnist_keras.py +++ b/python/ray/tune/examples/tune_mnist_keras.py @@ -48,7 +48,7 @@ if __name__ == "__main__": from ray.tune.schedulers import AsyncHyperBandScheduler mnist.load_data() # we do this on the driver because it's not threadsafe - ray.init() + ray.init(num_cpus=2 if args.smoke_test else None) sched = AsyncHyperBandScheduler( time_attr="training_iteration", metric="mean_accuracy", diff --git a/python/ray/tune/schedulers/pbt.py b/python/ray/tune/schedulers/pbt.py index 344db3b2b..d868b8289 100644 --- a/python/ray/tune/schedulers/pbt.py +++ b/python/ray/tune/schedulers/pbt.py @@ -179,9 +179,9 @@ class PopulationBasedTraining(FIFOScheduler): custom_explore_fn=None, log_config=True): for value in hyperparam_mutations.values(): - if not (isinstance(value, list) or callable(value)): + if not (isinstance(value, (list, dict)) or callable(value)): raise TypeError("`hyperparam_mutation` values must be either " - "a List or callable.") + "a List, Dict, or callable.") if not hyperparam_mutations and not custom_explore_fn: raise TuneError( diff --git a/python/ray/tune/tests/test_trial_scheduler.py b/python/ray/tune/tests/test_trial_scheduler.py index b8c12cbfb..f8e102220 100644 --- a/python/ray/tune/tests/test_trial_scheduler.py +++ b/python/ray/tune/tests/test_trial_scheduler.py @@ -711,28 +711,31 @@ class PopulationBasedTestingSuite(unittest.TestCase): explore=None, perturbation_interval=10, log_config=False, + hyperparams=None, + hyperparam_mutations=None, step_once=True): + hyperparam_mutations = hyperparam_mutations or { + "float_factor": lambda: 100.0, + "int_factor": lambda: 10, + "id_factor": [100] + } pbt = PopulationBasedTraining( time_attr="training_iteration", perturbation_interval=perturbation_interval, resample_probability=resample_prob, quantile_fraction=0.25, - hyperparam_mutations={ - "id_factor": [100], - "float_factor": lambda: 100.0, - "int_factor": lambda: 10, - }, + hyperparam_mutations=hyperparam_mutations, custom_explore_fn=explore, log_config=log_config) runner = _MockTrialRunner(pbt) for i in range(5): - trial = _MockTrial( - i, { - "id_factor": i, - "float_factor": 2.0, - "const_factor": 3, - "int_factor": 10 - }) + trial_hyperparams = hyperparams or { + "float_factor": 2.0, + "const_factor": 3, + "int_factor": 10, + "id_factor": i + } + trial = _MockTrial(i, trial_hyperparams) runner.add_trial(trial) trial.status = Trial.RUNNING if step_once: @@ -958,6 +961,37 @@ class PopulationBasedTestingSuite(unittest.TestCase): # Expect call count to be 100 because we call explore 100 times self.assertEqual(custom_explore_fn.call_count, 100) + def testDictPerturbation(self): + pbt, runner = self.basicSetup( + resample_prob=1.0, + hyperparams={ + "float_factor": 2.0, + "nest": { + "nest_float": 3.0 + }, + "int_factor": 10, + "const_factor": 3 + }, + hyperparam_mutations={ + "float_factor": lambda: 100.0, + "nest": { + "nest_float": lambda: 101.0 + }, + "int_factor": lambda: 10, + }) + trials = runner.get_trials() + self.assertEqual( + pbt.on_trial_result(runner, trials[0], result(20, -100)), + TrialScheduler.CONTINUE) + self.assertIn(trials[0].restored_checkpoint, ["trial_3", "trial_4"]) + self.assertEqual(trials[0].config["float_factor"], 100.0) + self.assertIsInstance(trials[0].config["float_factor"], float) + self.assertEqual(trials[0].config["int_factor"], 10) + self.assertIsInstance(trials[0].config["int_factor"], int) + self.assertEqual(trials[0].config["const_factor"], 3) + self.assertEqual(trials[0].config["nest"]["nest_float"], 101.0) + self.assertIsInstance(trials[0].config["nest"]["nest_float"], float) + def testYieldsTimeToOtherTrials(self): pbt, runner = self.basicSetup() trials = runner.get_trials()