mirror of
https://github.com/wassname/ray.git
synced 2026-06-29 17:21:06 +08:00
[tune] Split Search from Scheduling (#2452)
Introduces SearchAlgorithm concept, separate from schedulers in Tune. Moves HyperOpt under this concept.
This commit is contained in:
@@ -255,6 +255,7 @@ class _MockAgent(Agent):
|
||||
_default_config = {
|
||||
"mock_error": False,
|
||||
"persistent_error": False,
|
||||
"test_variable": 1
|
||||
}
|
||||
|
||||
def _init(self):
|
||||
|
||||
@@ -14,7 +14,7 @@ and compiles them into a number of `Trial` objects. It schedules trials on the R
|
||||
|
||||
This is implemented as follows:
|
||||
|
||||
- `variant_generator.py <https://github.com/ray-project/ray/blob/master/python/ray/tune/variant_generator.py>`__
|
||||
- `variant_generator.py <https://github.com/ray-project/ray/blob/master/python/ray/tune/suggest/variant_generator.py>`__
|
||||
parses the config and generates the trial variants.
|
||||
|
||||
- `trial.py <https://github.com/ray-project/ray/blob/master/python/ray/tune/trial.py>`__ manages the lifecycle
|
||||
|
||||
@@ -36,7 +36,6 @@
|
||||
},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"from ray.tune.variant_generator import generate_trials\n",
|
||||
"import yaml\n",
|
||||
"\n",
|
||||
"with open(\"../rllib/tuned_examples/hyperband-cartpole.yaml\") as f:\n",
|
||||
|
||||
@@ -3,11 +3,12 @@ from __future__ import division
|
||||
from __future__ import print_function
|
||||
|
||||
from ray.tune.error import TuneError
|
||||
from ray.tune.tune import run_experiments, Experiment
|
||||
from ray.tune.tune import run_experiments
|
||||
from ray.tune.experiment import Experiment
|
||||
from ray.tune.registry import register_env, register_trainable
|
||||
from ray.tune.result import TrainingResult
|
||||
from ray.tune.trainable import Trainable
|
||||
from ray.tune.variant_generator import grid_search
|
||||
from ray.tune.suggest import grid_search
|
||||
|
||||
__all__ = [
|
||||
"Trainable", "TrainingResult", "TuneError", "grid_search", "register_env",
|
||||
|
||||
@@ -4,10 +4,12 @@ from __future__ import print_function
|
||||
|
||||
import argparse
|
||||
import json
|
||||
import os
|
||||
|
||||
from ray.tune import TuneError
|
||||
from ray.tune.result import DEFAULT_RESULTS_DIR
|
||||
from ray.tune.trial import Resources
|
||||
from ray.tune.trial import Resources, Trial
|
||||
from ray.tune.logger import _SafeFallbackEncoder
|
||||
|
||||
|
||||
def json_to_resources(data):
|
||||
@@ -40,10 +42,6 @@ def resources_to_json(resources):
|
||||
}
|
||||
|
||||
|
||||
def _tune_error(msg):
|
||||
raise TuneError(msg)
|
||||
|
||||
|
||||
def make_parser(parser_creator=None, **kwargs):
|
||||
"""Returns a base argument parser for the ray.tune tool.
|
||||
|
||||
@@ -137,3 +135,62 @@ def make_parser(parser_creator=None, **kwargs):
|
||||
help="If specified, restore from this checkpoint.")
|
||||
|
||||
return parser
|
||||
|
||||
|
||||
def to_argv(config):
|
||||
"""Converts configuration to a command line argument format."""
|
||||
argv = []
|
||||
for k, v in config.items():
|
||||
if "-" in k:
|
||||
raise ValueError("Use '_' instead of '-' in `{}`".format(k))
|
||||
argv.append("--{}".format(k.replace("_", "-")))
|
||||
if isinstance(v, str):
|
||||
argv.append(v)
|
||||
else:
|
||||
argv.append(json.dumps(v, cls=_SafeFallbackEncoder))
|
||||
return argv
|
||||
|
||||
|
||||
def create_trial_from_spec(spec, output_path, parser, **trial_kwargs):
|
||||
"""Creates a Trial object from parsing the spec.
|
||||
|
||||
Arguments:
|
||||
spec (dict): A resolved experiment specification. Arguments should
|
||||
The args here should correspond to the command line flags
|
||||
in ray.tune.config_parser.
|
||||
output_path (str); A specific output path within the local_dir.
|
||||
Typically the name of the experiment.
|
||||
parser (ArgumentParser): An argument parser object from
|
||||
make_parser.
|
||||
trial_kwargs: Extra keyword arguments used in instantiating the Trial.
|
||||
|
||||
Returns:
|
||||
A trial object with corresponding parameters to the specification.
|
||||
"""
|
||||
try:
|
||||
# Special case the `env` param for RLlib by automatically
|
||||
# moving it into the `config` section.
|
||||
if "env" in spec:
|
||||
spec["config"] = spec.get("config", {})
|
||||
spec["config"]["env"] = spec["env"]
|
||||
del spec["env"]
|
||||
args = parser.parse_args(to_argv(spec))
|
||||
except SystemExit:
|
||||
raise TuneError("Error parsing args, see above message", spec)
|
||||
if "trial_resources" in spec:
|
||||
trial_kwargs["resources"] = json_to_resources(spec["trial_resources"])
|
||||
return Trial(
|
||||
# Submitting trial via server in py2.7 creates Unicode, which does not
|
||||
# convert to string in a straightforward manner.
|
||||
trainable_name=spec["run"],
|
||||
# json.load leads to str -> unicode in py2.7
|
||||
config=spec.get("config", {}),
|
||||
local_dir=os.path.join(args.local_dir, output_path),
|
||||
# json.load leads to str -> unicode in py2.7
|
||||
stopping_criterion=spec.get("stop", {}),
|
||||
checkpoint_freq=args.checkpoint_freq,
|
||||
# str(None) doesn't create None
|
||||
restore_path=spec.get("restore"),
|
||||
upload_dir=args.upload_dir,
|
||||
max_failures=args.max_failures,
|
||||
**trial_kwargs)
|
||||
|
||||
@@ -1,20 +1,27 @@
|
||||
"""This test checks that HyperOpt is functional.
|
||||
|
||||
It also checks that it is usable with a separate scheduler.
|
||||
"""
|
||||
from __future__ import absolute_import
|
||||
from __future__ import division
|
||||
from __future__ import print_function
|
||||
|
||||
import ray
|
||||
from ray.tune import run_experiments, register_trainable
|
||||
from ray.tune.hpo_scheduler import HyperOptScheduler
|
||||
from ray.tune.async_hyperband import AsyncHyperBandScheduler
|
||||
from ray.tune.suggest import HyperOptSearch
|
||||
|
||||
|
||||
def easy_objective(config, reporter):
|
||||
import time
|
||||
time.sleep(0.2)
|
||||
assert type(config["activation"]) == str
|
||||
reporter(
|
||||
timesteps_total=1,
|
||||
mean_loss=((config["height"] - 14)**2 + abs(config["width"] - 3)))
|
||||
time.sleep(0.2)
|
||||
assert type(config["activation"]) == str, \
|
||||
"Config is incorrect: {}".format(type(config["activation"]))
|
||||
for i in range(100):
|
||||
reporter(
|
||||
timesteps_total=i,
|
||||
mean_loss=((config["height"] - 14)**2 + abs(config["width"] - 3)))
|
||||
time.sleep(0.02)
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
@@ -38,15 +45,13 @@ if __name__ == '__main__':
|
||||
config = {
|
||||
"my_exp": {
|
||||
"run": "exp",
|
||||
"repeat": 5 if args.smoke_test else 1000,
|
||||
"repeat": 10 if args.smoke_test else 1000,
|
||||
"stop": {
|
||||
"training_iteration": 1
|
||||
"training_iteration": 100
|
||||
},
|
||||
"config": {
|
||||
"space": space
|
||||
}
|
||||
}
|
||||
}
|
||||
hpo_sched = HyperOptScheduler(reward_attr="neg_mean_loss")
|
||||
|
||||
run_experiments(config, verbose=False, scheduler=hpo_sched)
|
||||
algo = HyperOptSearch(
|
||||
config, space, max_concurrent=4, reward_attr="neg_mean_loss")
|
||||
scheduler = AsyncHyperBandScheduler(reward_attr="neg_mean_loss")
|
||||
run_experiments(search_alg=algo, scheduler=scheduler)
|
||||
|
||||
@@ -79,3 +79,41 @@ class Experiment(object):
|
||||
exp.name = name
|
||||
exp.spec = spec
|
||||
return exp
|
||||
|
||||
|
||||
def convert_to_experiment_list(experiments):
|
||||
"""Produces a list of Experiment objects.
|
||||
|
||||
Converts input from dict, single experiment, or list of
|
||||
experiments to list of experiments. If input is None,
|
||||
will return an empty list.
|
||||
|
||||
Arguments:
|
||||
experiments (Experiment | list | dict): Experiments to run.
|
||||
|
||||
Returns:
|
||||
List of experiments.
|
||||
"""
|
||||
exp_list = experiments
|
||||
|
||||
# Transform list if necessary
|
||||
if experiments is None:
|
||||
exp_list = []
|
||||
elif isinstance(experiments, Experiment):
|
||||
exp_list = [experiments]
|
||||
elif type(experiments) is dict:
|
||||
exp_list = [
|
||||
Experiment.from_json(name, spec)
|
||||
for name, spec in experiments.items()
|
||||
]
|
||||
|
||||
# Validate exp_list
|
||||
if (type(exp_list) is list
|
||||
and all(isinstance(exp, Experiment) for exp in exp_list)):
|
||||
if len(exp_list) > 1:
|
||||
print("Warning: All experiments will be"
|
||||
" using the same Search Algorithm.")
|
||||
else:
|
||||
raise TuneError("Invalid argument: {}".format(experiments))
|
||||
|
||||
return exp_list
|
||||
|
||||
@@ -1,215 +0,0 @@
|
||||
from __future__ import absolute_import
|
||||
from __future__ import division
|
||||
from __future__ import print_function
|
||||
|
||||
import os
|
||||
import copy
|
||||
import numpy as np
|
||||
try:
|
||||
import hyperopt as hpo
|
||||
except Exception as e:
|
||||
hpo = None
|
||||
|
||||
from ray.tune.trial import Trial
|
||||
from ray.tune.error import TuneError
|
||||
from ray.tune.trial_scheduler import TrialScheduler, FIFOScheduler
|
||||
from ray.tune.config_parser import make_parser
|
||||
from ray.tune.variant_generator import to_argv
|
||||
|
||||
|
||||
class HyperOptScheduler(FIFOScheduler):
|
||||
"""FIFOScheduler that uses HyperOpt to provide trial suggestions.
|
||||
|
||||
Requires HyperOpt to be installed via source.
|
||||
Uses the Tree-structured Parzen Estimators algorithm. Externally added
|
||||
trials will not be tracked by HyperOpt. Also,
|
||||
variant generation will be limited, as the hyperparameter configuration
|
||||
must be specified using HyperOpt primitives.
|
||||
|
||||
Parameters:
|
||||
max_concurrent (int | None): Number of maximum concurrent trials.
|
||||
If None, then trials will be queued only if resources
|
||||
are available.
|
||||
reward_attr (str): The TrainingResult objective value attribute.
|
||||
This refers to an increasing value, which is internally negated
|
||||
when interacting with HyperOpt so that HyperOpt can "maximize"
|
||||
this value.
|
||||
|
||||
Examples:
|
||||
>>> space = {'param': hp.uniform('param', 0, 20)}
|
||||
>>> config = {"my_exp": {
|
||||
"run": "exp",
|
||||
"repeat": 5,
|
||||
"config": {"space": space}}}
|
||||
>>> run_experiments(config, scheduler=HyperOptScheduler())
|
||||
"""
|
||||
|
||||
def __init__(self, max_concurrent=None, reward_attr="episode_reward_mean"):
|
||||
assert hpo is not None, "HyperOpt must be installed!"
|
||||
assert type(max_concurrent) in [type(None), int]
|
||||
if type(max_concurrent) is int:
|
||||
assert max_concurrent > 0
|
||||
self._max_concurrent = max_concurrent # NOTE: this is modified later
|
||||
self._reward_attr = reward_attr
|
||||
self._experiment = None
|
||||
|
||||
def add_experiment(self, experiment, trial_runner):
|
||||
"""Tracks one experiment.
|
||||
|
||||
Will error if one tries to track multiple experiments.
|
||||
"""
|
||||
assert self._experiment is None, "HyperOpt only tracks one experiment!"
|
||||
self._experiment = experiment
|
||||
|
||||
self._output_path = experiment.name
|
||||
spec = copy.deepcopy(experiment.spec)
|
||||
|
||||
# Set Scheduler field, as Tune Parser will default to FIFO
|
||||
assert spec.get("scheduler") in [None, "HyperOpt"], "Incorrectly " \
|
||||
"specified scheduler!"
|
||||
spec["scheduler"] = "HyperOpt"
|
||||
|
||||
if "env" in spec:
|
||||
spec["config"] = spec.get("config", {})
|
||||
spec["config"]["env"] = spec["env"]
|
||||
del spec["env"]
|
||||
|
||||
space = spec["config"]["space"]
|
||||
del spec["config"]["space"]
|
||||
|
||||
self.parser = make_parser()
|
||||
self.args = self.parser.parse_args(to_argv(spec))
|
||||
self.args.scheduler = "HyperOpt"
|
||||
self.default_config = copy.deepcopy(spec["config"])
|
||||
|
||||
self.algo = hpo.tpe.suggest
|
||||
self.domain = hpo.Domain(lambda spc: spc, space)
|
||||
self._hpopt_trials = hpo.Trials()
|
||||
self._tune_to_hp = {}
|
||||
self._num_trials_left = self.args.repeat
|
||||
|
||||
if type(self._max_concurrent) is int:
|
||||
self._max_concurrent = min(self._max_concurrent, self.args.repeat)
|
||||
|
||||
self.rstate = np.random.RandomState()
|
||||
self.trial_generator = self._trial_generator()
|
||||
self._add_new_trials_if_needed(trial_runner)
|
||||
|
||||
def _trial_generator(self):
|
||||
while self._num_trials_left > 0:
|
||||
new_cfg = copy.deepcopy(self.default_config)
|
||||
new_ids = self._hpopt_trials.new_trial_ids(1)
|
||||
self._hpopt_trials.refresh()
|
||||
|
||||
# Get new suggestion from
|
||||
new_trials = self.algo(new_ids, self.domain, self._hpopt_trials,
|
||||
self.rstate.randint(2**31 - 1))
|
||||
self._hpopt_trials.insert_trial_docs(new_trials)
|
||||
self._hpopt_trials.refresh()
|
||||
new_trial = new_trials[0]
|
||||
new_trial_id = new_trial["tid"]
|
||||
|
||||
# Taken from HyperOpt.base.evaluate
|
||||
config = hpo.base.spec_from_misc(new_trial["misc"])
|
||||
ctrl = hpo.base.Ctrl(self._hpopt_trials, current_trial=new_trial)
|
||||
memo = self.domain.memo_from_config(config)
|
||||
hpo.utils.use_obj_for_literal_in_memo(self.domain.expr, ctrl,
|
||||
hpo.base.Ctrl, memo)
|
||||
|
||||
suggested_config = hpo.pyll.rec_eval(
|
||||
self.domain.expr,
|
||||
memo=memo,
|
||||
print_node_on_error=self.domain.rec_eval_print_node_on_error)
|
||||
|
||||
new_cfg.update(suggested_config)
|
||||
|
||||
kv_str = "_".join([
|
||||
"{}={}".format(k,
|
||||
str(v)[:5])
|
||||
for k, v in sorted(suggested_config.items())
|
||||
])
|
||||
experiment_tag = "{}_{}".format(new_trial_id, kv_str)
|
||||
|
||||
# Keep this consistent with tune.variant_generator
|
||||
trial = Trial(
|
||||
trainable_name=self.args.run,
|
||||
config=new_cfg,
|
||||
local_dir=os.path.join(self.args.local_dir, self._output_path),
|
||||
experiment_tag=experiment_tag,
|
||||
resources=self.args.trial_resources,
|
||||
stopping_criterion=self.args.stop,
|
||||
checkpoint_freq=self.args.checkpoint_freq,
|
||||
restore_path=self.args.restore,
|
||||
upload_dir=self.args.upload_dir,
|
||||
max_failures=self.args.max_failures)
|
||||
|
||||
self._tune_to_hp[trial] = new_trial_id
|
||||
self._num_trials_left -= 1
|
||||
yield trial
|
||||
|
||||
def on_trial_result(self, trial_runner, trial, result):
|
||||
ho_trial = self._get_hyperopt_trial(self._tune_to_hp[trial])
|
||||
now = hpo.utils.coarse_utcnow()
|
||||
ho_trial['book_time'] = now
|
||||
ho_trial['refresh_time'] = now
|
||||
return TrialScheduler.CONTINUE
|
||||
|
||||
def on_trial_error(self, trial_runner, trial):
|
||||
ho_trial = self._get_hyperopt_trial(self._tune_to_hp[trial])
|
||||
ho_trial['refresh_time'] = hpo.utils.coarse_utcnow()
|
||||
ho_trial['state'] = hpo.base.JOB_STATE_ERROR
|
||||
ho_trial['misc']['error'] = (str(TuneError), "Tune Error")
|
||||
self._hpopt_trials.refresh()
|
||||
del self._tune_to_hp[trial]
|
||||
|
||||
def on_trial_remove(self, trial_runner, trial):
|
||||
ho_trial = self._get_hyperopt_trial(self._tune_to_hp[trial])
|
||||
ho_trial['refresh_time'] = hpo.utils.coarse_utcnow()
|
||||
ho_trial['state'] = hpo.base.JOB_STATE_ERROR
|
||||
ho_trial['misc']['error'] = (str(TuneError), "Tune Removed")
|
||||
self._hpopt_trials.refresh()
|
||||
del self._tune_to_hp[trial]
|
||||
|
||||
def on_trial_complete(self, trial_runner, trial, result):
|
||||
ho_trial = self._get_hyperopt_trial(self._tune_to_hp[trial])
|
||||
ho_trial['refresh_time'] = hpo.utils.coarse_utcnow()
|
||||
ho_trial['state'] = hpo.base.JOB_STATE_DONE
|
||||
hp_result = self._to_hyperopt_result(result)
|
||||
ho_trial['result'] = hp_result
|
||||
self._hpopt_trials.refresh()
|
||||
del self._tune_to_hp[trial]
|
||||
|
||||
def _to_hyperopt_result(self, result):
|
||||
return {"loss": -getattr(result, self._reward_attr), "status": "ok"}
|
||||
|
||||
def _get_hyperopt_trial(self, tid):
|
||||
return [t for t in self._hpopt_trials.trials if t["tid"] == tid][0]
|
||||
|
||||
def choose_trial_to_run(self, trial_runner):
|
||||
self._add_new_trials_if_needed(trial_runner)
|
||||
return FIFOScheduler.choose_trial_to_run(self, trial_runner)
|
||||
|
||||
def _add_new_trials_if_needed(self, trial_runner):
|
||||
"""Checks if there is a next trial ready to be queued.
|
||||
|
||||
This is determined by tracking the number of concurrent
|
||||
experiments and trials left to run. If self._max_concurrent is None,
|
||||
scheduler will add new trial if there is none that are pending.
|
||||
"""
|
||||
pending = [
|
||||
t for t in trial_runner.get_trials() if t.status == Trial.PENDING
|
||||
]
|
||||
if self._num_trials_left <= 0:
|
||||
return
|
||||
if self._max_concurrent is None:
|
||||
if not pending:
|
||||
trial_runner.add_trial(next(self.trial_generator))
|
||||
else:
|
||||
while self._num_live_trials() < self._max_concurrent:
|
||||
try:
|
||||
trial_runner.add_trial(next(self.trial_generator))
|
||||
except StopIteration:
|
||||
break
|
||||
|
||||
def _num_live_trials(self):
|
||||
return len(self._tune_to_hp)
|
||||
@@ -9,7 +9,7 @@ import copy
|
||||
from ray.tune.error import TuneError
|
||||
from ray.tune.trial import Trial
|
||||
from ray.tune.trial_scheduler import FIFOScheduler, TrialScheduler
|
||||
from ray.tune.variant_generator import _format_vars
|
||||
from ray.tune.suggest.variant_generator import format_vars
|
||||
|
||||
# Parameters are transferred from the top PBT_QUANTILE fraction of trials to
|
||||
# the bottom PBT_QUANTILE fraction.
|
||||
@@ -80,7 +80,7 @@ def make_experiment_tag(orig_tag, config, mutations):
|
||||
resolved_vars = {}
|
||||
for k in mutations.keys():
|
||||
resolved_vars[("config", k)] = config[k]
|
||||
return "{}@perturbed[{}]".format(orig_tag, _format_vars(resolved_vars))
|
||||
return "{}@perturbed[{}]".format(orig_tag, format_vars(resolved_vars))
|
||||
|
||||
|
||||
class PopulationBasedTraining(FIFOScheduler):
|
||||
|
||||
@@ -0,0 +1,10 @@
|
||||
from ray.tune.suggest.search import SearchAlgorithm
|
||||
from ray.tune.suggest.basic_variant import BasicVariantGenerator
|
||||
from ray.tune.suggest.suggestion import SuggestionAlgorithm
|
||||
from ray.tune.suggest.hyperopt import HyperOptSearch
|
||||
from ray.tune.suggest.variant_generator import grid_search
|
||||
|
||||
__all__ = [
|
||||
"SearchAlgorithm", "BasicVariantGenerator", "HyperOptSearch",
|
||||
"SuggestionAlgorithm", "grid_search"
|
||||
]
|
||||
@@ -0,0 +1,77 @@
|
||||
from __future__ import absolute_import
|
||||
from __future__ import division
|
||||
from __future__ import print_function
|
||||
|
||||
from itertools import chain
|
||||
|
||||
from ray.tune.error import TuneError
|
||||
from ray.tune.experiment import convert_to_experiment_list
|
||||
from ray.tune.config_parser import make_parser, create_trial_from_spec
|
||||
from ray.tune.suggest.variant_generator import generate_variants
|
||||
from ray.tune.suggest.search import SearchAlgorithm
|
||||
|
||||
|
||||
class BasicVariantGenerator(SearchAlgorithm):
|
||||
"""Uses Tune's variant generation for resolving variables.
|
||||
|
||||
See also: `ray.tune.suggest.variant_generator`.
|
||||
|
||||
Example:
|
||||
>>> searcher = BasicVariantGenerator({"experiment": { ... }})
|
||||
>>> list_of_trials = searcher.next_trials()
|
||||
>>> searcher.is_finished == True
|
||||
"""
|
||||
|
||||
def __init__(self, experiments=None):
|
||||
"""Constructs a generator given experiment specifications.
|
||||
|
||||
Arguments:
|
||||
experiments (Experiment | list | dict): Experiments to run.
|
||||
"""
|
||||
experiment_list = convert_to_experiment_list(experiments)
|
||||
self._parser = make_parser()
|
||||
self._trial_generator = chain.from_iterable([
|
||||
self._generate_trials(experiment.spec, experiment.name)
|
||||
for experiment in experiment_list
|
||||
])
|
||||
self._counter = 0
|
||||
self._finished = False
|
||||
|
||||
def next_trials(self):
|
||||
"""Provides Trial objects to be queued into the TrialRunner.
|
||||
|
||||
Returns:
|
||||
trials (list): Returns a list of trials.
|
||||
"""
|
||||
trials = list(self._trial_generator)
|
||||
self._finished = True
|
||||
return trials
|
||||
|
||||
def _generate_trials(self, unresolved_spec, output_path=""):
|
||||
"""Generates Trial objects with the variant generation process.
|
||||
|
||||
Uses a fixed point iteration to resolve variants. All trials
|
||||
should be able to be generated at once.
|
||||
|
||||
See also: `ray.tune.suggest.variant_generator`.
|
||||
|
||||
Yields:
|
||||
Trial object
|
||||
"""
|
||||
|
||||
if "run" not in unresolved_spec:
|
||||
raise TuneError("Must specify `run` in {}".format(unresolved_spec))
|
||||
for _ in range(unresolved_spec.get("repeat", 1)):
|
||||
for resolved_vars, spec in generate_variants(unresolved_spec):
|
||||
experiment_tag = str(self._counter)
|
||||
if resolved_vars:
|
||||
experiment_tag += "_{}".format(resolved_vars)
|
||||
self._counter += 1
|
||||
yield create_trial_from_spec(
|
||||
spec,
|
||||
output_path,
|
||||
self._parser,
|
||||
experiment_tag=experiment_tag)
|
||||
|
||||
def is_finished(self):
|
||||
return self._finished
|
||||
@@ -0,0 +1,125 @@
|
||||
from __future__ import absolute_import
|
||||
from __future__ import division
|
||||
from __future__ import print_function
|
||||
|
||||
import numpy as np
|
||||
import copy
|
||||
try:
|
||||
import hyperopt as hpo
|
||||
except Exception as e:
|
||||
hpo = None
|
||||
|
||||
from ray.tune.error import TuneError
|
||||
from ray.tune.suggest.suggestion import SuggestionAlgorithm
|
||||
|
||||
|
||||
class HyperOptSearch(SuggestionAlgorithm):
|
||||
"""A wrapper around HyperOpt to provide trial suggestions.
|
||||
|
||||
Requires HyperOpt to be installed from source.
|
||||
Uses the Tree-structured Parzen Estimators algorithm, although can be
|
||||
trivially extended to support any algorithm HyperOpt uses. Externally
|
||||
added trials will not be tracked by HyperOpt.
|
||||
|
||||
Parameters:
|
||||
experiments (Experiment | list | dict): Experiments to run. Will be
|
||||
used by SuggestionAlgorithm parent class to initialize Trials.
|
||||
space (dict): HyperOpt configuration. Parameters will be sampled
|
||||
from this configuration and will be used to override
|
||||
parameters generated in the variant generation process.
|
||||
max_concurrent (int): Number of maximum concurrent trials. Defaults
|
||||
to 10.
|
||||
reward_attr (str): The TrainingResult objective value attribute.
|
||||
This refers to an increasing value, which is internally negated
|
||||
when interacting with HyperOpt so that HyperOpt can "maximize"
|
||||
this value.
|
||||
"""
|
||||
|
||||
def __init__(self,
|
||||
experiments,
|
||||
space,
|
||||
max_concurrent=10,
|
||||
reward_attr="episode_reward_mean",
|
||||
**kwargs):
|
||||
assert hpo is not None, "HyperOpt must be installed!"
|
||||
assert type(max_concurrent) is int and max_concurrent > 0
|
||||
self._max_concurrent = max_concurrent
|
||||
self._reward_attr = reward_attr
|
||||
self.algo = hpo.tpe.suggest
|
||||
self.domain = hpo.Domain(lambda spc: spc, space)
|
||||
self._hpopt_trials = hpo.Trials()
|
||||
self._live_trial_mapping = {}
|
||||
self.rstate = np.random.RandomState()
|
||||
|
||||
super(HyperOptSearch, self).__init__(experiments=experiments, **kwargs)
|
||||
|
||||
def _suggest(self, trial_id):
|
||||
if self._num_live_trials() >= self._max_concurrent:
|
||||
return None
|
||||
new_ids = self._hpopt_trials.new_trial_ids(1)
|
||||
self._hpopt_trials.refresh()
|
||||
|
||||
# Get new suggestion from Hyperopt
|
||||
new_trials = self.algo(new_ids, self.domain, self._hpopt_trials,
|
||||
self.rstate.randint(2**31 - 1))
|
||||
self._hpopt_trials.insert_trial_docs(new_trials)
|
||||
self._hpopt_trials.refresh()
|
||||
new_trial = new_trials[0]
|
||||
self._live_trial_mapping[trial_id] = (new_trial["tid"], new_trial)
|
||||
|
||||
# Taken from HyperOpt.base.evaluate
|
||||
config = hpo.base.spec_from_misc(new_trial["misc"])
|
||||
ctrl = hpo.base.Ctrl(self._hpopt_trials, current_trial=new_trial)
|
||||
memo = self.domain.memo_from_config(config)
|
||||
hpo.utils.use_obj_for_literal_in_memo(self.domain.expr, ctrl,
|
||||
hpo.base.Ctrl, memo)
|
||||
|
||||
suggested_config = hpo.pyll.rec_eval(
|
||||
self.domain.expr,
|
||||
memo=memo,
|
||||
print_node_on_error=self.domain.rec_eval_print_node_on_error)
|
||||
return copy.deepcopy(suggested_config)
|
||||
|
||||
def on_trial_result(self, trial_id, result):
|
||||
ho_trial = self._get_hyperopt_trial(trial_id)
|
||||
if ho_trial is None:
|
||||
return
|
||||
now = hpo.utils.coarse_utcnow()
|
||||
ho_trial['book_time'] = now
|
||||
ho_trial['refresh_time'] = now
|
||||
|
||||
def on_trial_complete(self,
|
||||
trial_id,
|
||||
result=None,
|
||||
error=False,
|
||||
early_terminated=False):
|
||||
ho_trial = self._get_hyperopt_trial(trial_id)
|
||||
if ho_trial is None:
|
||||
return
|
||||
ho_trial['refresh_time'] = hpo.utils.coarse_utcnow()
|
||||
if error:
|
||||
ho_trial['state'] = hpo.base.JOB_STATE_ERROR
|
||||
ho_trial['misc']['error'] = (str(TuneError), "Tune Error")
|
||||
elif early_terminated:
|
||||
ho_trial['state'] = hpo.base.JOB_STATE_ERROR
|
||||
ho_trial['misc']['error'] = (str(TuneError), "Tune Removed")
|
||||
else:
|
||||
ho_trial['state'] = hpo.base.JOB_STATE_DONE
|
||||
hp_result = self._to_hyperopt_result(result)
|
||||
ho_trial['result'] = hp_result
|
||||
self._hpopt_trials.refresh()
|
||||
del self._live_trial_mapping[trial_id]
|
||||
|
||||
def _to_hyperopt_result(self, result):
|
||||
return {"loss": -getattr(result, self._reward_attr), "status": "ok"}
|
||||
|
||||
def _get_hyperopt_trial(self, trial_id):
|
||||
if trial_id not in self._live_trial_mapping:
|
||||
return
|
||||
hyperopt_tid = self._live_trial_mapping[trial_id][0]
|
||||
return [
|
||||
t for t in self._hpopt_trials.trials if t["tid"] == hyperopt_tid
|
||||
][0]
|
||||
|
||||
def _num_live_trials(self):
|
||||
return len(self._live_trial_mapping)
|
||||
@@ -0,0 +1,62 @@
|
||||
from __future__ import absolute_import
|
||||
from __future__ import division
|
||||
from __future__ import print_function
|
||||
|
||||
|
||||
class SearchAlgorithm(object):
|
||||
"""Interface of an event handler API for hyperparameter search.
|
||||
|
||||
Unlike TrialSchedulers, SearchAlgorithms will not have the ability
|
||||
to modify the execution (i.e., stop and pause trials).
|
||||
|
||||
Trials added manually (i.e., via the Client API) will also notify
|
||||
this class upon new events, so custom search algorithms should
|
||||
maintain a list of trials ID generated from this class.
|
||||
|
||||
See also: `ray.tune.suggest.BasicVariantGenerator`.
|
||||
"""
|
||||
|
||||
def next_trials(self):
|
||||
"""Provides Trial objects to be queued into the TrialRunner.
|
||||
|
||||
Returns:
|
||||
trials (list): Returns a list of trials.
|
||||
"""
|
||||
raise NotImplementedError
|
||||
|
||||
def on_trial_result(self, trial_id, result):
|
||||
"""Called on each intermediate result returned by a trial.
|
||||
|
||||
This will only be called when the trial is in the RUNNING state.
|
||||
|
||||
Arguments:
|
||||
trial_id: Identifier for the trial.
|
||||
"""
|
||||
pass
|
||||
|
||||
def on_trial_complete(self,
|
||||
trial_id,
|
||||
result=None,
|
||||
error=False,
|
||||
early_terminated=False):
|
||||
"""Notification for the completion of trial.
|
||||
|
||||
Arguments:
|
||||
trial_id: Identifier for the trial.
|
||||
result (TrainingResult): Defaults to None. A TrainingResult will
|
||||
be provided with this notification when the trial is in
|
||||
the RUNNING state AND either completes naturally or
|
||||
by manual termination.
|
||||
error (bool): Defaults to False. True if the trial is in
|
||||
the RUNNING state and errors.
|
||||
early_terminated (bool): Defaults to False. True if the trial
|
||||
is stopped while in PAUSED or PENDING state.
|
||||
"""
|
||||
pass
|
||||
|
||||
def is_finished(self):
|
||||
"""Returns True if no trials left to be queued into TrialRunner.
|
||||
|
||||
Can return True before all trials have finished executing.
|
||||
"""
|
||||
raise NotImplementedError
|
||||
@@ -0,0 +1,137 @@
|
||||
from __future__ import absolute_import
|
||||
from __future__ import division
|
||||
from __future__ import print_function
|
||||
|
||||
from itertools import chain
|
||||
import copy
|
||||
|
||||
from ray.tune.error import TuneError
|
||||
from ray.tune.trial import Trial
|
||||
from ray.tune.experiment import convert_to_experiment_list
|
||||
from ray.tune.config_parser import make_parser, create_trial_from_spec
|
||||
from ray.tune.suggest.search import SearchAlgorithm
|
||||
from ray.tune.suggest.variant_generator import format_vars
|
||||
|
||||
|
||||
class SuggestionAlgorithm(SearchAlgorithm):
|
||||
"""Abstract class for suggestion-based algorithms.
|
||||
|
||||
Custom search algorithms can extend this class easily by overriding the
|
||||
`_suggest` method provide generated parameters for the trials.
|
||||
|
||||
To track suggestions and their corresponding evaluations, the method
|
||||
`_suggest` will be passed a trial_id, which will be used in
|
||||
subsequent notifications.
|
||||
|
||||
Example:
|
||||
>>> suggester = SuggestionAlgorithm({ ... })
|
||||
>>> new_parameters = suggester._suggest()
|
||||
>>> suggester.on_trial_complete(trial_id, result)
|
||||
>>> better_parameters = suggester._suggest()
|
||||
"""
|
||||
|
||||
def __init__(self, experiments=None):
|
||||
"""Constructs a generator given experiment specifications.
|
||||
|
||||
Arguments:
|
||||
experiments (Experiment | list | dict): Experiments to run.
|
||||
"""
|
||||
experiment_list = convert_to_experiment_list(experiments)
|
||||
self._parser = make_parser()
|
||||
self._trial_generator = chain.from_iterable([
|
||||
self._generate_trials(experiment.spec, experiment.name)
|
||||
for experiment in experiment_list
|
||||
])
|
||||
self._finished = False
|
||||
|
||||
def next_trials(self):
|
||||
"""Provides a batch of Trial objects to be queued into the TrialRunner.
|
||||
|
||||
A batch ends when self._trial_generator returns None.
|
||||
|
||||
Returns:
|
||||
trials (list): Returns a list of trials.
|
||||
"""
|
||||
trials = []
|
||||
|
||||
for trial in self._trial_generator:
|
||||
if trial is None:
|
||||
return trials
|
||||
trials += [trial]
|
||||
|
||||
self._finished = True
|
||||
return trials
|
||||
|
||||
def _generate_trials(self, experiment_spec, output_path=""):
|
||||
"""Generates trials with configurations from `_suggest`.
|
||||
|
||||
Creates a trial_id that is passed into `_suggest`.
|
||||
|
||||
Yields:
|
||||
Trial objects constructed according to `spec`
|
||||
"""
|
||||
if "run" not in experiment_spec:
|
||||
raise TuneError("Must specify `run` in {}".format(experiment_spec))
|
||||
for _ in range(experiment_spec.get("repeat", 1)):
|
||||
trial_id = Trial.generate_id()
|
||||
while True:
|
||||
suggested_config = self._suggest(trial_id)
|
||||
if suggested_config is None:
|
||||
yield None
|
||||
else:
|
||||
break
|
||||
spec = copy.deepcopy(experiment_spec)
|
||||
spec["config"] = suggested_config
|
||||
yield create_trial_from_spec(
|
||||
spec,
|
||||
output_path,
|
||||
self._parser,
|
||||
experiment_tag=format_vars(spec["config"]),
|
||||
trial_id=trial_id)
|
||||
|
||||
def is_finished(self):
|
||||
return self._finished
|
||||
|
||||
def _suggest(self, trial_id):
|
||||
"""Queries the algorithm to retrieve the next set of parameters.
|
||||
|
||||
Arguments:
|
||||
trial_id: Trial ID used for subsequent notifications.
|
||||
|
||||
Returns:
|
||||
dict|None: Configuration for a trial, if possible.
|
||||
Else, returns None, which will temporarily stop the
|
||||
TrialRunner from querying.
|
||||
|
||||
Example:
|
||||
>>> suggester = SuggestionAlgorithm({ ... }, max_concurrent=1)
|
||||
>>> parameters_1 = suggester._suggest()
|
||||
>>> parameters_2 = suggester._suggest()
|
||||
>>> parameters_2 is None
|
||||
>>> suggester.on_trial_complete(trial_id, result)
|
||||
>>> parameters_2 = suggester._suggest()
|
||||
>>> parameters_2 is not None
|
||||
"""
|
||||
raise NotImplementedError
|
||||
|
||||
|
||||
class _MockSuggestionAlgorithm(SuggestionAlgorithm):
|
||||
def __init__(self, experiments, max_concurrent=2, **kwargs):
|
||||
self._max_concurrent = max_concurrent
|
||||
self.live_trials = {}
|
||||
self.counter = {"result": 0, "complete": 0}
|
||||
self.stall = False
|
||||
super(_MockSuggestionAlgorithm, self).__init__(experiments, **kwargs)
|
||||
|
||||
def _suggest(self, trial_id):
|
||||
if len(self.live_trials) < self._max_concurrent and not self.stall:
|
||||
self.live_trials[trial_id] = 1
|
||||
return {"test_variable": 2}
|
||||
return None
|
||||
|
||||
def on_trial_result(self, trial_id, result):
|
||||
self.counter["result"] += 1
|
||||
|
||||
def on_trial_complete(self, trial_id, **kwargs):
|
||||
self.counter["complete"] += 1
|
||||
del self.live_trials[trial_id]
|
||||
+2
-69
@@ -3,78 +3,11 @@ from __future__ import division
|
||||
from __future__ import print_function
|
||||
|
||||
import copy
|
||||
import json
|
||||
import numpy
|
||||
import os
|
||||
import random
|
||||
import types
|
||||
|
||||
from ray.tune import TuneError
|
||||
from ray.tune.logger import _SafeFallbackEncoder
|
||||
from ray.tune.trial import Trial
|
||||
from ray.tune.config_parser import make_parser, json_to_resources
|
||||
|
||||
|
||||
def to_argv(config):
|
||||
argv = []
|
||||
for k, v in config.items():
|
||||
if "-" in k:
|
||||
raise ValueError("Use '_' instead of '-' in `{}`".format(k))
|
||||
argv.append("--{}".format(k.replace("_", "-")))
|
||||
if isinstance(v, str):
|
||||
argv.append(v)
|
||||
else:
|
||||
argv.append(json.dumps(v, cls=_SafeFallbackEncoder))
|
||||
return argv
|
||||
|
||||
|
||||
def generate_trials(unresolved_spec, output_path=''):
|
||||
"""Wraps `generate_variants()` to return a Trial object for each variant.
|
||||
|
||||
See also: generate_variants()
|
||||
|
||||
Arguments:
|
||||
unresolved_spec (dict): Experiment spec conforming to the argument
|
||||
schema defined in `ray.tune.config_parser`.
|
||||
output_path (str): Path where to store experiment outputs.
|
||||
"""
|
||||
|
||||
if "run" not in unresolved_spec:
|
||||
raise TuneError("Must specify `run` in {}".format(unresolved_spec))
|
||||
parser = make_parser()
|
||||
i = 0
|
||||
for _ in range(unresolved_spec.get("repeat", 1)):
|
||||
for resolved_vars, spec in generate_variants(unresolved_spec):
|
||||
try:
|
||||
# Special case the `env` param for RLlib by automatically
|
||||
# moving it into the `config` section.
|
||||
if "env" in spec:
|
||||
spec["config"] = spec.get("config", {})
|
||||
spec["config"]["env"] = spec["env"]
|
||||
del spec["env"]
|
||||
args = parser.parse_args(to_argv(spec))
|
||||
except SystemExit:
|
||||
raise TuneError("Error parsing args, see above message", spec)
|
||||
if resolved_vars:
|
||||
experiment_tag = "{}_{}".format(i, resolved_vars)
|
||||
else:
|
||||
experiment_tag = str(i)
|
||||
i += 1
|
||||
if "trial_resources" in spec:
|
||||
resources = json_to_resources(spec["trial_resources"])
|
||||
else:
|
||||
resources = None
|
||||
yield Trial(
|
||||
trainable_name=spec["run"],
|
||||
config=spec.get("config", {}),
|
||||
local_dir=os.path.join(args.local_dir, output_path),
|
||||
experiment_tag=experiment_tag,
|
||||
resources=resources,
|
||||
stopping_criterion=spec.get("stop", {}),
|
||||
checkpoint_freq=args.checkpoint_freq,
|
||||
restore_path=spec.get("restore"),
|
||||
upload_dir=args.upload_dir,
|
||||
max_failures=args.max_failures)
|
||||
|
||||
|
||||
def generate_variants(unresolved_spec):
|
||||
@@ -109,7 +42,7 @@ def generate_variants(unresolved_spec):
|
||||
"""
|
||||
for resolved_vars, spec in _generate_variants(unresolved_spec):
|
||||
assert not _unresolved_values(spec)
|
||||
yield _format_vars(resolved_vars), spec
|
||||
yield format_vars(resolved_vars), spec
|
||||
|
||||
|
||||
def grid_search(values):
|
||||
@@ -126,7 +59,7 @@ _STANDARD_IMPORTS = {
|
||||
_MAX_RESOLUTION_PASSES = 20
|
||||
|
||||
|
||||
def _format_vars(resolved_vars):
|
||||
def format_vars(resolved_vars):
|
||||
out = []
|
||||
for path, value in sorted(resolved_vars.items()):
|
||||
if path[0] in ["run", "env", "trial_resources"]:
|
||||
@@ -0,0 +1,61 @@
|
||||
from __future__ import absolute_import
|
||||
from __future__ import division
|
||||
from __future__ import print_function
|
||||
|
||||
import unittest
|
||||
|
||||
from ray.tune.experiment import Experiment, convert_to_experiment_list
|
||||
from ray.tune.error import TuneError
|
||||
|
||||
|
||||
class ExperimentTest(unittest.TestCase):
|
||||
def testConvertExperimentFromExperiment(self):
|
||||
exp1 = Experiment(**{
|
||||
"name": "foo",
|
||||
"run": "f1",
|
||||
"config": {
|
||||
"script_min_iter_time_s": 0
|
||||
}
|
||||
})
|
||||
result = convert_to_experiment_list(exp1)
|
||||
self.assertEqual(len(result), 1)
|
||||
self.assertEqual(type(result), list)
|
||||
|
||||
def testConvertExperimentNone(self):
|
||||
result = convert_to_experiment_list(None)
|
||||
self.assertEqual(len(result), 0)
|
||||
self.assertEqual(type(result), list)
|
||||
|
||||
def testConvertExperimentList(self):
|
||||
exp1 = Experiment(**{
|
||||
"name": "foo",
|
||||
"run": "f1",
|
||||
"config": {
|
||||
"script_min_iter_time_s": 0
|
||||
}
|
||||
})
|
||||
result = convert_to_experiment_list([exp1, exp1])
|
||||
self.assertEqual(len(result), 2)
|
||||
self.assertEqual(type(result), list)
|
||||
|
||||
def testConvertExperimentJSON(self):
|
||||
experiment = {
|
||||
"name": {
|
||||
"run": "f1",
|
||||
"config": {
|
||||
"script_min_iter_time_s": 0
|
||||
}
|
||||
},
|
||||
"named": {
|
||||
"run": "f1",
|
||||
"config": {
|
||||
"script_min_iter_time_s": 0
|
||||
}
|
||||
}
|
||||
}
|
||||
result = convert_to_experiment_list(experiment)
|
||||
self.assertEqual(len(result), 2)
|
||||
self.assertEqual(type(result), list)
|
||||
|
||||
def testConvertExperimentIncorrect(self):
|
||||
self.assertRaises(TuneError, lambda: convert_to_experiment_list("hi"))
|
||||
@@ -11,14 +11,16 @@ from ray.rllib import _register_all
|
||||
|
||||
from ray.tune import Trainable, TuneError
|
||||
from ray.tune import register_env, register_trainable, run_experiments
|
||||
from ray.tune.trial_scheduler import TrialScheduler, FIFOScheduler
|
||||
from ray.tune.registry import _global_registry, TRAINABLE_CLASS
|
||||
from ray.tune.result import DEFAULT_RESULTS_DIR, TrainingResult
|
||||
from ray.tune.util import pin_in_object_store, get_pinned_object
|
||||
from ray.tune.experiment import Experiment
|
||||
from ray.tune.trial import Trial, Resources
|
||||
from ray.tune.trial_runner import TrialRunner
|
||||
from ray.tune.variant_generator import generate_trials, grid_search, \
|
||||
RecursiveDependencyError
|
||||
from ray.tune.suggest import grid_search, BasicVariantGenerator
|
||||
from ray.tune.suggest.suggestion import _MockSuggestionAlgorithm
|
||||
from ray.tune.suggest.variant_generator import RecursiveDependencyError
|
||||
|
||||
|
||||
class TrainableFunctionApiTest(unittest.TestCase):
|
||||
@@ -435,6 +437,28 @@ class RunExperimentTest(unittest.TestCase):
|
||||
self.assertEqual(trial.status, Trial.TERMINATED)
|
||||
self.assertEqual(trial.last_result.timesteps_total, 99)
|
||||
|
||||
def testSpecifyAlgorithm(self):
|
||||
"""Tests run_experiments works without specifying experiment."""
|
||||
|
||||
def train(config, reporter):
|
||||
for i in range(100):
|
||||
reporter(timesteps_total=i)
|
||||
|
||||
register_trainable("f1", train)
|
||||
|
||||
alg = BasicVariantGenerator({
|
||||
"foo": {
|
||||
"run": "f1",
|
||||
"config": {
|
||||
"script_min_iter_time_s": 0
|
||||
}
|
||||
}
|
||||
})
|
||||
trials = run_experiments(search_alg=alg)
|
||||
for trial in trials:
|
||||
self.assertEqual(trial.status, Trial.TERMINATED)
|
||||
self.assertEqual(trial.last_result.timesteps_total, 99)
|
||||
|
||||
|
||||
class VariantGeneratorTest(unittest.TestCase):
|
||||
def setUp(self):
|
||||
@@ -444,8 +468,12 @@ class VariantGeneratorTest(unittest.TestCase):
|
||||
ray.shutdown()
|
||||
_register_all() # re-register the evicted objects
|
||||
|
||||
def generate_trials(self, spec, name):
|
||||
suggester = BasicVariantGenerator({name: spec})
|
||||
return suggester.next_trials()
|
||||
|
||||
def testParseToTrials(self):
|
||||
trials = generate_trials({
|
||||
trials = self.generate_trials({
|
||||
"run": "PPO",
|
||||
"repeat": 2,
|
||||
"max_failures": 5,
|
||||
@@ -466,21 +494,21 @@ class VariantGeneratorTest(unittest.TestCase):
|
||||
self.assertEqual(trials[1].experiment_tag, "1")
|
||||
|
||||
def testEval(self):
|
||||
trials = generate_trials({
|
||||
trials = self.generate_trials({
|
||||
"run": "PPO",
|
||||
"config": {
|
||||
"foo": {
|
||||
"eval": "2 + 2"
|
||||
},
|
||||
},
|
||||
})
|
||||
}, "eval")
|
||||
trials = list(trials)
|
||||
self.assertEqual(len(trials), 1)
|
||||
self.assertEqual(trials[0].config, {"foo": 4})
|
||||
self.assertEqual(trials[0].experiment_tag, "0_foo=4")
|
||||
|
||||
def testGridSearch(self):
|
||||
trials = generate_trials({
|
||||
trials = self.generate_trials({
|
||||
"run": "PPO",
|
||||
"config": {
|
||||
"bar": {
|
||||
@@ -490,7 +518,7 @@ class VariantGeneratorTest(unittest.TestCase):
|
||||
"grid_search": [1, 2, 3]
|
||||
},
|
||||
},
|
||||
})
|
||||
}, "grid_search")
|
||||
trials = list(trials)
|
||||
self.assertEqual(len(trials), 6)
|
||||
self.assertEqual(trials[0].config, {"bar": True, "foo": 1})
|
||||
@@ -503,47 +531,47 @@ class VariantGeneratorTest(unittest.TestCase):
|
||||
self.assertEqual(trials[5].config, {"bar": False, "foo": 3})
|
||||
|
||||
def testGridSearchAndEval(self):
|
||||
trials = generate_trials({
|
||||
trials = self.generate_trials({
|
||||
"run": "PPO",
|
||||
"config": {
|
||||
"qux": lambda spec: 2 + 2,
|
||||
"bar": grid_search([True, False]),
|
||||
"foo": grid_search([1, 2, 3]),
|
||||
},
|
||||
})
|
||||
}, "grid_eval")
|
||||
trials = list(trials)
|
||||
self.assertEqual(len(trials), 6)
|
||||
self.assertEqual(trials[0].config, {"bar": True, "foo": 1, "qux": 4})
|
||||
self.assertEqual(trials[0].experiment_tag, "0_bar=True,foo=1,qux=4")
|
||||
|
||||
def testConditionResolution(self):
|
||||
trials = generate_trials({
|
||||
trials = self.generate_trials({
|
||||
"run": "PPO",
|
||||
"config": {
|
||||
"x": 1,
|
||||
"y": lambda spec: spec.config.x + 1,
|
||||
"z": lambda spec: spec.config.y + 1,
|
||||
},
|
||||
})
|
||||
}, "condition_resolution")
|
||||
trials = list(trials)
|
||||
self.assertEqual(len(trials), 1)
|
||||
self.assertEqual(trials[0].config, {"x": 1, "y": 2, "z": 3})
|
||||
|
||||
def testDependentLambda(self):
|
||||
trials = generate_trials({
|
||||
trials = self.generate_trials({
|
||||
"run": "PPO",
|
||||
"config": {
|
||||
"x": grid_search([1, 2]),
|
||||
"y": lambda spec: spec.config.x * 100,
|
||||
},
|
||||
})
|
||||
}, "dependent_lambda")
|
||||
trials = list(trials)
|
||||
self.assertEqual(len(trials), 2)
|
||||
self.assertEqual(trials[0].config, {"x": 1, "y": 100})
|
||||
self.assertEqual(trials[1].config, {"x": 2, "y": 200})
|
||||
|
||||
def testDependentGridSearch(self):
|
||||
trials = generate_trials({
|
||||
trials = self.generate_trials({
|
||||
"run": "PPO",
|
||||
"config": {
|
||||
"x": grid_search([
|
||||
@@ -552,7 +580,7 @@ class VariantGeneratorTest(unittest.TestCase):
|
||||
]),
|
||||
"y": lambda spec: 1,
|
||||
},
|
||||
})
|
||||
}, "dependent_grid_search")
|
||||
trials = list(trials)
|
||||
self.assertEqual(len(trials), 2)
|
||||
self.assertEqual(trials[0].config, {"x": 100, "y": 1})
|
||||
@@ -561,17 +589,45 @@ class VariantGeneratorTest(unittest.TestCase):
|
||||
def testRecursiveDep(self):
|
||||
try:
|
||||
list(
|
||||
generate_trials({
|
||||
self.generate_trials({
|
||||
"run": "PPO",
|
||||
"config": {
|
||||
"foo": lambda spec: spec.config.foo,
|
||||
},
|
||||
}))
|
||||
}, "recursive_dep"))
|
||||
except RecursiveDependencyError as e:
|
||||
assert "`foo` recursively depends on" in str(e), e
|
||||
else:
|
||||
assert False
|
||||
|
||||
def testMaxConcurrentSuggestions(self):
|
||||
"""Checks that next_trials() supports throttling."""
|
||||
experiment_spec = {
|
||||
"run": "PPO",
|
||||
"repeat": 6,
|
||||
}
|
||||
experiments = [Experiment.from_json("test", experiment_spec)]
|
||||
|
||||
searcher = _MockSuggestionAlgorithm(experiments, max_concurrent=4)
|
||||
trials = searcher.next_trials()
|
||||
self.assertEqual(len(trials), 4)
|
||||
self.assertEqual(searcher.next_trials(), [])
|
||||
|
||||
finished_trial = trials.pop()
|
||||
searcher.on_trial_complete(finished_trial.trial_id)
|
||||
self.assertEqual(len(searcher.next_trials()), 1)
|
||||
|
||||
finished_trial = trials.pop()
|
||||
searcher.on_trial_complete(finished_trial.trial_id)
|
||||
|
||||
finished_trial = trials.pop()
|
||||
searcher.on_trial_complete(finished_trial.trial_id)
|
||||
|
||||
finished_trial = trials.pop()
|
||||
searcher.on_trial_complete(finished_trial.trial_id)
|
||||
self.assertEqual(len(searcher.next_trials()), 1)
|
||||
self.assertEqual(len(searcher.next_trials()), 0)
|
||||
|
||||
|
||||
class TrialRunnerTest(unittest.TestCase):
|
||||
def tearDown(self):
|
||||
@@ -608,7 +664,8 @@ class TrialRunnerTest(unittest.TestCase):
|
||||
}
|
||||
|
||||
for name, spec in experiments.items():
|
||||
for trial in generate_trials(spec, name):
|
||||
trial_generator = BasicVariantGenerator({name: spec})
|
||||
for trial in trial_generator.next_trials():
|
||||
trial.start()
|
||||
self.assertLessEqual(len(trial.logdir), 200)
|
||||
trial.stop()
|
||||
@@ -624,7 +681,7 @@ class TrialRunnerTest(unittest.TestCase):
|
||||
|
||||
def testExtraResources(self):
|
||||
ray.init(num_cpus=4, num_gpus=2)
|
||||
runner = TrialRunner()
|
||||
runner = TrialRunner(BasicVariantGenerator())
|
||||
kwargs = {
|
||||
"stopping_criterion": {
|
||||
"training_iteration": 1
|
||||
@@ -645,7 +702,7 @@ class TrialRunnerTest(unittest.TestCase):
|
||||
|
||||
def testResourceScheduler(self):
|
||||
ray.init(num_cpus=4, num_gpus=1)
|
||||
runner = TrialRunner()
|
||||
runner = TrialRunner(BasicVariantGenerator())
|
||||
kwargs = {
|
||||
"stopping_criterion": {
|
||||
"training_iteration": 1
|
||||
@@ -674,7 +731,7 @@ class TrialRunnerTest(unittest.TestCase):
|
||||
|
||||
def testMultiStepRun(self):
|
||||
ray.init(num_cpus=4, num_gpus=2)
|
||||
runner = TrialRunner()
|
||||
runner = TrialRunner(BasicVariantGenerator())
|
||||
kwargs = {
|
||||
"stopping_criterion": {
|
||||
"training_iteration": 5
|
||||
@@ -703,7 +760,7 @@ class TrialRunnerTest(unittest.TestCase):
|
||||
|
||||
def testErrorHandling(self):
|
||||
ray.init(num_cpus=4, num_gpus=2)
|
||||
runner = TrialRunner()
|
||||
runner = TrialRunner(BasicVariantGenerator())
|
||||
kwargs = {
|
||||
"stopping_criterion": {
|
||||
"training_iteration": 1
|
||||
@@ -725,7 +782,7 @@ class TrialRunnerTest(unittest.TestCase):
|
||||
|
||||
def testFailureRecoveryDisabled(self):
|
||||
ray.init(num_cpus=1, num_gpus=1)
|
||||
runner = TrialRunner()
|
||||
runner = TrialRunner(BasicVariantGenerator())
|
||||
kwargs = {
|
||||
"resources": Resources(cpu=1, gpu=1),
|
||||
"checkpoint_freq": 1,
|
||||
@@ -747,7 +804,7 @@ class TrialRunnerTest(unittest.TestCase):
|
||||
|
||||
def testFailureRecoveryEnabled(self):
|
||||
ray.init(num_cpus=1, num_gpus=1)
|
||||
runner = TrialRunner()
|
||||
runner = TrialRunner(BasicVariantGenerator())
|
||||
kwargs = {
|
||||
"resources": Resources(cpu=1, gpu=1),
|
||||
"checkpoint_freq": 1,
|
||||
@@ -771,7 +828,7 @@ class TrialRunnerTest(unittest.TestCase):
|
||||
|
||||
def testFailureRecoveryMaxFailures(self):
|
||||
ray.init(num_cpus=1, num_gpus=1)
|
||||
runner = TrialRunner()
|
||||
runner = TrialRunner(BasicVariantGenerator())
|
||||
kwargs = {
|
||||
"resources": Resources(cpu=1, gpu=1),
|
||||
"checkpoint_freq": 1,
|
||||
@@ -800,7 +857,7 @@ class TrialRunnerTest(unittest.TestCase):
|
||||
|
||||
def testCheckpointing(self):
|
||||
ray.init(num_cpus=1, num_gpus=1)
|
||||
runner = TrialRunner()
|
||||
runner = TrialRunner(BasicVariantGenerator())
|
||||
kwargs = {
|
||||
"stopping_criterion": {
|
||||
"training_iteration": 1
|
||||
@@ -833,7 +890,7 @@ class TrialRunnerTest(unittest.TestCase):
|
||||
def testResultDone(self):
|
||||
"""Tests that last_result is marked `done` after trial is complete."""
|
||||
ray.init(num_cpus=1, num_gpus=1)
|
||||
runner = TrialRunner()
|
||||
runner = TrialRunner(BasicVariantGenerator())
|
||||
kwargs = {
|
||||
"stopping_criterion": {
|
||||
"training_iteration": 2
|
||||
@@ -852,7 +909,7 @@ class TrialRunnerTest(unittest.TestCase):
|
||||
|
||||
def testPauseThenResume(self):
|
||||
ray.init(num_cpus=1, num_gpus=1)
|
||||
runner = TrialRunner()
|
||||
runner = TrialRunner(BasicVariantGenerator())
|
||||
kwargs = {
|
||||
"stopping_criterion": {
|
||||
"training_iteration": 2
|
||||
@@ -883,7 +940,7 @@ class TrialRunnerTest(unittest.TestCase):
|
||||
|
||||
def testStopTrial(self):
|
||||
ray.init(num_cpus=4, num_gpus=2)
|
||||
runner = TrialRunner()
|
||||
runner = TrialRunner(BasicVariantGenerator())
|
||||
kwargs = {
|
||||
"stopping_criterion": {
|
||||
"training_iteration": 5
|
||||
@@ -924,6 +981,117 @@ class TrialRunnerTest(unittest.TestCase):
|
||||
self.assertEqual(trials[2].status, Trial.RUNNING)
|
||||
self.assertEqual(trials[-1].status, Trial.TERMINATED)
|
||||
|
||||
def testSearchAlgNotification(self):
|
||||
"""Checks notification of trial to the Search Algorithm."""
|
||||
ray.init(num_cpus=4, num_gpus=2)
|
||||
experiment_spec = {"run": "__fake", "stop": {"training_iteration": 2}}
|
||||
experiments = [Experiment.from_json("test", experiment_spec)]
|
||||
searcher = _MockSuggestionAlgorithm(experiments, max_concurrent=10)
|
||||
runner = TrialRunner(search_alg=searcher)
|
||||
runner.step()
|
||||
trials = runner.get_trials()
|
||||
self.assertEqual(trials[0].status, Trial.RUNNING)
|
||||
|
||||
runner.step()
|
||||
self.assertEqual(trials[0].status, Trial.RUNNING)
|
||||
|
||||
runner.step()
|
||||
self.assertEqual(trials[0].status, Trial.TERMINATED)
|
||||
|
||||
self.assertEqual(searcher.counter["result"], 1)
|
||||
self.assertEqual(searcher.counter["complete"], 1)
|
||||
|
||||
def testSearchAlgFinished(self):
|
||||
"""Checks that SearchAlg is Finished before all trials are done."""
|
||||
ray.init(num_cpus=4, num_gpus=2)
|
||||
experiment_spec = {"run": "__fake", "stop": {"training_iteration": 1}}
|
||||
experiments = [Experiment.from_json("test", experiment_spec)]
|
||||
searcher = _MockSuggestionAlgorithm(experiments, max_concurrent=10)
|
||||
runner = TrialRunner(search_alg=searcher)
|
||||
runner.step()
|
||||
trials = runner.get_trials()
|
||||
self.assertEqual(trials[0].status, Trial.RUNNING)
|
||||
self.assertTrue(searcher.is_finished())
|
||||
self.assertFalse(runner.is_finished())
|
||||
|
||||
runner.step()
|
||||
self.assertEqual(trials[0].status, Trial.TERMINATED)
|
||||
self.assertEqual(len(searcher.live_trials), 0)
|
||||
self.assertTrue(searcher.is_finished())
|
||||
self.assertTrue(runner.is_finished())
|
||||
|
||||
def testSearchAlgSchedulerInteraction(self):
|
||||
"""Checks that TrialScheduler killing trial will notify SearchAlg."""
|
||||
|
||||
class _MockScheduler(FIFOScheduler):
|
||||
def on_trial_result(self, *args, **kwargs):
|
||||
return TrialScheduler.STOP
|
||||
|
||||
ray.init(num_cpus=4, num_gpus=2)
|
||||
experiment_spec = {"run": "__fake", "stop": {"training_iteration": 2}}
|
||||
experiments = [Experiment.from_json("test", experiment_spec)]
|
||||
searcher = _MockSuggestionAlgorithm(experiments, max_concurrent=10)
|
||||
runner = TrialRunner(search_alg=searcher, scheduler=_MockScheduler())
|
||||
runner.step()
|
||||
trials = runner.get_trials()
|
||||
self.assertEqual(trials[0].status, Trial.RUNNING)
|
||||
self.assertTrue(searcher.is_finished())
|
||||
self.assertFalse(runner.is_finished())
|
||||
|
||||
runner.step()
|
||||
self.assertEqual(trials[0].status, Trial.TERMINATED)
|
||||
self.assertEqual(len(searcher.live_trials), 0)
|
||||
self.assertTrue(searcher.is_finished())
|
||||
self.assertTrue(runner.is_finished())
|
||||
|
||||
def testSearchAlgStalled(self):
|
||||
"""Checks that runner and searcher state is maintained when stalled."""
|
||||
ray.init(num_cpus=4, num_gpus=2)
|
||||
experiment_spec = {
|
||||
"run": "__fake",
|
||||
"repeat": 3,
|
||||
"stop": {
|
||||
"training_iteration": 1
|
||||
}
|
||||
}
|
||||
experiments = [Experiment.from_json("test", experiment_spec)]
|
||||
searcher = _MockSuggestionAlgorithm(experiments, max_concurrent=1)
|
||||
runner = TrialRunner(search_alg=searcher)
|
||||
runner.step()
|
||||
trials = runner.get_trials()
|
||||
self.assertEqual(trials[0].status, Trial.RUNNING)
|
||||
|
||||
runner.step()
|
||||
self.assertEqual(trials[0].status, Trial.TERMINATED)
|
||||
|
||||
trials = runner.get_trials()
|
||||
runner.step()
|
||||
self.assertEqual(trials[1].status, Trial.RUNNING)
|
||||
self.assertEqual(len(searcher.live_trials), 1)
|
||||
|
||||
searcher.stall = True
|
||||
|
||||
runner.step()
|
||||
self.assertEqual(trials[1].status, Trial.TERMINATED)
|
||||
self.assertEqual(len(searcher.live_trials), 0)
|
||||
|
||||
self.assertTrue(all(trial.is_finished() for trial in trials))
|
||||
self.assertFalse(searcher.is_finished())
|
||||
self.assertFalse(runner.is_finished())
|
||||
|
||||
searcher.stall = False
|
||||
|
||||
runner.step()
|
||||
trials = runner.get_trials()
|
||||
self.assertEqual(trials[2].status, Trial.RUNNING)
|
||||
self.assertEqual(len(searcher.live_trials), 1)
|
||||
|
||||
runner.step()
|
||||
self.assertEqual(trials[2].status, Trial.TERMINATED)
|
||||
self.assertEqual(len(searcher.live_trials), 0)
|
||||
self.assertTrue(searcher.is_finished())
|
||||
self.assertTrue(runner.is_finished())
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
unittest.main(verbosity=2)
|
||||
|
||||
@@ -9,6 +9,7 @@ import ray
|
||||
from ray.rllib import _register_all
|
||||
from ray.tune.trial import Trial, Resources
|
||||
from ray.tune.web_server import TuneClient
|
||||
from ray.tune.suggest import BasicVariantGenerator
|
||||
from ray.tune.trial_runner import TrialRunner
|
||||
|
||||
|
||||
@@ -30,7 +31,8 @@ class TuneServerSuite(unittest.TestCase):
|
||||
def basicSetup(self):
|
||||
ray.init(num_cpus=4, num_gpus=1)
|
||||
port = get_valid_port()
|
||||
self.runner = TrialRunner(launch_web_server=True, server_port=port)
|
||||
self.runner = TrialRunner(
|
||||
BasicVariantGenerator(), launch_web_server=True, server_port=port)
|
||||
runner = self.runner
|
||||
kwargs = {
|
||||
"stopping_criterion": {
|
||||
|
||||
@@ -80,6 +80,7 @@ class Trial(object):
|
||||
def __init__(self,
|
||||
trainable_name,
|
||||
config=None,
|
||||
trial_id=None,
|
||||
local_dir=DEFAULT_RESULTS_DIR,
|
||||
experiment_tag="",
|
||||
resources=None,
|
||||
@@ -131,10 +132,17 @@ class Trial(object):
|
||||
self.logdir = None
|
||||
self.result_logger = None
|
||||
self.last_debug = 0
|
||||
self.trial_id = binary_to_hex(random_string())[:8]
|
||||
if trial_id is not None:
|
||||
self.trial_id = trial_id
|
||||
else:
|
||||
self.trial_id = Trial.generate_id()
|
||||
self.error_file = None
|
||||
self.num_failures = 0
|
||||
|
||||
@classmethod
|
||||
def generate_id(cls):
|
||||
return binary_to_hex(random_string())[:8]
|
||||
|
||||
def start(self, checkpoint_obj=None):
|
||||
"""Starts this trial.
|
||||
|
||||
|
||||
@@ -20,7 +20,7 @@ class TrialRunner(object):
|
||||
"""A TrialRunner implements the event loop for scheduling trials on Ray.
|
||||
|
||||
Example:
|
||||
runner = TrialRunner()
|
||||
runner = TrialRunner(BasicVariantGenerator())
|
||||
runner.add_trial(Trial(...))
|
||||
runner.add_trial(Trial(...))
|
||||
while not runner.is_finished():
|
||||
@@ -39,6 +39,7 @@ class TrialRunner(object):
|
||||
"""
|
||||
|
||||
def __init__(self,
|
||||
search_alg,
|
||||
scheduler=None,
|
||||
launch_web_server=False,
|
||||
server_port=TuneServer.DEFAULT_PORT,
|
||||
@@ -47,6 +48,8 @@ class TrialRunner(object):
|
||||
"""Initializes a new TrialRunner.
|
||||
|
||||
Args:
|
||||
search_alg (SearchAlgorithm): SearchAlgorithm for generating
|
||||
Trial objects.
|
||||
scheduler (TrialScheduler): Defaults to FIFOScheduler.
|
||||
launch_web_server (bool): Flag for starting TuneServer
|
||||
server_port (int): Port number for launching TuneServer
|
||||
@@ -57,7 +60,7 @@ class TrialRunner(object):
|
||||
be set to True when running on an autoscaling cluster to enable
|
||||
automatic scale-up.
|
||||
"""
|
||||
|
||||
self._search_alg = search_alg
|
||||
self._scheduler_alg = scheduler or FIFOScheduler()
|
||||
self._trials = []
|
||||
self._running = {}
|
||||
@@ -85,10 +88,8 @@ class TrialRunner(object):
|
||||
self._total_time, self._global_time_limit))
|
||||
return True
|
||||
|
||||
for t in self._trials:
|
||||
if t.status in [Trial.PENDING, Trial.RUNNING, Trial.PAUSED]:
|
||||
return False
|
||||
return True
|
||||
trials_done = all(trial.is_finished() for trial in self._trials)
|
||||
return trials_done and self._search_alg.is_finished()
|
||||
|
||||
def step(self):
|
||||
"""Runs one step of the trial event loop.
|
||||
@@ -224,7 +225,15 @@ class TrialRunner(object):
|
||||
return False
|
||||
|
||||
def _get_next_trial(self):
|
||||
"""Replenishes queue.
|
||||
|
||||
Blocks if all trials queued have finished, but search algorithm is
|
||||
still not finished.
|
||||
"""
|
||||
self._update_avail_resources()
|
||||
trials_done = all(trial.is_finished() for trial in self._trials)
|
||||
wait_for_trial = trials_done and not self._search_alg.is_finished()
|
||||
self._update_trial_queue(blocking=wait_for_trial)
|
||||
trial = self._scheduler_alg.choose_trial_to_run(self)
|
||||
return trial
|
||||
|
||||
@@ -258,10 +267,16 @@ class TrialRunner(object):
|
||||
if trial.should_stop(result):
|
||||
# Hook into scheduler
|
||||
self._scheduler_alg.on_trial_complete(self, trial, result)
|
||||
self._search_alg.on_trial_complete(
|
||||
trial.trial_id, result=result)
|
||||
decision = TrialScheduler.STOP
|
||||
else:
|
||||
decision = self._scheduler_alg.on_trial_result(
|
||||
self, trial, result)
|
||||
self._search_alg.on_trial_result(trial.trial_id, result)
|
||||
if decision == TrialScheduler.STOP:
|
||||
self._search_alg.on_trial_complete(
|
||||
trial.trial_id, early_terminated=True)
|
||||
trial.update_last_result(
|
||||
result, terminate=(decision == TrialScheduler.STOP))
|
||||
|
||||
@@ -286,6 +301,8 @@ class TrialRunner(object):
|
||||
self._try_recover(trial, error_msg)
|
||||
else:
|
||||
self._scheduler_alg.on_trial_error(self, trial)
|
||||
self._search_alg.on_trial_complete(
|
||||
trial.trial_id, error=True)
|
||||
self._stop_trial(trial, error=True, error_msg=error_msg)
|
||||
|
||||
def _try_recover(self, trial, error_msg):
|
||||
@@ -300,6 +317,28 @@ class TrialRunner(object):
|
||||
print("Error recovering trial from checkpoint, abort:", error_msg)
|
||||
self._stop_trial(trial, error=True, error_msg=error_msg)
|
||||
|
||||
def _update_trial_queue(self, blocking=False, timeout=600):
|
||||
"""Adds next trials to queue if possible.
|
||||
|
||||
Note that the timeout is currently unexposed to the user.
|
||||
|
||||
Arguments:
|
||||
blocking (bool): Blocks until either a trial is available
|
||||
or the Runner finishes (i.e., timeout or search algorithm
|
||||
finishes).
|
||||
timeout (int): Seconds before blocking times out."""
|
||||
trials = self._search_alg.next_trials()
|
||||
if blocking and not trials:
|
||||
start = time.time()
|
||||
while (not trials and not self.is_finished()
|
||||
and time.time() - start < timeout):
|
||||
print("Blocking for next trial...")
|
||||
trials = self._search_alg.next_trials()
|
||||
time.sleep(1)
|
||||
|
||||
for trial in trials:
|
||||
self.add_trial(trial)
|
||||
|
||||
def _commit_resources(self, resources):
|
||||
self._committed_resources = Resources(
|
||||
self._committed_resources.cpu + resources.cpu_total(),
|
||||
@@ -324,9 +363,11 @@ class TrialRunner(object):
|
||||
"""Stops trial.
|
||||
|
||||
Trials may be stopped at any time. If trial is in state PENDING
|
||||
or PAUSED, calls `scheduler.on_trial_remove`. Otherwise waits for
|
||||
result for the trial and calls `scheduler.on_trial_complete`
|
||||
if RUNNING."""
|
||||
or PAUSED, calls `on_trial_remove` for scheduler and
|
||||
`on_trial_complete(..., early_terminated=True) for search_alg.
|
||||
Otherwise waits for result for the trial and calls
|
||||
`on_trial_complete` for scheduler and search_alg if RUNNING.
|
||||
"""
|
||||
error = False
|
||||
error_msg = None
|
||||
|
||||
@@ -334,6 +375,8 @@ class TrialRunner(object):
|
||||
return
|
||||
elif trial.status in [Trial.PENDING, Trial.PAUSED]:
|
||||
self._scheduler_alg.on_trial_remove(self, trial)
|
||||
self._search_alg.on_trial_complete(
|
||||
trial.trial_id, early_terminated=True)
|
||||
elif trial.status is Trial.RUNNING:
|
||||
# NOTE: There should only be one...
|
||||
result_id = [
|
||||
@@ -344,10 +387,13 @@ class TrialRunner(object):
|
||||
result = ray.get(result_id)
|
||||
trial.update_last_result(result, terminate=True)
|
||||
self._scheduler_alg.on_trial_complete(self, trial, result)
|
||||
self._search_alg.on_trial_complete(
|
||||
trial.trial_id, result=result)
|
||||
except Exception:
|
||||
error_msg = traceback.format_exc()
|
||||
print("Error processing event:", error_msg)
|
||||
self._scheduler_alg.on_trial_error(self, trial)
|
||||
self._search_alg.on_trial_complete(trial.trial_id, error=True)
|
||||
error = True
|
||||
|
||||
self._stop_trial(trial, error=error, error_msg=error_msg)
|
||||
|
||||
@@ -3,7 +3,6 @@ from __future__ import division
|
||||
from __future__ import print_function
|
||||
|
||||
from ray.tune.trial import Trial
|
||||
from ray.tune.variant_generator import generate_trials
|
||||
|
||||
|
||||
class TrialScheduler(object):
|
||||
@@ -48,21 +47,6 @@ class TrialScheduler(object):
|
||||
|
||||
raise NotImplementedError
|
||||
|
||||
def add_experiment(self, experiment, trial_runner):
|
||||
"""Adds an experiment to the scheduler.
|
||||
|
||||
The scheduler is responsible for adding the trials of the experiment
|
||||
to the runner, which can be done immediately (if there are a finite
|
||||
set of trials), or over time (if there is an infinite stream of trials
|
||||
or if the scheduler is iterative in nature).
|
||||
"""
|
||||
generator = generate_trials(experiment.spec, experiment.name)
|
||||
while True:
|
||||
try:
|
||||
trial_runner.add_trial(next(generator))
|
||||
except StopIteration:
|
||||
break
|
||||
|
||||
def choose_trial_to_run(self, trial_runner):
|
||||
"""Called to choose a new trial to run.
|
||||
|
||||
|
||||
+13
-22
@@ -5,23 +5,21 @@ from __future__ import print_function
|
||||
import time
|
||||
|
||||
from ray.tune.error import TuneError
|
||||
from ray.tune.suggest import BasicVariantGenerator
|
||||
from ray.tune.hyperband import HyperBandScheduler
|
||||
from ray.tune.async_hyperband import AsyncHyperBandScheduler
|
||||
from ray.tune.median_stopping_rule import MedianStoppingRule
|
||||
from ray.tune.hpo_scheduler import HyperOptScheduler
|
||||
from ray.tune.trial import Trial, DEBUG_PRINT_INTERVAL
|
||||
from ray.tune.log_sync import wait_for_log_sync
|
||||
from ray.tune.trial_runner import TrialRunner
|
||||
from ray.tune.trial_scheduler import FIFOScheduler
|
||||
from ray.tune.web_server import TuneServer
|
||||
from ray.tune.experiment import Experiment
|
||||
|
||||
_SCHEDULERS = {
|
||||
"FIFO": FIFOScheduler,
|
||||
"MedianStopping": MedianStoppingRule,
|
||||
"HyperBand": HyperBandScheduler,
|
||||
"AsyncHyperBand": AsyncHyperBandScheduler,
|
||||
"HyperOpt": HyperOptScheduler,
|
||||
}
|
||||
|
||||
|
||||
@@ -33,7 +31,8 @@ def _make_scheduler(args):
|
||||
args.scheduler, _SCHEDULERS.keys()))
|
||||
|
||||
|
||||
def run_experiments(experiments,
|
||||
def run_experiments(experiments=None,
|
||||
search_alg=None,
|
||||
scheduler=None,
|
||||
with_server=False,
|
||||
server_port=TuneServer.DEFAULT_PORT,
|
||||
@@ -43,9 +42,11 @@ def run_experiments(experiments,
|
||||
|
||||
Args:
|
||||
experiments (Experiment | list | dict): Experiments to run.
|
||||
search_alg (SearchAlgorithm): Search Algorithm. Defaults to
|
||||
BasicVariantGenerator.
|
||||
scheduler (TrialScheduler): Scheduler for executing
|
||||
the experiment. Choose among FIFO (default), MedianStopping,
|
||||
AsyncHyperBand, HyperBand, or HyperOpt.
|
||||
AsyncHyperBand, and HyperBand.
|
||||
with_server (bool): Starts a background Tune server. Needed for
|
||||
using the Client API.
|
||||
server_port (int): Port number for launching TuneServer.
|
||||
@@ -58,31 +59,21 @@ def run_experiments(experiments,
|
||||
Returns:
|
||||
List of Trial objects, holding data for each executed trial.
|
||||
"""
|
||||
|
||||
if scheduler is None:
|
||||
scheduler = FIFOScheduler()
|
||||
|
||||
if search_alg is None:
|
||||
assert experiments is not None, "Experiments need to be specified" \
|
||||
"if search_alg is not provided."
|
||||
search_alg = BasicVariantGenerator(experiments)
|
||||
|
||||
runner = TrialRunner(
|
||||
scheduler,
|
||||
search_alg,
|
||||
scheduler=scheduler,
|
||||
launch_web_server=with_server,
|
||||
server_port=server_port,
|
||||
verbose=verbose,
|
||||
queue_trials=queue_trials)
|
||||
exp_list = experiments
|
||||
if isinstance(experiments, Experiment):
|
||||
exp_list = [experiments]
|
||||
elif type(experiments) is dict:
|
||||
exp_list = [
|
||||
Experiment.from_json(name, spec)
|
||||
for name, spec in experiments.items()
|
||||
]
|
||||
|
||||
if (type(exp_list) is list
|
||||
and all(isinstance(exp, Experiment) for exp in exp_list)):
|
||||
for experiment in exp_list:
|
||||
scheduler.add_experiment(experiment, runner)
|
||||
else:
|
||||
raise TuneError("Invalid argument: {}".format(experiments))
|
||||
|
||||
print(runner.debug_string(max_debug=99999))
|
||||
|
||||
|
||||
@@ -7,7 +7,7 @@ import sys
|
||||
import threading
|
||||
|
||||
from ray.tune.error import TuneError, TuneManagerError
|
||||
from ray.tune.variant_generator import generate_trials
|
||||
from ray.tune.suggest import BasicVariantGenerator
|
||||
|
||||
if sys.version_info[0] == 2:
|
||||
from SimpleHTTPServer import SimpleHTTPRequestHandler
|
||||
@@ -124,7 +124,8 @@ def RunnerHandler(runner):
|
||||
elif command == TuneClient.ADD:
|
||||
name = args["name"]
|
||||
spec = args["spec"]
|
||||
for trial in generate_trials(spec, name):
|
||||
trial_generator = BasicVariantGenerator({name: spec})
|
||||
for trial in trial_generator.next_trials():
|
||||
runner.add_trial(trial)
|
||||
else:
|
||||
raise TuneManagerError("Unknown command.")
|
||||
|
||||
Reference in New Issue
Block a user