[tune] Refactor search algorithms (#7037)

* start refactoring of search algorithms

* format

* needs tests

* fix

* suggestions

* Fix PBT

* lint

* refactoring

* hyperopt_working

* dragonfly

* hyperopt

* change_half_of_algs

* save

* code-removed

* remove_lots_of_unneccessary

* changes

* formatting

* suggest

* reset

* rm

* tests

* search-change

* exception

* refactor-doc

* search

* py

* moredocs

* Update doc/source/tune-searchalg.rst

* concurrency

* max

* tune

* betterwarning

* bohb

* tests

* test-change

Co-authored-by: ujvl <misraujval@gmail.com>
This commit is contained in:
Richard Liaw
2020-04-27 08:51:13 -07:00
committed by GitHub
parent 1d5bceddf0
commit 87557a00fa
31 changed files with 527 additions and 611 deletions
+1 -5
View File
@@ -130,11 +130,7 @@ class AutoMLSearcher(SearchAlgorithm):
> self.best_trial.best_result[self.reward_attr]):
self.best_trial = self._running_trials[trial_id]
def on_trial_complete(self,
trial_id,
result=None,
error=False,
early_terminated=False):
def on_trial_complete(self, trial_id, result=None, error=False):
self.on_trial_result(trial_id, result)
self._unfinished_count -= 1
if self._unfinished_count == 0:
@@ -40,7 +40,6 @@ if __name__ == "__main__":
}
algo = BayesOptSearch(
space,
max_concurrent=4,
metric="mean_loss",
mode="min",
utility_kwargs={
@@ -73,8 +73,7 @@ if __name__ == "__main__":
func_caller = EuclideanFunctionCaller(
None, domain_config.domain.list_of_domains[0])
optimizer = EuclideanGPBandit(func_caller, ask_tell_mode=True)
algo = DragonflySearch(
optimizer, max_concurrent=4, metric="objective", mode="max")
algo = DragonflySearch(optimizer, metric="objective", mode="max")
scheduler = AsyncHyperBandScheduler(metric="objective", mode="max")
run(objective,
name="dragonfly_search",
+1 -2
View File
@@ -28,7 +28,7 @@ if __name__ == "__main__":
parser.add_argument(
"--smoke-test", action="store_true", help="Finish quickly for testing")
args, _ = parser.parse_known_args()
ray.init()
ray.init(configure_logging=False)
space = {
"width": hp.uniform("width", 0, 20),
@@ -60,7 +60,6 @@ if __name__ == "__main__":
}
algo = HyperOptSearch(
space,
max_concurrent=4,
metric="mean_loss",
mode="min",
points_to_evaluate=current_best_params)
@@ -47,11 +47,7 @@ if __name__ == "__main__":
# parameter_names = None # names are provided by the instrumentation
optimizer = optimizerlib.OnePlusOne(instrumentation)
algo = NevergradSearch(
optimizer,
parameter_names,
max_concurrent=4,
metric="mean_loss",
mode="min")
optimizer, parameter_names, metric="mean_loss", mode="min")
scheduler = AsyncHyperBandScheduler(metric="mean_loss", mode="min")
run(easy_objective,
name="nevergrad",
@@ -42,7 +42,6 @@ if __name__ == "__main__":
known_rewards = [-189, -1144]
algo = SkOptSearch(
optimizer, ["width", "height"],
max_concurrent=4,
metric="mean_loss",
mode="min",
points_to_evaluate=previously_run_params,
@@ -58,7 +57,6 @@ if __name__ == "__main__":
algo = SkOptSearch(
optimizer, ["width", "height"],
max_concurrent=4,
metric="mean_loss",
mode="min",
points_to_evaluate=previously_run_params)
@@ -51,7 +51,6 @@ if __name__ == "__main__":
algo="Asracos", # only support ASRacos currently
budget=config["num_samples"],
dim_dict=dim_dict,
max_concurrent=4,
metric="mean_loss",
mode="min")
+2 -2
View File
@@ -84,14 +84,14 @@ class HyperBandForBOHB(HyperBandScheduler):
if not bracket.filled() or any(status != Trial.PAUSED
for t, status in statuses
if t is not trial):
trial_runner._search_alg.on_pause(trial.trial_id)
trial_runner._search_alg.searcher.on_pause(trial.trial_id)
return TrialScheduler.PAUSE
action = self._process_bracket(trial_runner, bracket)
return action
def _unpause_trial(self, trial_runner, trial):
trial_runner.trial_executor.unpause_trial(trial)
trial_runner._search_alg.on_unpause(trial.trial_id)
trial_runner._search_alg.searcher.on_unpause(trial.trial_id)
def choose_trial_to_run(self, trial_runner):
"""Fair scheduling within iteration by completion percentage.
+4 -4
View File
@@ -1,13 +1,13 @@
from ray.tune.suggest.search import SearchAlgorithm
from ray.tune.suggest.basic_variant import BasicVariantGenerator
from ray.tune.suggest.suggestion import SuggestionAlgorithm
from ray.tune.suggest.suggestion import (SearchGenerator, Searcher,
ConcurrencyLimiter)
from ray.tune.suggest.variant_generator import grid_search
from ray.tune.suggest.repeater import Repeater
from ray.tune.suggest.bohb import TuneBOHB
__all__ = [
"SearchAlgorithm", "BasicVariantGenerator", "SuggestionAlgorithm",
"grid_search", "TuneBOHB", "Repeater"
"SearchAlgorithm", "Searcher", "BasicVariantGenerator", "SearchGenerator",
"grid_search", "Repeater", "ConcurrencyLimiter"
]
+17 -25
View File
@@ -4,18 +4,20 @@ except ImportError:
ax = None
import logging
from ray.tune.suggest.suggestion import SuggestionAlgorithm
from ray.tune.suggest import Searcher
logger = logging.getLogger(__name__)
class AxSearch(SuggestionAlgorithm):
class AxSearch(Searcher):
"""A wrapper around Ax to provide trial suggestions.
Requires Ax to be installed. Ax is an open source tool from
Facebook for configuring and optimizing experiments. More information
can be found in https://ax.dev/.
This module manages its own concurrency.
Parameters:
parameters (list[dict]): Parameters in the experiment search space.
Required elements in the dictionaries are: "name" (name of
@@ -36,9 +38,7 @@ class AxSearch(SuggestionAlgorithm):
"x3 >= x4" or "x3 + x4 >= 2".
outcome_constraints (list[str]): Outcome constraints of form
"metric_name >= bound", like "m1 <= 3."
use_early_stopped_trials (bool): Whether to use early terminated
trial results in the optimization process.
use_early_stopped_trials: Deprecated.
.. code-block:: python
@@ -56,7 +56,11 @@ class AxSearch(SuggestionAlgorithm):
"""
def __init__(self, ax_client, max_concurrent=10, mode="max", **kwargs):
def __init__(self,
ax_client,
max_concurrent=10,
mode="max",
use_early_stopped_trials=None):
assert ax is not None, "Ax must be installed!"
assert type(max_concurrent) is int and max_concurrent > 0
self._ax = ax_client
@@ -66,38 +70,29 @@ class AxSearch(SuggestionAlgorithm):
logger.warning("Detected sequential enforcement. Setting max "
"concurrency to 1.")
max_concurrent = 1
self._max_concurrent = max_concurrent
self._parameters = list(exp.parameters)
self._live_index_mapping = {}
super(AxSearch, self).__init__(
metric=self._objective_name, mode=mode, **kwargs)
metric=self._objective_name,
mode=mode,
max_concurrent=max_concurrent,
use_early_stopped_trials=use_early_stopped_trials)
def suggest(self, trial_id):
if self._num_live_trials() >= self._max_concurrent:
return None
parameters, trial_index = self._ax.get_next_trial()
self._live_index_mapping[trial_id] = trial_index
return parameters
def on_trial_result(self, trial_id, result):
pass
def on_trial_complete(self,
trial_id,
result=None,
error=False,
early_terminated=False):
def on_trial_complete(self, trial_id, result=None, error=False):
"""Notification for the completion of trial.
Data of form key value dictionary of metric names and values.
"""
if result:
self._process_result(trial_id, result, early_terminated)
self._process_result(trial_id, result)
self._live_index_mapping.pop(trial_id)
def _process_result(self, trial_id, result, early_terminated=False):
if early_terminated and self._use_early_stopped is False:
return
def _process_result(self, trial_id, result):
ax_trial_index = self._live_index_mapping[trial_id]
metric_dict = {
self._objective_name: (result[self._objective_name], 0.0)
@@ -109,6 +104,3 @@ class AxSearch(SuggestionAlgorithm):
metric_dict.update({on: (result[on], 0.0) for on in outcome_names})
self._ax.complete_trial(
trial_index=ax_trial_index, raw_data=metric_dict)
def _num_live_trials(self):
return len(self._live_index_mapping)
+17 -45
View File
@@ -6,12 +6,12 @@ try: # Python 3 only -- needed for lint test.
except ImportError:
byo = None
from ray.tune.suggest.suggestion import SuggestionAlgorithm
from ray.tune.suggest import Searcher
logger = logging.getLogger(__name__)
class BayesOptSearch(SuggestionAlgorithm):
class BayesOptSearch(Searcher):
"""A wrapper around BayesOpt to provide trial suggestions.
Requires BayesOpt to be installed. You can install BayesOpt with the
@@ -20,8 +20,6 @@ class BayesOptSearch(SuggestionAlgorithm):
Parameters:
space (dict): Continuous search space. Parameters will be sampled from
this space which will be used to run trials.
max_concurrent (int): Number of maximum concurrent trials. Defaults
to 10.
metric (str): The training result objective value attribute.
mode (str): One of {min, max}. Determines whether objective is
minimizing or maximizing the metric attribute.
@@ -29,8 +27,8 @@ class BayesOptSearch(SuggestionAlgorithm):
provide values for the keys `kind`, `kappa`, and `xi`.
random_state (int): Used to initialize BayesOpt.
verbose (int): Sets verbosity level for BayesOpt packages.
use_early_stopped_trials (bool): Whether to use early terminated
trial results in the optimization process.
max_concurrent: Deprecated.
use_early_stopped_trials: Deprecated.
.. code-block:: python
@@ -41,9 +39,7 @@ class BayesOptSearch(SuggestionAlgorithm):
'width': (0, 20),
'height': (-100, 100),
}
algo = BayesOptSearch(
space, max_concurrent=4, metric="mean_loss", mode="min")
algo = BayesOptSearch(space, metric="mean_loss", mode="min")
tune.run(my_func, algo=algo)
"""
# bayes_opt.BayesianOptimization: Optimization object
@@ -51,32 +47,26 @@ class BayesOptSearch(SuggestionAlgorithm):
def __init__(self,
space,
max_concurrent=10,
reward_attr=None,
metric="episode_reward_mean",
mode="max",
utility_kwargs=None,
random_state=1,
verbose=0,
**kwargs):
max_concurrent=None,
use_early_stopped_trials=None):
assert byo is not None, (
"BayesOpt must be installed!. You can install BayesOpt with"
" the command: `pip install bayesian-optimization`.")
assert type(max_concurrent) is int and max_concurrent > 0
assert utility_kwargs is not None, (
"Must define arguments for the utiliy function!")
"Must define arguments for the utility function!")
assert mode in ["min", "max"], "`mode` must be 'min' or 'max'!"
if reward_attr is not None:
mode = "max"
metric = reward_attr
logger.warning(
"`reward_attr` is deprecated and will be removed in a future "
"version of Tune. "
"Setting `metric={}` and `mode=max`.".format(reward_attr))
super(BayesOptSearch, self).__init__(
metric=metric,
mode=mode,
max_concurrent=max_concurrent,
use_early_stopped_trials=use_early_stopped_trials)
self._max_concurrent = max_concurrent
self._metric = metric
if mode == "max":
self._metric_op = 1.
elif mode == "min":
@@ -88,41 +78,23 @@ class BayesOptSearch(SuggestionAlgorithm):
self.utility = byo.UtilityFunction(**utility_kwargs)
super(BayesOptSearch, self).__init__(
metric=self._metric, mode=mode, **kwargs)
def suggest(self, trial_id):
if self._num_live_trials() >= self._max_concurrent:
return None
new_trial = self.optimizer.suggest(self.utility)
self._live_trial_mapping[trial_id] = new_trial
return copy.deepcopy(new_trial)
def on_trial_result(self, trial_id, result):
pass
def on_trial_complete(self,
trial_id,
result=None,
error=False,
early_terminated=False):
def on_trial_complete(self, trial_id, result=None, error=False):
"""Notification for the completion of trial."""
if result:
self._process_result(trial_id, result, early_terminated)
self._process_result(trial_id, result)
del self._live_trial_mapping[trial_id]
def _process_result(self, trial_id, result, early_terminated=False):
if early_terminated and self._use_early_stopped is False:
return
def _process_result(self, trial_id, result):
self.optimizer.register(
params=self._live_trial_mapping[trial_id],
target=self._metric_op * result[self._metric])
def _num_live_trials(self):
return len(self._live_trial_mapping)
target=self._metric_op * result[self.metric])
def save(self, checkpoint_dir):
trials_object = self.optimizer
+3 -7
View File
@@ -3,7 +3,7 @@
import copy
import logging
from ray.tune.suggest import SuggestionAlgorithm
from ray.tune.suggest import Searcher
logger = logging.getLogger(__name__)
@@ -17,7 +17,7 @@ class _BOHBJobWrapper():
self.exception = None
class TuneBOHB(SuggestionAlgorithm):
class TuneBOHB(Searcher):
"""BOHB suggestion component.
@@ -104,11 +104,7 @@ class TuneBOHB(SuggestionAlgorithm):
hbs_wrapper = self.to_wrapper(trial_id, result)
self.bohber.new_result(hbs_wrapper)
def on_trial_complete(self,
trial_id,
result=None,
error=False,
early_terminated=False):
def on_trial_complete(self, trial_id, result=None, error=False):
del self.trial_to_params[trial_id]
if trial_id in self.paused:
self.paused.remove(trial_id)
+13 -34
View File
@@ -10,12 +10,12 @@ try: # Python 3 only -- needed for lint test.
except ImportError:
dragonfly = None
from ray.tune.suggest.suggestion import SuggestionAlgorithm
from ray.tune.suggest.suggestion import Searcher
logger = logging.getLogger(__name__)
class DragonflySearch(SuggestionAlgorithm):
class DragonflySearch(Searcher):
"""A wrapper around Dragonfly to provide trial suggestions.
Requires Dragonfly to be installed via ``pip install dragonfly-opt``.
@@ -23,8 +23,6 @@ class DragonflySearch(SuggestionAlgorithm):
Parameters:
optimizer (dragonfly.opt.BlackboxOptimiser): Optimizer provided
from dragonfly. Choose an optimiser that extends BlackboxOptimiser.
max_concurrent (int): Number of maximum concurrent trials. Defaults
to 10.
metric (str): The training result objective value attribute.
mode (str): One of {min, max}. Determines whether objective is
minimizing or maximizing the metric attribute.
@@ -69,16 +67,13 @@ class DragonflySearch(SuggestionAlgorithm):
domain_config.domain.list_of_domains[0])
optimizer = EuclideanGPBandit(func_caller, ask_tell_mode=True)
algo = DragonflySearch(optimizer, max_concurrent=4,
metric="objective", mode="max")
algo = DragonflySearch(optimizer, metric="objective", mode="max")
tune.run(my_func, algo=algo)
"""
def __init__(self,
optimizer,
max_concurrent=10,
reward_attr=None,
metric="episode_reward_mean",
mode="max",
points_to_evaluate=None,
@@ -87,17 +82,8 @@ class DragonflySearch(SuggestionAlgorithm):
assert dragonfly is not None, """dragonfly must be installed!
You can install Dragonfly with the command:
`pip install dragonfly`."""
assert type(max_concurrent) is int and max_concurrent > 0
assert mode in ["min", "max"], "`mode` must be 'min' or 'max'!"
if reward_attr is not None:
mode = "max"
metric = reward_attr
logger.warning(
"`reward_attr` is deprecated and will be removed in a future "
"version of Tune. "
"Setting `metric={}` and `mode=max`.".format(reward_attr))
self._initial_points = []
self._opt = optimizer
self._opt.initialise()
@@ -105,8 +91,6 @@ class DragonflySearch(SuggestionAlgorithm):
self._opt.tell([(points_to_evaluate, evaluated_rewards)])
elif points_to_evaluate:
self._initial_points = points_to_evaluate
self._max_concurrent = max_concurrent
self._metric = metric
# Dragonfly internally maximizes, so "min" => -1
if mode == "min":
self._metric_op = -1.
@@ -114,36 +98,31 @@ class DragonflySearch(SuggestionAlgorithm):
self._metric_op = 1.
self._live_trial_mapping = {}
super(DragonflySearch, self).__init__(
metric=self._metric, mode=mode, **kwargs)
metric=metric, mode=mode, **kwargs)
def suggest(self, trial_id):
if self._num_live_trials() >= self._max_concurrent:
return None
if self._initial_points:
suggested_config = self._initial_points[0]
del self._initial_points[0]
else:
suggested_config = self._opt.ask()
try:
suggested_config = self._opt.ask()
except Exception as exc:
logger.warning(
"Dragonfly errored when querying. This may be due to a "
"higher level of parallelism than supported. Try reducing "
"parallelism in the experiment: %s", str(exc))
return None
self._live_trial_mapping[trial_id] = suggested_config
return {"point": suggested_config}
def on_trial_result(self, trial_id, result):
pass
def on_trial_complete(self,
trial_id,
result=None,
error=False,
early_terminated=False):
def on_trial_complete(self, trial_id, result=None, error=False):
"""Passes result to Dragonfly unless early terminated or errored."""
trial_info = self._live_trial_mapping.pop(trial_id)
if result:
self._opt.tell([(trial_info,
self._metric_op * result[self._metric])])
def _num_live_trials(self):
return len(self._live_trial_mapping)
def save(self, checkpoint_dir):
trials_object = (self._initial_points, self._opt)
with open(checkpoint_dir, "wb") as outputFile:
+34 -57
View File
@@ -11,12 +11,12 @@ except ImportError:
hpo = None
from ray.tune.error import TuneError
from ray.tune.suggest.suggestion import SuggestionAlgorithm
from ray.tune.suggest import Searcher
logger = logging.getLogger(__name__)
class HyperOptSearch(SuggestionAlgorithm):
class HyperOptSearch(Searcher):
"""A wrapper around HyperOpt to provide trial suggestions.
Requires HyperOpt to be installed from source.
@@ -30,8 +30,6 @@ class HyperOptSearch(SuggestionAlgorithm):
space (dict): HyperOpt configuration. Parameters will be sampled
from this configuration and will be used to override
parameters generated in the variant generation process.
max_concurrent (int): Number of maximum concurrent trials. Defaults
to 10.
metric (str): The training result objective value attribute.
mode (str): One of {min, max}. Determines whether objective is
minimizing or maximizing the metric attribute.
@@ -49,8 +47,8 @@ class HyperOptSearch(SuggestionAlgorithm):
results. Defaults to None.
gamma (float in range (0,1)): parameter governing the tree parzen
estimators suggestion algorithm. Defaults to 0.25.
use_early_stopped_trials (bool): Whether to use early terminated
trial results in the optimization process.
max_concurrent: Deprecated.
use_early_stopped_trials: Deprecated.
.. code-block:: python
@@ -65,42 +63,37 @@ class HyperOptSearch(SuggestionAlgorithm):
'activation': 0, # The index of "relu"
}]
algo = HyperOptSearch(
space, max_concurrent=4, metric="mean_loss", mode="min",
space, metric="mean_loss", mode="min",
points_to_evaluate=current_best_params)
"""
def __init__(self,
space,
max_concurrent=10,
reward_attr=None,
metric="episode_reward_mean",
mode="max",
points_to_evaluate=None,
n_initial_points=20,
random_state_seed=None,
gamma=0.25,
**kwargs):
assert hpo is not None, "HyperOpt must be installed!"
def __init__(
self,
space,
metric="episode_reward_mean",
mode="max",
points_to_evaluate=None,
n_initial_points=20,
random_state_seed=None,
gamma=0.25,
max_concurrent=None,
use_early_stopped_trials=None,
):
assert hpo is not None, (
"HyperOpt must be installed! Run `pip install hyperopt`.")
from hyperopt.fmin import generate_trials_to_calculate
assert type(max_concurrent) is int and max_concurrent > 0
assert mode in ["min", "max"], "`mode` must be 'min' or 'max'!"
if reward_attr is not None:
mode = "max"
metric = reward_attr
logger.warning(
"`reward_attr` is deprecated and will be removed in a future "
"version of Tune. "
"Setting `metric={}` and `mode=max`.".format(reward_attr))
self._max_concurrent = max_concurrent
self._metric = metric
super(HyperOptSearch, self).__init__(
metric=metric,
mode=mode,
max_concurrent=max_concurrent,
use_early_stopped_trials=use_early_stopped_trials)
# hyperopt internally minimizes, so "max" => -1
if mode == "max":
self._metric_op = -1.
self.metric_op = -1.
elif mode == "min":
self._metric_op = 1.
self.metric_op = 1.
if n_initial_points is None:
self.algo = hpo.tpe.suggest
else:
@@ -124,13 +117,7 @@ class HyperOptSearch(SuggestionAlgorithm):
else:
self.rstate = np.random.RandomState(random_state_seed)
super(HyperOptSearch, self).__init__(
metric=self._metric, mode=mode, **kwargs)
def suggest(self, trial_id):
if self._num_live_trials() >= self._max_concurrent:
return None
if self._points_to_evaluate > 0:
new_trial = self._hpopt_trials.trials[self._points_to_evaluate - 1]
self._points_to_evaluate -= 1
@@ -167,11 +154,7 @@ class HyperOptSearch(SuggestionAlgorithm):
ho_trial["book_time"] = now
ho_trial["refresh_time"] = now
def on_trial_complete(self,
trial_id,
result=None,
error=False,
early_terminated=False):
def on_trial_complete(self, trial_id, result=None, error=False):
"""Notification for the completion of trial.
The result is internally negated when interacting with HyperOpt
@@ -185,18 +168,15 @@ class HyperOptSearch(SuggestionAlgorithm):
ho_trial["state"] = hpo.base.JOB_STATE_ERROR
ho_trial["misc"]["error"] = (str(TuneError), "Tune Error")
self._hpopt_trials.refresh()
else:
self._process_result(trial_id, result, early_terminated)
elif result:
self._process_result(trial_id, result)
del self._live_trial_mapping[trial_id]
def _process_result(self, trial_id, result, early_terminated=False):
def _process_result(self, trial_id, result):
ho_trial = self._get_hyperopt_trial(trial_id)
ho_trial["refresh_time"] = hpo.utils.coarse_utcnow()
if early_terminated and self._use_early_stopped is False:
ho_trial["state"] = hpo.base.JOB_STATE_ERROR
ho_trial["misc"]["error"] = (str(TuneError), "Tune Removed")
if not ho_trial:
return
ho_trial["refresh_time"] = hpo.utils.coarse_utcnow()
ho_trial["state"] = hpo.base.JOB_STATE_DONE
hp_result = self._to_hyperopt_result(result)
@@ -204,7 +184,7 @@ class HyperOptSearch(SuggestionAlgorithm):
self._hpopt_trials.refresh()
def _to_hyperopt_result(self, result):
return {"loss": self._metric_op * result[self._metric], "status": "ok"}
return {"loss": self.metric_op * result[self.metric], "status": "ok"}
def _get_hyperopt_trial(self, trial_id):
if trial_id not in self._live_trial_mapping:
@@ -214,9 +194,6 @@ class HyperOptSearch(SuggestionAlgorithm):
t for t in self._hpopt_trials.trials if t["tid"] == hyperopt_tid
][0]
def _num_live_trials(self):
return len(self._live_trial_mapping)
def save(self, checkpoint_dir):
trials_object = (self._hpopt_trials, self.rstate.get_state())
with open(checkpoint_dir, "wb") as outputFile:
+18 -44
View File
@@ -5,15 +5,16 @@ try:
except ImportError:
ng = None
from ray.tune.suggest.suggestion import SuggestionAlgorithm
from ray.tune.suggest import Searcher
logger = logging.getLogger(__name__)
class NevergradSearch(SuggestionAlgorithm):
class NevergradSearch(Searcher):
"""A wrapper around Nevergrad to provide trial suggestions.
Requires Nevergrad to be installed.
Nevergrad is an open source tool from Facebook for derivative free
optimization of parameters and/or hyperparameters. It features a wide
range of optimizers in a standard ask and tell interface. More information
@@ -26,20 +27,20 @@ class NevergradSearch(SuggestionAlgorithm):
the dimension of the optimizer output. Alternatively, set to None
if the optimizer is already instrumented with kwargs
(see nevergrad v0.2.0+).
max_concurrent (int): Number of maximum concurrent trials. Defaults
to 10.
metric (str): The training result objective value attribute.
mode (str): One of {min, max}. Determines whether objective is
minimizing or maximizing the metric attribute.
use_early_stopped_trials (bool): Whether to use early terminated
trial results in the optimization process.
use_early_stopped_trials: Deprecated.
max_concurrent: Deprecated.
Example:
>>> from nevergrad.optimization import optimizerlib
>>> instrumentation = 1
>>> optimizer = optimizerlib.OnePlusOne(instrumentation, budget=100)
>>> algo = NevergradSearch(optimizer, ["lr"], max_concurrent=4,
>>> metric="mean_loss", mode="min")
.. code-block:: python
from nevergrad.optimization import optimizerlib
instrumentation = 1
optimizer = optimizerlib.OnePlusOne(instrumentation, budget=100)
algo = NevergradSearch(
optimizer, ["lr"], metric="mean_loss", mode="min")
Note:
In nevergrad v0.2.0+, optimizers can be instrumented.
@@ -51,34 +52,21 @@ class NevergradSearch(SuggestionAlgorithm):
>>> lr = inst.var.Array(1).bounded(1, 2).asfloat()
>>> instrumentation = inst.Instrumentation(lr=lr)
>>> optimizer = optimizerlib.OnePlusOne(instrumentation, budget=100)
>>> algo = NevergradSearch(optimizer, None, max_concurrent=4,
>>> metric="mean_loss", mode="min")
>>> algo = NevergradSearch(
optimizer, None, metric="mean_loss", mode="min")
"""
def __init__(self,
optimizer,
parameter_names,
max_concurrent=10,
reward_attr=None,
metric="episode_reward_mean",
mode="max",
**kwargs):
assert ng is not None, "Nevergrad must be installed!"
assert type(max_concurrent) is int and max_concurrent > 0
assert mode in ["min", "max"], "`mode` must be 'min' or 'max'!"
if reward_attr is not None:
mode = "max"
metric = reward_attr
logger.warning(
"`reward_attr` is deprecated and will be removed in a future "
"version of Tune. "
"Setting `metric={}` and `mode=max`.".format(reward_attr))
self._max_concurrent = max_concurrent
self._parameters = parameter_names
self._metric = metric
# nevergrad.tell internally minimizes, so "max" => -1
if mode == "max":
self._metric_op = -1.
@@ -110,8 +98,6 @@ class NevergradSearch(SuggestionAlgorithm):
"dimension for non-instrumented optimizers")
def suggest(self, trial_id):
if self._num_live_trials() >= self._max_concurrent:
return None
suggested_config = self._nevergrad_opt.ask()
self._live_trial_mapping[trial_id] = suggested_config
# in v0.2.0+, output of ask() is a Candidate,
@@ -122,14 +108,7 @@ class NevergradSearch(SuggestionAlgorithm):
else:
return suggested_config.kwargs
def on_trial_result(self, trial_id, result):
pass
def on_trial_complete(self,
trial_id,
result=None,
error=False,
early_terminated=False):
def on_trial_complete(self, trial_id, result=None, error=False):
"""Notification for the completion of trial.
The result is internally negated when interacting with Nevergrad
@@ -137,20 +116,15 @@ class NevergradSearch(SuggestionAlgorithm):
as it minimizes on default.
"""
if result:
self._process_result(trial_id, result, early_terminated)
self._process_result(trial_id, result)
self._live_trial_mapping.pop(trial_id)
def _process_result(self, trial_id, result, early_terminated=False):
if early_terminated and self._use_early_stopped is False:
return
def _process_result(self, trial_id, result):
ng_trial_info = self._live_trial_mapping[trial_id]
self._nevergrad_opt.tell(ng_trial_info,
self._metric_op * result[self._metric])
def _num_live_trials(self):
return len(self._live_trial_mapping)
def save(self, checkpoint_dir):
trials_object = (self._nevergrad_opt, self._parameters)
with open(checkpoint_dir, "wb") as outputFile:
+36 -35
View File
@@ -1,10 +1,8 @@
import copy
import itertools
import logging
import numpy as np
from ray.tune.suggest.suggestion import SuggestionAlgorithm
from ray.tune.experiment import convert_to_experiment_list
from ray.tune.suggest.suggestion import Searcher
logger = logging.getLogger(__name__)
@@ -12,6 +10,16 @@ TRIAL_INDEX = "__trial_index__"
"""str: A constant value representing the repeat index of the trial."""
def _warn_num_samples(searcher, num_samples):
if isinstance(searcher, Repeater) and num_samples % searcher.repeat:
logger.warning(
"`num_samples` is now expected to be the total number of trials, "
"including the repeat trials. For example, set num_samples=15 if "
"you intend to obtain 3 search algorithm suggestions and repeat "
"each suggestion 5 times. Any leftover trials "
"(num_samples mod repeat) will be ignored.")
class _TrialGroup:
"""Internal class for grouping trials of same parameters.
@@ -57,17 +65,22 @@ class _TrialGroup:
return len(self._trials)
class Repeater(SuggestionAlgorithm):
class Repeater(Searcher):
"""A wrapper algorithm for repeating trials of same parameters.
Set tune.run(num_samples=...) to be a multiple of `repeat`. For example,
set num_samples=15 if you intend to obtain 3 search algorithm suggestions
and repeat each suggestion 5 times. Any leftover trials
(num_samples mod repeat) will be ignored.
It is recommended that you do not run an early-stopping TrialScheduler
simultaneously.
Args:
search_alg (SearchAlgorithm): SearchAlgorithm object that the
Repeater will optimize. Note that the SearchAlgorithm
searcher (Searcher): Searcher object that the
Repeater will optimize. Note that the Searcher
will only see 1 trial among multiple repeated trials.
The result/metric passed to the SearchAlgorithm upon
The result/metric passed to the Searcher upon
trial completion will be averaged among all repeats.
repeat (int): Number of times to generate a trial with a repeated
configuration. Defaults to 1.
@@ -77,41 +90,23 @@ class Repeater(SuggestionAlgorithm):
"""
def __init__(self, search_alg, repeat=1, set_index=True):
self.search_alg = search_alg
self._repeat = repeat
def __init__(self, searcher, repeat=1, set_index=True):
self.searcher = searcher
self.repeat = repeat
self._set_index = set_index
self._groups = []
self._trial_id_to_group = {}
self._current_group = None
super(Repeater, self).__init__(
metric=self.search_alg.metric,
mode=self.search_alg.mode,
use_early_stopped_trials=self.search_alg._use_early_stopped)
def add_configurations(self, experiments):
"""Chains generator given experiment specifications.
Multiplies the number of trials by the repeat factor.
Arguments:
experiments (Experiment | list | dict): Experiments to run.
"""
experiment_list = convert_to_experiment_list(experiments)
for experiment in experiment_list:
self._trial_generator = itertools.chain(
self._trial_generator,
self._generate_trials(
experiment.spec.get("num_samples", 1) * self._repeat,
experiment.spec, experiment.name))
metric=self.searcher.metric, mode=self.searcher.mode)
def suggest(self, trial_id):
if self._current_group is None or self._current_group.full():
config = self.search_alg.suggest(trial_id)
config = self.searcher.suggest(trial_id)
if config is None:
return config
self._current_group = _TrialGroup(
trial_id, copy.deepcopy(config), max_trials=self._repeat)
trial_id, copy.deepcopy(config), max_trials=self.repeat)
self._groups.append(self._current_group)
index_in_group = 0
else:
@@ -139,15 +134,21 @@ class Repeater(SuggestionAlgorithm):
"Seen trials: {}".format(
trial_id, list(self._trial_id_to_group)))
trial_group = self._trial_id_to_group[trial_id]
if not result or self.search_alg.metric not in result:
if not result or self.searcher.metric not in result:
score = np.nan
else:
score = result[self.search_alg.metric]
score = result[self.searcher.metric]
trial_group.report(trial_id, score)
if trial_group.finished_reporting():
scores = trial_group.scores()
self.search_alg.on_trial_complete(
self.searcher.on_trial_complete(
trial_group.primary_trial_id,
result={self.search_alg.metric: np.nanmean(scores)},
result={self.searcher.metric: np.nanmean(scores)},
**kwargs)
def save(self, path):
self.searcher.save(path)
def restore(self, path):
self.searcher.restore(path)
+1 -7
View File
@@ -38,11 +38,7 @@ class SearchAlgorithm:
"""
pass
def on_trial_complete(self,
trial_id,
result=None,
error=False,
early_terminated=False):
def on_trial_complete(self, trial_id, result=None, error=False):
"""Notification for the completion of trial.
Arguments:
@@ -53,8 +49,6 @@ class SearchAlgorithm:
by manual termination.
error (bool): Defaults to False. True if the trial is in
the RUNNING state and errors.
early_terminated (bool): Defaults to False. True if the trial
is stopped while in PAUSED or PENDING state.
"""
pass
+6 -28
View File
@@ -7,17 +7,19 @@ try:
except ImportError:
sgo = None
from ray.tune.suggest.suggestion import SuggestionAlgorithm
from ray.tune.suggest import Searcher
logger = logging.getLogger(__name__)
class SigOptSearch(SuggestionAlgorithm):
class SigOptSearch(Searcher):
"""A wrapper around SigOpt to provide trial suggestions.
Requires SigOpt to be installed. Requires user to store their SigOpt
API key locally as an environment variable at `SIGOPT_KEY`.
This module manages its own concurrency.
Parameters:
space (list of dict): SigOpt configuration. Parameters will be sampled
from this configuration and will be used to override
@@ -70,17 +72,6 @@ class SigOptSearch(SuggestionAlgorithm):
"SigOpt API key must be stored as environ variable at SIGOPT_KEY"
assert mode in ["min", "max"], "`mode` must be 'min' or 'max'!"
if reward_attr is not None:
mode = "max"
metric = reward_attr
logger.warning(
"`reward_attr` is deprecated and will be removed in a future "
"version of Tune. "
"Setting `metric={}` and `mode=max`.".format(reward_attr))
if "use_early_stopped_trials" in kwargs:
logger.warning(
"`use_early_stopped_trials` is not used in SigOptSearch.")
self._max_concurrent = max_concurrent
self._metric = metric
if mode == "max":
@@ -101,9 +92,6 @@ class SigOptSearch(SuggestionAlgorithm):
super(SigOptSearch, self).__init__(metric=metric, mode=mode, **kwargs)
def suggest(self, trial_id):
if self._num_live_trials() >= self._max_concurrent:
return None
# Get new suggestion from SigOpt
suggestion = self.conn.experiments(
self.experiment.id).suggestions().create()
@@ -112,14 +100,7 @@ class SigOptSearch(SuggestionAlgorithm):
return copy.deepcopy(suggestion.assignments)
def on_trial_result(self, trial_id, result):
pass
def on_trial_complete(self,
trial_id,
result=None,
error=False,
early_terminated=False):
def on_trial_complete(self, trial_id, result=None, error=False):
"""Notification for the completion of trial.
If a trial fails, it will be reported as a failed Observation, telling
@@ -135,15 +116,12 @@ class SigOptSearch(SuggestionAlgorithm):
)
# Update the experiment object
self.experiment = self.conn.experiments(self.experiment.id).fetch()
elif error or early_terminated:
elif error:
# Reports a failed Observation
self.conn.experiments(self.experiment.id).observations().create(
failed=True, suggestion=self._live_trial_mapping[trial_id].id)
del self._live_trial_mapping[trial_id]
def _num_live_trials(self):
return len(self._live_trial_mapping)
def save(self, checkpoint_dir):
trials_object = (self.conn, self.experiment)
with open(checkpoint_dir, "wb") as outputFile:
+14 -40
View File
@@ -5,7 +5,7 @@ try:
except ImportError:
sko = None
from ray.tune.suggest.suggestion import SuggestionAlgorithm
from ray.tune.suggest import Searcher
logger = logging.getLogger(__name__)
@@ -40,7 +40,7 @@ def _validate_warmstart(parameter_names, points_to_evaluate,
" do not match.")
class SkOptSearch(SuggestionAlgorithm):
class SkOptSearch(Searcher):
"""A wrapper around skopt to provide trial suggestions.
Requires skopt to be installed.
@@ -50,8 +50,6 @@ class SkOptSearch(SuggestionAlgorithm):
from skopt.
parameter_names (list): List of parameter names. Should match
the dimension of the optimizer output.
max_concurrent (int): Number of maximum concurrent trials. Defaults
to 10.
metric (str): The training result objective value attribute.
mode (str): One of {min, max}. Determines whether objective is
minimizing or maximizing the metric attribute.
@@ -66,8 +64,8 @@ class SkOptSearch(SuggestionAlgorithm):
as a list so the optimiser can be told the results without
needing to re-compute the trial. Must be the same length as
points_to_evaluate. (See tune/examples/skopt_example.py)
use_early_stopped_trials (bool): Whether to use early terminated
trial results in the optimization process.
max_concurrent: Deprecated.
use_early_stopped_trials: Deprecated.
Example:
>>> from skopt import Optimizer
@@ -75,7 +73,6 @@ class SkOptSearch(SuggestionAlgorithm):
>>> current_best_params = [[10, 0], [15, -20]]
>>> algo = SkOptSearch(optimizer,
>>> ["width", "height"],
>>> max_concurrent=4,
>>> metric="mean_loss",
>>> mode="min",
>>> points_to_evaluate=current_best_params)
@@ -84,37 +81,30 @@ class SkOptSearch(SuggestionAlgorithm):
def __init__(self,
optimizer,
parameter_names,
max_concurrent=10,
reward_attr=None,
metric="episode_reward_mean",
mode="max",
points_to_evaluate=None,
evaluated_rewards=None,
**kwargs):
max_concurrent=None,
use_early_stopped_trials=None):
assert sko is not None, """skopt must be installed!
You can install Skopt with the command:
`pip install scikit-optimize`."""
assert type(max_concurrent) is int and max_concurrent > 0
_validate_warmstart(parameter_names, points_to_evaluate,
evaluated_rewards)
assert mode in ["min", "max"], "`mode` must be 'min' or 'max'!"
if reward_attr is not None:
mode = "max"
metric = reward_attr
logger.warning(
"`reward_attr` is deprecated and will be removed in a future "
"version of Tune. "
"Setting `metric={}` and `mode=max`.".format(reward_attr))
super(SkOptSearch, self).__init__(
metric=metric,
mode=mode,
max_concurrent=max_concurrent,
use_early_stopped_trials=use_early_stopped_trials)
self._initial_points = []
if points_to_evaluate and evaluated_rewards:
optimizer.tell(points_to_evaluate, evaluated_rewards)
elif points_to_evaluate:
self._initial_points = points_to_evaluate
self._max_concurrent = max_concurrent
self._parameters = parameter_names
self._metric = metric
# Skopt internally minimizes, so "max" => -1
if mode == "max":
self._metric_op = -1.
@@ -122,12 +112,8 @@ class SkOptSearch(SuggestionAlgorithm):
self._metric_op = 1.
self._skopt_opt = optimizer
self._live_trial_mapping = {}
super(SkOptSearch, self).__init__(
metric=self._metric, mode=mode, **kwargs)
def suggest(self, trial_id):
if self._num_live_trials() >= self._max_concurrent:
return None
if self._initial_points:
suggested_config = self._initial_points[0]
del self._initial_points[0]
@@ -136,14 +122,7 @@ class SkOptSearch(SuggestionAlgorithm):
self._live_trial_mapping[trial_id] = suggested_config
return dict(zip(self._parameters, suggested_config))
def on_trial_result(self, trial_id, result):
pass
def on_trial_complete(self,
trial_id,
result=None,
error=False,
early_terminated=False):
def on_trial_complete(self, trial_id, result=None, error=False):
"""Notification for the completion of trial.
The result is internally negated when interacting with Skopt
@@ -152,19 +131,14 @@ class SkOptSearch(SuggestionAlgorithm):
"""
if result:
self._process_result(trial_id, result, early_terminated)
self._process_result(trial_id, result)
self._live_trial_mapping.pop(trial_id)
def _process_result(self, trial_id, result, early_terminated=False):
if early_terminated and self._use_early_stopped is False:
return
def _process_result(self, trial_id, result):
skopt_trial_info = self._live_trial_mapping[trial_id]
self._skopt_opt.tell(skopt_trial_info,
self._metric_op * result[self._metric])
def _num_live_trials(self):
return len(self._live_trial_mapping)
def save(self, checkpoint_dir):
trials_object = (self._initial_points, self._skopt_opt)
with open(checkpoint_dir, "wb") as outputFile:
+251 -102
View File
@@ -1,5 +1,5 @@
import itertools
import copy
import logging
from ray.tune.error import TuneError
from ray.tune.experiment import convert_to_experiment_list
@@ -9,129 +9,127 @@ from ray.tune.suggest.variant_generator import format_vars, resolve_nested_dict
from ray.tune.trial import Trial
from ray.tune.utils import merge_dicts, flatten_dict
logger = logging.getLogger(__name__)
class SuggestionAlgorithm(SearchAlgorithm):
"""Abstract class for suggestion-based algorithms.
Custom search algorithms can extend this class easily by overriding the
def _warn_on_repeater(searcher, total_samples):
from ray.tune.suggest.repeater import _warn_num_samples
_warn_num_samples(searcher, total_samples)
class Searcher:
"""Abstract class for wrapping suggesting algorithms.
Custom algorithms can extend this class easily by overriding the
`suggest` method provide generated parameters for the trials.
Any subclass that implements ``__init__`` must also call the
constructor of this class: ``super(Subclass, self).__init__(...)``.
To track suggestions and their corresponding evaluations, the method
`suggest` will be passed a trial_id, which will be used in
subsequent notifications.
Args:
metric (str): The training result objective value attribute.
mode (str): One of {min, max}. Determines whether objective is
minimizing or maximizing the metric attribute.
.. code-block:: python
suggester = SuggestionAlgorithm()
suggester.add_configurations({ ... })
new_parameters = suggester.suggest()
suggester.on_trial_complete(trial_id, result)
better_parameters = suggester.suggest()
class ExampleSearch(Searcher):
def __init__(self, metric="mean_loss", mode="min", **kwargs):
super(ExampleSearch, self).__init__(
metric=metric, mode=mode, **kwargs)
self.optimizer = Optimizer()
self.configurations = {}
def suggest(self, trial_id):
configuration = self.optimizer.query()
self.configurations[trial_id] = configuration
def on_trial_complete(self, trial_id, result, **kwargs):
configuration = self.configurations[trial_id]
if result and self.metric in result:
self.optimizer.update(configuration, result[self.metric])
tune.run(trainable_function, search_alg=ExampleSearch())
"""
def __init__(self, metric=None, mode="max", use_early_stopped_trials=True):
"""Constructs a generator given experiment specifications."""
self._parser = make_parser()
self._trial_generator = []
self._counter = 0
def __init__(self,
metric="episode_reward_mean",
mode="max",
max_concurrent=None,
use_early_stopped_trials=None):
if use_early_stopped_trials is False:
raise DeprecationWarning(
"Early stopped trials are now always used. If this is a "
"problem, file an issue: https://github.com/ray-project/ray.")
if max_concurrent is not None:
raise DeprecationWarning(
"max_concurrent is now deprecated for this search algorithm. "
"Please use tune.suggest.ConcurrencyLimiter instead.")
assert mode in ["min", "max"], "`mode` must be 'min' or 'max'!"
self._metric = metric
assert mode in ["min", "max"]
self._mode = mode
self._use_early_stopped = use_early_stopped_trials
self._finished = False
def add_configurations(self, experiments):
"""Chains generator given experiment specifications.
def on_trial_result(self, trial_id, result):
"""Optional notification for result during training.
Arguments:
experiments (Experiment | list | dict): Experiments to run.
Note that by default, the result dict may include NaNs or
may not include the optimization metric. It is up to the
subclass implementation to preprocess the result to
avoid breaking the optimization process.
Args:
trial_id (str): A unique string ID for the trial.
result (dict): Dictionary of metrics for current training progress.
Note that the result dict may include NaNs or
may not include the optimization metric. It is up to the
subclass implementation to preprocess the result to
avoid breaking the optimization process.
"""
experiment_list = convert_to_experiment_list(experiments)
for experiment in experiment_list:
self._trial_generator = itertools.chain(
self._trial_generator,
self._generate_trials(
experiment.spec.get("num_samples", 1), experiment.spec,
experiment.name))
pass
def next_trials(self):
"""Provides a batch of Trial objects to be queued into the TrialRunner.
def on_trial_complete(self, trial_id, result=None, error=False):
"""Notification for the completion of trial.
A batch ends when self._trial_generator returns None.
Typically, this method is used for notifying the underlying
optimizer of the result.
Args:
trial_id (str): A unique string ID for the trial.
result (dict): Dictionary of metrics for current training progress.
Note that the result dict may include NaNs or
may not include the optimization metric. It is up to the
subclass implementation to preprocess the result to
avoid breaking the optimization process. Upon errors, this
may also be None.
error (bool): True if the training process raised an error.
Returns:
trials (list): Returns a list of trials.
"""
trials = []
for trial in self._trial_generator:
if trial is None:
return trials
trials += [trial]
self.set_finished()
return trials
def _generate_trials(self, num_samples, experiment_spec, output_path=""):
"""Generates trials with configurations from `suggest`.
Creates a trial_id that is passed into `suggest`.
Yields:
Trial objects constructed according to `spec`
"""
if "run" not in experiment_spec:
raise TuneError("Must specify `run` in {}".format(experiment_spec))
for _ in range(num_samples):
trial_id = Trial.generate_id()
while True:
suggested_config = self.suggest(trial_id)
if suggested_config is None:
yield None
else:
break
spec = copy.deepcopy(experiment_spec)
spec["config"] = merge_dicts(spec["config"],
copy.deepcopy(suggested_config))
flattened_config = resolve_nested_dict(spec["config"])
self._counter += 1
tag = "{0}_{1}".format(
str(self._counter), format_vars(flattened_config))
yield create_trial_from_spec(
spec,
output_path,
self._parser,
evaluated_params=flatten_dict(suggested_config),
experiment_tag=tag,
trial_id=trial_id)
raise NotImplementedError
def suggest(self, trial_id):
"""Queries the algorithm to retrieve the next set of parameters.
Arguments:
trial_id: Trial ID used for subsequent notifications.
trial_id (str): Trial ID used for subsequent notifications.
Returns:
dict|None: Configuration for a trial, if possible.
Else, returns None, which will temporarily stop the
TrialRunner from querying.
Example:
>>> suggester = SuggestionAlgorithm(max_concurrent=1)
>>> suggester.add_configurations({ ... })
>>> parameters_1 = suggester.suggest()
>>> parameters_2 = suggester.suggest()
>>> parameters_2 is None
>>> suggester.on_trial_complete(trial_id, result)
>>> parameters_2 = suggester.suggest()
>>> parameters_2 is not None
"""
raise NotImplementedError
def save(self, checkpoint_dir):
"""Save function for this object."""
raise NotImplementedError
def restore(self, checkpoint_dir):
"""Restore function for this object."""
raise NotImplementedError
@property
@@ -145,18 +143,156 @@ class SuggestionAlgorithm(SearchAlgorithm):
return self._mode
class _MockSuggestionAlgorithm(SuggestionAlgorithm):
def __init__(self, max_concurrent=2, **kwargs):
self._max_concurrent = max_concurrent
class ConcurrencyLimiter(Searcher):
"""A wrapper algorithm for limiting the number of concurrent trials.
Args:
searcher (Searcher): Searcher object that the
ConcurrencyLimiter will manage.
Example:
.. code-block:: python
from ray.tune.suggest import ConcurrencyLimiter
search_alg = HyperOptSearch(metric="accuracy")
search_alg = ConcurrencyLimiter(search_alg, max_concurrent=2)
tune.run(trainable, search_alg=search_alg)
"""
def __init__(self, searcher, max_concurrent):
assert type(max_concurrent) is int and max_concurrent > 0
self.searcher = searcher
self.max_concurrent = max_concurrent
self.live_trials = set()
super(ConcurrencyLimiter, self).__init__(
metric=self.searcher.metric, mode=self.searcher.mode)
def suggest(self, trial_id):
if len(self.live_trials) >= self.max_concurrent:
return
self.live_trials.add(trial_id)
return self.searcher.suggest(trial_id)
def on_trial_complete(self, trial_id, result=None, error=False):
if trial_id not in self.live_trials:
return
else:
self.searcher.on_trial_complete(
trial_id, result=result, error=error)
self.live_trials.remove(trial_id)
def save(self, checkpoint_dir):
self.searcher.save(checkpoint_dir)
def restore(self, checkpoint_dir):
self.searcher.restore(checkpoint_dir)
class SearchGenerator(SearchAlgorithm):
"""Generates trials to be passed to the TrialRunner.
Uses the provided ``searcher`` object to generate trials. This class
transparently handles repeating trials with score aggregation
without embedding logic into the Searcher.
Args:
searcher: Search object that subclasses the Searcher base class. This
is then used for generating new hyperparameter samples.
"""
def __init__(self, searcher):
assert issubclass(
type(searcher),
Searcher), ("Searcher should be subclassing Searcher.")
self.searcher = searcher
self._parser = make_parser()
self._experiment = None
self._counter = 0 # Keeps track of number of trials created.
self._total_samples = 0 # int: total samples to evaluate.
self._finished = False
def add_configurations(self, experiments):
"""Registers experiment specifications.
Arguments:
experiments (Experiment | list | dict): Experiments to run.
"""
logger.debug("added configurations")
experiment_list = convert_to_experiment_list(experiments)
assert len(experiment_list) == 1, (
"SearchAlgorithms can only support 1 experiment at a time.")
self._experiment = experiment_list[0]
experiment_spec = self._experiment.spec
self._total_samples = experiment_spec.get("num_samples", 1)
_warn_on_repeater(self.searcher, self._total_samples)
if "run" not in experiment_spec:
raise TuneError("Must specify `run` in {}".format(experiment_spec))
def next_trials(self):
"""Provides a batch of Trial objects to be queued into the TrialRunner.
Returns:
List[Trial]: A list of trials for the Runner to consume.
"""
trials = []
while not self.is_finished():
trial = self.create_trial_if_possible(self._experiment.spec,
self._experiment.name)
if trial is None:
break
trials.append(trial)
return trials
def create_trial_if_possible(self, experiment_spec, output_path):
logger.debug("creating trial")
trial_id = Trial.generate_id()
suggested_config = self.searcher.suggest(trial_id)
if suggested_config is None:
return
spec = copy.deepcopy(experiment_spec)
spec["config"] = merge_dicts(spec["config"],
copy.deepcopy(suggested_config))
# Create a new trial_id if duplicate trial is created
flattened_config = resolve_nested_dict(spec["config"])
self._counter += 1
tag = "{0}_{1}".format(
str(self._counter), format_vars(flattened_config))
trial = create_trial_from_spec(
spec,
output_path,
self._parser,
evaluated_params=flatten_dict(suggested_config),
experiment_tag=tag,
trial_id=trial_id)
return trial
def on_trial_result(self, trial_id, result):
"""Notifies the underlying searcher."""
self.searcher.on_trial_result(trial_id, result)
def on_trial_complete(self, trial_id, result=None, error=False):
self.searcher.on_trial_complete(
trial_id=trial_id, result=result, error=error)
def is_finished(self):
return self._counter >= self._total_samples or self._finished
class _MockSearcher(Searcher):
def __init__(self, **kwargs):
self.live_trials = {}
self.counter = {"result": 0, "complete": 0}
self.final_results = []
self.stall = False
self.results = []
super(_MockSuggestionAlgorithm, self).__init__(**kwargs)
super(_MockSearcher, self).__init__(**kwargs)
def suggest(self, trial_id):
if len(self.live_trials) < self._max_concurrent and not self.stall:
if not self.stall:
self.live_trials[trial_id] = 1
return {"test_variable": 2}
return None
@@ -165,16 +301,29 @@ class _MockSuggestionAlgorithm(SuggestionAlgorithm):
self.counter["result"] += 1
self.results += [result]
def on_trial_complete(self,
trial_id,
result=None,
error=False,
early_terminated=False):
def on_trial_complete(self, trial_id, result=None, error=False):
self.counter["complete"] += 1
if result:
self._process_result(result, early_terminated)
del self.live_trials[trial_id]
self._process_result(result)
if trial_id in self.live_trials:
del self.live_trials[trial_id]
def _process_result(self, result, early_terminated):
if early_terminated and self._use_early_stopped:
self.final_results += [result]
def _process_result(self, result):
self.final_results += [result]
class _MockSuggestionAlgorithm(SearchGenerator):
def __init__(self, max_concurrent=None, **kwargs):
self.searcher = _MockSearcher(**kwargs)
if max_concurrent:
self.searcher = ConcurrencyLimiter(
self.searcher, max_concurrent=max_concurrent)
super(_MockSuggestionAlgorithm, self).__init__(self.searcher)
@property
def live_trials(self):
return self.searcher.live_trials
@property
def results(self):
return self.searcher.results
+3 -27
View File
@@ -7,12 +7,12 @@ try:
except ImportError:
zoopt = None
from ray.tune.suggest.suggestion import SuggestionAlgorithm
from ray.tune.suggest import Searcher
logger = logging.getLogger(__name__)
class ZOOptSearch(SuggestionAlgorithm):
class ZOOptSearch(Searcher):
"""A wrapper around ZOOpt to provide trial suggestions.
Requires zoopt package (>=0.4.0) to be installed. You can install it
@@ -26,8 +26,6 @@ class ZOOptSearch(SuggestionAlgorithm):
For continuous dimensions: (continuous, search_range, precision);
For discrete dimensions: (discrete, search_range, has_order).
More details can be found in zoopt package.
max_concurrent (int): Number of maximum concurrent trials.
Defaults to 10.
metric (str): The training result objective value attribute.
Defaults to "episode_reward_mean".
mode (str): One of {min, max}. Determines whether objective is
@@ -59,7 +57,6 @@ class ZOOptSearch(SuggestionAlgorithm):
algo="Asracos", # only support Asracos currently
budget=config["num_samples"],
dim_dict=dim_dict,
max_concurrent=4,
metric="mean_loss",
mode="min")
@@ -76,20 +73,17 @@ class ZOOptSearch(SuggestionAlgorithm):
algo="asracos",
budget=None,
dim_dict=None,
max_concurrent=10,
metric="episode_reward_mean",
mode="min",
**kwargs):
assert zoopt is not None, "Zoopt not found - please install zoopt."
assert budget is not None, "`budget` should not be None!"
assert dim_dict is not None, "`dim_list` should not be None!"
assert type(max_concurrent) is int and max_concurrent > 0
assert mode in ["min", "max"], "`mode` must be 'min' or 'max'!"
_algo = algo.lower()
assert _algo in ["asracos", "sracos"
], "`algo` must be in ['asracos', 'sracos'] currently"
self._max_concurrent = max_concurrent
self._metric = metric
if mode == "max":
self._metric_op = -1.
@@ -116,9 +110,6 @@ class ZOOptSearch(SuggestionAlgorithm):
metric=self._metric, mode=mode, **kwargs)
def suggest(self, trial_id):
if self._num_live_trials() >= self._max_concurrent:
return None
_solution = self.optimizer.suggest()
if _solution:
self.solution_dict[str(trial_id)] = _solution
@@ -127,14 +118,7 @@ class ZOOptSearch(SuggestionAlgorithm):
self._live_trial_mapping[trial_id] = new_trial
return copy.deepcopy(new_trial)
def on_trial_result(self, trial_id, result):
pass
def on_trial_complete(self,
trial_id,
result=None,
error=False,
early_terminated=False):
def on_trial_complete(self, trial_id, result=None, error=False):
"""Notification for the completion of trial."""
if result:
_solution = self.solution_dict[str(trial_id)]
@@ -142,17 +126,9 @@ class ZOOptSearch(SuggestionAlgorithm):
_solution, self._metric_op * result[self._metric])
if _best_solution_so_far:
self.best_solution_list.append(_best_solution_so_far)
self._process_result(trial_id, result, early_terminated)
del self._live_trial_mapping[trial_id]
def _process_result(self, trial_id, result, early_terminated=False):
if early_terminated and self._use_early_stopped is False:
return
def _num_live_trials(self):
return len(self._live_trial_mapping)
def save(self, checkpoint_dir):
trials_object = self.optimizer
with open(checkpoint_dir, "wb") as output:
+56 -62
View File
@@ -15,7 +15,7 @@ from ray.tune.trial_runner import TrialRunner
from ray.tune.resources import Resources, json_to_resources, resources_to_json
from ray.tune.suggest.repeater import Repeater
from ray.tune.suggest.suggestion import (_MockSuggestionAlgorithm,
SuggestionAlgorithm)
SearchGenerator, Searcher)
class TrialRunnerTest3(unittest.TestCase):
@@ -30,11 +30,11 @@ class TrialRunnerTest3(unittest.TestCase):
def on_step_begin(self, trialrunner):
self._update_avail_resources()
cnt = self.pre_step if hasattr(self, "pre_step") else 0
setattr(self, "pre_step", cnt + 1)
self.pre_step = cnt + 1
def on_step_end(self, trialrunner):
cnt = self.pre_step if hasattr(self, "post_step") else 0
setattr(self, "post_step", 1 + cnt)
self.post_step = 1 + cnt
import types
runner.trial_executor.on_step_begin = types.MethodType(
@@ -101,9 +101,10 @@ class TrialRunnerTest3(unittest.TestCase):
ray.init(num_cpus=4, num_gpus=2)
experiment_spec = {"run": "__fake", "stop": {"training_iteration": 2}}
experiments = [Experiment.from_json("test", experiment_spec)]
searcher = _MockSuggestionAlgorithm(max_concurrent=10)
searcher.add_configurations(experiments)
runner = TrialRunner(search_alg=searcher)
search_alg = _MockSuggestionAlgorithm()
searcher = search_alg.searcher
search_alg.add_configurations(experiments)
runner = TrialRunner(search_alg=search_alg)
runner.step()
trials = runner.get_trials()
self.assertEqual(trials[0].status, Trial.RUNNING)
@@ -122,7 +123,7 @@ class TrialRunnerTest3(unittest.TestCase):
ray.init(num_cpus=4, num_gpus=2)
experiment_spec = {"run": "__fake", "stop": {"training_iteration": 1}}
experiments = [Experiment.from_json("test", experiment_spec)]
searcher = _MockSuggestionAlgorithm(max_concurrent=10)
searcher = _MockSuggestionAlgorithm()
searcher.add_configurations(experiments)
runner = TrialRunner(search_alg=searcher)
runner.step()
@@ -147,7 +148,7 @@ class TrialRunnerTest3(unittest.TestCase):
ray.init(num_cpus=4, num_gpus=2)
experiment_spec = {"run": "__fake", "stop": {"training_iteration": 2}}
experiments = [Experiment.from_json("test", experiment_spec)]
searcher = _MockSuggestionAlgorithm(max_concurrent=10)
searcher = _MockSuggestionAlgorithm()
searcher.add_configurations(experiments)
runner = TrialRunner(search_alg=searcher, scheduler=_MockScheduler())
runner.step()
@@ -162,30 +163,6 @@ class TrialRunnerTest3(unittest.TestCase):
self.assertTrue(searcher.is_finished())
self.assertTrue(runner.is_finished())
def testSearchAlgSchedulerEarlyStop(self):
"""Early termination notif to Searcher can be turned off."""
class _MockScheduler(FIFOScheduler):
def on_trial_result(self, *args, **kwargs):
return TrialScheduler.STOP
ray.init(num_cpus=4, num_gpus=2)
experiment_spec = {"run": "__fake", "stop": {"training_iteration": 2}}
experiments = [Experiment.from_json("test", experiment_spec)]
searcher = _MockSuggestionAlgorithm(use_early_stopped_trials=True)
searcher.add_configurations(experiments)
runner = TrialRunner(search_alg=searcher, scheduler=_MockScheduler())
runner.step()
runner.step()
self.assertEqual(len(searcher.final_results), 1)
searcher = _MockSuggestionAlgorithm(use_early_stopped_trials=False)
searcher.add_configurations(experiments)
runner = TrialRunner(search_alg=searcher, scheduler=_MockScheduler())
runner.step()
runner.step()
self.assertEqual(len(searcher.final_results), 0)
def testSearchAlgStalled(self):
"""Checks that runner and searcher state is maintained when stalled."""
ray.init(num_cpus=4, num_gpus=2)
@@ -197,9 +174,10 @@ class TrialRunnerTest3(unittest.TestCase):
}
}
experiments = [Experiment.from_json("test", experiment_spec)]
searcher = _MockSuggestionAlgorithm(max_concurrent=1)
searcher.add_configurations(experiments)
runner = TrialRunner(search_alg=searcher)
search_alg = _MockSuggestionAlgorithm(max_concurrent=1)
search_alg.add_configurations(experiments)
searcher = search_alg.searcher
runner = TrialRunner(search_alg=search_alg)
runner.step()
trials = runner.get_trials()
self.assertEqual(trials[0].status, Trial.RUNNING)
@@ -219,7 +197,7 @@ class TrialRunnerTest3(unittest.TestCase):
self.assertEqual(len(searcher.live_trials), 0)
self.assertTrue(all(trial.is_finished() for trial in trials))
self.assertFalse(searcher.is_finished())
self.assertFalse(search_alg.is_finished())
self.assertFalse(runner.is_finished())
searcher.stall = False
@@ -232,25 +210,27 @@ class TrialRunnerTest3(unittest.TestCase):
runner.step()
self.assertEqual(trials[2].status, Trial.TERMINATED)
self.assertEqual(len(searcher.live_trials), 0)
self.assertTrue(searcher.is_finished())
self.assertTrue(search_alg.is_finished())
self.assertTrue(runner.is_finished())
def testSearchAlgFinishes(self):
"""Empty SearchAlg changing state in `next_trials` does not crash."""
class FinishFastAlg(SuggestionAlgorithm):
class FinishFastAlg(_MockSuggestionAlgorithm):
_index = 0
def next_trials(self):
spec = self._experiment.spec
trials = []
if self._index < spec["num_samples"]:
trial = Trial(
spec.get("run"), stopping_criterion=spec.get("stop"))
trials.append(trial)
self._index += 1
for trial in self._trial_generator:
trials += [trial]
break
if self._index > 4:
self.set_finished()
return trials
def suggest(self, trial_id):
@@ -406,7 +386,7 @@ class TrialRunnerTest3(unittest.TestCase):
tmpdir = tempfile.mkdtemp()
runner = TrialRunner(local_checkpoint_dir=tmpdir, checkpoint_period=0)
runner.add_trial(trial)
for i in range(5):
for _ in range(5):
runner.step()
# force checkpoint
runner.checkpoint()
@@ -427,14 +407,14 @@ class TrialRunnerTest3(unittest.TestCase):
tmpdir = tempfile.mkdtemp()
runner = TrialRunner(local_checkpoint_dir=tmpdir, checkpoint_period=0)
runner.add_trial(trial)
for i in range(5):
for _ in range(5):
runner.step()
# force checkpoint
runner.checkpoint()
self.assertEquals(count_checkpoints(tmpdir), 1)
runner2 = TrialRunner(resume="LOCAL", local_checkpoint_dir=tmpdir)
for i in range(5):
for _ in range(5):
runner2.step()
self.assertEquals(count_checkpoints(tmpdir), 2)
@@ -473,50 +453,64 @@ class SearchAlgorithmTest(unittest.TestCase):
_register_all()
def testNestedSuggestion(self):
class TestSuggestion(SuggestionAlgorithm):
class TestSuggestion(Searcher):
def suggest(self, trial_id):
return {"a": {"b": {"c": {"d": 4, "e": 5}}}}
alg = TestSuggestion()
searcher = TestSuggestion()
alg = SearchGenerator(searcher)
alg.add_configurations({"test": {"run": "__fake"}})
trial = alg.next_trials()[0]
self.assertTrue("e=5" in trial.experiment_tag)
self.assertTrue("d=4" in trial.experiment_tag)
def _test_repeater(self, repeat):
def _test_repeater(self, num_samples, repeat):
ray.init(num_cpus=4)
class TestSuggestion(SuggestionAlgorithm):
count = 0
class TestSuggestion(Searcher):
index = 0
def suggest(self, trial_id):
return {"test_variable": 5}
self.index += 1
return {"test_variable": 5 + self.index}
def on_trial_complete(self, *args, **kwargs):
self.count += 1
return
alg = TestSuggestion(metric="episode_reward_mean")
repeat_alg = Repeater(alg, repeat=repeat, set_index=False)
searcher = TestSuggestion(metric="episode_reward_mean")
repeat_searcher = Repeater(searcher, repeat=repeat, set_index=False)
alg = SearchGenerator(repeat_searcher)
experiment_spec = {
"run": "__fake",
"num_samples": 1,
"num_samples": num_samples,
"stop": {
"training_iteration": 1
}
}
repeat_alg.add_configurations({"test": experiment_spec})
runner = TrialRunner(search_alg=repeat_alg)
for i in range(repeat * 2):
alg.add_configurations({"test": experiment_spec})
runner = TrialRunner(search_alg=alg)
while not runner.is_finished():
runner.step()
trials = runner.get_trials()
self.assertEquals(len(trials), repeat)
return runner.get_trials()
def testRepeat1(self):
self._test_repeater(repeat=1)
trials = self._test_repeater(num_samples=2, repeat=1)
self.assertEquals(len(trials), 2)
parameter_set = {t.evaluated_params["test_variable"] for t in trials}
self.assertEquals(len(parameter_set), 2)
def testRepeat4(self):
self._test_repeater(repeat=4)
trials = self._test_repeater(num_samples=12, repeat=4)
self.assertEquals(len(trials), 12)
parameter_set = {t.evaluated_params["test_variable"] for t in trials}
self.assertEquals(len(parameter_set), 3)
def testOddRepeat(self):
trials = self._test_repeater(num_samples=11, repeat=5)
self.assertEquals(len(trials), 11)
parameter_set = {t.evaluated_params["test_variable"] for t in trials}
self.assertEquals(len(parameter_set), 3)
class ResourcesTest(unittest.TestCase):
@@ -643,6 +643,7 @@ class BOHBSuite(unittest.TestCase):
sched = HyperBandForBOHB(max_t=3, reduction_factor=3)
runner = _MockTrialRunner(sched)
runner._search_alg = MagicMock()
runner._search_alg.searcher = MagicMock()
trials = [Trial("__fake") for i in range(3)]
for t in trials:
runner.add_trial(t)
@@ -656,8 +657,8 @@ class BOHBSuite(unittest.TestCase):
decision = sched.on_trial_result(runner, trials[-1], spy_result)
self.assertEqual(decision, TrialScheduler.STOP)
sched.choose_trial_to_run(runner)
self.assertEqual(runner._search_alg.on_pause.call_count, 2)
self.assertEqual(runner._search_alg.on_unpause.call_count, 1)
self.assertEqual(runner._search_alg.searcher.on_pause.call_count, 2)
self.assertEqual(runner._search_alg.searcher.on_unpause.call_count, 1)
self.assertTrue("hyperband_info" in spy_result)
self.assertEquals(spy_result["hyperband_info"]["budget"], 1)
@@ -668,6 +669,7 @@ class BOHBSuite(unittest.TestCase):
sched = HyperBandForBOHB(max_t=3, reduction_factor=3, mode="min")
runner = _MockTrialRunner(sched)
runner._search_alg = MagicMock()
runner._search_alg.searcher = MagicMock()
trials = [Trial("__fake") for i in range(3)]
for t in trials:
runner.add_trial(t)
@@ -681,7 +683,7 @@ class BOHBSuite(unittest.TestCase):
decision = sched.on_trial_result(runner, trials[-1], spy_result)
self.assertEqual(decision, TrialScheduler.CONTINUE)
sched.choose_trial_to_run(runner)
self.assertEqual(runner._search_alg.on_pause.call_count, 2)
self.assertEqual(runner._search_alg.searcher.on_pause.call_count, 2)
self.assertTrue("hyperband_info" in spy_result)
self.assertEquals(spy_result["hyperband_info"]["budget"], 1)
+14 -14
View File
@@ -13,6 +13,7 @@ import ray
from ray import tune
from ray.test_utils import recursive_fnmatch
from ray.rllib import _register_all
from ray.tune.suggest import ConcurrencyLimiter
from ray.tune.suggest.hyperopt import HyperOptSearch
from ray.tune.suggest.bayesopt import BayesOptSearch
from ray.tune.suggest.skopt import SkOptSearch
@@ -132,7 +133,7 @@ class AutoInitTest(unittest.TestCase):
class AbstractWarmStartTest:
def setUp(self):
ray.init(local_mode=True)
ray.init(num_cpus=1, local_mode=True)
self.tmpdir = tempfile.mkdtemp()
def tearDown(self):
@@ -146,20 +147,26 @@ class AbstractWarmStartTest:
def run_exp_1(self):
np.random.seed(162)
search_alg, cost = self.set_basic_conf()
results_exp_1 = tune.run(cost, num_samples=5, search_alg=search_alg)
search_alg = ConcurrencyLimiter(search_alg, 1)
results_exp_1 = tune.run(
cost, num_samples=5, search_alg=search_alg, verbose=0)
self.log_dir = os.path.join(self.tmpdir, "warmStartTest.pkl")
search_alg.save(self.log_dir)
return results_exp_1
def run_exp_2(self):
search_alg2, cost = self.set_basic_conf()
search_alg2 = ConcurrencyLimiter(search_alg2, 1)
search_alg2.restore(self.log_dir)
return tune.run(cost, num_samples=5, search_alg=search_alg2)
return tune.run(cost, num_samples=5, search_alg=search_alg2, verbose=0)
def run_exp_3(self):
print("FULL RUN")
np.random.seed(162)
search_alg3, cost = self.set_basic_conf()
return tune.run(cost, num_samples=10, search_alg=search_alg3)
search_alg3 = ConcurrencyLimiter(search_alg3, 1)
return tune.run(
cost, num_samples=10, search_alg=search_alg3, verbose=0)
def testWarmStart(self):
results_exp_1 = self.run_exp_1()
@@ -185,10 +192,10 @@ class HyperoptWarmStartTest(AbstractWarmStartTest, unittest.TestCase):
search_alg = HyperOptSearch(
space,
max_concurrent=1,
metric="loss",
mode="min",
random_state_seed=5)
random_state_seed=5,
n_initial_points=1)
return search_alg, cost
@@ -201,7 +208,6 @@ class BayesoptWarmStartTest(AbstractWarmStartTest, unittest.TestCase):
search_alg = BayesOptSearch(
space,
max_concurrent=1,
metric="loss",
mode="min",
utility_kwargs={
@@ -223,7 +229,6 @@ class SkoptWarmStartTest(AbstractWarmStartTest, unittest.TestCase):
search_alg = SkOptSearch(
optimizer, ["width", "height"],
max_concurrent=1,
metric="loss",
mode="min",
points_to_evaluate=previously_run_params,
@@ -242,11 +247,7 @@ class NevergradWarmStartTest(AbstractWarmStartTest, unittest.TestCase):
mean_loss=(space["height"] - 14)**2 - abs(space["width"] - 3))
search_alg = NevergradSearch(
optimizer,
parameter_names,
max_concurrent=1,
metric="mean_loss",
mode="min")
optimizer, parameter_names, metric="mean_loss", mode="min")
return search_alg, cost
@@ -305,7 +306,6 @@ class ZOOptWarmStartTest(AbstractWarmStartTest, unittest.TestCase):
algo="Asracos", # only support ASRacos currently
budget=200,
dim_dict=dim_dict,
max_concurrent=1,
metric="loss",
mode="min")
+1 -32
View File
@@ -8,9 +8,7 @@ from ray.rllib import _register_all
from ray import tune
from ray.tune.result import DEFAULT_RESULTS_DIR
from ray.tune.experiment import Experiment
from ray.tune.suggest import grid_search, BasicVariantGenerator
from ray.tune.suggest.suggestion import _MockSuggestionAlgorithm
from ray.tune.suggest.variant_generator import (RecursiveDependencyError,
resolve_nested_dict)
@@ -301,36 +299,7 @@ class VariantGeneratorTest(unittest.TestCase):
except RecursiveDependencyError as e:
assert "`foo` recursively depends on" in str(e), e
else:
assert False
def testMaxConcurrentSuggestions(self):
"""Checks that next_trials() supports throttling."""
experiment_spec = {
"run": "PPO",
"num_samples": 6,
}
experiments = [Experiment.from_json("test", experiment_spec)]
searcher = _MockSuggestionAlgorithm(max_concurrent=4)
searcher.add_configurations(experiments)
trials = searcher.next_trials()
self.assertEqual(len(trials), 4)
self.assertEqual(searcher.next_trials(), [])
finished_trial = trials.pop()
searcher.on_trial_complete(finished_trial.trial_id)
self.assertEqual(len(searcher.next_trials()), 1)
finished_trial = trials.pop()
searcher.on_trial_complete(finished_trial.trial_id)
finished_trial = trials.pop()
searcher.on_trial_complete(finished_trial.trial_id)
finished_trial = trials.pop()
searcher.on_trial_complete(finished_trial.trial_id)
self.assertEqual(len(searcher.next_trials()), 1)
self.assertEqual(len(searcher.next_trials()), 0)
raise
if __name__ == "__main__":
+1 -2
View File
@@ -79,8 +79,7 @@ space = {
"momentum": hp.uniform("momentum", 0.1, 0.9),
}
hyperopt_search = HyperOptSearch(
space, max_concurrent=2, reward_attr="mean_accuracy")
hyperopt_search = HyperOptSearch(space, metric="mean_accuracy", mode="max")
analysis = tune.run(train_mnist, num_samples=10, search_alg=hyperopt_search)
# __run_searchalg_end__
+3 -6
View File
@@ -495,9 +495,7 @@ class TrialRunner:
if decision == TrialScheduler.STOP:
with warn_if_slow("search_alg.on_trial_complete"):
self._search_alg.on_trial_complete(
trial.trial_id,
result=flat_result,
early_terminated=True)
trial.trial_id, result=flat_result)
if not is_duplicate:
trial.update_last_result(
@@ -711,7 +709,7 @@ class TrialRunner:
Trials may be stopped at any time. If trial is in state PENDING
or PAUSED, calls `on_trial_remove` for scheduler and
`on_trial_complete(..., early_terminated=True) for search_alg.
`on_trial_complete() for search_alg.
Otherwise waits for result for the trial and calls
`on_trial_complete` for scheduler and search_alg if RUNNING.
"""
@@ -722,8 +720,7 @@ class TrialRunner:
return
elif trial.status in [Trial.PENDING, Trial.PAUSED]:
self._scheduler_alg.on_trial_remove(self, trial)
self._search_alg.on_trial_complete(
trial.trial_id, early_terminated=True)
self._search_alg.on_trial_complete(trial.trial_id)
elif trial.status is Trial.RUNNING:
try:
result = self.trial_executor.fetch_result(trial)
+5 -2
View File
@@ -4,6 +4,7 @@ from ray.tune.error import TuneError
from ray.tune.experiment import convert_to_experiment_list, Experiment
from ray.tune.analysis import ExperimentAnalysis
from ray.tune.suggest import BasicVariantGenerator
from ray.tune.suggest.suggestion import Searcher, SearchGenerator
from ray.tune.trial import Trial
from ray.tune.trainable import Trainable
from ray.tune.ray_trial_executor import RayTrialExecutor
@@ -176,8 +177,7 @@ def run(run_or_experiment,
fail_fast (bool): Whether to fail upon the first error.
restore (str): Path to checkpoint. Only makes sense to set if
running 1 trial. Defaults to None.
search_alg (SearchAlgorithm): Search Algorithm. Defaults to
BasicVariantGenerator.
search_alg (Searcher): Search algorithm for optimization.
scheduler (TrialScheduler): Scheduler for executing
the experiment. Choose among FIFO (default), MedianStopping,
AsyncHyperBand, HyperBand and PopulationBasedTraining. Refer to
@@ -275,6 +275,9 @@ def run(run_or_experiment,
if fail_fast and max_failures != 0:
raise ValueError("max_failures must be 0 if fail_fast=True.")
if issubclass(type(search_alg), Searcher):
search_alg = SearchGenerator(search_alg)
runner = TrialRunner(
search_alg=search_alg or BasicVariantGenerator(),
scheduler=scheduler or FIFOScheduler(),