[tune] Update API Reference Page (#7671)

* widerdocs

* init

* docs

* fix

* moveit

* mix

* better_docs

* remove

* Apply suggestions from code review

Co-Authored-By: Sven Mika <sven@anyscale.io>

Co-authored-by: Sven Mika <sven@anyscale.io>
This commit is contained in:
Richard Liaw
2020-03-22 16:42:20 -07:00
committed by GitHub
parent 288933ec6b
commit 81d311031b
27 changed files with 744 additions and 394 deletions
+15 -11
View File
@@ -17,7 +17,10 @@ logger = logging.getLogger(__name__)
class Analysis:
"""Analyze all results from a directory of experiments."""
"""Analyze all results from a directory of experiments.
To use this class, the experiment must be executed with the JsonLogger.
"""
def __init__(self, experiment_dir):
experiment_dir = os.path.expanduser(experiment_dir)
@@ -42,6 +45,9 @@ class Analysis:
metric (str): Key for trial info to order on.
If None, uses last result.
mode (str): One of [min, max].
Returns:
pd.DataFrame: Constructed from a result dict of each trial.
"""
rows = self._retrieve_rows(metric=metric, mode=mode)
all_configs = self.get_all_configs(prefix=True)
@@ -97,6 +103,9 @@ class Analysis:
Args:
prefix (bool): If True, flattens the config dict
and prepends `config/`.
Returns:
List[dict]: List of all configurations of trials,
"""
fail_count = 0
for path in self._get_trial_paths():
@@ -124,8 +133,7 @@ class Analysis:
"training_iteration" is used by default.
Returns:
A list of [path, metric] lists for all persistent checkpoints of
the trial.
List of [path, metric] for all persistent checkpoints of the trial.
"""
if isinstance(trial, str):
trial_dir = os.path.expanduser(trial)
@@ -177,10 +185,14 @@ class Analysis:
class ExperimentAnalysis(Analysis):
"""Analyze results from a Tune experiment.
To use this class, the experiment must be executed with the JsonLogger.
Parameters:
experiment_checkpoint_path (str): Path to a json file
representing an experiment state. Corresponds to
Experiment.local_dir/Experiment.name/experiment_state.json
trials (list|None): List of trials that can be accessed via
`analysis.trials`.
Example:
>>> tune.run(my_trainable, name="my_exp", local_dir="~/tune_results")
@@ -189,14 +201,6 @@ class ExperimentAnalysis(Analysis):
"""
def __init__(self, experiment_checkpoint_path, trials=None):
"""Initializer.
Args:
experiment_checkpoint_path (str): Path to where experiment is
located.
trials (list|None): List of trials that can be accessed via
`analysis.trials`.
"""
with open(experiment_checkpoint_path) as f:
_experiment_state = json.load(f)
self._experiment_state = _experiment_state
+21 -35
View File
@@ -48,25 +48,27 @@ def _raise_on_durable(trainable_name, sync_to_driver, upload_dir):
class Experiment:
"""Tracks experiment specifications.
Implicitly registers the Trainable if needed.
Implicitly registers the Trainable if needed. The args here take
the same meaning as the arguments defined `tune.py:run`.
Examples:
>>> experiment_spec = Experiment(
>>> "my_experiment_name",
>>> my_func,
>>> stop={"mean_accuracy": 100},
>>> config={
>>> "alpha": tune.grid_search([0.2, 0.4, 0.6]),
>>> "beta": tune.grid_search([1, 2]),
>>> },
>>> resources_per_trial={
>>> "cpu": 1,
>>> "gpu": 0
>>> },
>>> num_samples=10,
>>> local_dir="~/ray_results",
>>> checkpoint_freq=10,
>>> max_failures=2)
.. code-block:: python
experiment_spec = Experiment(
"my_experiment_name",
my_func,
stop={"mean_accuracy": 100},
config={
"alpha": tune.grid_search([0.2, 0.4, 0.6]),
"beta": tune.grid_search([1, 2]),
},
resources_per_trial={
"cpu": 1,
"gpu": 0
},
num_samples=10,
local_dir="~/ray_results",
checkpoint_freq=10,
max_failures=2)
"""
def __init__(self,
@@ -88,23 +90,7 @@ class Experiment:
checkpoint_score_attr=None,
export_formats=None,
max_failures=0,
restore=None,
repeat=None,
trial_resources=None,
sync_function=None):
"""Initialize a new Experiment.
The args here take the same meaning as the command line flags defined
in `tune.py:run`.
"""
if repeat:
_raise_deprecation_note("repeat", "num_samples", soft=False)
if trial_resources:
_raise_deprecation_note(
"trial_resources", "resources_per_trial", soft=False)
if sync_function:
_raise_deprecation_note(
"sync_function", "sync_to_driver", soft=False)
restore=None):
config = config or {}
self._run_identifier = Experiment.register_if_needed(run)
+13
View File
@@ -31,6 +31,7 @@ class Logger:
Arguments:
config: Configuration passed to all logger creators.
logdir: Directory for all logger creators to log to.
trial (Trial): Trial object for the logger to access.
"""
def __init__(self, config, logdir, trial=None):
@@ -97,6 +98,13 @@ class MLFLowLogger(Logger):
class JsonLogger(Logger):
"""Logs trial results in json format.
Also writes to a results file and param.json file when results or
configurations are updated. Experiments must be executed with the
JsonLogger to be compatible with the ExperimentAnalysis tool.
"""
def _init(self):
self.update_config(self.config)
local_file = os.path.join(self.logdir, EXPR_RESULT_FILE)
@@ -278,6 +286,11 @@ class UnifiedLogger(Logger):
self._logger_cls_list = DEFAULT_LOGGERS
else:
self._logger_cls_list = loggers
if JsonLogger not in self._logger_cls_list:
if log_once("JsonLogger"):
logger.warning(
"JsonLogger not provided. The ExperimentAnalysis tool is "
"disabled.")
self._sync_function = sync_function
self._log_syncer = None
+49 -51
View File
@@ -44,7 +44,22 @@ class ProgressReporter:
class TuneReporterBase(ProgressReporter):
"""Abstract base class for the default Tune reporters."""
"""Abstract base class for the default Tune reporters.
Args:
metric_columns (dict[str, str]|list[str]): Names of metrics to
include in progress table. If this is a dict, the keys should
be metric names and the values should be the displayed names.
If this is a list, the metric name is used directly.
max_progress_rows (int): Maximum number of rows to print
in the progress table. The progress table describes the
progress of each trial. Defaults to 20.
max_error_rows (int): Maximum number of rows to print in the
error table. The error table lists the error file, if any,
corresponding to each trial. Defaults to 20.
max_report_frequency (int): Maximum report frequency in seconds.
Defaults to 5s.
"""
# Truncated representations of column names (to accommodate small screens).
DEFAULT_COLUMNS = collections.OrderedDict({
@@ -61,22 +76,6 @@ class TuneReporterBase(ProgressReporter):
max_progress_rows=20,
max_error_rows=20,
max_report_frequency=5):
"""Initializes a new TuneReporterBase.
Args:
metric_columns (dict[str, str]|list[str]): Names of metrics to
include in progress table. If this is a dict, the keys should
be metric names and the values should be the displayed names.
If this is a list, the metric name is used directly.
max_progress_rows (int): Maximum number of rows to print
in the progress table. The progress table describes the
progress of each trial. Defaults to 20.
max_error_rows (int): Maximum number of rows to print in the
error table. The error table lists the error file, if any,
corresponding to each trial. Defaults to 20.
max_report_frequency (int): Maximum report frequency in seconds.
Defaults to 5s.
"""
self._metric_columns = metric_columns or self.DEFAULT_COLUMNS
self._max_progress_rows = max_progress_rows
self._max_error_rows = max_error_rows
@@ -145,7 +144,23 @@ class TuneReporterBase(ProgressReporter):
class JupyterNotebookReporter(TuneReporterBase):
"""Jupyter notebook-friendly Reporter that can update display in-place."""
"""Jupyter notebook-friendly Reporter that can update display in-place.
Args:
overwrite (bool): Flag for overwriting the last reported progress.
metric_columns (dict[str, str]|list[str]): Names of metrics to
include in progress table. If this is a dict, the keys should
be metric names and the values should be the displayed names.
If this is a list, the metric name is used directly.
max_progress_rows (int): Maximum number of rows to print
in the progress table. The progress table describes the
progress of each trial. Defaults to 20.
max_error_rows (int): Maximum number of rows to print in the
error table. The error table lists the error file, if any,
corresponding to each trial. Defaults to 20.
max_report_frequency (int): Maximum report frequency in seconds.
Defaults to 5s.
"""
def __init__(self,
overwrite,
@@ -153,23 +168,6 @@ class JupyterNotebookReporter(TuneReporterBase):
max_progress_rows=20,
max_error_rows=20,
max_report_frequency=5):
"""Initializes a new JupyterNotebookReporter.
Args:
overwrite (bool): Flag for overwriting the last reported progress.
metric_columns (dict[str, str]|list[str]): Names of metrics to
include in progress table. If this is a dict, the keys should
be metric names and the values should be the displayed names.
If this is a list, the metric name is used directly.
max_progress_rows (int): Maximum number of rows to print
in the progress table. The progress table describes the
progress of each trial. Defaults to 20.
max_error_rows (int): Maximum number of rows to print in the
error table. The error table lists the error file, if any,
corresponding to each trial. Defaults to 20.
max_report_frequency (int): Maximum report frequency in seconds.
Defaults to 5s.
"""
super(JupyterNotebookReporter,
self).__init__(metric_columns, max_progress_rows, max_error_rows,
max_report_frequency)
@@ -186,29 +184,29 @@ class JupyterNotebookReporter(TuneReporterBase):
class CLIReporter(TuneReporterBase):
"""Command-line reporter"""
"""Command-line reporter
Args:
metric_columns (dict[str, str]|list[str]): Names of metrics to
include in progress table. If this is a dict, the keys should
be metric names and the values should be the displayed names.
If this is a list, the metric name is used directly.
max_progress_rows (int): Maximum number of rows to print
in the progress table. The progress table describes the
progress of each trial. Defaults to 20.
max_error_rows (int): Maximum number of rows to print in the
error table. The error table lists the error file, if any,
corresponding to each trial. Defaults to 20.
max_report_frequency (int): Maximum report frequency in seconds.
Defaults to 5s.
"""
def __init__(self,
metric_columns=None,
max_progress_rows=20,
max_error_rows=20,
max_report_frequency=5):
"""Initializes a CLIReporter.
Args:
metric_columns (dict[str, str]|list[str]): Names of metrics to
include in progress table. If this is a dict, the keys should
be metric names and the values should be the displayed names.
If this is a list, the metric name is used directly.
max_progress_rows (int): Maximum number of rows to print
in the progress table. The progress table describes the
progress of each trial. Defaults to 20.
max_error_rows (int): Maximum number of rows to print in the
error table. The error table lists the error file, if any,
corresponding to each trial. Defaults to 20.
max_report_frequency (int): Maximum report frequency in seconds.
Defaults to 5s.
"""
super(CLIReporter, self).__init__(metric_columns, max_progress_rows,
max_error_rows, max_report_frequency)
+6
View File
@@ -41,6 +41,9 @@ def validate_trainable(trainable_name):
def register_trainable(name, trainable):
"""Register a trainable function or class.
This enables a class or function to be accessed on every Ray process
in the cluster.
Args:
name (str): Name to register.
trainable (obj): Function or tune.Trainable class. Functions must
@@ -70,6 +73,9 @@ def register_trainable(name, trainable):
def register_env(name, env_creator):
"""Register a custom environment for use with RLlib.
This enables the environment to be accessed on every Ray process
in the cluster.
Args:
name (str): Name to register.
env_creator (obj): Function that creates an env.
+1 -1
View File
@@ -19,7 +19,7 @@ class Resources(
])):
"""Ray resources required to schedule a trial.
Attributes:
Parameters:
cpu (float): Number of CPUs to allocate to the trial.
gpu (float): Number of GPUs to allocate to the trial.
memory (float): Memory to reserve for the trial.
+21 -16
View File
@@ -149,22 +149,27 @@ class PopulationBasedTraining(FIFOScheduler):
local_dir at each exploit. Allows config schedule to be
reconstructed.
Example:
>>> pbt = PopulationBasedTraining(
>>> time_attr="training_iteration",
>>> metric="episode_reward_mean",
>>> mode="max",
>>> perturbation_interval=10, # every 10 `time_attr` units
>>> # (training_iterations in this case)
>>> hyperparam_mutations={
>>> # Perturb factor1 by scaling it by 0.8 or 1.2. Resampling
>>> # resets it to a value sampled from the lambda function.
>>> "factor_1": lambda: random.uniform(0.0, 20.0),
>>> # Perturb factor2 by changing it to an adjacent value, e.g.
>>> # 10 -> 1 or 10 -> 100. Resampling will choose at random.
>>> "factor_2": [1, 10, 100, 1000, 10000],
>>> })
>>> tune.run({...}, num_samples=8, scheduler=pbt)
.. code-block:: python
import random
from ray import tune
from ray.tune.schedulers import PopulationBasedTraining
pbt = PopulationBasedTraining(
time_attr="training_iteration",
metric="episode_reward_mean",
mode="max",
perturbation_interval=10, # every 10 `time_attr` units
# (training_iterations in this case)
hyperparam_mutations={
# Perturb factor1 by scaling it by 0.8 or 1.2. Resampling
# resets it to a value sampled from the lambda function.
"factor_1": lambda: random.uniform(0.0, 20.0),
# Perturb factor2 by changing it to an adjacent value, e.g.
# 10 -> 1 or 10 -> 100. Resampling will choose at random.
"factor_2": [1, 10, 100, 1000, 10000],
})
tune.run({...}, num_samples=8, scheduler=pbt)
"""
def __init__(self,
+14 -7
View File
@@ -40,13 +40,20 @@ class AxSearch(SuggestionAlgorithm):
trial results in the optimization process.
Example:
>>> parameters = [
>>> {"name": "x1", "type": "range", "bounds": [0.0, 1.0]},
>>> {"name": "x2", "type": "range", "bounds": [0.0, 1.0]},
>>> ]
>>> algo = AxSearch(parameters=parameters,
>>> objective_name="hartmann6", max_concurrent=4)
.. code-block:: python
from ray import tune
from ray.tune.suggest.ax import AxSearch
parameters = [
{"name": "x1", "type": "range", "bounds": [0.0, 1.0]},
{"name": "x2", "type": "range", "bounds": [0.0, 1.0]},
]
algo = AxSearch(parameters=parameters,
objective_name="hartmann6", max_concurrent=4)
tune.run(my_func, algo=algo)
"""
def __init__(self, ax_client, max_concurrent=10, mode="max", **kwargs):
+24 -7
View File
@@ -14,18 +14,35 @@ class BasicVariantGenerator(SearchAlgorithm):
See also: `ray.tune.suggest.variant_generator`.
Example:
>>> searcher = BasicVariantGenerator()
>>> searcher.add_configurations({"experiment": { ... }})
>>> list_of_trials = searcher.next_trials()
>>> searcher.is_finished == True
Parameters:
shuffle (bool): Shuffles the generated list of configurations.
User API:
.. code-block:: python
from ray import tune
from ray.tune.suggest import BasicVariantGenerator
searcher = BasicVariantGenerator()
tune.run(my_trainable_func, algo=searcher)
Internal API:
.. code-block:: python
from ray.tune.suggest import BasicVariantGenerator
searcher = BasicVariantGenerator()
searcher.add_configurations({"experiment": { ... }})
list_of_trials = searcher.next_trials()
searcher.is_finished == True
"""
def __init__(self, shuffle=False):
"""Initializes the Variant Generator.
Arguments:
shuffle (bool): Shuffles the generated list of configurations.
"""
self._parser = make_parser()
self._trial_generator = []
+16 -8
View File
@@ -15,7 +15,7 @@ class BayesOptSearch(SuggestionAlgorithm):
"""A wrapper around BayesOpt to provide trial suggestions.
Requires BayesOpt to be installed. You can install BayesOpt with the
command: `pip install bayesian-optimization`.
command: ``pip install bayesian-optimization``.
Parameters:
space (dict): Continuous search space. Parameters will be sampled from
@@ -32,14 +32,22 @@ class BayesOptSearch(SuggestionAlgorithm):
use_early_stopped_trials (bool): Whether to use early terminated
trial results in the optimization process.
Example:
>>> space = {
>>> 'width': (0, 20),
>>> 'height': (-100, 100),
>>> }
>>> algo = BayesOptSearch(
>>> space, max_concurrent=4, metric="mean_loss", mode="min")
.. code-block:: python
from ray import tune
from ray.tune.suggest.bayesopt import BayesOptSearch
space = {
'width': (0, 20),
'height': (-100, 100),
}
algo = BayesOptSearch(
space, max_concurrent=4, metric="mean_loss", mode="min")
tune.run(my_func, algo=algo)
"""
# bayes_opt.BayesianOptimization: Optimization object
optimizer = None
def __init__(self,
space,
+23 -18
View File
@@ -22,7 +22,7 @@ class TuneBOHB(SuggestionAlgorithm):
Requires HpBandSter and ConfigSpace to be installed. You can install
HpBandSter and ConfigSpace with: `pip install hpbandster ConfigSpace`.
HpBandSter and ConfigSpace with: ``pip install hpbandster ConfigSpace``.
This should be used in conjunction with HyperBandForBOHB.
@@ -38,23 +38,28 @@ class TuneBOHB(SuggestionAlgorithm):
minimizing or maximizing the metric attribute.
Example:
>>> import ConfigSpace as CS
>>> config_space = CS.ConfigurationSpace()
>>> config_space.add_hyperparameter(
CS.UniformFloatHyperparameter('width', lower=0, upper=20))
>>> config_space.add_hyperparameter(
CS.UniformFloatHyperparameter('height', lower=-100, upper=100))
>>> config_space.add_hyperparameter(
CS.CategoricalHyperparameter(
name='activation', choices=['relu', 'tanh']))
>>> algo = TuneBOHB(
config_space, max_concurrent=4, metric='mean_loss', mode='min')
>>> bohb = HyperBandForBOHB(
time_attr='training_iteration',
metric='mean_loss',
mode='min',
max_t=100)
>>> run(MyTrainableClass, scheduler=bohb, search_alg=algo)
.. code-block:: python
import ConfigSpace as CS
config_space = CS.ConfigurationSpace()
config_space.add_hyperparameter(
CS.UniformFloatHyperparameter('width', lower=0, upper=20))
config_space.add_hyperparameter(
CS.UniformFloatHyperparameter('height', lower=-100, upper=100))
config_space.add_hyperparameter(
CS.CategoricalHyperparameter(
name='activation', choices=['relu', 'tanh']))
algo = TuneBOHB(
config_space, max_concurrent=4, metric='mean_loss', mode='min')
bohb = HyperBandForBOHB(
time_attr='training_iteration',
metric='mean_loss',
mode='min',
max_t=100)
run(MyTrainableClass, scheduler=bohb, search_alg=algo)
"""
+33 -27
View File
@@ -18,7 +18,7 @@ logger = logging.getLogger(__name__)
class DragonflySearch(SuggestionAlgorithm):
"""A wrapper around Dragonfly to provide trial suggestions.
Requires Dragonfly to be installed.
Requires Dragonfly to be installed via ``pip install dragonfly-opt``.
Parameters:
optimizer (dragonfly.opt.BlackboxOptimiser): Optimizer provided
@@ -40,33 +40,39 @@ class DragonflySearch(SuggestionAlgorithm):
needing to re-compute the trial. Must be the same length as
points_to_evaluate.
Example:
>>> from dragonfly.opt.gp_bandit import EuclideanGPBandit
>>> from dragonfly.exd.experiment_caller import EuclideanFunctionCaller
>>> from dragonfly import load_config
>>> domain_vars = [{
"name": "LiNO3_vol",
"type": "float",
"min": 0,
"max": 7
}, {
"name": "Li2SO4_vol",
"type": "float",
"min": 0,
"max": 7
}, {
"name": "NaClO4_vol",
"type": "float",
"min": 0,
"max": 7
}]
.. code-block:: python
>>> domain_config = load_config({"domain": domain_vars})
>>> func_caller = EuclideanFunctionCaller(None,
domain_config.domain.list_of_domains[0])
>>> optimizer = EuclideanGPBandit(func_caller, ask_tell_mode=True)
>>> algo = DragonflySearch(optimizer, max_concurrent=4,
metric="objective", mode="max")
from ray import tune
from dragonfly.opt.gp_bandit import EuclideanGPBandit
from dragonfly.exd.experiment_caller import EuclideanFunctionCaller
from dragonfly import load_config
domain_vars = [{
"name": "LiNO3_vol",
"type": "float",
"min": 0,
"max": 7
}, {
"name": "Li2SO4_vol",
"type": "float",
"min": 0,
"max": 7
}, {
"name": "NaClO4_vol",
"type": "float",
"min": 0,
"max": 7
}]
domain_config = load_config({"domain": domain_vars})
func_caller = EuclideanFunctionCaller(None,
domain_config.domain.list_of_domains[0])
optimizer = EuclideanGPBandit(func_caller, ask_tell_mode=True)
algo = DragonflySearch(optimizer, max_concurrent=4,
metric="objective", mode="max")
tune.run(my_func, algo=algo)
"""
def __init__(self,
+16 -14
View File
@@ -52,20 +52,22 @@ class HyperOptSearch(SuggestionAlgorithm):
use_early_stopped_trials (bool): Whether to use early terminated
trial results in the optimization process.
Example:
>>> space = {
>>> 'width': hp.uniform('width', 0, 20),
>>> 'height': hp.uniform('height', -100, 100),
>>> 'activation': hp.choice("activation", ["relu", "tanh"])
>>> }
>>> current_best_params = [{
>>> 'width': 10,
>>> 'height': 0,
>>> 'activation': 0, # The index of "relu"
>>> }]
>>> algo = HyperOptSearch(
>>> space, max_concurrent=4, metric="mean_loss", mode="min",
>>> points_to_evaluate=current_best_params)
.. code-block:: python
space = {
'width': hp.uniform('width', 0, 20),
'height': hp.uniform('height', -100, 100),
'activation': hp.choice("activation", ["relu", "tanh"])
}
current_best_params = [{
'width': 10,
'height': 0,
'activation': 0, # The index of "relu"
}]
algo = HyperOptSearch(
space, max_concurrent=4, metric="mean_loss", mode="min",
points_to_evaluate=current_best_params)
"""
def __init__(self,
+24 -21
View File
@@ -30,27 +30,30 @@ class SigOptSearch(SuggestionAlgorithm):
minimizing or maximizing the metric attribute.
Example:
>>> space = [
>>> {
>>> 'name': 'width',
>>> 'type': 'int',
>>> 'bounds': {
>>> 'min': 0,
>>> 'max': 20
>>> },
>>> },
>>> {
>>> 'name': 'height',
>>> 'type': 'int',
>>> 'bounds': {
>>> 'min': -100,
>>> 'max': 100
>>> },
>>> },
>>> ]
>>> algo = SigOptSearch(
>>> space, name="SigOpt Example Experiment",
>>> max_concurrent=1, metric="mean_loss", mode="min")
.. code-block:: python
space = [
{
'name': 'width',
'type': 'int',
'bounds': {
'min': 0,
'max': 20
},
},
{
'name': 'height',
'type': 'int',
'bounds': {
'min': -100,
'max': 100
},
},
]
algo = SigOptSearch(
space, name="SigOpt Example Experiment",
max_concurrent=1, metric="mean_loss", mode="min")
"""
def __init__(self,
+7 -6
View File
@@ -20,12 +20,13 @@ class SuggestionAlgorithm(SearchAlgorithm):
`suggest` will be passed a trial_id, which will be used in
subsequent notifications.
Example:
>>> suggester = SuggestionAlgorithm()
>>> suggester.add_configurations({ ... })
>>> new_parameters = suggester.suggest()
>>> suggester.on_trial_complete(trial_id, result)
>>> better_parameters = suggester.suggest()
.. code-block:: python
suggester = SuggestionAlgorithm()
suggester.add_configurations({ ... })
new_parameters = suggester.suggest()
suggester.on_trial_complete(trial_id, result)
better_parameters = suggester.suggest()
"""
def __init__(self, metric=None, mode="max", use_early_stopped_trials=True):
+22 -5
View File
@@ -1,6 +1,6 @@
import logging
from ray.tune.track.session import TrackSession
from ray.tune.track.session import TrackSession as _TrackSession
logger = logging.getLogger(__name__)
@@ -37,7 +37,7 @@ def init(ignore_reinit_error=True, **session_kwargs):
else:
raise ValueError(reinit_msg)
_session = TrackSession(**session_kwargs)
_session = _TrackSession(**session_kwargs)
def shutdown():
@@ -50,7 +50,25 @@ def shutdown():
def log(**kwargs):
"""Applies TrackSession.log to the trial in the current context."""
"""Logs all keyword arguments.
.. code-block:: python
import time
from ray import tune
from ray.tune import track
def run_me(config):
for iter in range(100):
time.sleep(1)
track.log(hello="world", ray="tune")
analysis = tune.run(run_me)
Args:
**kwargs: Any key value pair to be logged by Tune. Any of these
metrics can be used for early stopping or optimization.
"""
_session = get_session()
return _session.log(**kwargs)
@@ -83,6 +101,5 @@ def trial_id():
__all__ = [
"TrackSession", "session", "log", "trial_dir", "init", "shutdown",
"trial_name", "trial_id"
"session", "log", "trial_dir", "init", "shutdown", "trial_name", "trial_id"
]
+19 -10
View File
@@ -125,7 +125,7 @@ class Trainable:
When using Tune, Tune will convert this class into a Ray actor, which
runs on a separate process. Tune will also change the current working
directory of this process to `self.logdir`.
directory of this process to ``self.logdir``.
"""
@@ -184,18 +184,23 @@ class Trainable:
@classmethod
def default_resource_request(cls, config):
"""Returns the resource requirement for the given configuration.
"""Provides a static resource requirement for the given configuration.
This can be overriden by sub-classes to set the correct trial resource
allocation, so the user does not need to.
Example:
>>> def default_resource_request(cls, config):
>>> return Resources(
>>> cpu=0,
>>> gpu=0,
>>> extra_cpu=config["workers"],
>>> extra_gpu=int(config["use_gpu"]) * config["workers"])
.. code-block:: python
@classmethod
def default_resource_request(cls, config):
return Resources(
cpu=0,
gpu=0,
extra_cpu=config["workers"],
extra_gpu=int(config["use_gpu"]) * config["workers"])
Returns:
Resources: A Resources object consumed by Tune for queueing.
"""
return None
@@ -329,7 +334,7 @@ class Trainable:
checkpoint_dir (str): Optional dir to place the checkpoint.
Returns:
Checkpoint path or prefix that may be passed to restore().
str: Checkpoint path or prefix that may be passed to restore().
"""
checkpoint_dir = os.path.join(checkpoint_dir or self.logdir,
"checkpoint_{}".format(self._iteration))
@@ -658,6 +663,10 @@ class Trainable:
def _log_result(self, result):
"""Subclasses can optionally override this to customize logging.
The logging here is done on the worker process rather than
the driver. You may want to turn off driver logging via the
``loggers`` parameter in ``tune.run`` when overriding this function.
Args:
result (dict): Training result returned by _train().
"""
+13
View File
@@ -131,6 +131,19 @@ class Trial:
Trials start in the PENDING state, and transition to RUNNING once started.
On error it transitions to ERROR, otherwise TERMINATED on success.
Attributes:
trainable_name (str): Name of the trainable object to be executed.
config (dict): Provided configuration dictionary with evaluated params.
trial_id (str): Unique identifier for the trial.
local_dir (str): Local_dir as passed to tune.run.
logdir (str): Directory where the trial logs are saved.
evaluated_params (dict): Evaluated parameters by search algorithm,
experiment_tag (str): Identifying trial name to show in the console.
resources (Resources): Amount of resources that this trial will use.
status (str): One of PENDING, RUNNING, PAUSED, TERMINATED, ERROR/
error_file (str): Path to the errors that this trial has raised.
"""
PENDING = "PENDING"
+4 -2
View File
@@ -8,7 +8,9 @@ logger = logging.getLogger(__name__)
class TrialExecutor:
"""Manages platform-specific details such as resource handling
"""Module for interacting with remote trainables.
Manages platform-specific details such as resource handling
and starting/stopping trials.
"""
@@ -75,7 +77,7 @@ class TrialExecutor:
Args:
trial (Trial): Trial to be started.
checkpoint(Checkpoint): A Python object or path storing the state
checkpoint (Checkpoint): A Python object or path storing the state
of trial.
"""
raise NotImplementedError("Subclasses of TrialExecutor must provide "
+23 -24
View File
@@ -70,7 +70,8 @@ class _TuneFunctionDecoder(json.JSONDecoder):
class TrialRunner:
"""A TrialRunner implements the event loop for scheduling trials on Ray.
Example:
.. code-block: python
runner = TrialRunner()
runner.add_trial(Trial(...))
runner.add_trial(Trial(...))
@@ -87,6 +88,27 @@ class TrialRunner:
could deadlock waiting for new resources to become available. Furthermore,
oversubscribing the cluster could degrade training performance, leading to
misleading benchmark results.
Args:
search_alg (SearchAlgorithm): SearchAlgorithm for generating
Trial objects.
scheduler (TrialScheduler): Defaults to FIFOScheduler.
launch_web_server (bool): Flag for starting TuneServer
local_checkpoint_dir (str): Path where
global checkpoints are stored and restored from.
remote_checkpoint_dir (str): Remote path where
global checkpoints are stored and restored from. Used
if `resume` == REMOTE.
stopper: Custom class for stopping whole experiments. See
``Stopper``.
resume (str|False): see `tune.py:run`.
sync_to_cloud (func|str): See `tune.py:run`.
server_port (int): Port number for launching TuneServer.
verbose (bool): Flag for verbosity. If False, trial results
will not be output.
checkpoint_period (int): Trial runner checkpoint periodicity in
seconds. Defaults to 10.
trial_executor (TrialExecutor): Defaults to RayTrialExecutor.
"""
CKPT_FILE_TMPL = "experiment_state-{}.json"
@@ -105,29 +127,6 @@ class TrialRunner:
verbose=True,
checkpoint_period=10,
trial_executor=None):
"""Initializes a new TrialRunner.
Args:
search_alg (SearchAlgorithm): SearchAlgorithm for generating
Trial objects.
scheduler (TrialScheduler): Defaults to FIFOScheduler.
launch_web_server (bool): Flag for starting TuneServer
local_checkpoint_dir (str): Path where
global checkpoints are stored and restored from.
remote_checkpoint_dir (str): Remote path where
global checkpoints are stored and restored from. Used
if `resume` == REMOTE.
stopper: Custom class for stopping whole experiments. See
``Stopper``.
resume (str|False): see `tune.py:run`.
sync_to_cloud (func|str): See `tune.py:run`.
server_port (int): Port number for launching TuneServer.
verbose (bool): Flag for verbosity. If False, trial results
will not be output.
checkpoint_period (int): Trial runner checkpoint periodicity in
seconds. Defaults to 10.
trial_executor (TrialExecutor): Defaults to RayTrialExecutor.
"""
self._search_alg = search_alg or BasicVariantGenerator()
self._scheduler_alg = scheduler or FIFOScheduler()
self.trial_executor = trial_executor or RayTrialExecutor()
+10 -14
View File
@@ -98,12 +98,11 @@ def run(run_or_experiment,
trial_executor=None,
raise_on_failed_trial=True,
return_trials=False,
ray_auto_init=True,
sync_function=None):
ray_auto_init=True):
"""Executes training.
Args:
run_or_experiment (function|class|str|Experiment): If
run_or_experiment (function | class | str | :class:`Experiment`): If
function|class|str, this is the algorithm or model to train.
This may refer to the name of a built-on algorithm
(e.g. RLLib's DQN or PPO), a user-defined trainable
@@ -112,11 +111,11 @@ def run(run_or_experiment,
If Experiment, then Tune will execute training based on
Experiment.spec.
name (str): Name of experiment.
stop (dict|callable): The stopping criteria. If dict, the keys may be
any field in the return result of 'train()', whichever is
reached first. If function, it must take (trial_id, result) as
arguments and return a boolean (True if trial should be stopped,
False otherwise). This can also be a subclass of
stop (dict | callable | :class:`Stopper`): Stopping criteria. If dict,
the keys may be any field in the return result of 'train()',
whichever is reached first. If function, it must take (trial_id,
result) as arguments and return a boolean (True if trial should be
stopped, False otherwise). This can also be a subclass of
``ray.tune.Stopper``, which allows users to implement
custom experiment-wide stopping (i.e., stopping an entire Tune
run based on some time constraint).
@@ -211,14 +210,12 @@ def run(run_or_experiment,
ray_auto_init (bool): Automatically starts a local Ray cluster
if using a RayTrialExecutor (which is the default) and
if Ray is not initialized. Defaults to True.
sync_function: Deprecated. See `sync_to_cloud` and
`sync_to_driver`.
Returns:
List of Trial objects.
ExperimentAnalysis: Object for experiment analysis.
Raises:
TuneError if any trials failed and `raise_on_failed_trial` is True.
TuneError: Any trials failed and `raise_on_failed_trial` is True.
Examples:
>>> tune.run(mytrainable, scheduler=PopulationBasedTraining())
@@ -265,8 +262,7 @@ def run(run_or_experiment,
checkpoint_score_attr=checkpoint_score_attr,
export_formats=export_formats,
max_failures=max_failures,
restore=restore,
sync_function=sync_function)
restore=restore)
else:
logger.debug("Ignoring some parameters passed into tune.run.")