mirror of
https://github.com/wassname/ray.git
synced 2026-06-29 14:39:44 +08:00
[tune] Tune Facelift (#2472)
This PR introduces the following changes: * Ray Tune -> Tune * [breaking] Creation of `schedulers/`, moving PBT, HyperBand into a submodule * [breaking] Search Algorithms now must take in experiment configurations via `add_configurations` rather through initialization * Support `"run": (function | class | str)` with automatic registering of trainable * Documentation Changes
This commit is contained in:
+14
-20
@@ -1,28 +1,22 @@
|
||||
Ray.tune: Hyperparameter Optimization Framework
|
||||
===============================================
|
||||
Tune: Scalable Hyperparameter Search
|
||||
====================================
|
||||
|
||||
Ray.tune is a hyperparameter tuning framework for long-running tasks such as RL and deep learning training.
|
||||
Tune is a scalable framework for hyperparameter search with a focus on deep learning and deep reinforcement learning.
|
||||
|
||||
User documentation can be `found here <http://ray.readthedocs.io/en/latest/tune.html>`__.
|
||||
|
||||
Implementation overview
|
||||
-----------------------
|
||||
|
||||
At a high level, Ray.tune takes in JSON experiment configs (e.g. that defines the grid or random search)
|
||||
and compiles them into a number of `Trial` objects. It schedules trials on the Ray cluster using a given
|
||||
`TrialScheduler` implementation (e.g. median stopping rule or HyperBand).
|
||||
Citing Tune
|
||||
-----------
|
||||
|
||||
This is implemented as follows:
|
||||
If Tune helps you in your academic research, you are encouraged to cite `our paper <https://arxiv.org/abs/1807.05118>`__. Here is an example bibtex:
|
||||
|
||||
- `variant_generator.py <https://github.com/ray-project/ray/blob/master/python/ray/tune/suggest/variant_generator.py>`__
|
||||
parses the config and generates the trial variants.
|
||||
.. code-block:: tex
|
||||
|
||||
- `trial.py <https://github.com/ray-project/ray/blob/master/python/ray/tune/trial.py>`__ manages the lifecycle
|
||||
of the Ray actor responsible for executing the trial.
|
||||
|
||||
- `trial_runner.py <https://github.com/ray-project/ray/blob/master/python/ray/tune/tune.py>`__ tracks scheduling
|
||||
state for all the trials of an experiment. TrialRunners are usually
|
||||
created automatically by ``run_experiments(experiment_json)``, which parses and starts the experiments.
|
||||
|
||||
- `trial_scheduler.py <https://github.com/ray-project/ray/blob/master/python/ray/tune/trial_scheduler.py>`__
|
||||
plugs into TrialRunner to implement custom prioritization or early stopping algorithms.
|
||||
@article{liaw2018tune,
|
||||
title={Tune: A Research Platform for Distributed Model Selection and Training},
|
||||
author={Liaw, Richard and Liang, Eric and Nishihara, Robert and
|
||||
Moritz, Philipp and Gonzalez, Joseph E and Stoica, Ion},
|
||||
journal={arXiv preprint arXiv:1807.05118},
|
||||
year={2018}
|
||||
}
|
||||
|
||||
@@ -168,12 +168,6 @@ def create_trial_from_spec(spec, output_path, parser, **trial_kwargs):
|
||||
A trial object with corresponding parameters to the specification.
|
||||
"""
|
||||
try:
|
||||
# Special case the `env` param for RLlib by automatically
|
||||
# moving it into the `config` section.
|
||||
if "env" in spec:
|
||||
spec["config"] = spec.get("config", {})
|
||||
spec["config"]["env"] = spec["env"]
|
||||
del spec["env"]
|
||||
args = parser.parse_args(to_argv(spec))
|
||||
except SystemExit:
|
||||
raise TuneError("Error parsing args, see above message", spec)
|
||||
|
||||
@@ -1,4 +1,4 @@
|
||||
Ray Tune Examples
|
||||
=================
|
||||
Tune Examples
|
||||
=============
|
||||
|
||||
Code examples for various schedulers and Ray Tune features.
|
||||
Code examples for various schedulers and Tune features.
|
||||
|
||||
@@ -12,9 +12,8 @@ import random
|
||||
import numpy as np
|
||||
|
||||
import ray
|
||||
from ray.tune import Trainable, register_trainable, \
|
||||
run_experiments
|
||||
from ray.tune.async_hyperband import AsyncHyperBandScheduler
|
||||
from ray.tune import Trainable, run_experiments
|
||||
from ray.tune.schedulers import AsyncHyperBandScheduler
|
||||
|
||||
|
||||
class MyTrainableClass(Trainable):
|
||||
@@ -47,8 +46,6 @@ class MyTrainableClass(Trainable):
|
||||
self.timestep = json.loads(f.read())["timestep"]
|
||||
|
||||
|
||||
register_trainable("my_class", MyTrainableClass)
|
||||
|
||||
if __name__ == "__main__":
|
||||
parser = argparse.ArgumentParser()
|
||||
parser.add_argument(
|
||||
@@ -69,7 +66,7 @@ if __name__ == "__main__":
|
||||
run_experiments(
|
||||
{
|
||||
"asynchyperband_test": {
|
||||
"run": "my_class",
|
||||
"run": MyTrainableClass,
|
||||
"stop": {
|
||||
"training_iteration": 1 if args.smoke_test else 99999
|
||||
},
|
||||
|
||||
@@ -12,9 +12,8 @@ import random
|
||||
import numpy as np
|
||||
|
||||
import ray
|
||||
from ray.tune import Trainable, register_trainable, \
|
||||
run_experiments, Experiment
|
||||
from ray.tune.hyperband import HyperBandScheduler
|
||||
from ray.tune import Trainable, run_experiments, Experiment
|
||||
from ray.tune.schedulers import HyperBandScheduler
|
||||
|
||||
|
||||
class MyTrainableClass(Trainable):
|
||||
@@ -47,8 +46,6 @@ class MyTrainableClass(Trainable):
|
||||
self.timestep = json.loads(f.read())["timestep"]
|
||||
|
||||
|
||||
register_trainable("my_class", MyTrainableClass)
|
||||
|
||||
if __name__ == "__main__":
|
||||
parser = argparse.ArgumentParser()
|
||||
parser.add_argument(
|
||||
@@ -66,7 +63,7 @@ if __name__ == "__main__":
|
||||
|
||||
exp = Experiment(
|
||||
name="hyperband_test",
|
||||
run="my_class",
|
||||
run=MyTrainableClass,
|
||||
repeat=20,
|
||||
stop={"training_iteration": 1 if args.smoke_test else 99999},
|
||||
config={
|
||||
|
||||
@@ -8,7 +8,7 @@ from __future__ import print_function
|
||||
|
||||
import ray
|
||||
from ray.tune import run_experiments, register_trainable
|
||||
from ray.tune.async_hyperband import AsyncHyperBandScheduler
|
||||
from ray.tune.schedulers import AsyncHyperBandScheduler
|
||||
from ray.tune.suggest import HyperOptSearch
|
||||
|
||||
|
||||
@@ -52,7 +52,6 @@ if __name__ == '__main__':
|
||||
},
|
||||
}
|
||||
}
|
||||
algo = HyperOptSearch(
|
||||
config, space, max_concurrent=4, reward_attr="neg_mean_loss")
|
||||
algo = HyperOptSearch(space, max_concurrent=4, reward_attr="neg_mean_loss")
|
||||
scheduler = AsyncHyperBandScheduler(reward_attr="neg_mean_loss")
|
||||
run_experiments(search_alg=algo, scheduler=scheduler)
|
||||
run_experiments(config, search_alg=algo, scheduler=scheduler)
|
||||
|
||||
@@ -11,8 +11,8 @@ import random
|
||||
import time
|
||||
|
||||
import ray
|
||||
from ray.tune import Trainable, register_trainable, run_experiments
|
||||
from ray.tune.pbt import PopulationBasedTraining
|
||||
from ray.tune import Trainable, run_experiments
|
||||
from ray.tune.schedulers import PopulationBasedTraining
|
||||
|
||||
|
||||
class MyTrainableClass(Trainable):
|
||||
@@ -54,8 +54,6 @@ class MyTrainableClass(Trainable):
|
||||
self.current_value = data["value"]
|
||||
|
||||
|
||||
register_trainable("my_class", MyTrainableClass)
|
||||
|
||||
if __name__ == "__main__":
|
||||
parser = argparse.ArgumentParser()
|
||||
parser.add_argument(
|
||||
@@ -79,7 +77,7 @@ if __name__ == "__main__":
|
||||
run_experiments(
|
||||
{
|
||||
"pbt_test": {
|
||||
"run": "my_class",
|
||||
"run": MyTrainableClass,
|
||||
"stop": {
|
||||
"training_iteration": 2 if args.smoke_test else 99999
|
||||
},
|
||||
|
||||
@@ -14,7 +14,7 @@ import random
|
||||
|
||||
import ray
|
||||
from ray.tune import run_experiments
|
||||
from ray.tune.pbt import PopulationBasedTraining
|
||||
from ray.tune.schedulers import PopulationBasedTraining
|
||||
|
||||
if __name__ == "__main__":
|
||||
|
||||
|
||||
@@ -24,9 +24,8 @@ from tensorflow.python.keras.preprocessing.image import ImageDataGenerator
|
||||
|
||||
import ray
|
||||
from ray.tune import grid_search, run_experiments
|
||||
from ray.tune import register_trainable
|
||||
from ray.tune import Trainable
|
||||
from ray.tune.pbt import PopulationBasedTraining
|
||||
from ray.tune.schedulers import PopulationBasedTraining
|
||||
|
||||
num_classes = 10
|
||||
|
||||
@@ -179,9 +178,8 @@ if __name__ == "__main__":
|
||||
"--smoke-test", action="store_true", help="Finish quickly for testing")
|
||||
args, _ = parser.parse_known_args()
|
||||
|
||||
register_trainable("train_cifar10", Cifar10Model)
|
||||
train_spec = {
|
||||
"run": "train_cifar10",
|
||||
"run": Cifar10Model,
|
||||
"trial_resources": {
|
||||
"cpu": 1,
|
||||
"gpu": 1
|
||||
|
||||
@@ -33,7 +33,7 @@ import tempfile
|
||||
import time
|
||||
|
||||
import ray
|
||||
from ray.tune import grid_search, run_experiments, register_trainable
|
||||
from ray.tune import grid_search, run_experiments
|
||||
|
||||
from tensorflow.examples.tutorials.mnist import input_data
|
||||
|
||||
@@ -218,9 +218,8 @@ if __name__ == '__main__':
|
||||
'--smoke-test', action='store_true', help='Finish quickly for testing')
|
||||
args, _ = parser.parse_known_args()
|
||||
|
||||
register_trainable('train_mnist', train)
|
||||
mnist_spec = {
|
||||
'run': 'train_mnist',
|
||||
'run': train,
|
||||
'repeat': 10,
|
||||
'stop': {
|
||||
'mean_accuracy': 0.99,
|
||||
@@ -237,7 +236,7 @@ if __name__ == '__main__':
|
||||
|
||||
ray.init()
|
||||
|
||||
from ray.tune.async_hyperband import AsyncHyperBandScheduler
|
||||
from ray.tune.schedulers import AsyncHyperBandScheduler
|
||||
run_experiments(
|
||||
{
|
||||
'tune_mnist_test': mnist_spec
|
||||
|
||||
@@ -13,7 +13,7 @@ from keras import backend as K
|
||||
|
||||
import ray
|
||||
from ray import tune
|
||||
from ray.tune.async_hyperband import AsyncHyperBandScheduler
|
||||
from ray.tune.schedulers import AsyncHyperBandScheduler
|
||||
|
||||
|
||||
class TuneCallback(keras.callbacks.Callback):
|
||||
|
||||
@@ -32,7 +32,7 @@ import time
|
||||
import ray
|
||||
from ray.tune import grid_search, run_experiments, register_trainable, \
|
||||
Trainable
|
||||
from ray.tune.hyperband import HyperBandScheduler
|
||||
from ray.tune.schedulers import HyperBandScheduler
|
||||
from tensorflow.examples.tutorials.mnist import input_data
|
||||
|
||||
import tensorflow as tf
|
||||
|
||||
@@ -2,8 +2,13 @@ from __future__ import absolute_import
|
||||
from __future__ import division
|
||||
from __future__ import print_function
|
||||
|
||||
import copy
|
||||
import six
|
||||
import types
|
||||
|
||||
from ray.tune.result import DEFAULT_RESULTS_DIR
|
||||
from ray.tune.error import TuneError
|
||||
from ray.tune.registry import register_trainable
|
||||
|
||||
|
||||
class Experiment(object):
|
||||
@@ -11,19 +16,21 @@ class Experiment(object):
|
||||
|
||||
Parameters:
|
||||
name (str): Name of experiment.
|
||||
run (str): The algorithm or model to train. This may refer to the
|
||||
name of a built-on algorithm (e.g. RLLib's DQN or PPO), or a
|
||||
user-defined trainable function or class
|
||||
registered in the tune registry.
|
||||
run (function|class|str): The algorithm or model to train.
|
||||
This may refer to the name of a built-on algorithm
|
||||
(e.g. RLLib's DQN or PPO), a user-defined trainable
|
||||
function or class, or the string identifier of a
|
||||
trainable function or class registered in the tune registry.
|
||||
stop (dict): The stopping criteria. The keys may be any field in
|
||||
the return result of 'train()', whichever is reached first.
|
||||
Defaults to empty dict.
|
||||
config (dict): Algorithm-specific configuration
|
||||
(e.g. env, hyperparams). Defaults to empty dict.
|
||||
config (dict): Algorithm-specific configuration for Tune variant
|
||||
generation (e.g. env, hyperparams). Defaults to empty dict.
|
||||
Custom search algorithms may ignore this.
|
||||
trial_resources (dict): Machine resources to allocate per trial,
|
||||
e.g. ``{"cpu": 64, "gpu": 8}``. Note that GPUs will not be
|
||||
assigned unless you specify them here. Defaults to 1 CPU and 0
|
||||
GPUs.
|
||||
GPUs in ``Trainable.default_resource_request()``.
|
||||
repeat (int): Number of times to repeat each trial. Defaults to 1.
|
||||
local_dir (str): Local dir to save training results to.
|
||||
Defaults to ``~/ray_results``.
|
||||
@@ -34,6 +41,29 @@ class Experiment(object):
|
||||
max_failures (int): Try to recover a trial from its last
|
||||
checkpoint at least this many times. Only applies if
|
||||
checkpointing is enabled. Defaults to 3.
|
||||
restore (str): Path to checkpoint. Only makes sense to set if
|
||||
running 1 trial. Defaults to None.
|
||||
|
||||
|
||||
Examples:
|
||||
>>> experiment_spec = Experiment(
|
||||
>>> "my_experiment_name",
|
||||
>>> my_func,
|
||||
>>> stop={"mean_accuracy": 100},
|
||||
>>> config={
|
||||
>>> "alpha": tune.grid_search([0.2, 0.4, 0.6]),
|
||||
>>> "beta": tune.grid_search([1, 2]),
|
||||
>>> },
|
||||
>>> trial_resources={
|
||||
>>> "cpu": 1,
|
||||
>>> "gpu": 0
|
||||
>>> },
|
||||
>>> repeat=10,
|
||||
>>> local_dir="~/ray_results",
|
||||
>>> upload_dir="s3://your_bucket/path",
|
||||
>>> checkpoint_freq=10,
|
||||
>>> max_failures=2)
|
||||
|
||||
"""
|
||||
|
||||
def __init__(self,
|
||||
@@ -46,20 +76,19 @@ class Experiment(object):
|
||||
local_dir=None,
|
||||
upload_dir="",
|
||||
checkpoint_freq=0,
|
||||
max_failures=3):
|
||||
max_failures=3,
|
||||
restore=None):
|
||||
spec = {
|
||||
"run": run,
|
||||
"run": self._register_if_needed(run),
|
||||
"stop": stop or {},
|
||||
"config": config or {},
|
||||
"trial_resources": trial_resources or {
|
||||
"cpu": 1,
|
||||
"gpu": 0
|
||||
},
|
||||
"trial_resources": trial_resources,
|
||||
"repeat": repeat,
|
||||
"local_dir": local_dir or DEFAULT_RESULTS_DIR,
|
||||
"upload_dir": upload_dir,
|
||||
"checkpoint_freq": checkpoint_freq,
|
||||
"max_failures": max_failures
|
||||
"max_failures": max_failures,
|
||||
"restore": restore
|
||||
}
|
||||
|
||||
self.name = name
|
||||
@@ -75,11 +104,56 @@ class Experiment(object):
|
||||
"""
|
||||
if "run" not in spec:
|
||||
raise TuneError("No trainable specified!")
|
||||
exp = cls(name, spec["run"])
|
||||
exp.name = name
|
||||
exp.spec = spec
|
||||
|
||||
# Special case the `env` param for RLlib by automatically
|
||||
# moving it into the `config` section.
|
||||
if "env" in spec:
|
||||
spec["config"] = spec.get("config", {})
|
||||
spec["config"]["env"] = spec["env"]
|
||||
del spec["env"]
|
||||
|
||||
spec = copy.deepcopy(spec)
|
||||
|
||||
run_value = spec.pop("run")
|
||||
try:
|
||||
exp = cls(name, run_value, **spec)
|
||||
except TypeError:
|
||||
raise TuneError("Improper argument from JSON: {}.".format(spec))
|
||||
return exp
|
||||
|
||||
def _register_if_needed(self, run_object):
|
||||
"""Registers Trainable or Function at runtime.
|
||||
|
||||
Assumes already registered if run_object is a string. Does not
|
||||
register lambdas because they could be part of variant generation.
|
||||
Also, does not inspect interface of given run_object.
|
||||
|
||||
Arguments:
|
||||
run_object (str|function|class): Trainable to run. If string,
|
||||
assumes it is an ID and does not modify it. Otherwise,
|
||||
returns a string corresponding to the run_object name.
|
||||
|
||||
Returns:
|
||||
A string representing the trainable identifier.
|
||||
"""
|
||||
|
||||
if isinstance(run_object, six.string_types):
|
||||
return run_object
|
||||
elif isinstance(run_object, types.FunctionType):
|
||||
if run_object.__name__ == "<lambda>":
|
||||
print("Not auto-registering lambdas - resolving as variant.")
|
||||
return run_object
|
||||
else:
|
||||
name = run_object.__name__
|
||||
register_trainable(name, run_object)
|
||||
return name
|
||||
elif isinstance(run_object, type):
|
||||
name = run_object.__name__
|
||||
register_trainable(name, run_object)
|
||||
return name
|
||||
else:
|
||||
raise TuneError("Improper 'run' - not string nor trainable.")
|
||||
|
||||
|
||||
def convert_to_experiment_list(experiments):
|
||||
"""Produces a list of Experiment objects.
|
||||
|
||||
@@ -12,7 +12,12 @@ from ray.tune.result import TIMESTEPS_TOTAL
|
||||
|
||||
|
||||
class StatusReporter(object):
|
||||
"""Object passed into your main() that you can report status through."""
|
||||
"""Object passed into your main() that you can report status through.
|
||||
|
||||
Example:
|
||||
>>> reporter = StatusReporter()
|
||||
>>> reporter(timesteps_total=1)
|
||||
"""
|
||||
|
||||
def __init__(self):
|
||||
self._latest_result = None
|
||||
|
||||
@@ -0,0 +1,14 @@
|
||||
from __future__ import absolute_import
|
||||
from __future__ import division
|
||||
from __future__ import print_function
|
||||
|
||||
from ray.tune.schedulers.trial_scheduler import TrialScheduler, FIFOScheduler
|
||||
from ray.tune.schedulers.hyperband import HyperBandScheduler
|
||||
from ray.tune.schedulers.async_hyperband import AsyncHyperBandScheduler
|
||||
from ray.tune.schedulers.median_stopping_rule import MedianStoppingRule
|
||||
from ray.tune.schedulers.pbt import PopulationBasedTraining
|
||||
|
||||
__all__ = [
|
||||
"TrialScheduler", "HyperBandScheduler", "AsyncHyperBandScheduler",
|
||||
"MedianStoppingRule", "FIFOScheduler", "PopulationBasedTraining"
|
||||
]
|
||||
@@ -4,7 +4,7 @@ from __future__ import print_function
|
||||
|
||||
import numpy as np
|
||||
|
||||
from ray.tune.trial_scheduler import FIFOScheduler, TrialScheduler
|
||||
from ray.tune.schedulers.trial_scheduler import FIFOScheduler, TrialScheduler
|
||||
|
||||
|
||||
class AsyncHyperBandScheduler(FIFOScheduler):
|
||||
@@ -5,7 +5,7 @@ from __future__ import print_function
|
||||
import collections
|
||||
import numpy as np
|
||||
|
||||
from ray.tune.trial_scheduler import FIFOScheduler, TrialScheduler
|
||||
from ray.tune.schedulers.trial_scheduler import FIFOScheduler, TrialScheduler
|
||||
from ray.tune.trial import Trial
|
||||
|
||||
|
||||
@@ -38,7 +38,7 @@ class HyperBandScheduler(FIFOScheduler):
|
||||
algorithm. It divides trials into brackets of varying sizes, and
|
||||
periodically early stops low-performing trials within each bracket.
|
||||
|
||||
To use this implementation of HyperBand with Ray Tune, all you need
|
||||
To use this implementation of HyperBand with Tune, all you need
|
||||
to do is specify the max length of time a trial can run `max_t`, the time
|
||||
units `time_attr`, and the name of the reported objective value
|
||||
`reward_attr`. We automatically determine reasonable values for the other
|
||||
+1
-1
@@ -6,7 +6,7 @@ import collections
|
||||
import numpy as np
|
||||
|
||||
from ray.tune.trial import Trial
|
||||
from ray.tune.trial_scheduler import FIFOScheduler, TrialScheduler
|
||||
from ray.tune.schedulers.trial_scheduler import FIFOScheduler, TrialScheduler
|
||||
|
||||
|
||||
class MedianStoppingRule(FIFOScheduler):
|
||||
@@ -8,7 +8,7 @@ import copy
|
||||
|
||||
from ray.tune.error import TuneError
|
||||
from ray.tune.trial import Trial
|
||||
from ray.tune.trial_scheduler import FIFOScheduler, TrialScheduler
|
||||
from ray.tune.schedulers.trial_scheduler import FIFOScheduler, TrialScheduler
|
||||
from ray.tune.suggest.variant_generator import format_vars
|
||||
|
||||
# Parameters are transferred from the top PBT_QUANTILE fraction of trials to
|
||||
@@ -97,7 +97,7 @@ class PopulationBasedTraining(FIFOScheduler):
|
||||
during training time. This enables very fast hyperparameter discovery and
|
||||
also automatically discovers good annealing schedules.
|
||||
|
||||
This Ray Tune PBT implementation considers all trials added as part of the
|
||||
This Tune PBT implementation considers all trials added as part of the
|
||||
PBT population. If the number of trials exceeds the cluster capacity,
|
||||
they will be time-multiplexed as to balance training progress across the
|
||||
population.
|
||||
@@ -6,9 +6,11 @@ from ray.tune.trial import Trial
|
||||
|
||||
|
||||
class TrialScheduler(object):
|
||||
CONTINUE = "CONTINUE"
|
||||
PAUSE = "PAUSE"
|
||||
STOP = "STOP"
|
||||
"""Interface for implementing a Trial Scheduler class."""
|
||||
|
||||
CONTINUE = "CONTINUE" #: Status for continuing trial execution
|
||||
PAUSE = "PAUSE" #: Status for pausing trial execution
|
||||
STOP = "STOP" #: Status for stopping trial execution
|
||||
|
||||
def on_trial_add(self, trial_runner, trial):
|
||||
"""Called when a new trial is added to the trial runner."""
|
||||
@@ -2,7 +2,7 @@ from __future__ import absolute_import
|
||||
from __future__ import division
|
||||
from __future__ import print_function
|
||||
|
||||
from itertools import chain
|
||||
import itertools
|
||||
|
||||
from ray.tune.error import TuneError
|
||||
from ray.tune.experiment import convert_to_experiment_list
|
||||
@@ -17,25 +17,29 @@ class BasicVariantGenerator(SearchAlgorithm):
|
||||
See also: `ray.tune.suggest.variant_generator`.
|
||||
|
||||
Example:
|
||||
>>> searcher = BasicVariantGenerator({"experiment": { ... }})
|
||||
>>> searcher = BasicVariantGenerator()
|
||||
>>> searcher.add_configurations({"experiment": { ... }})
|
||||
>>> list_of_trials = searcher.next_trials()
|
||||
>>> searcher.is_finished == True
|
||||
"""
|
||||
|
||||
def __init__(self, experiments=None):
|
||||
"""Constructs a generator given experiment specifications.
|
||||
def __init__(self):
|
||||
self._parser = make_parser()
|
||||
self._trial_generator = []
|
||||
self._counter = 0
|
||||
self._finished = False
|
||||
|
||||
def add_configurations(self, experiments):
|
||||
"""Chains generator given experiment specifications.
|
||||
|
||||
Arguments:
|
||||
experiments (Experiment | list | dict): Experiments to run.
|
||||
"""
|
||||
experiment_list = convert_to_experiment_list(experiments)
|
||||
self._parser = make_parser()
|
||||
self._trial_generator = chain.from_iterable([
|
||||
self._generate_trials(experiment.spec, experiment.name)
|
||||
for experiment in experiment_list
|
||||
])
|
||||
self._counter = 0
|
||||
self._finished = False
|
||||
for experiment in experiment_list:
|
||||
self._trial_generator = itertools.chain(
|
||||
self._trial_generator,
|
||||
self._generate_trials(experiment.spec, experiment.name))
|
||||
|
||||
def next_trials(self):
|
||||
"""Provides Trial objects to be queued into the TrialRunner.
|
||||
|
||||
@@ -22,21 +22,35 @@ class HyperOptSearch(SuggestionAlgorithm):
|
||||
added trials will not be tracked by HyperOpt.
|
||||
|
||||
Parameters:
|
||||
experiments (Experiment | list | dict): Experiments to run. Will be
|
||||
used by SuggestionAlgorithm parent class to initialize Trials.
|
||||
space (dict): HyperOpt configuration. Parameters will be sampled
|
||||
from this configuration and will be used to override
|
||||
parameters generated in the variant generation process.
|
||||
max_concurrent (int): Number of maximum concurrent trials. Defaults
|
||||
to 10.
|
||||
reward_attr (str): The training result objective value attribute.
|
||||
This refers to an increasing value, which is internally negated
|
||||
when interacting with HyperOpt so that HyperOpt can "maximize"
|
||||
this value.
|
||||
This refers to an increasing value.
|
||||
|
||||
Example:
|
||||
>>> space = {
|
||||
>>> 'width': hp.uniform('width', 0, 20),
|
||||
>>> 'height': hp.uniform('height', -100, 100),
|
||||
>>> 'activation': hp.choice("activation", ["relu", "tanh"])
|
||||
>>> }
|
||||
>>> config = {
|
||||
>>> "my_exp": {
|
||||
>>> "run": "exp",
|
||||
>>> "repeat": 10 if args.smoke_test else 1000,
|
||||
>>> "stop": {
|
||||
>>> "training_iteration": 100
|
||||
>>> },
|
||||
>>> }
|
||||
>>> }
|
||||
>>> algo = HyperOptSearch(
|
||||
>>> space, max_concurrent=4, reward_attr="neg_mean_loss")
|
||||
>>> algo.add_configurations(config)
|
||||
"""
|
||||
|
||||
def __init__(self,
|
||||
experiments,
|
||||
space,
|
||||
max_concurrent=10,
|
||||
reward_attr="episode_reward_mean",
|
||||
@@ -51,7 +65,7 @@ class HyperOptSearch(SuggestionAlgorithm):
|
||||
self._live_trial_mapping = {}
|
||||
self.rstate = np.random.RandomState()
|
||||
|
||||
super(HyperOptSearch, self).__init__(experiments=experiments, **kwargs)
|
||||
super(HyperOptSearch, self).__init__(**kwargs)
|
||||
|
||||
def _suggest(self, trial_id):
|
||||
if self._num_live_trials() >= self._max_concurrent:
|
||||
@@ -93,6 +107,11 @@ class HyperOptSearch(SuggestionAlgorithm):
|
||||
result=None,
|
||||
error=False,
|
||||
early_terminated=False):
|
||||
"""Passes the result to HyperOpt unless early terminated or errored.
|
||||
|
||||
The result is internally negated when interacting with HyperOpt
|
||||
so that HyperOpt can "maximize" this value, as it minimizes on default.
|
||||
"""
|
||||
ho_trial = self._get_hyperopt_trial(trial_id)
|
||||
if ho_trial is None:
|
||||
return
|
||||
|
||||
@@ -16,6 +16,14 @@ class SearchAlgorithm(object):
|
||||
See also: `ray.tune.suggest.BasicVariantGenerator`.
|
||||
"""
|
||||
|
||||
def add_configurations(self, experiments):
|
||||
"""Tracks given experiment specifications.
|
||||
|
||||
Arguments:
|
||||
experiments (Experiment | list | dict): Experiments to run.
|
||||
"""
|
||||
raise NotImplementedError
|
||||
|
||||
def next_trials(self):
|
||||
"""Provides Trial objects to be queued into the TrialRunner.
|
||||
|
||||
|
||||
@@ -2,7 +2,7 @@ from __future__ import absolute_import
|
||||
from __future__ import division
|
||||
from __future__ import print_function
|
||||
|
||||
from itertools import chain
|
||||
import itertools
|
||||
import copy
|
||||
|
||||
from ray.tune.error import TuneError
|
||||
@@ -24,26 +24,35 @@ class SuggestionAlgorithm(SearchAlgorithm):
|
||||
subsequent notifications.
|
||||
|
||||
Example:
|
||||
>>> suggester = SuggestionAlgorithm({ ... })
|
||||
>>> suggester = SuggestionAlgorithm()
|
||||
>>> suggester.add_configurations({ ... })
|
||||
>>> new_parameters = suggester._suggest()
|
||||
>>> suggester.on_trial_complete(trial_id, result)
|
||||
>>> better_parameters = suggester._suggest()
|
||||
"""
|
||||
|
||||
def __init__(self, experiments=None):
|
||||
def __init__(self):
|
||||
"""Constructs a generator given experiment specifications.
|
||||
|
||||
Arguments:
|
||||
experiments (Experiment | list | dict): Experiments to run.
|
||||
"""
|
||||
experiment_list = convert_to_experiment_list(experiments)
|
||||
self._parser = make_parser()
|
||||
self._trial_generator = chain.from_iterable([
|
||||
self._generate_trials(experiment.spec, experiment.name)
|
||||
for experiment in experiment_list
|
||||
])
|
||||
self._trial_generator = []
|
||||
self._finished = False
|
||||
|
||||
def add_configurations(self, experiments):
|
||||
"""Chains generator given experiment specifications.
|
||||
|
||||
Arguments:
|
||||
experiments (Experiment | list | dict): Experiments to run.
|
||||
"""
|
||||
experiment_list = convert_to_experiment_list(experiments)
|
||||
for experiment in experiment_list:
|
||||
self._trial_generator = itertools.chain(
|
||||
self._trial_generator,
|
||||
self._generate_trials(experiment.spec, experiment.name))
|
||||
|
||||
def next_trials(self):
|
||||
"""Provides a batch of Trial objects to be queued into the TrialRunner.
|
||||
|
||||
@@ -104,7 +113,8 @@ class SuggestionAlgorithm(SearchAlgorithm):
|
||||
TrialRunner from querying.
|
||||
|
||||
Example:
|
||||
>>> suggester = SuggestionAlgorithm({ ... }, max_concurrent=1)
|
||||
>>> suggester = SuggestionAlgorithm(max_concurrent=1)
|
||||
>>> suggester.add_configurations({ ... })
|
||||
>>> parameters_1 = suggester._suggest()
|
||||
>>> parameters_2 = suggester._suggest()
|
||||
>>> parameters_2 is None
|
||||
@@ -116,12 +126,12 @@ class SuggestionAlgorithm(SearchAlgorithm):
|
||||
|
||||
|
||||
class _MockSuggestionAlgorithm(SuggestionAlgorithm):
|
||||
def __init__(self, experiments, max_concurrent=2, **kwargs):
|
||||
def __init__(self, max_concurrent=2, **kwargs):
|
||||
self._max_concurrent = max_concurrent
|
||||
self.live_trials = {}
|
||||
self.counter = {"result": 0, "complete": 0}
|
||||
self.stall = False
|
||||
super(_MockSuggestionAlgorithm, self).__init__(experiments, **kwargs)
|
||||
super(_MockSuggestionAlgorithm, self).__init__(**kwargs)
|
||||
|
||||
def _suggest(self, trial_id):
|
||||
if len(self.live_trials) < self._max_concurrent and not self.stall:
|
||||
|
||||
@@ -46,7 +46,11 @@ def generate_variants(unresolved_spec):
|
||||
|
||||
|
||||
def grid_search(values):
|
||||
"""Convenience method for specifying grid search over a value."""
|
||||
"""Convenience method for specifying grid search over a value.
|
||||
|
||||
Arguments:
|
||||
values: An iterable whose parameters will be gridded.
|
||||
"""
|
||||
|
||||
return {"grid_search": values}
|
||||
|
||||
|
||||
@@ -11,7 +11,7 @@ from ray.rllib import _register_all
|
||||
|
||||
from ray.tune import Trainable, TuneError
|
||||
from ray.tune import register_env, register_trainable, run_experiments
|
||||
from ray.tune.trial_scheduler import TrialScheduler, FIFOScheduler
|
||||
from ray.tune.schedulers import TrialScheduler, FIFOScheduler
|
||||
from ray.tune.registry import _global_registry, TRAINABLE_CLASS
|
||||
from ray.tune.result import DEFAULT_RESULTS_DIR, TIMESTEPS_TOTAL, DONE
|
||||
from ray.tune.util import pin_in_object_store, get_pinned_object
|
||||
@@ -449,7 +449,8 @@ class RunExperimentTest(unittest.TestCase):
|
||||
|
||||
register_trainable("f1", train)
|
||||
|
||||
alg = BasicVariantGenerator({
|
||||
alg = BasicVariantGenerator()
|
||||
alg.add_configurations({
|
||||
"foo": {
|
||||
"run": "f1",
|
||||
"config": {
|
||||
@@ -462,6 +463,30 @@ class RunExperimentTest(unittest.TestCase):
|
||||
self.assertEqual(trial.status, Trial.TERMINATED)
|
||||
self.assertEqual(trial.last_result[TIMESTEPS_TOTAL], 99)
|
||||
|
||||
def testAutoregisterTrainable(self):
|
||||
def train(config, reporter):
|
||||
for i in range(100):
|
||||
reporter(timesteps_total=i)
|
||||
|
||||
class B(Trainable):
|
||||
def _train(self):
|
||||
return dict(timesteps_this_iter=1, done=True)
|
||||
|
||||
register_trainable("f1", train)
|
||||
trials = run_experiments({
|
||||
"foo": {
|
||||
"run": train,
|
||||
"config": {
|
||||
"script_min_iter_time_s": 0
|
||||
}
|
||||
},
|
||||
"bar": {
|
||||
"run": B
|
||||
}
|
||||
})
|
||||
for trial in trials:
|
||||
self.assertEqual(trial.status, Trial.TERMINATED)
|
||||
|
||||
|
||||
class VariantGeneratorTest(unittest.TestCase):
|
||||
def setUp(self):
|
||||
@@ -472,7 +497,8 @@ class VariantGeneratorTest(unittest.TestCase):
|
||||
_register_all() # re-register the evicted objects
|
||||
|
||||
def generate_trials(self, spec, name):
|
||||
suggester = BasicVariantGenerator({name: spec})
|
||||
suggester = BasicVariantGenerator()
|
||||
suggester.add_configurations({name: spec})
|
||||
return suggester.next_trials()
|
||||
|
||||
def testParseToTrials(self):
|
||||
@@ -611,7 +637,8 @@ class VariantGeneratorTest(unittest.TestCase):
|
||||
}
|
||||
experiments = [Experiment.from_json("test", experiment_spec)]
|
||||
|
||||
searcher = _MockSuggestionAlgorithm(experiments, max_concurrent=4)
|
||||
searcher = _MockSuggestionAlgorithm(max_concurrent=4)
|
||||
searcher.add_configurations(experiments)
|
||||
trials = searcher.next_trials()
|
||||
self.assertEqual(len(trials), 4)
|
||||
self.assertEqual(searcher.next_trials(), [])
|
||||
@@ -667,7 +694,8 @@ class TrialRunnerTest(unittest.TestCase):
|
||||
}
|
||||
|
||||
for name, spec in experiments.items():
|
||||
trial_generator = BasicVariantGenerator({name: spec})
|
||||
trial_generator = BasicVariantGenerator()
|
||||
trial_generator.add_configurations({name: spec})
|
||||
for trial in trial_generator.next_trials():
|
||||
trial.start()
|
||||
self.assertLessEqual(len(trial.logdir), 200)
|
||||
@@ -989,7 +1017,8 @@ class TrialRunnerTest(unittest.TestCase):
|
||||
ray.init(num_cpus=4, num_gpus=2)
|
||||
experiment_spec = {"run": "__fake", "stop": {"training_iteration": 2}}
|
||||
experiments = [Experiment.from_json("test", experiment_spec)]
|
||||
searcher = _MockSuggestionAlgorithm(experiments, max_concurrent=10)
|
||||
searcher = _MockSuggestionAlgorithm(max_concurrent=10)
|
||||
searcher.add_configurations(experiments)
|
||||
runner = TrialRunner(search_alg=searcher)
|
||||
runner.step()
|
||||
trials = runner.get_trials()
|
||||
@@ -1009,7 +1038,8 @@ class TrialRunnerTest(unittest.TestCase):
|
||||
ray.init(num_cpus=4, num_gpus=2)
|
||||
experiment_spec = {"run": "__fake", "stop": {"training_iteration": 1}}
|
||||
experiments = [Experiment.from_json("test", experiment_spec)]
|
||||
searcher = _MockSuggestionAlgorithm(experiments, max_concurrent=10)
|
||||
searcher = _MockSuggestionAlgorithm(max_concurrent=10)
|
||||
searcher.add_configurations(experiments)
|
||||
runner = TrialRunner(search_alg=searcher)
|
||||
runner.step()
|
||||
trials = runner.get_trials()
|
||||
@@ -1033,7 +1063,8 @@ class TrialRunnerTest(unittest.TestCase):
|
||||
ray.init(num_cpus=4, num_gpus=2)
|
||||
experiment_spec = {"run": "__fake", "stop": {"training_iteration": 2}}
|
||||
experiments = [Experiment.from_json("test", experiment_spec)]
|
||||
searcher = _MockSuggestionAlgorithm(experiments, max_concurrent=10)
|
||||
searcher = _MockSuggestionAlgorithm(max_concurrent=10)
|
||||
searcher.add_configurations(experiments)
|
||||
runner = TrialRunner(search_alg=searcher, scheduler=_MockScheduler())
|
||||
runner.step()
|
||||
trials = runner.get_trials()
|
||||
@@ -1058,7 +1089,8 @@ class TrialRunnerTest(unittest.TestCase):
|
||||
}
|
||||
}
|
||||
experiments = [Experiment.from_json("test", experiment_spec)]
|
||||
searcher = _MockSuggestionAlgorithm(experiments, max_concurrent=1)
|
||||
searcher = _MockSuggestionAlgorithm(max_concurrent=1)
|
||||
searcher.add_configurations(experiments)
|
||||
runner = TrialRunner(search_alg=searcher)
|
||||
runner.step()
|
||||
trials = runner.get_trials()
|
||||
|
||||
@@ -7,12 +7,11 @@ import unittest
|
||||
import numpy as np
|
||||
|
||||
import ray
|
||||
from ray.tune.hyperband import HyperBandScheduler
|
||||
from ray.tune.async_hyperband import AsyncHyperBandScheduler
|
||||
from ray.tune.pbt import PopulationBasedTraining, explore
|
||||
from ray.tune.median_stopping_rule import MedianStoppingRule
|
||||
from ray.tune.schedulers import (HyperBandScheduler, AsyncHyperBandScheduler,
|
||||
PopulationBasedTraining, MedianStoppingRule,
|
||||
TrialScheduler)
|
||||
from ray.tune.schedulers.pbt import explore
|
||||
from ray.tune.trial import Trial, Resources
|
||||
from ray.tune.trial_scheduler import TrialScheduler
|
||||
|
||||
from ray.rllib import _register_all
|
||||
_register_all()
|
||||
|
||||
@@ -36,26 +36,23 @@ class Trainable(object):
|
||||
|
||||
Note that, if you don't require checkpoint/restore functionality, then
|
||||
instead of implementing this class you can also get away with supplying
|
||||
just a `my_train(config, reporter)` function and calling:
|
||||
|
||||
``register_trainable("my_func", train)``
|
||||
|
||||
to register it for use with Tune. The function will be automatically
|
||||
converted to this interface (sans checkpoint functionality).
|
||||
|
||||
Attributes:
|
||||
config (obj): The hyperparam configuration for this trial.
|
||||
logdir (str): Directory in which training outputs should be placed.
|
||||
just a ``my_train(config, reporter)`` function to the config.
|
||||
The function will be automatically converted to this interface
|
||||
(sans checkpoint functionality).
|
||||
"""
|
||||
|
||||
def __init__(self, config=None, logger_creator=None):
|
||||
"""Initialize an Trainable.
|
||||
|
||||
Sets up logging and points ``self.logdir`` to a directory in which
|
||||
training outputs should be placed.
|
||||
|
||||
Subclasses should prefer defining ``_setup()`` instead of overriding
|
||||
``__init__()`` directly.
|
||||
|
||||
Args:
|
||||
config (dict): Trainable-specific configuration data.
|
||||
config (dict): Trainable-specific configuration data. By default
|
||||
will be saved as ``self.config``.
|
||||
logger_creator (func): Function that creates a ray.tune.Logger
|
||||
object. If unspecified, a default logger is created.
|
||||
"""
|
||||
@@ -102,28 +99,36 @@ class Trainable(object):
|
||||
"""Runs one logical iteration of training.
|
||||
|
||||
Subclasses should override ``_train()`` instead to return results.
|
||||
|
||||
This class automatically fills the following fields in the result:
|
||||
done (bool): training is terminated. Filled only if not provided.
|
||||
time_this_iter_s (float): Time in seconds
|
||||
this iteration took to run. This may be overriden in order to
|
||||
override the system-computed time difference.
|
||||
time_total_s (float): Accumulated time in seconds
|
||||
for this entire experiment.
|
||||
experiment_id (str): Unique string identifier
|
||||
for this experiment. This id is preserved
|
||||
across checkpoint / restore calls.
|
||||
training_iteration (int): The index of this
|
||||
training iteration, e.g. call to train().
|
||||
pid (str): The pid of the training process.
|
||||
date (str): A formatted date of
|
||||
when the result was processed.
|
||||
timestamp (str): A UNIX timestamp of
|
||||
when the result was processed.
|
||||
hostname (str): The hostname of the machine
|
||||
hosting the training process.
|
||||
node_ip (str): The node ip of the machine
|
||||
hosting the training process.
|
||||
|
||||
`done` (bool): training is terminated. Filled only if not provided.
|
||||
|
||||
`time_this_iter_s` (float): Time in seconds this iteration
|
||||
took to run. This may be overriden in order to override the
|
||||
system-computed time difference.
|
||||
|
||||
`time_total_s` (float): Accumulated time in seconds for this
|
||||
entire experiment.
|
||||
|
||||
`experiment_id` (str): Unique string identifier
|
||||
for this experiment. This id is preserved
|
||||
across checkpoint / restore calls.
|
||||
|
||||
`training_iteration` (int): The index of this
|
||||
training iteration, e.g. call to train().
|
||||
|
||||
`pid` (str): The pid of the training process.
|
||||
|
||||
`date` (str): A formatted date of when the result was processed.
|
||||
|
||||
`timestamp` (str): A UNIX timestamp of when the result
|
||||
was processed.
|
||||
|
||||
`hostname` (str): Hostname of the machine hosting the training
|
||||
process.
|
||||
|
||||
`node_ip` (str): Node ip of the machine hosting the training
|
||||
process.
|
||||
|
||||
Returns:
|
||||
A dict that describes training progress.
|
||||
@@ -283,7 +288,11 @@ class Trainable(object):
|
||||
raise NotImplementedError
|
||||
|
||||
def _setup(self):
|
||||
"""Subclasses should override this for custom initialization."""
|
||||
"""Subclasses should override this for custom initialization.
|
||||
|
||||
Subclasses can access the hyperparameter configuration via
|
||||
``self.config``.
|
||||
"""
|
||||
pass
|
||||
|
||||
def _stop(self):
|
||||
|
||||
@@ -95,7 +95,6 @@ class Trial(object):
|
||||
The args here take the same meaning as the command line flags defined
|
||||
in ray.tune.config_parser.
|
||||
"""
|
||||
|
||||
if not has_trainable(trainable_name):
|
||||
# Make sure rllib agents are registered
|
||||
from ray import rllib # noqa: F401
|
||||
@@ -267,8 +266,8 @@ class Trial(object):
|
||||
self._status_string(),
|
||||
location_string(
|
||||
self.last_result.get(HOSTNAME),
|
||||
self.last_result.get(PID))),
|
||||
'{} s'.format(int(self.last_result.get(TIME_TOTAL_S))),
|
||||
self.last_result.get(PID))), '{} s'.format(
|
||||
int(self.last_result.get(TIME_TOTAL_S)))
|
||||
]
|
||||
|
||||
if self.last_result.get("episode_reward_mean") is not None:
|
||||
|
||||
@@ -12,7 +12,7 @@ from ray.tune import TuneError
|
||||
from ray.tune.result import TIME_THIS_ITER_S
|
||||
from ray.tune.web_server import TuneServer
|
||||
from ray.tune.trial import Trial, Resources
|
||||
from ray.tune.trial_scheduler import FIFOScheduler, TrialScheduler
|
||||
from ray.tune.schedulers import FIFOScheduler, TrialScheduler
|
||||
|
||||
MAX_DEBUG_TRIALS = 20
|
||||
|
||||
|
||||
+24
-9
@@ -6,13 +6,11 @@ import time
|
||||
|
||||
from ray.tune.error import TuneError
|
||||
from ray.tune.suggest import BasicVariantGenerator
|
||||
from ray.tune.hyperband import HyperBandScheduler
|
||||
from ray.tune.async_hyperband import AsyncHyperBandScheduler
|
||||
from ray.tune.median_stopping_rule import MedianStoppingRule
|
||||
from ray.tune.trial import Trial, DEBUG_PRINT_INTERVAL
|
||||
from ray.tune.log_sync import wait_for_log_sync
|
||||
from ray.tune.trial_runner import TrialRunner
|
||||
from ray.tune.trial_scheduler import FIFOScheduler
|
||||
from ray.tune.schedulers import (HyperBandScheduler, AsyncHyperBandScheduler,
|
||||
FIFOScheduler, MedianStoppingRule)
|
||||
from ray.tune.web_server import TuneServer
|
||||
|
||||
_SCHEDULERS = {
|
||||
@@ -38,10 +36,11 @@ def run_experiments(experiments=None,
|
||||
server_port=TuneServer.DEFAULT_PORT,
|
||||
verbose=True,
|
||||
queue_trials=False):
|
||||
"""Tunes experiments.
|
||||
"""Runs and blocks until all trials finish.
|
||||
|
||||
Args:
|
||||
experiments (Experiment | list | dict): Experiments to run.
|
||||
experiments (Experiment | list | dict): Experiments to run. Will be
|
||||
passed to `search_alg` via `add_configurations`.
|
||||
search_alg (SearchAlgorithm): Search Algorithm. Defaults to
|
||||
BasicVariantGenerator.
|
||||
scheduler (TrialScheduler): Scheduler for executing
|
||||
@@ -56,6 +55,22 @@ def run_experiments(experiments=None,
|
||||
be set to True when running on an autoscaling cluster to enable
|
||||
automatic scale-up.
|
||||
|
||||
Examples:
|
||||
>>> experiment_spec = Experiment("experiment", my_func)
|
||||
>>> run_experiments(experiments=experiment_spec)
|
||||
|
||||
>>> experiment_spec = {"experiment": {"run": my_func}}
|
||||
>>> run_experiments(experiments=experiment_spec)
|
||||
|
||||
>>> run_experiments(
|
||||
>>> experiments=experiment_spec,
|
||||
>>> scheduler=MedianStoppingRule(...))
|
||||
|
||||
>>> run_experiments(
|
||||
>>> experiments=experiment_spec,
|
||||
>>> search_alg=SearchAlgorithm(),
|
||||
>>> scheduler=MedianStoppingRule(...))
|
||||
|
||||
Returns:
|
||||
List of Trial objects, holding data for each executed trial.
|
||||
"""
|
||||
@@ -63,9 +78,9 @@ def run_experiments(experiments=None,
|
||||
scheduler = FIFOScheduler()
|
||||
|
||||
if search_alg is None:
|
||||
assert experiments is not None, "Experiments need to be specified" \
|
||||
"if search_alg is not provided."
|
||||
search_alg = BasicVariantGenerator(experiments)
|
||||
search_alg = BasicVariantGenerator()
|
||||
|
||||
search_alg.add_configurations(experiments)
|
||||
|
||||
runner = TrialRunner(
|
||||
search_alg,
|
||||
|
||||
@@ -124,7 +124,8 @@ def RunnerHandler(runner):
|
||||
elif command == TuneClient.ADD:
|
||||
name = args["name"]
|
||||
spec = args["spec"]
|
||||
trial_generator = BasicVariantGenerator({name: spec})
|
||||
trial_generator = BasicVariantGenerator()
|
||||
trial_generator.add_configurations({name: spec})
|
||||
for trial in trial_generator.next_trials():
|
||||
runner.add_trial(trial)
|
||||
else:
|
||||
|
||||
Reference in New Issue
Block a user