mirror of
https://github.com/wassname/ray.git
synced 2026-06-29 03:21:06 +08:00
[Tune] PBT hyperparam_mutations improvements (#10170)
This commit is contained in:
@@ -9,6 +9,7 @@ import shutil
|
||||
from ray.tune.error import TuneError
|
||||
from ray.tune.result import TRAINING_ITERATION
|
||||
from ray.tune.logger import _SafeFallbackEncoder
|
||||
from ray.tune.sample import sample_from
|
||||
from ray.tune.schedulers import FIFOScheduler, TrialScheduler
|
||||
from ray.tune.suggest.variant_generator import format_vars
|
||||
from ray.tune.trial import Trial, Checkpoint
|
||||
@@ -64,6 +65,8 @@ def explore(config, mutations, resample_probability, custom_explore_fn):
|
||||
len(distribution) - 1,
|
||||
distribution.index(config[key]) + 1)]
|
||||
else:
|
||||
if isinstance(distribution, sample_from):
|
||||
distribution = distribution.func(None)
|
||||
if random.random() < resample_probability:
|
||||
new_config[key] = distribution()
|
||||
elif random.random() > 0.5:
|
||||
@@ -76,8 +79,6 @@ def explore(config, mutations, resample_probability, custom_explore_fn):
|
||||
new_config = custom_explore_fn(new_config)
|
||||
assert new_config is not None, \
|
||||
"Custom explore fn failed to return new config"
|
||||
logger.info("[explore] perturbed config from {} -> {}".format(
|
||||
config, new_config))
|
||||
return new_config
|
||||
|
||||
|
||||
@@ -90,6 +91,20 @@ def make_experiment_tag(orig_tag, config, mutations):
|
||||
return "{}@perturbed[{}]".format(orig_tag, format_vars(resolved_vars))
|
||||
|
||||
|
||||
def fill_config(config, attr, search_space):
|
||||
"""Add attr to config by sampling from search_space."""
|
||||
if callable(search_space):
|
||||
config[attr] = search_space()
|
||||
elif isinstance(search_space, sample_from):
|
||||
config[attr] = search_space.func(None)
|
||||
elif isinstance(search_space, list):
|
||||
config[attr] = random.choice(search_space)
|
||||
elif isinstance(search_space, dict):
|
||||
config[attr] = {}
|
||||
for k, v in search_space.items():
|
||||
fill_config(config[attr], k, v)
|
||||
|
||||
|
||||
class PopulationBasedTraining(FIFOScheduler):
|
||||
"""Implements the Population Based Training (PBT) algorithm.
|
||||
|
||||
@@ -129,11 +144,16 @@ class PopulationBasedTraining(FIFOScheduler):
|
||||
perturbation incurs checkpoint overhead, so you shouldn't set this
|
||||
to be too frequent.
|
||||
hyperparam_mutations (dict): Hyperparams to mutate. The format is
|
||||
as follows: for each key, either a list or function can be
|
||||
provided. A list specifies an allowed set of categorical values.
|
||||
A function specifies the distribution of a continuous parameter.
|
||||
as follows: for each key, either a list, function,
|
||||
or a tune search space object (tune.sample_from, tune.uniform,
|
||||
etc.) can be provided. A list specifies an allowed set of
|
||||
categorical values. A function or tune search space object
|
||||
specifies the distribution of a continuous parameter.
|
||||
You must specify at least one of `hyperparam_mutations` or
|
||||
`custom_explore_fn`.
|
||||
Tune will use the search space provided by
|
||||
`hyperparam_mutations` for the initial samples if the
|
||||
corresponding attributes are not present in `config`.
|
||||
quantile_fraction (float): Parameters are transferred from the top
|
||||
`quantile_fraction` fraction of trials to the bottom
|
||||
`quantile_fraction` fraction. Needs to be between 0 and 0.5.
|
||||
@@ -170,9 +190,15 @@ class PopulationBasedTraining(FIFOScheduler):
|
||||
# Perturb factor1 by scaling it by 0.8 or 1.2. Resampling
|
||||
# resets it to a value sampled from the lambda function.
|
||||
"factor_1": lambda: random.uniform(0.0, 20.0),
|
||||
# Perturb factor2 by changing it to an adjacent value, e.g.
|
||||
# Alternatively, use tune search space primitives.
|
||||
# The search space for factor_1 is equivalent to factor_2.
|
||||
"factor_2": tune.uniform(0.0, 20.0),
|
||||
# Perturb factor3 by changing it to an adjacent value, e.g.
|
||||
# 10 -> 1 or 10 -> 100. Resampling will choose at random.
|
||||
"factor_2": [1, 10, 100, 1000, 10000],
|
||||
"factor_3": [1, 10, 100, 1000, 10000],
|
||||
# Using tune.choice is NOT equivalent to the above.
|
||||
# factor_4 is treated as a continuous hyperparameter.
|
||||
"factor_4": tune.choice([1, 10, 100, 1000, 10000]),
|
||||
})
|
||||
tune.run({...}, num_samples=8, scheduler=pbt)
|
||||
"""
|
||||
@@ -190,9 +216,11 @@ class PopulationBasedTraining(FIFOScheduler):
|
||||
log_config=True,
|
||||
require_attrs=True):
|
||||
for value in hyperparam_mutations.values():
|
||||
if not (isinstance(value, (list, dict)) or callable(value)):
|
||||
if not (isinstance(value,
|
||||
(list, dict, sample_from)) or callable(value)):
|
||||
raise TypeError("`hyperparam_mutation` values must be either "
|
||||
"a List, Dict, or callable.")
|
||||
"a List, Dict, a tune search space object, or "
|
||||
"callable.")
|
||||
|
||||
if not hyperparam_mutations and not custom_explore_fn:
|
||||
raise TuneError(
|
||||
@@ -237,6 +265,18 @@ class PopulationBasedTraining(FIFOScheduler):
|
||||
def on_trial_add(self, trial_runner, trial):
|
||||
self._trial_state[trial] = PBTTrialState(trial)
|
||||
|
||||
for attr in self._hyperparam_mutations.keys():
|
||||
if attr not in trial.config:
|
||||
if log_once(attr + "-missing"):
|
||||
logger.debug("Cannot find {} in config. Using search "
|
||||
"space provided by hyperparam_mutations.")
|
||||
# Add attr to trial's config by sampling search space from
|
||||
# hyperparam_mutations.
|
||||
fill_config(trial.config, attr,
|
||||
self._hyperparam_mutations[attr])
|
||||
# Make sure this attribute is added to CLI output.
|
||||
trial.evaluated_params[attr] = trial.config[attr]
|
||||
|
||||
def on_trial_result(self, trial_runner, trial, result):
|
||||
if self._time_attr not in result:
|
||||
time_missing_msg = "Cannot find time_attr {} " \
|
||||
@@ -352,6 +392,18 @@ class PopulationBasedTraining(FIFOScheduler):
|
||||
"{} (score {}) -> {} (score {})".format(
|
||||
trial_to_clone, new_state.last_score, trial,
|
||||
trial_state.last_score))
|
||||
# Only log mutated hyperparameters and not entire config.
|
||||
old_hparams = {
|
||||
k: v
|
||||
for k, v in trial_to_clone.config.items()
|
||||
if k in self._hyperparam_mutations
|
||||
}
|
||||
new_hparams = {
|
||||
k: v
|
||||
for k, v in new_config.items() if k in self._hyperparam_mutations
|
||||
}
|
||||
logger.info("[explore] perturbed config from {} -> {}".format(
|
||||
old_hparams, new_hparams))
|
||||
|
||||
if self._log_config:
|
||||
self._log_config_on_step(trial_state, new_state, trial,
|
||||
|
||||
@@ -52,6 +52,16 @@ def MockTrainingFunc(config, checkpoint_dir=None):
|
||||
tune.report(mean_accuracy=(a - iter) * b)
|
||||
|
||||
|
||||
def MockTrainingFunc2(config):
|
||||
a = config["a"]
|
||||
b = config["b"]
|
||||
c1 = config["c"]["c1"]
|
||||
c2 = config["c"]["c2"]
|
||||
|
||||
while True:
|
||||
tune.report(mean_accuracy=a * b * (c1 + c2))
|
||||
|
||||
|
||||
class MockParam(object):
|
||||
def __init__(self, params):
|
||||
self._params = params
|
||||
@@ -63,6 +73,38 @@ class MockParam(object):
|
||||
return val
|
||||
|
||||
|
||||
class PopulationBasedTrainingConfigTest(unittest.TestCase):
|
||||
def setUp(self):
|
||||
ray.init()
|
||||
|
||||
def tearDown(self):
|
||||
ray.shutdown()
|
||||
|
||||
def testNoConfig(self):
|
||||
scheduler = PopulationBasedTraining(
|
||||
time_attr="training_iteration",
|
||||
metric="mean_accuracy",
|
||||
mode="max",
|
||||
perturbation_interval=1,
|
||||
hyperparam_mutations={
|
||||
"a": tune.uniform(0, 0.3),
|
||||
"b": [1, 2, 3],
|
||||
"c": {
|
||||
"c1": lambda: np.random.uniform(0.5),
|
||||
"c2": tune.choice([2, 3, 4])
|
||||
}
|
||||
},
|
||||
)
|
||||
|
||||
tune.run(
|
||||
MockTrainingFunc2,
|
||||
fail_fast=True,
|
||||
num_samples=4,
|
||||
scheduler=scheduler,
|
||||
name="testNoConfig",
|
||||
stop={"training_iteration": 3})
|
||||
|
||||
|
||||
class PopulationBasedTrainingResumeTest(unittest.TestCase):
|
||||
def setUp(self):
|
||||
ray.init()
|
||||
|
||||
Reference in New Issue
Block a user