From 8d466749eed111c6d2852574e58fcfe2575f7ef8 Mon Sep 17 00:00:00 2001 From: Amog Kamsetty Date: Thu, 20 Aug 2020 12:02:29 -0700 Subject: [PATCH] [Tune] PBT hyperparam_mutations fix (#10217) --- python/ray/tune/sample.py | 38 ++++++++++++------- python/ray/tune/schedulers/pbt.py | 16 +++++--- python/ray/tune/tests/test_trial_scheduler.py | 26 +++++++++++++ 3 files changed, 61 insertions(+), 19 deletions(-) diff --git a/python/ray/tune/sample.py b/python/ray/tune/sample.py index dc952f7b5..4cac8572b 100644 --- a/python/ray/tune/sample.py +++ b/python/ray/tune/sample.py @@ -30,17 +30,19 @@ def function(func): return func -def uniform(*args, **kwargs): +class uniform(sample_from): """Wraps tune.sample_from around ``np.random.uniform``. ``tune.uniform(1, 10)`` is equivalent to ``tune.sample_from(lambda _: np.random.uniform(1, 10))`` """ - return sample_from(lambda _: np.random.uniform(*args, **kwargs)) + + def __init__(self, *args, **kwargs): + super().__init__(lambda _: np.random.uniform(*args, **kwargs)) -def loguniform(min_bound, max_bound, base=10): +class loguniform(sample_from): """Sugar for sampling in different orders of magnitude. Args: @@ -48,40 +50,48 @@ def loguniform(min_bound, max_bound, base=10): max_bound (float): Upper boundary of the output interval (1e-2) base (float): Base of the log. Defaults to 10. """ - logmin = np.log(min_bound) / np.log(base) - logmax = np.log(max_bound) / np.log(base) - def apply_log(_): - return base**(np.random.uniform(logmin, logmax)) + def __init__(self, min_bound, max_bound, base=10): + logmin = np.log(min_bound) / np.log(base) + logmax = np.log(max_bound) / np.log(base) - return sample_from(apply_log) + def apply_log(_): + return base**(np.random.uniform(logmin, logmax)) + + super().__init__(apply_log) -def choice(*args, **kwargs): +class choice(sample_from): """Wraps tune.sample_from around ``random.choice``. ``tune.choice([1, 2])`` is equivalent to ``tune.sample_from(lambda _: random.choice([1, 2]))`` """ - return sample_from(lambda _: random.choice(*args, **kwargs)) + + def __init__(self, *args, **kwargs): + super().__init__(lambda _: random.choice(*args, **kwargs)) -def randint(*args, **kwargs): +class randint(sample_from): """Wraps tune.sample_from around ``np.random.randint``. ``tune.randint(10)`` is equivalent to ``tune.sample_from(lambda _: np.random.randint(10))`` """ - return sample_from(lambda _: np.random.randint(*args, **kwargs)) + + def __init__(self, *args, **kwargs): + super().__init__(lambda _: np.random.randint(*args, **kwargs)) -def randn(*args, **kwargs): +class randn(sample_from): """Wraps tune.sample_from around ``np.random.randn``. ``tune.randn(10)`` is equivalent to ``tune.sample_from(lambda _: np.random.randn(10))`` """ - return sample_from(lambda _: np.random.randn(*args, **kwargs)) + + def __init__(self, *args, **kwargs): + super().__init__(lambda _: np.random.randn(*args, **kwargs)) diff --git a/python/ray/tune/schedulers/pbt.py b/python/ray/tune/schedulers/pbt.py index 45d325203..588c247bd 100644 --- a/python/ray/tune/schedulers/pbt.py +++ b/python/ray/tune/schedulers/pbt.py @@ -65,10 +65,9 @@ def explore(config, mutations, resample_probability, custom_explore_fn): len(distribution) - 1, distribution.index(config[key]) + 1)] else: - if isinstance(distribution, sample_from): - distribution = distribution.func(None) if random.random() < resample_probability: - new_config[key] = distribution() + new_config[key] = distribution.func(None) if isinstance( + distribution, sample_from) else distribution() elif random.random() > 0.5: new_config[key] = config[key] * 1.2 else: @@ -145,10 +144,12 @@ class PopulationBasedTraining(FIFOScheduler): to be too frequent. hyperparam_mutations (dict): Hyperparams to mutate. The format is as follows: for each key, either a list, function, - or a tune search space object (tune.sample_from, tune.uniform, + or a tune search space object (tune.loguniform, tune.uniform, etc.) can be provided. A list specifies an allowed set of categorical values. A function or tune search space object - specifies the distribution of a continuous parameter. + specifies the distribution of a continuous parameter. You must + use tune.choice, tune.uniform, tune.loguniform, etc.. Arbitrary + tune.sample_from objects are not supported. You must specify at least one of `hyperparam_mutations` or `custom_explore_fn`. Tune will use the search space provided by @@ -221,6 +222,11 @@ class PopulationBasedTraining(FIFOScheduler): raise TypeError("`hyperparam_mutation` values must be either " "a List, Dict, a tune search space object, or " "callable.") + if type(value) is sample_from: + raise ValueError("arbitrary tune.sample_from objects are not " + "supported for `hyperparam_mutation` values." + "You must use other built in primitives like" + "tune.uniform, tune.loguniform, etc.") if not hyperparam_mutations and not custom_explore_fn: raise TuneError( diff --git a/python/ray/tune/tests/test_trial_scheduler.py b/python/ray/tune/tests/test_trial_scheduler.py index 32a5713dd..7a9b35f81 100644 --- a/python/ray/tune/tests/test_trial_scheduler.py +++ b/python/ray/tune/tests/test_trial_scheduler.py @@ -879,6 +879,32 @@ class PopulationBasedTestingSuite(unittest.TestCase): self.assertEqual(type(trials[0].config["int_factor"]), int) self.assertEqual(trials[0].config["const_factor"], 3) + def testTuneSamplePrimitives(self): + pbt, runner = self.basicSetup( + resample_prob=1.0, + hyperparam_mutations={ + "float_factor": lambda: 100.0, + "int_factor": lambda: 10, + "id_factor": tune.choice([100]) + }) + trials = runner.get_trials() + self.assertEqual( + pbt.on_trial_result(runner, trials[0], result(20, -100)), + TrialScheduler.CONTINUE) + self.assertIn(trials[0].restored_checkpoint, ["trial_3", "trial_4"]) + self.assertEqual(trials[0].config["id_factor"], 100) + self.assertEqual(trials[0].config["float_factor"], 100.0) + self.assertEqual(type(trials[0].config["float_factor"]), float) + self.assertEqual(trials[0].config["int_factor"], 10) + self.assertEqual(type(trials[0].config["int_factor"]), int) + self.assertEqual(trials[0].config["const_factor"], 3) + + def testTuneSampleFromError(self): + with self.assertRaises(ValueError): + pbt, runner = self.basicSetup(hyperparam_mutations={ + "float_factor": tune.sample_from(lambda: 100.0) + }) + def testPerturbationValues(self): def assertProduces(fn, values): random.seed(0)