diff --git a/python/ray/tune/__init__.py b/python/ray/tune/__init__.py index b1d865338..810256e07 100644 --- a/python/ray/tune/__init__.py +++ b/python/ray/tune/__init__.py @@ -7,10 +7,12 @@ from ray.tune.tune import run_experiments, run from ray.tune.experiment import Experiment from ray.tune.registry import register_env, register_trainable from ray.tune.trainable import Trainable -from ray.tune.suggest import grid_search, function, sample_from +from ray.tune.suggest import grid_search +from ray.tune.sample import (function, sample_from, uniform, choice, randint, + randn) __all__ = [ "Trainable", "TuneError", "grid_search", "register_env", "register_trainable", "run", "run_experiments", "Experiment", "function", - "sample_from" + "sample_from", "uniform", "choice", "randint", "randn" ] diff --git a/python/ray/tune/examples/mnist_pytorch.py b/python/ray/tune/examples/mnist_pytorch.py index 6a17ffa39..188a8baec 100644 --- a/python/ray/tune/examples/mnist_pytorch.py +++ b/python/ray/tune/examples/mnist_pytorch.py @@ -154,7 +154,6 @@ if __name__ == "__main__": datasets.MNIST("~/data", train=True, download=True) args = parser.parse_args() - import numpy as np import ray from ray import tune from ray.tune.schedulers import AsyncHyperBandScheduler @@ -183,9 +182,7 @@ if __name__ == "__main__": }, "num_samples": 1 if args.smoke_test else 10, "config": { - "lr": tune.sample_from( - lambda spec: np.random.uniform(0.001, 0.1)), - "momentum": tune.sample_from( - lambda spec: np.random.uniform(0.1, 0.9)), + "lr": tune.uniform(0.001, 0.1), + "momentum": tune.uniform(0.1, 0.9), } }) diff --git a/python/ray/tune/examples/mnist_pytorch_trainable.py b/python/ray/tune/examples/mnist_pytorch_trainable.py index 8a8715dd5..ac26d0353 100644 --- a/python/ray/tune/examples/mnist_pytorch_trainable.py +++ b/python/ray/tune/examples/mnist_pytorch_trainable.py @@ -169,7 +169,6 @@ if __name__ == "__main__": datasets.MNIST("~/data", train=True, download=True) args = parser.parse_args() - import numpy as np import ray from ray import tune from ray.tune.schedulers import HyperBandScheduler @@ -193,9 +192,7 @@ if __name__ == "__main__": "checkpoint_at_end": True, "config": { "args": args, - "lr": tune.sample_from( - lambda spec: np.random.uniform(0.001, 0.1)), - "momentum": tune.sample_from( - lambda spec: np.random.uniform(0.1, 0.9)), + "lr": tune.uniform(0.001, 0.1), + "momentum": tune.uniform(0.1, 0.9), } }) diff --git a/python/ray/tune/log_sync.py b/python/ray/tune/log_sync.py index a2662af6d..0bb079b48 100644 --- a/python/ray/tune/log_sync.py +++ b/python/ray/tune/log_sync.py @@ -19,7 +19,7 @@ import ray from ray.tune.cluster_info import get_ssh_key, get_ssh_user from ray.tune.error import TuneError from ray.tune.result import DEFAULT_RESULTS_DIR -from ray.tune.suggest.variant_generator import function as tune_function +from ray.tune.sample import function as tune_function logger = logging.getLogger(__name__) _log_sync_warned = False diff --git a/python/ray/tune/sample.py b/python/ray/tune/sample.py new file mode 100644 index 000000000..d40990a27 --- /dev/null +++ b/python/ray/tune/sample.py @@ -0,0 +1,71 @@ +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +import logging +import numpy as np + +logger = logging.getLogger(__name__) + + +class sample_from(object): + """Specify that tune should sample configuration values from this function. + + The use of function arguments in tune configs must be disambiguated by + either wrapped the function in tune.sample_from() or tune.function(). + + Arguments: + func: An callable function to draw a sample from. + """ + + def __init__(self, func): + self.func = func + + def __str__(self): + return "tune.sample_from({})".format(str(self.func)) + + def __repr__(self): + return "tune.sample_from({})".format(repr(self.func)) + + +class function(object): + """Wraps `func` to make sure it is not expanded during resolution. + + The use of function arguments in tune configs must be disambiguated by + either wrapped the function in tune.sample_from() or tune.function(). + + Arguments: + func: A function literal. + """ + + def __init__(self, func): + self.func = func + + def __call__(self, *args, **kwargs): + return self.func(*args, **kwargs) + + def __str__(self): + return "tune.function({})".format(str(self.func)) + + def __repr__(self): + return "tune.function({})".format(repr(self.func)) + + +def uniform(*args, **kwargs): + """A wrapper around np.random.uniform.""" + return sample_from(lambda _: np.random.uniform(*args, **kwargs)) + + +def choice(*args, **kwargs): + """A wrapper around np.random.choice.""" + return sample_from(lambda _: np.random.choice(*args, **kwargs)) + + +def randint(*args, **kwargs): + """A wrapper around np.random.randint.""" + return sample_from(lambda _: np.random.randint(*args, **kwargs)) + + +def randn(*args, **kwargs): + """A wrapper around np.random.randn.""" + return sample_from(lambda _: np.random.randn(*args, **kwargs)) diff --git a/python/ray/tune/suggest/__init__.py b/python/ray/tune/suggest/__init__.py index 50d9588e7..69f289720 100644 --- a/python/ray/tune/suggest/__init__.py +++ b/python/ray/tune/suggest/__init__.py @@ -1,16 +1,11 @@ from ray.tune.suggest.search import SearchAlgorithm from ray.tune.suggest.basic_variant import BasicVariantGenerator from ray.tune.suggest.suggestion import SuggestionAlgorithm -from ray.tune.suggest.variant_generator import grid_search, function, \ - sample_from +from ray.tune.suggest.variant_generator import grid_search __all__ = [ - "SearchAlgorithm", - "BasicVariantGenerator", - "SuggestionAlgorithm", - "grid_search", - "function", - "sample_from", + "SearchAlgorithm", "BasicVariantGenerator", "SuggestionAlgorithm", + "grid_search" ] diff --git a/python/ray/tune/suggest/variant_generator.py b/python/ray/tune/suggest/variant_generator.py index 72b867c8f..cf2b01eea 100644 --- a/python/ray/tune/suggest/variant_generator.py +++ b/python/ray/tune/suggest/variant_generator.py @@ -9,6 +9,7 @@ import random import types from ray.tune import TuneError +from ray.tune.sample import sample_from logger = logging.getLogger(__name__) @@ -54,49 +55,6 @@ def grid_search(values): return {"grid_search": values} -class sample_from(object): - """Specify that tune should sample configuration values from this function. - - The use of function arguments in tune configs must be disambiguated by - either wrapped the function in tune.sample_from() or tune.function(). - - Arguments: - func: An callable function to draw a sample from. - """ - - def __init__(self, func): - self.func = func - - def __str__(self): - return "tune.sample_from({})".format(str(self.func)) - - def __repr__(self): - return "tune.sample_from({})".format(repr(self.func)) - - -class function(object): - """Wraps `func` to make sure it is not expanded during resolution. - - The use of function arguments in tune configs must be disambiguated by - either wrapped the function in tune.sample_from() or tune.function(). - - Arguments: - func: A function literal. - """ - - def __init__(self, func): - self.func = func - - def __call__(self, *args, **kwargs): - return self.func(*args, **kwargs) - - def __str__(self): - return "tune.function({})".format(str(self.func)) - - def __repr__(self): - return "tune.function({})".format(repr(self.func)) - - _STANDARD_IMPORTS = { "random": random, "np": numpy, diff --git a/python/ray/tune/trial_runner.py b/python/ray/tune/trial_runner.py index 166b248d1..8ffcb6e31 100644 --- a/python/ray/tune/trial_runner.py +++ b/python/ray/tune/trial_runner.py @@ -16,7 +16,7 @@ from ray.tune import TuneError from ray.tune.ray_trial_executor import RayTrialExecutor from ray.tune.result import TIME_THIS_ITER_S, RESULT_DUPLICATE from ray.tune.trial import Trial, Checkpoint -from ray.tune.suggest import function +from ray.tune.sample import function from ray.tune.schedulers import FIFOScheduler, TrialScheduler from ray.tune.util import warn_if_slow from ray.utils import binary_to_hex, hex_to_binary