[tune] Reduce sampling API clutter (#4739)

Adds some sugar for tune sampling API (for commonplace sampling idioms).
This commit is contained in:
Richard Liaw
2019-05-06 17:42:39 -07:00
committed by GitHub
parent 71b2dec3b4
commit 7f50c96adb
8 changed files with 85 additions and 65 deletions
+4 -2
View File
@@ -7,10 +7,12 @@ from ray.tune.tune import run_experiments, run
from ray.tune.experiment import Experiment
from ray.tune.registry import register_env, register_trainable
from ray.tune.trainable import Trainable
from ray.tune.suggest import grid_search, function, sample_from
from ray.tune.suggest import grid_search
from ray.tune.sample import (function, sample_from, uniform, choice, randint,
randn)
__all__ = [
"Trainable", "TuneError", "grid_search", "register_env",
"register_trainable", "run", "run_experiments", "Experiment", "function",
"sample_from"
"sample_from", "uniform", "choice", "randint", "randn"
]
+2 -5
View File
@@ -154,7 +154,6 @@ if __name__ == "__main__":
datasets.MNIST("~/data", train=True, download=True)
args = parser.parse_args()
import numpy as np
import ray
from ray import tune
from ray.tune.schedulers import AsyncHyperBandScheduler
@@ -183,9 +182,7 @@ if __name__ == "__main__":
},
"num_samples": 1 if args.smoke_test else 10,
"config": {
"lr": tune.sample_from(
lambda spec: np.random.uniform(0.001, 0.1)),
"momentum": tune.sample_from(
lambda spec: np.random.uniform(0.1, 0.9)),
"lr": tune.uniform(0.001, 0.1),
"momentum": tune.uniform(0.1, 0.9),
}
})
@@ -169,7 +169,6 @@ if __name__ == "__main__":
datasets.MNIST("~/data", train=True, download=True)
args = parser.parse_args()
import numpy as np
import ray
from ray import tune
from ray.tune.schedulers import HyperBandScheduler
@@ -193,9 +192,7 @@ if __name__ == "__main__":
"checkpoint_at_end": True,
"config": {
"args": args,
"lr": tune.sample_from(
lambda spec: np.random.uniform(0.001, 0.1)),
"momentum": tune.sample_from(
lambda spec: np.random.uniform(0.1, 0.9)),
"lr": tune.uniform(0.001, 0.1),
"momentum": tune.uniform(0.1, 0.9),
}
})
+1 -1
View File
@@ -19,7 +19,7 @@ import ray
from ray.tune.cluster_info import get_ssh_key, get_ssh_user
from ray.tune.error import TuneError
from ray.tune.result import DEFAULT_RESULTS_DIR
from ray.tune.suggest.variant_generator import function as tune_function
from ray.tune.sample import function as tune_function
logger = logging.getLogger(__name__)
_log_sync_warned = False
+71
View File
@@ -0,0 +1,71 @@
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
import logging
import numpy as np
logger = logging.getLogger(__name__)
class sample_from(object):
"""Specify that tune should sample configuration values from this function.
The use of function arguments in tune configs must be disambiguated by
either wrapped the function in tune.sample_from() or tune.function().
Arguments:
func: An callable function to draw a sample from.
"""
def __init__(self, func):
self.func = func
def __str__(self):
return "tune.sample_from({})".format(str(self.func))
def __repr__(self):
return "tune.sample_from({})".format(repr(self.func))
class function(object):
"""Wraps `func` to make sure it is not expanded during resolution.
The use of function arguments in tune configs must be disambiguated by
either wrapped the function in tune.sample_from() or tune.function().
Arguments:
func: A function literal.
"""
def __init__(self, func):
self.func = func
def __call__(self, *args, **kwargs):
return self.func(*args, **kwargs)
def __str__(self):
return "tune.function({})".format(str(self.func))
def __repr__(self):
return "tune.function({})".format(repr(self.func))
def uniform(*args, **kwargs):
"""A wrapper around np.random.uniform."""
return sample_from(lambda _: np.random.uniform(*args, **kwargs))
def choice(*args, **kwargs):
"""A wrapper around np.random.choice."""
return sample_from(lambda _: np.random.choice(*args, **kwargs))
def randint(*args, **kwargs):
"""A wrapper around np.random.randint."""
return sample_from(lambda _: np.random.randint(*args, **kwargs))
def randn(*args, **kwargs):
"""A wrapper around np.random.randn."""
return sample_from(lambda _: np.random.randn(*args, **kwargs))
+3 -8
View File
@@ -1,16 +1,11 @@
from ray.tune.suggest.search import SearchAlgorithm
from ray.tune.suggest.basic_variant import BasicVariantGenerator
from ray.tune.suggest.suggestion import SuggestionAlgorithm
from ray.tune.suggest.variant_generator import grid_search, function, \
sample_from
from ray.tune.suggest.variant_generator import grid_search
__all__ = [
"SearchAlgorithm",
"BasicVariantGenerator",
"SuggestionAlgorithm",
"grid_search",
"function",
"sample_from",
"SearchAlgorithm", "BasicVariantGenerator", "SuggestionAlgorithm",
"grid_search"
]
+1 -43
View File
@@ -9,6 +9,7 @@ import random
import types
from ray.tune import TuneError
from ray.tune.sample import sample_from
logger = logging.getLogger(__name__)
@@ -54,49 +55,6 @@ def grid_search(values):
return {"grid_search": values}
class sample_from(object):
"""Specify that tune should sample configuration values from this function.
The use of function arguments in tune configs must be disambiguated by
either wrapped the function in tune.sample_from() or tune.function().
Arguments:
func: An callable function to draw a sample from.
"""
def __init__(self, func):
self.func = func
def __str__(self):
return "tune.sample_from({})".format(str(self.func))
def __repr__(self):
return "tune.sample_from({})".format(repr(self.func))
class function(object):
"""Wraps `func` to make sure it is not expanded during resolution.
The use of function arguments in tune configs must be disambiguated by
either wrapped the function in tune.sample_from() or tune.function().
Arguments:
func: A function literal.
"""
def __init__(self, func):
self.func = func
def __call__(self, *args, **kwargs):
return self.func(*args, **kwargs)
def __str__(self):
return "tune.function({})".format(str(self.func))
def __repr__(self):
return "tune.function({})".format(repr(self.func))
_STANDARD_IMPORTS = {
"random": random,
"np": numpy,
+1 -1
View File
@@ -16,7 +16,7 @@ from ray.tune import TuneError
from ray.tune.ray_trial_executor import RayTrialExecutor
from ray.tune.result import TIME_THIS_ITER_S, RESULT_DUPLICATE
from ray.tune.trial import Trial, Checkpoint
from ray.tune.suggest import function
from ray.tune.sample import function
from ray.tune.schedulers import FIFOScheduler, TrialScheduler
from ray.tune.util import warn_if_slow
from ray.utils import binary_to_hex, hex_to_binary