[Tune] PBT hyperparam_mutations improvements (#10170)

This commit is contained in:
Amog Kamsetty
2020-08-19 16:50:19 -07:00
committed by GitHub
parent 5d265e9bd1
commit 44e254788a
2 changed files with 103 additions and 9 deletions
+61 -9
View File
@@ -9,6 +9,7 @@ import shutil
from ray.tune.error import TuneError
from ray.tune.result import TRAINING_ITERATION
from ray.tune.logger import _SafeFallbackEncoder
from ray.tune.sample import sample_from
from ray.tune.schedulers import FIFOScheduler, TrialScheduler
from ray.tune.suggest.variant_generator import format_vars
from ray.tune.trial import Trial, Checkpoint
@@ -64,6 +65,8 @@ def explore(config, mutations, resample_probability, custom_explore_fn):
len(distribution) - 1,
distribution.index(config[key]) + 1)]
else:
if isinstance(distribution, sample_from):
distribution = distribution.func(None)
if random.random() < resample_probability:
new_config[key] = distribution()
elif random.random() > 0.5:
@@ -76,8 +79,6 @@ def explore(config, mutations, resample_probability, custom_explore_fn):
new_config = custom_explore_fn(new_config)
assert new_config is not None, \
"Custom explore fn failed to return new config"
logger.info("[explore] perturbed config from {} -> {}".format(
config, new_config))
return new_config
@@ -90,6 +91,20 @@ def make_experiment_tag(orig_tag, config, mutations):
return "{}@perturbed[{}]".format(orig_tag, format_vars(resolved_vars))
def fill_config(config, attr, search_space):
"""Add attr to config by sampling from search_space."""
if callable(search_space):
config[attr] = search_space()
elif isinstance(search_space, sample_from):
config[attr] = search_space.func(None)
elif isinstance(search_space, list):
config[attr] = random.choice(search_space)
elif isinstance(search_space, dict):
config[attr] = {}
for k, v in search_space.items():
fill_config(config[attr], k, v)
class PopulationBasedTraining(FIFOScheduler):
"""Implements the Population Based Training (PBT) algorithm.
@@ -129,11 +144,16 @@ class PopulationBasedTraining(FIFOScheduler):
perturbation incurs checkpoint overhead, so you shouldn't set this
to be too frequent.
hyperparam_mutations (dict): Hyperparams to mutate. The format is
as follows: for each key, either a list or function can be
provided. A list specifies an allowed set of categorical values.
A function specifies the distribution of a continuous parameter.
as follows: for each key, either a list, function,
or a tune search space object (tune.sample_from, tune.uniform,
etc.) can be provided. A list specifies an allowed set of
categorical values. A function or tune search space object
specifies the distribution of a continuous parameter.
You must specify at least one of `hyperparam_mutations` or
`custom_explore_fn`.
Tune will use the search space provided by
`hyperparam_mutations` for the initial samples if the
corresponding attributes are not present in `config`.
quantile_fraction (float): Parameters are transferred from the top
`quantile_fraction` fraction of trials to the bottom
`quantile_fraction` fraction. Needs to be between 0 and 0.5.
@@ -170,9 +190,15 @@ class PopulationBasedTraining(FIFOScheduler):
# Perturb factor1 by scaling it by 0.8 or 1.2. Resampling
# resets it to a value sampled from the lambda function.
"factor_1": lambda: random.uniform(0.0, 20.0),
# Perturb factor2 by changing it to an adjacent value, e.g.
# Alternatively, use tune search space primitives.
# The search space for factor_1 is equivalent to factor_2.
"factor_2": tune.uniform(0.0, 20.0),
# Perturb factor3 by changing it to an adjacent value, e.g.
# 10 -> 1 or 10 -> 100. Resampling will choose at random.
"factor_2": [1, 10, 100, 1000, 10000],
"factor_3": [1, 10, 100, 1000, 10000],
# Using tune.choice is NOT equivalent to the above.
# factor_4 is treated as a continuous hyperparameter.
"factor_4": tune.choice([1, 10, 100, 1000, 10000]),
})
tune.run({...}, num_samples=8, scheduler=pbt)
"""
@@ -190,9 +216,11 @@ class PopulationBasedTraining(FIFOScheduler):
log_config=True,
require_attrs=True):
for value in hyperparam_mutations.values():
if not (isinstance(value, (list, dict)) or callable(value)):
if not (isinstance(value,
(list, dict, sample_from)) or callable(value)):
raise TypeError("`hyperparam_mutation` values must be either "
"a List, Dict, or callable.")
"a List, Dict, a tune search space object, or "
"callable.")
if not hyperparam_mutations and not custom_explore_fn:
raise TuneError(
@@ -237,6 +265,18 @@ class PopulationBasedTraining(FIFOScheduler):
def on_trial_add(self, trial_runner, trial):
self._trial_state[trial] = PBTTrialState(trial)
for attr in self._hyperparam_mutations.keys():
if attr not in trial.config:
if log_once(attr + "-missing"):
logger.debug("Cannot find {} in config. Using search "
"space provided by hyperparam_mutations.")
# Add attr to trial's config by sampling search space from
# hyperparam_mutations.
fill_config(trial.config, attr,
self._hyperparam_mutations[attr])
# Make sure this attribute is added to CLI output.
trial.evaluated_params[attr] = trial.config[attr]
def on_trial_result(self, trial_runner, trial, result):
if self._time_attr not in result:
time_missing_msg = "Cannot find time_attr {} " \
@@ -352,6 +392,18 @@ class PopulationBasedTraining(FIFOScheduler):
"{} (score {}) -> {} (score {})".format(
trial_to_clone, new_state.last_score, trial,
trial_state.last_score))
# Only log mutated hyperparameters and not entire config.
old_hparams = {
k: v
for k, v in trial_to_clone.config.items()
if k in self._hyperparam_mutations
}
new_hparams = {
k: v
for k, v in new_config.items() if k in self._hyperparam_mutations
}
logger.info("[explore] perturbed config from {} -> {}".format(
old_hparams, new_hparams))
if self._log_config:
self._log_config_on_step(trial_state, new_state, trial,
@@ -52,6 +52,16 @@ def MockTrainingFunc(config, checkpoint_dir=None):
tune.report(mean_accuracy=(a - iter) * b)
def MockTrainingFunc2(config):
a = config["a"]
b = config["b"]
c1 = config["c"]["c1"]
c2 = config["c"]["c2"]
while True:
tune.report(mean_accuracy=a * b * (c1 + c2))
class MockParam(object):
def __init__(self, params):
self._params = params
@@ -63,6 +73,38 @@ class MockParam(object):
return val
class PopulationBasedTrainingConfigTest(unittest.TestCase):
def setUp(self):
ray.init()
def tearDown(self):
ray.shutdown()
def testNoConfig(self):
scheduler = PopulationBasedTraining(
time_attr="training_iteration",
metric="mean_accuracy",
mode="max",
perturbation_interval=1,
hyperparam_mutations={
"a": tune.uniform(0, 0.3),
"b": [1, 2, 3],
"c": {
"c1": lambda: np.random.uniform(0.5),
"c2": tune.choice([2, 3, 4])
}
},
)
tune.run(
MockTrainingFunc2,
fail_fast=True,
num_samples=4,
scheduler=scheduler,
name="testNoConfig",
stop={"training_iteration": 3})
class PopulationBasedTrainingResumeTest(unittest.TestCase):
def setUp(self):
ray.init()