diff --git a/python/ray/tune/examples/pbt_example.py b/python/ray/tune/examples/pbt_example.py index 70fbfcb62..089f181b1 100755 --- a/python/ray/tune/examples/pbt_example.py +++ b/python/ray/tune/examples/pbt_example.py @@ -68,8 +68,8 @@ if __name__ == "__main__": hyperparam_mutations={ # Allow for scaling-based perturbations, with a uniform backing # distribution for resampling. - "factor_1": lambda config: random.uniform(0.0, 20.0), - # Only allows resampling from this list as a perturbation. + "factor_1": lambda: random.uniform(0.0, 20.0), + # Allow perturbations within this set of categorical values. "factor_2": [1, 2], }) diff --git a/python/ray/tune/examples/pbt_ppo_example.py b/python/ray/tune/examples/pbt_ppo_example.py index c612dd136..02843e7a9 100755 --- a/python/ray/tune/examples/pbt_ppo_example.py +++ b/python/ray/tune/examples/pbt_ppo_example.py @@ -33,15 +33,14 @@ if __name__ == "__main__": time_attr="time_total_s", reward_attr="episode_reward_mean", perturbation_interval=120, resample_probability=0.25, - # Specifies the resampling distributions of these hyperparams + # Specifies the mutations of these hyperparams hyperparam_mutations={ - "lambda": lambda config: random.uniform(0.9, 1.0), - "clip_param": lambda config: random.uniform(0.01, 0.5), - "sgd_stepsize": lambda config: random.uniform(.00001, .001), - "num_sgd_iter": lambda config: random.randint(1, 30), - "sgd_batchsize": lambda config: random.randint(128, 16384), - "timesteps_per_batch": - lambda config: random.randint(2000, 160000), + "lambda": lambda: random.uniform(0.9, 1.0), + "clip_param": lambda: random.uniform(0.01, 0.5), + "sgd_stepsize": [1e-3, 5e-4, 1e-4, 5e-5, 1e-5], + "num_sgd_iter": lambda: random.randint(1, 30), + "sgd_batchsize": lambda: random.randint(128, 16384), + "timesteps_per_batch": lambda: random.randint(2000, 160000), }, custom_explore_fn=explore) @@ -57,11 +56,11 @@ if __name__ == "__main__": "num_workers": 8, "devices": ["/gpu:0"], "model": {"free_log_std": True}, - # These params are tuned from their starting value + # These params are tuned from a fixed starting value. "lambda": 0.95, "clip_param": 0.2, - # Start off with several random variations - "sgd_stepsize": lambda spec: random.uniform(.00001, .001), + "sgd_stepsize": 1e-4, + # These params start off randomly drawn from a set. "num_sgd_iter": lambda spec: random.choice([10, 20, 30]), "sgd_batchsize": lambda spec: random.choice([128, 512, 2048]), "timesteps_per_batch": diff --git a/python/ray/tune/pbt.py b/python/ray/tune/pbt.py index d7745fc20..5c266de5a 100644 --- a/python/ray/tune/pbt.py +++ b/python/ray/tune/pbt.py @@ -47,11 +47,19 @@ def explore(config, mutations, resample_probability, custom_explore_fn): new_config = copy.deepcopy(config) for key, distribution in mutations.items(): if isinstance(distribution, list): - if random.random() < resample_probability: + if random.random() < resample_probability or \ + config[key] not in distribution: new_config[key] = random.choice(distribution) + elif random.random() > 0.5: + new_config[key] = distribution[ + max(0, distribution.index(config[key]) - 1)] + else: + new_config[key] = distribution[ + min(len(distribution) - 1, + distribution.index(config[key]) + 1)] else: if random.random() < resample_probability: - new_config[key] = distribution(config) + new_config[key] = distribution() elif random.random() > 0.5: new_config[key] = config[key] * 1.2 else: @@ -109,14 +117,14 @@ class PopulationBasedTraining(FIFOScheduler): to be too frequent. hyperparam_mutations (dict): Hyperparams to mutate. The format is as follows: for each key, either a list or function can be - provided. A list specifies values for a discrete parameter. + provided. A list specifies an allowed set of categorical values. A function specifies the distribution of a continuous parameter. You must specify at least one of `hyperparam_mutations` or `custom_explore_fn`. resample_probability (float): The probability of resampling from the original distribution when applying `hyperparam_mutations`. If not resampled, the value will be perturbed by a factor of 1.2 or 0.8 - if continuous, or left unchanged if discrete. + if continuous, or changed to an adjacent value if discrete. custom_explore_fn (func): You can also specify a custom exploration function. This function is invoked as `f(config)` after built-in perturbations from `hyperparam_mutations` are applied, and should @@ -130,11 +138,12 @@ class PopulationBasedTraining(FIFOScheduler): >>> perturbation_interval=10, # every 10 `time_attr` units >>> # (training_iterations in this case) >>> hyperparam_mutations={ - >>> # Allow for scaling-based perturbations, with a uniform - >>> # backing distribution for resampling. - >>> "factor_1": lambda config: random.uniform(0.0, 20.0), - >>> # Only allows resampling from this list as a perturbation. - >>> "factor_2": [1, 2], + >>> # Perturb factor1 by scaling it by 0.8 or 1.2. Resampling + >>> # resets it to a value sampled from the lambda function. + >>> "factor_1": lambda: random.uniform(0.0, 20.0), + >>> # Perturb factor2 by changing it to an adjacent value, e.g. + >>> # 10 -> 1 or 10 -> 100. Resampling will choose at random. + >>> "factor_2": [1, 10, 100, 1000, 10000], >>> }) >>> run_experiments({...}, scheduler=pbt) """ diff --git a/python/ray/tune/test/trial_scheduler_test.py b/python/ray/tune/test/trial_scheduler_test.py index 3f8db69ff..a73647c4c 100644 --- a/python/ray/tune/test/trial_scheduler_test.py +++ b/python/ray/tune/test/trial_scheduler_test.py @@ -3,11 +3,12 @@ from __future__ import absolute_import from __future__ import division from __future__ import print_function +import random import unittest import numpy as np from ray.tune.hyperband import HyperBandScheduler -from ray.tune.pbt import PopulationBasedTraining +from ray.tune.pbt import PopulationBasedTraining, explore from ray.tune.median_stopping_rule import MedianStoppingRule from ray.tune.result import TrainingResult from ray.tune.trial import Trial, Resources @@ -551,8 +552,8 @@ class PopulationBasedTestingSuite(unittest.TestCase): resample_probability=resample_prob, hyperparam_mutations={ "id_factor": [100], - "float_factor": lambda c: 100.0, - "int_factor": lambda c: 10, + "float_factor": lambda: 100.0, + "int_factor": lambda: 10, }, custom_explore_fn=explore) runner = _MockTrialRunner(pbt) @@ -644,7 +645,7 @@ class PopulationBasedTestingSuite(unittest.TestCase): pbt.on_trial_result(runner, trials[0], result(20, -100)), TrialScheduler.CONTINUE) self.assertIn(trials[0].restored_checkpoint, ["trial_3", "trial_4"]) - self.assertIn(trials[0].config["id_factor"], [3, 4]) + self.assertIn(trials[0].config["id_factor"], [100]) self.assertIn(trials[0].config["float_factor"], [2.4, 1.6]) self.assertEqual(type(trials[0].config["float_factor"]), float) self.assertIn(trials[0].config["int_factor"], [8, 12]) @@ -665,6 +666,49 @@ class PopulationBasedTestingSuite(unittest.TestCase): self.assertEqual(type(trials[0].config["int_factor"]), int) self.assertEqual(trials[0].config["const_factor"], 3) + def testPerturbationValues(self): + + def assertProduces(fn, values): + random.seed(0) + seen = set() + for _ in range(100): + seen.add(fn()["v"]) + self.assertEqual(seen, values) + + # Categorical case + assertProduces( + lambda: explore({"v": 4}, {"v": [3, 4, 8, 10]}, 0.0, lambda x: x), + set([3, 8])) + assertProduces( + lambda: explore({"v": 3}, {"v": [3, 4, 8, 10]}, 0.0, lambda x: x), + set([3, 4])) + assertProduces( + lambda: explore({"v": 10}, {"v": [3, 4, 8, 10]}, 0.0, lambda x: x), + set([8, 10])) + assertProduces( + lambda: explore({"v": 7}, {"v": [3, 4, 8, 10]}, 0.0, lambda x: x), + set([3, 4, 8, 10])) + assertProduces( + lambda: explore({"v": 4}, {"v": [3, 4, 8, 10]}, 1.0, lambda x: x), + set([3, 4, 8, 10])) + + # Continuous case + assertProduces( + lambda: explore( + {"v": 100}, {"v": lambda: random.choice([10, 100])}, 0.0, + lambda x: x), + set([80, 120])) + assertProduces( + lambda: explore( + {"v": 100.0}, {"v": lambda: random.choice([10, 100])}, 0.0, + lambda x: x), + set([80.0, 120.0])) + assertProduces( + lambda: explore( + {"v": 100.0}, {"v": lambda: random.choice([10, 100])}, 1.0, + lambda x: x), + set([10.0, 100.0])) + def testYieldsTimeToOtherTrials(self): pbt, runner = self.basicSetup() trials = runner.get_trials()