diff --git a/python/ray/tune/schedulers/pbt.py b/python/ray/tune/schedulers/pbt.py index 6711f0817..45d325203 100644 --- a/python/ray/tune/schedulers/pbt.py +++ b/python/ray/tune/schedulers/pbt.py @@ -9,6 +9,7 @@ import shutil from ray.tune.error import TuneError from ray.tune.result import TRAINING_ITERATION from ray.tune.logger import _SafeFallbackEncoder +from ray.tune.sample import sample_from from ray.tune.schedulers import FIFOScheduler, TrialScheduler from ray.tune.suggest.variant_generator import format_vars from ray.tune.trial import Trial, Checkpoint @@ -64,6 +65,8 @@ def explore(config, mutations, resample_probability, custom_explore_fn): len(distribution) - 1, distribution.index(config[key]) + 1)] else: + if isinstance(distribution, sample_from): + distribution = distribution.func(None) if random.random() < resample_probability: new_config[key] = distribution() elif random.random() > 0.5: @@ -76,8 +79,6 @@ def explore(config, mutations, resample_probability, custom_explore_fn): new_config = custom_explore_fn(new_config) assert new_config is not None, \ "Custom explore fn failed to return new config" - logger.info("[explore] perturbed config from {} -> {}".format( - config, new_config)) return new_config @@ -90,6 +91,20 @@ def make_experiment_tag(orig_tag, config, mutations): return "{}@perturbed[{}]".format(orig_tag, format_vars(resolved_vars)) +def fill_config(config, attr, search_space): + """Add attr to config by sampling from search_space.""" + if callable(search_space): + config[attr] = search_space() + elif isinstance(search_space, sample_from): + config[attr] = search_space.func(None) + elif isinstance(search_space, list): + config[attr] = random.choice(search_space) + elif isinstance(search_space, dict): + config[attr] = {} + for k, v in search_space.items(): + fill_config(config[attr], k, v) + + class PopulationBasedTraining(FIFOScheduler): """Implements the Population Based Training (PBT) algorithm. @@ -129,11 +144,16 @@ class PopulationBasedTraining(FIFOScheduler): perturbation incurs checkpoint overhead, so you shouldn't set this to be too frequent. hyperparam_mutations (dict): Hyperparams to mutate. The format is - as follows: for each key, either a list or function can be - provided. A list specifies an allowed set of categorical values. - A function specifies the distribution of a continuous parameter. + as follows: for each key, either a list, function, + or a tune search space object (tune.sample_from, tune.uniform, + etc.) can be provided. A list specifies an allowed set of + categorical values. A function or tune search space object + specifies the distribution of a continuous parameter. You must specify at least one of `hyperparam_mutations` or `custom_explore_fn`. + Tune will use the search space provided by + `hyperparam_mutations` for the initial samples if the + corresponding attributes are not present in `config`. quantile_fraction (float): Parameters are transferred from the top `quantile_fraction` fraction of trials to the bottom `quantile_fraction` fraction. Needs to be between 0 and 0.5. @@ -170,9 +190,15 @@ class PopulationBasedTraining(FIFOScheduler): # Perturb factor1 by scaling it by 0.8 or 1.2. Resampling # resets it to a value sampled from the lambda function. "factor_1": lambda: random.uniform(0.0, 20.0), - # Perturb factor2 by changing it to an adjacent value, e.g. + # Alternatively, use tune search space primitives. + # The search space for factor_1 is equivalent to factor_2. + "factor_2": tune.uniform(0.0, 20.0), + # Perturb factor3 by changing it to an adjacent value, e.g. # 10 -> 1 or 10 -> 100. Resampling will choose at random. - "factor_2": [1, 10, 100, 1000, 10000], + "factor_3": [1, 10, 100, 1000, 10000], + # Using tune.choice is NOT equivalent to the above. + # factor_4 is treated as a continuous hyperparameter. + "factor_4": tune.choice([1, 10, 100, 1000, 10000]), }) tune.run({...}, num_samples=8, scheduler=pbt) """ @@ -190,9 +216,11 @@ class PopulationBasedTraining(FIFOScheduler): log_config=True, require_attrs=True): for value in hyperparam_mutations.values(): - if not (isinstance(value, (list, dict)) or callable(value)): + if not (isinstance(value, + (list, dict, sample_from)) or callable(value)): raise TypeError("`hyperparam_mutation` values must be either " - "a List, Dict, or callable.") + "a List, Dict, a tune search space object, or " + "callable.") if not hyperparam_mutations and not custom_explore_fn: raise TuneError( @@ -237,6 +265,18 @@ class PopulationBasedTraining(FIFOScheduler): def on_trial_add(self, trial_runner, trial): self._trial_state[trial] = PBTTrialState(trial) + for attr in self._hyperparam_mutations.keys(): + if attr not in trial.config: + if log_once(attr + "-missing"): + logger.debug("Cannot find {} in config. Using search " + "space provided by hyperparam_mutations.") + # Add attr to trial's config by sampling search space from + # hyperparam_mutations. + fill_config(trial.config, attr, + self._hyperparam_mutations[attr]) + # Make sure this attribute is added to CLI output. + trial.evaluated_params[attr] = trial.config[attr] + def on_trial_result(self, trial_runner, trial, result): if self._time_attr not in result: time_missing_msg = "Cannot find time_attr {} " \ @@ -352,6 +392,18 @@ class PopulationBasedTraining(FIFOScheduler): "{} (score {}) -> {} (score {})".format( trial_to_clone, new_state.last_score, trial, trial_state.last_score)) + # Only log mutated hyperparameters and not entire config. + old_hparams = { + k: v + for k, v in trial_to_clone.config.items() + if k in self._hyperparam_mutations + } + new_hparams = { + k: v + for k, v in new_config.items() if k in self._hyperparam_mutations + } + logger.info("[explore] perturbed config from {} -> {}".format( + old_hparams, new_hparams)) if self._log_config: self._log_config_on_step(trial_state, new_state, trial, diff --git a/python/ray/tune/tests/test_trial_scheduler_pbt.py b/python/ray/tune/tests/test_trial_scheduler_pbt.py index 0cf4f8035..f4cfbdaad 100644 --- a/python/ray/tune/tests/test_trial_scheduler_pbt.py +++ b/python/ray/tune/tests/test_trial_scheduler_pbt.py @@ -52,6 +52,16 @@ def MockTrainingFunc(config, checkpoint_dir=None): tune.report(mean_accuracy=(a - iter) * b) +def MockTrainingFunc2(config): + a = config["a"] + b = config["b"] + c1 = config["c"]["c1"] + c2 = config["c"]["c2"] + + while True: + tune.report(mean_accuracy=a * b * (c1 + c2)) + + class MockParam(object): def __init__(self, params): self._params = params @@ -63,6 +73,38 @@ class MockParam(object): return val +class PopulationBasedTrainingConfigTest(unittest.TestCase): + def setUp(self): + ray.init() + + def tearDown(self): + ray.shutdown() + + def testNoConfig(self): + scheduler = PopulationBasedTraining( + time_attr="training_iteration", + metric="mean_accuracy", + mode="max", + perturbation_interval=1, + hyperparam_mutations={ + "a": tune.uniform(0, 0.3), + "b": [1, 2, 3], + "c": { + "c1": lambda: np.random.uniform(0.5), + "c2": tune.choice([2, 3, 4]) + } + }, + ) + + tune.run( + MockTrainingFunc2, + fail_fast=True, + num_samples=4, + scheduler=scheduler, + name="testNoConfig", + stop={"training_iteration": 3}) + + class PopulationBasedTrainingResumeTest(unittest.TestCase): def setUp(self): ray.init()