mirror of
https://github.com/wassname/ray.git
synced 2026-07-01 08:02:33 +08:00
[tune] Reduce sampling API clutter (#4739)
Adds some sugar for tune sampling API (for commonplace sampling idioms).
This commit is contained in:
@@ -7,10 +7,12 @@ from ray.tune.tune import run_experiments, run
|
||||
from ray.tune.experiment import Experiment
|
||||
from ray.tune.registry import register_env, register_trainable
|
||||
from ray.tune.trainable import Trainable
|
||||
from ray.tune.suggest import grid_search, function, sample_from
|
||||
from ray.tune.suggest import grid_search
|
||||
from ray.tune.sample import (function, sample_from, uniform, choice, randint,
|
||||
randn)
|
||||
|
||||
__all__ = [
|
||||
"Trainable", "TuneError", "grid_search", "register_env",
|
||||
"register_trainable", "run", "run_experiments", "Experiment", "function",
|
||||
"sample_from"
|
||||
"sample_from", "uniform", "choice", "randint", "randn"
|
||||
]
|
||||
|
||||
@@ -154,7 +154,6 @@ if __name__ == "__main__":
|
||||
datasets.MNIST("~/data", train=True, download=True)
|
||||
args = parser.parse_args()
|
||||
|
||||
import numpy as np
|
||||
import ray
|
||||
from ray import tune
|
||||
from ray.tune.schedulers import AsyncHyperBandScheduler
|
||||
@@ -183,9 +182,7 @@ if __name__ == "__main__":
|
||||
},
|
||||
"num_samples": 1 if args.smoke_test else 10,
|
||||
"config": {
|
||||
"lr": tune.sample_from(
|
||||
lambda spec: np.random.uniform(0.001, 0.1)),
|
||||
"momentum": tune.sample_from(
|
||||
lambda spec: np.random.uniform(0.1, 0.9)),
|
||||
"lr": tune.uniform(0.001, 0.1),
|
||||
"momentum": tune.uniform(0.1, 0.9),
|
||||
}
|
||||
})
|
||||
|
||||
@@ -169,7 +169,6 @@ if __name__ == "__main__":
|
||||
datasets.MNIST("~/data", train=True, download=True)
|
||||
args = parser.parse_args()
|
||||
|
||||
import numpy as np
|
||||
import ray
|
||||
from ray import tune
|
||||
from ray.tune.schedulers import HyperBandScheduler
|
||||
@@ -193,9 +192,7 @@ if __name__ == "__main__":
|
||||
"checkpoint_at_end": True,
|
||||
"config": {
|
||||
"args": args,
|
||||
"lr": tune.sample_from(
|
||||
lambda spec: np.random.uniform(0.001, 0.1)),
|
||||
"momentum": tune.sample_from(
|
||||
lambda spec: np.random.uniform(0.1, 0.9)),
|
||||
"lr": tune.uniform(0.001, 0.1),
|
||||
"momentum": tune.uniform(0.1, 0.9),
|
||||
}
|
||||
})
|
||||
|
||||
@@ -19,7 +19,7 @@ import ray
|
||||
from ray.tune.cluster_info import get_ssh_key, get_ssh_user
|
||||
from ray.tune.error import TuneError
|
||||
from ray.tune.result import DEFAULT_RESULTS_DIR
|
||||
from ray.tune.suggest.variant_generator import function as tune_function
|
||||
from ray.tune.sample import function as tune_function
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
_log_sync_warned = False
|
||||
|
||||
@@ -0,0 +1,71 @@
|
||||
from __future__ import absolute_import
|
||||
from __future__ import division
|
||||
from __future__ import print_function
|
||||
|
||||
import logging
|
||||
import numpy as np
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class sample_from(object):
|
||||
"""Specify that tune should sample configuration values from this function.
|
||||
|
||||
The use of function arguments in tune configs must be disambiguated by
|
||||
either wrapped the function in tune.sample_from() or tune.function().
|
||||
|
||||
Arguments:
|
||||
func: An callable function to draw a sample from.
|
||||
"""
|
||||
|
||||
def __init__(self, func):
|
||||
self.func = func
|
||||
|
||||
def __str__(self):
|
||||
return "tune.sample_from({})".format(str(self.func))
|
||||
|
||||
def __repr__(self):
|
||||
return "tune.sample_from({})".format(repr(self.func))
|
||||
|
||||
|
||||
class function(object):
|
||||
"""Wraps `func` to make sure it is not expanded during resolution.
|
||||
|
||||
The use of function arguments in tune configs must be disambiguated by
|
||||
either wrapped the function in tune.sample_from() or tune.function().
|
||||
|
||||
Arguments:
|
||||
func: A function literal.
|
||||
"""
|
||||
|
||||
def __init__(self, func):
|
||||
self.func = func
|
||||
|
||||
def __call__(self, *args, **kwargs):
|
||||
return self.func(*args, **kwargs)
|
||||
|
||||
def __str__(self):
|
||||
return "tune.function({})".format(str(self.func))
|
||||
|
||||
def __repr__(self):
|
||||
return "tune.function({})".format(repr(self.func))
|
||||
|
||||
|
||||
def uniform(*args, **kwargs):
|
||||
"""A wrapper around np.random.uniform."""
|
||||
return sample_from(lambda _: np.random.uniform(*args, **kwargs))
|
||||
|
||||
|
||||
def choice(*args, **kwargs):
|
||||
"""A wrapper around np.random.choice."""
|
||||
return sample_from(lambda _: np.random.choice(*args, **kwargs))
|
||||
|
||||
|
||||
def randint(*args, **kwargs):
|
||||
"""A wrapper around np.random.randint."""
|
||||
return sample_from(lambda _: np.random.randint(*args, **kwargs))
|
||||
|
||||
|
||||
def randn(*args, **kwargs):
|
||||
"""A wrapper around np.random.randn."""
|
||||
return sample_from(lambda _: np.random.randn(*args, **kwargs))
|
||||
@@ -1,16 +1,11 @@
|
||||
from ray.tune.suggest.search import SearchAlgorithm
|
||||
from ray.tune.suggest.basic_variant import BasicVariantGenerator
|
||||
from ray.tune.suggest.suggestion import SuggestionAlgorithm
|
||||
from ray.tune.suggest.variant_generator import grid_search, function, \
|
||||
sample_from
|
||||
from ray.tune.suggest.variant_generator import grid_search
|
||||
|
||||
__all__ = [
|
||||
"SearchAlgorithm",
|
||||
"BasicVariantGenerator",
|
||||
"SuggestionAlgorithm",
|
||||
"grid_search",
|
||||
"function",
|
||||
"sample_from",
|
||||
"SearchAlgorithm", "BasicVariantGenerator", "SuggestionAlgorithm",
|
||||
"grid_search"
|
||||
]
|
||||
|
||||
|
||||
|
||||
@@ -9,6 +9,7 @@ import random
|
||||
import types
|
||||
|
||||
from ray.tune import TuneError
|
||||
from ray.tune.sample import sample_from
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
@@ -54,49 +55,6 @@ def grid_search(values):
|
||||
return {"grid_search": values}
|
||||
|
||||
|
||||
class sample_from(object):
|
||||
"""Specify that tune should sample configuration values from this function.
|
||||
|
||||
The use of function arguments in tune configs must be disambiguated by
|
||||
either wrapped the function in tune.sample_from() or tune.function().
|
||||
|
||||
Arguments:
|
||||
func: An callable function to draw a sample from.
|
||||
"""
|
||||
|
||||
def __init__(self, func):
|
||||
self.func = func
|
||||
|
||||
def __str__(self):
|
||||
return "tune.sample_from({})".format(str(self.func))
|
||||
|
||||
def __repr__(self):
|
||||
return "tune.sample_from({})".format(repr(self.func))
|
||||
|
||||
|
||||
class function(object):
|
||||
"""Wraps `func` to make sure it is not expanded during resolution.
|
||||
|
||||
The use of function arguments in tune configs must be disambiguated by
|
||||
either wrapped the function in tune.sample_from() or tune.function().
|
||||
|
||||
Arguments:
|
||||
func: A function literal.
|
||||
"""
|
||||
|
||||
def __init__(self, func):
|
||||
self.func = func
|
||||
|
||||
def __call__(self, *args, **kwargs):
|
||||
return self.func(*args, **kwargs)
|
||||
|
||||
def __str__(self):
|
||||
return "tune.function({})".format(str(self.func))
|
||||
|
||||
def __repr__(self):
|
||||
return "tune.function({})".format(repr(self.func))
|
||||
|
||||
|
||||
_STANDARD_IMPORTS = {
|
||||
"random": random,
|
||||
"np": numpy,
|
||||
|
||||
@@ -16,7 +16,7 @@ from ray.tune import TuneError
|
||||
from ray.tune.ray_trial_executor import RayTrialExecutor
|
||||
from ray.tune.result import TIME_THIS_ITER_S, RESULT_DUPLICATE
|
||||
from ray.tune.trial import Trial, Checkpoint
|
||||
from ray.tune.suggest import function
|
||||
from ray.tune.sample import function
|
||||
from ray.tune.schedulers import FIFOScheduler, TrialScheduler
|
||||
from ray.tune.util import warn_if_slow
|
||||
from ray.utils import binary_to_hex, hex_to_binary
|
||||
|
||||
Reference in New Issue
Block a user