From 5f04ade6ef0f08b6456be169fde11e01fe27c99d Mon Sep 17 00:00:00 2001 From: Kai Fricke Date: Sat, 12 Dec 2020 10:47:19 +0100 Subject: [PATCH] [tune] add more stoppers and stopper documentation (#12750) * Add new stoppers & docs * Add tests for maximum iteration stopper and trial plateau stopper * Update python/ray/tune/stopper.py Co-authored-by: Richard Liaw * Update doc/source/tune/api_docs/stoppers.rst Co-authored-by: Richard Liaw * Update doc/source/tune/api_docs/stoppers.rst Co-authored-by: Richard Liaw * Apply suggestions from code review * Apply suggestions from code review * Update python/ray/tune/stopper.py Co-authored-by: Richard Liaw --- doc/source/tune/api_docs/execution.rst | 8 - doc/source/tune/api_docs/overview.rst | 1 + doc/source/tune/api_docs/stoppers.rst | 46 ++++++ doc/source/tune/user-guide.rst | 10 +- python/ray/tune/experiment.py | 7 + python/ray/tune/stopper.py | 206 +++++++++++++++++++++---- python/ray/tune/tests/test_api.py | 39 +++++ 7 files changed, 280 insertions(+), 37 deletions(-) create mode 100644 doc/source/tune/api_docs/stoppers.rst diff --git a/doc/source/tune/api_docs/execution.rst b/doc/source/tune/api_docs/execution.rst index c984d3894..9eebc3c27 100644 --- a/doc/source/tune/api_docs/execution.rst +++ b/doc/source/tune/api_docs/execution.rst @@ -23,14 +23,6 @@ tune.with_parameters .. autofunction:: ray.tune.with_parameters -.. _tune-stop-ref: - -Stopper (tune.Stopper) ----------------------- - -.. autoclass:: ray.tune.Stopper - :members: __call__, stop_all - .. _tune-sync-config: tune.SyncConfig diff --git a/doc/source/tune/api_docs/overview.rst b/doc/source/tune/api_docs/overview.rst index c8cc034a6..cb3f5193c 100644 --- a/doc/source/tune/api_docs/overview.rst +++ b/doc/source/tune/api_docs/overview.rst @@ -21,6 +21,7 @@ on `Github`_. suggestion.rst schedulers.rst sklearn.rst + stoppers.rst logging.rst integration.rst internals.rst diff --git a/doc/source/tune/api_docs/stoppers.rst b/doc/source/tune/api_docs/stoppers.rst new file mode 100644 index 000000000..4d65754be --- /dev/null +++ b/doc/source/tune/api_docs/stoppers.rst @@ -0,0 +1,46 @@ +.. _tune-stoppers: + +Stopping mechanisms (tune.stopper) +================================== + +In addition to Trial Schedulers like :ref:`ASHA `, where a number of +trials are stopped if they perform subpar, Ray Tune also supports custom stopping mechanisms to stop trials early. For instance, stopping mechanisms can specify to stop trials when they reached a plateau and the metric +doesn't change anymore. + +Ray Tune comes with several stopping mechanisms out of the box. For custom stopping behavior, you can +inherit from the :class:`Stopper ` class. + +Other stopping behaviors are described :ref:`in the user guide `. + +.. contents:: + :local: + :depth: 1 + + +.. _tune-stop-ref: + +Stopper (tune.Stopper) +---------------------- + +.. autoclass:: ray.tune.Stopper + :members: __call__, stop_all + +MaximumIterationStopper (tune.stopper.MaximumIterationStopper) +-------------------------------------------------------------- + +.. autoclass:: ray.tune.stopper.MaximumIterationStopper + +ExperimentPlateauStopper (tune.stopper.ExperimentPlateauStopper) +---------------------------------------------------------------- + +.. autoclass:: ray.tune.stopper.ExperimentPlateauStopper + +TrialPlateauStopper (tune.stopper.TrialPlateauStopper) +------------------------------------------------------ + +.. autoclass:: ray.tune.stopper.TrialPlateauStopper + +TimeoutStopper (tune.stopper.TimeoutStopper) +-------------------------------------------- + +.. autoclass:: ray.tune.stopper.TimeoutStopper diff --git a/doc/source/tune/user-guide.rst b/doc/source/tune/user-guide.rst index 1efa218b1..08fb4cba5 100644 --- a/doc/source/tune/user-guide.rst +++ b/doc/source/tune/user-guide.rst @@ -305,7 +305,9 @@ and passed to your trainable as a parameter. Stopping Trials --------------- -You can control when trials are stopped early by passing the ``stop`` argument to ``tune.run``. This argument takes either a dictionary or a function. +You can control when trials are stopped early by passing the ``stop`` argument to ``tune.run``. +This argument takes, a dictionary, a function, or a :class:`Stopper ` class +as an argument. If a dictionary is passed in, the keys may be any field in the return result of ``tune.report`` in the Function API or ``step()`` (including the results from ``step`` and auto-filled metrics). @@ -329,7 +331,7 @@ For more flexibility, you can pass in a function instead. If a function is passe tune.run(my_trainable, stop=stopper) -Finally, you can implement the ``Stopper`` abstract class for stopping entire experiments. For example, the following example stops all trials after the criteria is fulfilled by any individual trial, and prevents new ones from starting: +Finally, you can implement the :class:`Stopper ` abstract class for stopping entire experiments. For example, the following example stops all trials after the criteria is fulfilled by any individual trial, and prevents new ones from starting: .. code-block:: python @@ -352,7 +354,9 @@ Finally, you can implement the ``Stopper`` abstract class for stopping entire ex tune.run(my_trainable, stop=stopper) -Note that in the above example the currently running trials will not stop immediately but will do so once their current iterations are complete. See the :ref:`tune-stop-ref` documentation. +Note that in the above example the currently running trials will not stop immediately but will do so once their current iterations are complete. + +Ray Tune comes with a set of out-of-the-box stopper classes. See the :ref:`Stopper ` documentation. .. _tune-logging: diff --git a/python/ray/tune/experiment.py b/python/ray/tune/experiment.py index bf50951a3..4ad5c43c7 100644 --- a/python/ray/tune/experiment.py +++ b/python/ray/tune/experiment.py @@ -167,6 +167,13 @@ class Experiment: stopping_criteria = {} if not stop: pass + elif isinstance(stop, list): + if any(not isinstance(s, Stopper) for s in stop): + raise ValueError( + "If you pass a list as the `stop` argument to " + "`tune.run()`, each element must be an instance of " + "`tune.stopper.Stopper`.") + self._stopper = CombinedStopper(*stop) elif isinstance(stop, dict): stopping_criteria = stop elif callable(stop): diff --git a/python/ray/tune/stopper.py b/python/ray/tune/stopper.py index 9dc20773d..bc0940bd7 100644 --- a/python/ray/tune/stopper.py +++ b/python/ray/tune/stopper.py @@ -1,4 +1,7 @@ +import warnings +from typing import Dict, Optional import time +from collections import defaultdict, deque import numpy as np @@ -43,6 +46,27 @@ class Stopper: class CombinedStopper(Stopper): + """Combine several stoppers via 'OR'. + + Args: + *stoppers (Stopper): Stoppers to be combined. + + Example: + + .. code-block:: python + + from ray.tune.stopper import CombinedStopper, \ + MaximumIterationStopper, TrialPlateauStopper + + stopper = CombinedStopper( + MaximumIterationStopper(max_iter=20), + TrialPlateauStopper(metric="my_metric") + ) + + tune.run(train, stop=stopper) + + """ + def __init__(self, *stoppers: Stopper): self._stoppers = stoppers @@ -62,6 +86,18 @@ class NoopStopper(Stopper): class FunctionStopper(Stopper): + """Provide a custom function to check if trial should be stopped. + + The passed function will be called after each iteration. If it returns + True, the trial will be stopped. + + Args: + function (Callable[[str, Dict], bool): Function that checks if a trial + should be stopped. Must accept the `trial_id` string and `result` + dictionary as arguments. Must return a boolean. + + """ + def __init__(self, function): self._fn = function @@ -81,33 +117,53 @@ class FunctionStopper(Stopper): return is_function -class EarlyStopping(Stopper): +class MaximumIterationStopper(Stopper): + """Stop trials after reaching a maximum number of iterations + + Args: + max_iter (int): Number of iterations before stopping a trial. + """ + + def __init__(self, max_iter: int): + self._max_iter = max_iter + self._iter = defaultdict(lambda: 0) + + def __call__(self, trial_id: str, result: Dict): + self._iter[trial_id] += 1 + return self._iter[trial_id] >= self._max_iter + + def stop_all(self): + return False + + +class ExperimentPlateauStopper(Stopper): + """Early stop the experiment when a metric plateaued across trials. + + Stops the entire experiment when the metric has plateaued + for more than the given amount of iterations specified in + the patience parameter. + + Args: + metric (str): The metric to be monitored. + std (float): The minimal standard deviation after which + the tuning process has to stop. + top (int): The number of best models to consider. + mode (str): The mode to select the top results. + Can either be "min" or "max". + patience (int): Number of epochs to wait for + a change in the top models. + + Raises: + ValueError: If the mode parameter is not "min" nor "max". + ValueError: If the top parameter is not an integer + greater than 1. + ValueError: If the standard deviation parameter is not + a strictly positive float. + ValueError: If the patience parameter is not + a strictly positive integer. + """ + def __init__(self, metric, std=0.001, top=10, mode="min", patience=0): - """Create the EarlyStopping object. - - Stops the entire experiment when the metric has plateaued - for more than the given amount of iterations specified in - the patience parameter. - - Args: - metric (str): The metric to be monitored. - std (float): The minimal standard deviation after which - the tuning process has to stop. - top (int): The number of best model to consider. - mode (str): The mode to select the top results. - Can either be "min" or "max". - patience (int): Number of epochs to wait for - a change in the top models. - - Raises: - ValueError: If the mode parameter is not "min" nor "max". - ValueError: If the top parameter is not an integer - greater than 1. - ValueError: If the standard deviation parameter is not - a strictly positive float. - ValueError: If the patience parameter is not - a strictly positive integer. - """ if mode not in ("min", "max"): raise ValueError("The mode parameter can only be" " either min or max.") @@ -157,9 +213,107 @@ class EarlyStopping(Stopper): return self.has_plateaued() and self._iterations >= self._patience +class EarlyStopping(ExperimentPlateauStopper): + def __init__(self, *args, **kwargs): + warnings.warn( + "The `EarlyStopping` stopper has been renamed to " + "`ExperimentPlateauStopper`. The reference will be removed " + "in a future version of Ray. Please use ExperimentPlateauStopper" + "instead.", DeprecationWarning) + super(EarlyStopping, self).__init__(*args, **kwargs) + + +class TrialPlateauStopper(Stopper): + """Early stop single trials when they reached a plateau. + + When the standard deviation of the `metric` result of a trial is + below a threshold `std`, the trial plateaued and will be stopped + early. + + Args: + metric (str): Metric to check for convergence. + std (float): Maximum metric standard deviation to decide if a + trial plateaued. Defaults to 0.01. + num_results (int): Number of results to consider for stdev + calculation. + grace_period (int): Minimum number of timesteps before a trial + can be early stopped + metric_threshold (Optional[float]): + Minimum or maximum value the result has to exceed before it can + be stopped early. + mode (Optional[str]): If a `metric_threshold` argument has been + passed, this must be one of [min, max]. Specifies if we optimize + for a large metric (max) or a small metric (min). If max, the + `metric_threshold` has to be exceeded, if min the value has to + be lower than `metric_threshold` in order to early stop. + """ + + def __init__(self, + metric: str, + std: float = 0.01, + num_results: int = 4, + grace_period: int = 4, + metric_threshold: Optional[float] = None, + mode: Optional[str] = None): + self._metric = metric + self._mode = mode + + self._std = std + self._num_results = num_results + self._grace_period = grace_period + self._metric_threshold = metric_threshold + + if self._metric_threshold: + if mode not in ["min", "max"]: + raise ValueError( + f"When specifying a `metric_threshold`, the `mode` " + f"argument has to be one of [min, max]. " + f"Got: {mode}") + + self._iter = defaultdict(lambda: 0) + self._trial_results = defaultdict( + lambda: deque(maxlen=self._num_results)) + + def __call__(self, trial_id: str, result: Dict): + metric_result = result.get(self._metric) + self._trial_results[trial_id].append(metric_result) + self._iter[trial_id] += 1 + + # If still in grace period, do not stop yet + if self._iter[trial_id] < self._grace_period: + return False + + # If not enough results yet, do not stop yet + if len(self._trial_results[trial_id]) < self._num_results: + return False + + # If metric threshold value not reached, do not stop yet + if self._metric_threshold is not None: + if self._mode == "min" and metric_result > self._metric_threshold: + return False + elif self._mode == "max" and \ + metric_result < self._metric_threshold: + return False + + # Calculate stdev of last `num_results` results + try: + current_std = np.std(self._trial_results[trial_id]) + except Exception: + current_std = float("inf") + + # If stdev is lower than threshold, stop early. + return current_std < self._std + + def stop_all(self): + return False + + class TimeoutStopper(Stopper): """Stops all trials after a certain timeout. + This stopper is automatically created when the `time_budget_s` + argument is passed to `tune.run()`. + Args: timeout (int|float|datetime.timedelta): Either a number specifying the timeout in seconds, or a `datetime.timedelta` object. diff --git a/python/ray/tune/tests/test_api.py b/python/ray/tune/tests/test_api.py index cfb1d3f6c..2a6f96868 100644 --- a/python/ray/tune/tests/test_api.py +++ b/python/ray/tune/tests/test_api.py @@ -17,6 +17,7 @@ from ray.tune import (DurableTrainable, Trainable, TuneError, Stopper, from ray.tune import register_env, register_trainable, run_experiments from ray.tune.schedulers import (TrialScheduler, FIFOScheduler, AsyncHyperBandScheduler) +from ray.tune.stopper import MaximumIterationStopper, TrialPlateauStopper from ray.tune.trial import Trial from ray.tune.result import (TIMESTEPS_TOTAL, DONE, HOSTNAME, NODE_IP, PID, EPISODES_TOTAL, TRAINING_ITERATION, @@ -556,6 +557,44 @@ class TrainableFunctionApiTest(unittest.TestCase): with self.assertRaises(TuneError): tune.run(train, stop=stop) + def testMaximumIterationStopper(self): + def train(config): + for i in range(10): + tune.report(it=i) + + stopper = MaximumIterationStopper(max_iter=6) + + out = tune.run(train, stop=stopper) + self.assertEqual(out.trials[0].last_result[TRAINING_ITERATION], 6) + + def testTrialPlateauStopper(self): + def train(config): + tune.report(10.0) + tune.report(11.0) + tune.report(12.0) + for i in range(10): + tune.report(20.0) + + # num_results = 4, no other constraints --> early stop after 7 + stopper = TrialPlateauStopper(metric="_metric", num_results=4) + + out = tune.run(train, stop=stopper) + self.assertEqual(out.trials[0].last_result[TRAINING_ITERATION], 7) + + # num_results = 4, grace period 9 --> early stop after 9 + stopper = TrialPlateauStopper( + metric="_metric", num_results=4, grace_period=9) + + out = tune.run(train, stop=stopper) + self.assertEqual(out.trials[0].last_result[TRAINING_ITERATION], 9) + + # num_results = 4, min_metric = 22 --> full 13 iterations + stopper = TrialPlateauStopper( + metric="_metric", num_results=4, metric_threshold=22.0, mode="max") + + out = tune.run(train, stop=stopper) + self.assertEqual(out.trials[0].last_result[TRAINING_ITERATION], 13) + def testCustomTrialDir(self): def train(config): for i in range(10):