mirror of
https://github.com/wassname/ray.git
synced 2026-06-28 01:00:10 +08:00
[Tune] PBT hyperparam_mutations fix (#10217)
This commit is contained in:
+24
-14
@@ -30,17 +30,19 @@ def function(func):
|
||||
return func
|
||||
|
||||
|
||||
def uniform(*args, **kwargs):
|
||||
class uniform(sample_from):
|
||||
"""Wraps tune.sample_from around ``np.random.uniform``.
|
||||
|
||||
``tune.uniform(1, 10)`` is equivalent to
|
||||
``tune.sample_from(lambda _: np.random.uniform(1, 10))``
|
||||
|
||||
"""
|
||||
return sample_from(lambda _: np.random.uniform(*args, **kwargs))
|
||||
|
||||
def __init__(self, *args, **kwargs):
|
||||
super().__init__(lambda _: np.random.uniform(*args, **kwargs))
|
||||
|
||||
|
||||
def loguniform(min_bound, max_bound, base=10):
|
||||
class loguniform(sample_from):
|
||||
"""Sugar for sampling in different orders of magnitude.
|
||||
|
||||
Args:
|
||||
@@ -48,40 +50,48 @@ def loguniform(min_bound, max_bound, base=10):
|
||||
max_bound (float): Upper boundary of the output interval (1e-2)
|
||||
base (float): Base of the log. Defaults to 10.
|
||||
"""
|
||||
logmin = np.log(min_bound) / np.log(base)
|
||||
logmax = np.log(max_bound) / np.log(base)
|
||||
|
||||
def apply_log(_):
|
||||
return base**(np.random.uniform(logmin, logmax))
|
||||
def __init__(self, min_bound, max_bound, base=10):
|
||||
logmin = np.log(min_bound) / np.log(base)
|
||||
logmax = np.log(max_bound) / np.log(base)
|
||||
|
||||
return sample_from(apply_log)
|
||||
def apply_log(_):
|
||||
return base**(np.random.uniform(logmin, logmax))
|
||||
|
||||
super().__init__(apply_log)
|
||||
|
||||
|
||||
def choice(*args, **kwargs):
|
||||
class choice(sample_from):
|
||||
"""Wraps tune.sample_from around ``random.choice``.
|
||||
|
||||
``tune.choice([1, 2])`` is equivalent to
|
||||
``tune.sample_from(lambda _: random.choice([1, 2]))``
|
||||
|
||||
"""
|
||||
return sample_from(lambda _: random.choice(*args, **kwargs))
|
||||
|
||||
def __init__(self, *args, **kwargs):
|
||||
super().__init__(lambda _: random.choice(*args, **kwargs))
|
||||
|
||||
|
||||
def randint(*args, **kwargs):
|
||||
class randint(sample_from):
|
||||
"""Wraps tune.sample_from around ``np.random.randint``.
|
||||
|
||||
``tune.randint(10)`` is equivalent to
|
||||
``tune.sample_from(lambda _: np.random.randint(10))``
|
||||
|
||||
"""
|
||||
return sample_from(lambda _: np.random.randint(*args, **kwargs))
|
||||
|
||||
def __init__(self, *args, **kwargs):
|
||||
super().__init__(lambda _: np.random.randint(*args, **kwargs))
|
||||
|
||||
|
||||
def randn(*args, **kwargs):
|
||||
class randn(sample_from):
|
||||
"""Wraps tune.sample_from around ``np.random.randn``.
|
||||
|
||||
``tune.randn(10)`` is equivalent to
|
||||
``tune.sample_from(lambda _: np.random.randn(10))``
|
||||
|
||||
"""
|
||||
return sample_from(lambda _: np.random.randn(*args, **kwargs))
|
||||
|
||||
def __init__(self, *args, **kwargs):
|
||||
super().__init__(lambda _: np.random.randn(*args, **kwargs))
|
||||
|
||||
@@ -65,10 +65,9 @@ def explore(config, mutations, resample_probability, custom_explore_fn):
|
||||
len(distribution) - 1,
|
||||
distribution.index(config[key]) + 1)]
|
||||
else:
|
||||
if isinstance(distribution, sample_from):
|
||||
distribution = distribution.func(None)
|
||||
if random.random() < resample_probability:
|
||||
new_config[key] = distribution()
|
||||
new_config[key] = distribution.func(None) if isinstance(
|
||||
distribution, sample_from) else distribution()
|
||||
elif random.random() > 0.5:
|
||||
new_config[key] = config[key] * 1.2
|
||||
else:
|
||||
@@ -145,10 +144,12 @@ class PopulationBasedTraining(FIFOScheduler):
|
||||
to be too frequent.
|
||||
hyperparam_mutations (dict): Hyperparams to mutate. The format is
|
||||
as follows: for each key, either a list, function,
|
||||
or a tune search space object (tune.sample_from, tune.uniform,
|
||||
or a tune search space object (tune.loguniform, tune.uniform,
|
||||
etc.) can be provided. A list specifies an allowed set of
|
||||
categorical values. A function or tune search space object
|
||||
specifies the distribution of a continuous parameter.
|
||||
specifies the distribution of a continuous parameter. You must
|
||||
use tune.choice, tune.uniform, tune.loguniform, etc.. Arbitrary
|
||||
tune.sample_from objects are not supported.
|
||||
You must specify at least one of `hyperparam_mutations` or
|
||||
`custom_explore_fn`.
|
||||
Tune will use the search space provided by
|
||||
@@ -221,6 +222,11 @@ class PopulationBasedTraining(FIFOScheduler):
|
||||
raise TypeError("`hyperparam_mutation` values must be either "
|
||||
"a List, Dict, a tune search space object, or "
|
||||
"callable.")
|
||||
if type(value) is sample_from:
|
||||
raise ValueError("arbitrary tune.sample_from objects are not "
|
||||
"supported for `hyperparam_mutation` values."
|
||||
"You must use other built in primitives like"
|
||||
"tune.uniform, tune.loguniform, etc.")
|
||||
|
||||
if not hyperparam_mutations and not custom_explore_fn:
|
||||
raise TuneError(
|
||||
|
||||
@@ -879,6 +879,32 @@ class PopulationBasedTestingSuite(unittest.TestCase):
|
||||
self.assertEqual(type(trials[0].config["int_factor"]), int)
|
||||
self.assertEqual(trials[0].config["const_factor"], 3)
|
||||
|
||||
def testTuneSamplePrimitives(self):
|
||||
pbt, runner = self.basicSetup(
|
||||
resample_prob=1.0,
|
||||
hyperparam_mutations={
|
||||
"float_factor": lambda: 100.0,
|
||||
"int_factor": lambda: 10,
|
||||
"id_factor": tune.choice([100])
|
||||
})
|
||||
trials = runner.get_trials()
|
||||
self.assertEqual(
|
||||
pbt.on_trial_result(runner, trials[0], result(20, -100)),
|
||||
TrialScheduler.CONTINUE)
|
||||
self.assertIn(trials[0].restored_checkpoint, ["trial_3", "trial_4"])
|
||||
self.assertEqual(trials[0].config["id_factor"], 100)
|
||||
self.assertEqual(trials[0].config["float_factor"], 100.0)
|
||||
self.assertEqual(type(trials[0].config["float_factor"]), float)
|
||||
self.assertEqual(trials[0].config["int_factor"], 10)
|
||||
self.assertEqual(type(trials[0].config["int_factor"]), int)
|
||||
self.assertEqual(trials[0].config["const_factor"], 3)
|
||||
|
||||
def testTuneSampleFromError(self):
|
||||
with self.assertRaises(ValueError):
|
||||
pbt, runner = self.basicSetup(hyperparam_mutations={
|
||||
"float_factor": tune.sample_from(lambda: 100.0)
|
||||
})
|
||||
|
||||
def testPerturbationValues(self):
|
||||
def assertProduces(fn, values):
|
||||
random.seed(0)
|
||||
|
||||
Reference in New Issue
Block a user