[tune] add more stoppers and stopper documentation (#12750)

* Add new stoppers & docs

* Add tests for maximum iteration stopper and trial plateau stopper

* Update python/ray/tune/stopper.py

Co-authored-by: Richard Liaw <rliaw@berkeley.edu>

* Update doc/source/tune/api_docs/stoppers.rst

Co-authored-by: Richard Liaw <rliaw@berkeley.edu>

* Update doc/source/tune/api_docs/stoppers.rst

Co-authored-by: Richard Liaw <rliaw@berkeley.edu>

* Apply suggestions from code review

* Apply suggestions from code review

* Update python/ray/tune/stopper.py

Co-authored-by: Richard Liaw <rliaw@berkeley.edu>
This commit is contained in:
Kai Fricke
2020-12-12 10:47:19 +01:00
committed by GitHub
parent 905652cdd6
commit 5f04ade6ef
7 changed files with 280 additions and 37 deletions
-8
View File
@@ -23,14 +23,6 @@ tune.with_parameters
.. autofunction:: ray.tune.with_parameters
.. _tune-stop-ref:
Stopper (tune.Stopper)
----------------------
.. autoclass:: ray.tune.Stopper
:members: __call__, stop_all
.. _tune-sync-config:
tune.SyncConfig
+1
View File
@@ -21,6 +21,7 @@ on `Github`_.
suggestion.rst
schedulers.rst
sklearn.rst
stoppers.rst
logging.rst
integration.rst
internals.rst
+46
View File
@@ -0,0 +1,46 @@
.. _tune-stoppers:
Stopping mechanisms (tune.stopper)
==================================
In addition to Trial Schedulers like :ref:`ASHA <tune-scheduler-hyperband>`, where a number of
trials are stopped if they perform subpar, Ray Tune also supports custom stopping mechanisms to stop trials early. For instance, stopping mechanisms can specify to stop trials when they reached a plateau and the metric
doesn't change anymore.
Ray Tune comes with several stopping mechanisms out of the box. For custom stopping behavior, you can
inherit from the :class:`Stopper <ray.tune.Stopper>` class.
Other stopping behaviors are described :ref:`in the user guide <tune-stopping>`.
.. contents::
:local:
:depth: 1
.. _tune-stop-ref:
Stopper (tune.Stopper)
----------------------
.. autoclass:: ray.tune.Stopper
:members: __call__, stop_all
MaximumIterationStopper (tune.stopper.MaximumIterationStopper)
--------------------------------------------------------------
.. autoclass:: ray.tune.stopper.MaximumIterationStopper
ExperimentPlateauStopper (tune.stopper.ExperimentPlateauStopper)
----------------------------------------------------------------
.. autoclass:: ray.tune.stopper.ExperimentPlateauStopper
TrialPlateauStopper (tune.stopper.TrialPlateauStopper)
------------------------------------------------------
.. autoclass:: ray.tune.stopper.TrialPlateauStopper
TimeoutStopper (tune.stopper.TimeoutStopper)
--------------------------------------------
.. autoclass:: ray.tune.stopper.TimeoutStopper
+7 -3
View File
@@ -305,7 +305,9 @@ and passed to your trainable as a parameter.
Stopping Trials
---------------
You can control when trials are stopped early by passing the ``stop`` argument to ``tune.run``. This argument takes either a dictionary or a function.
You can control when trials are stopped early by passing the ``stop`` argument to ``tune.run``.
This argument takes, a dictionary, a function, or a :class:`Stopper <ray.tune.stopper.Stopper>` class
as an argument.
If a dictionary is passed in, the keys may be any field in the return result of ``tune.report`` in the Function API or ``step()`` (including the results from ``step`` and auto-filled metrics).
@@ -329,7 +331,7 @@ For more flexibility, you can pass in a function instead. If a function is passe
tune.run(my_trainable, stop=stopper)
Finally, you can implement the ``Stopper`` abstract class for stopping entire experiments. For example, the following example stops all trials after the criteria is fulfilled by any individual trial, and prevents new ones from starting:
Finally, you can implement the :class:`Stopper <ray.tune.stopper.Stopper>` abstract class for stopping entire experiments. For example, the following example stops all trials after the criteria is fulfilled by any individual trial, and prevents new ones from starting:
.. code-block:: python
@@ -352,7 +354,9 @@ Finally, you can implement the ``Stopper`` abstract class for stopping entire ex
tune.run(my_trainable, stop=stopper)
Note that in the above example the currently running trials will not stop immediately but will do so once their current iterations are complete. See the :ref:`tune-stop-ref` documentation.
Note that in the above example the currently running trials will not stop immediately but will do so once their current iterations are complete.
Ray Tune comes with a set of out-of-the-box stopper classes. See the :ref:`Stopper <tune-stoppers>` documentation.
.. _tune-logging:
+7
View File
@@ -167,6 +167,13 @@ class Experiment:
stopping_criteria = {}
if not stop:
pass
elif isinstance(stop, list):
if any(not isinstance(s, Stopper) for s in stop):
raise ValueError(
"If you pass a list as the `stop` argument to "
"`tune.run()`, each element must be an instance of "
"`tune.stopper.Stopper`.")
self._stopper = CombinedStopper(*stop)
elif isinstance(stop, dict):
stopping_criteria = stop
elif callable(stop):
+180 -26
View File
@@ -1,4 +1,7 @@
import warnings
from typing import Dict, Optional
import time
from collections import defaultdict, deque
import numpy as np
@@ -43,6 +46,27 @@ class Stopper:
class CombinedStopper(Stopper):
"""Combine several stoppers via 'OR'.
Args:
*stoppers (Stopper): Stoppers to be combined.
Example:
.. code-block:: python
from ray.tune.stopper import CombinedStopper, \
MaximumIterationStopper, TrialPlateauStopper
stopper = CombinedStopper(
MaximumIterationStopper(max_iter=20),
TrialPlateauStopper(metric="my_metric")
)
tune.run(train, stop=stopper)
"""
def __init__(self, *stoppers: Stopper):
self._stoppers = stoppers
@@ -62,6 +86,18 @@ class NoopStopper(Stopper):
class FunctionStopper(Stopper):
"""Provide a custom function to check if trial should be stopped.
The passed function will be called after each iteration. If it returns
True, the trial will be stopped.
Args:
function (Callable[[str, Dict], bool): Function that checks if a trial
should be stopped. Must accept the `trial_id` string and `result`
dictionary as arguments. Must return a boolean.
"""
def __init__(self, function):
self._fn = function
@@ -81,33 +117,53 @@ class FunctionStopper(Stopper):
return is_function
class EarlyStopping(Stopper):
class MaximumIterationStopper(Stopper):
"""Stop trials after reaching a maximum number of iterations
Args:
max_iter (int): Number of iterations before stopping a trial.
"""
def __init__(self, max_iter: int):
self._max_iter = max_iter
self._iter = defaultdict(lambda: 0)
def __call__(self, trial_id: str, result: Dict):
self._iter[trial_id] += 1
return self._iter[trial_id] >= self._max_iter
def stop_all(self):
return False
class ExperimentPlateauStopper(Stopper):
"""Early stop the experiment when a metric plateaued across trials.
Stops the entire experiment when the metric has plateaued
for more than the given amount of iterations specified in
the patience parameter.
Args:
metric (str): The metric to be monitored.
std (float): The minimal standard deviation after which
the tuning process has to stop.
top (int): The number of best models to consider.
mode (str): The mode to select the top results.
Can either be "min" or "max".
patience (int): Number of epochs to wait for
a change in the top models.
Raises:
ValueError: If the mode parameter is not "min" nor "max".
ValueError: If the top parameter is not an integer
greater than 1.
ValueError: If the standard deviation parameter is not
a strictly positive float.
ValueError: If the patience parameter is not
a strictly positive integer.
"""
def __init__(self, metric, std=0.001, top=10, mode="min", patience=0):
"""Create the EarlyStopping object.
Stops the entire experiment when the metric has plateaued
for more than the given amount of iterations specified in
the patience parameter.
Args:
metric (str): The metric to be monitored.
std (float): The minimal standard deviation after which
the tuning process has to stop.
top (int): The number of best model to consider.
mode (str): The mode to select the top results.
Can either be "min" or "max".
patience (int): Number of epochs to wait for
a change in the top models.
Raises:
ValueError: If the mode parameter is not "min" nor "max".
ValueError: If the top parameter is not an integer
greater than 1.
ValueError: If the standard deviation parameter is not
a strictly positive float.
ValueError: If the patience parameter is not
a strictly positive integer.
"""
if mode not in ("min", "max"):
raise ValueError("The mode parameter can only be"
" either min or max.")
@@ -157,9 +213,107 @@ class EarlyStopping(Stopper):
return self.has_plateaued() and self._iterations >= self._patience
class EarlyStopping(ExperimentPlateauStopper):
def __init__(self, *args, **kwargs):
warnings.warn(
"The `EarlyStopping` stopper has been renamed to "
"`ExperimentPlateauStopper`. The reference will be removed "
"in a future version of Ray. Please use ExperimentPlateauStopper"
"instead.", DeprecationWarning)
super(EarlyStopping, self).__init__(*args, **kwargs)
class TrialPlateauStopper(Stopper):
"""Early stop single trials when they reached a plateau.
When the standard deviation of the `metric` result of a trial is
below a threshold `std`, the trial plateaued and will be stopped
early.
Args:
metric (str): Metric to check for convergence.
std (float): Maximum metric standard deviation to decide if a
trial plateaued. Defaults to 0.01.
num_results (int): Number of results to consider for stdev
calculation.
grace_period (int): Minimum number of timesteps before a trial
can be early stopped
metric_threshold (Optional[float]):
Minimum or maximum value the result has to exceed before it can
be stopped early.
mode (Optional[str]): If a `metric_threshold` argument has been
passed, this must be one of [min, max]. Specifies if we optimize
for a large metric (max) or a small metric (min). If max, the
`metric_threshold` has to be exceeded, if min the value has to
be lower than `metric_threshold` in order to early stop.
"""
def __init__(self,
metric: str,
std: float = 0.01,
num_results: int = 4,
grace_period: int = 4,
metric_threshold: Optional[float] = None,
mode: Optional[str] = None):
self._metric = metric
self._mode = mode
self._std = std
self._num_results = num_results
self._grace_period = grace_period
self._metric_threshold = metric_threshold
if self._metric_threshold:
if mode not in ["min", "max"]:
raise ValueError(
f"When specifying a `metric_threshold`, the `mode` "
f"argument has to be one of [min, max]. "
f"Got: {mode}")
self._iter = defaultdict(lambda: 0)
self._trial_results = defaultdict(
lambda: deque(maxlen=self._num_results))
def __call__(self, trial_id: str, result: Dict):
metric_result = result.get(self._metric)
self._trial_results[trial_id].append(metric_result)
self._iter[trial_id] += 1
# If still in grace period, do not stop yet
if self._iter[trial_id] < self._grace_period:
return False
# If not enough results yet, do not stop yet
if len(self._trial_results[trial_id]) < self._num_results:
return False
# If metric threshold value not reached, do not stop yet
if self._metric_threshold is not None:
if self._mode == "min" and metric_result > self._metric_threshold:
return False
elif self._mode == "max" and \
metric_result < self._metric_threshold:
return False
# Calculate stdev of last `num_results` results
try:
current_std = np.std(self._trial_results[trial_id])
except Exception:
current_std = float("inf")
# If stdev is lower than threshold, stop early.
return current_std < self._std
def stop_all(self):
return False
class TimeoutStopper(Stopper):
"""Stops all trials after a certain timeout.
This stopper is automatically created when the `time_budget_s`
argument is passed to `tune.run()`.
Args:
timeout (int|float|datetime.timedelta): Either a number specifying
the timeout in seconds, or a `datetime.timedelta` object.
+39
View File
@@ -17,6 +17,7 @@ from ray.tune import (DurableTrainable, Trainable, TuneError, Stopper,
from ray.tune import register_env, register_trainable, run_experiments
from ray.tune.schedulers import (TrialScheduler, FIFOScheduler,
AsyncHyperBandScheduler)
from ray.tune.stopper import MaximumIterationStopper, TrialPlateauStopper
from ray.tune.trial import Trial
from ray.tune.result import (TIMESTEPS_TOTAL, DONE, HOSTNAME, NODE_IP, PID,
EPISODES_TOTAL, TRAINING_ITERATION,
@@ -556,6 +557,44 @@ class TrainableFunctionApiTest(unittest.TestCase):
with self.assertRaises(TuneError):
tune.run(train, stop=stop)
def testMaximumIterationStopper(self):
def train(config):
for i in range(10):
tune.report(it=i)
stopper = MaximumIterationStopper(max_iter=6)
out = tune.run(train, stop=stopper)
self.assertEqual(out.trials[0].last_result[TRAINING_ITERATION], 6)
def testTrialPlateauStopper(self):
def train(config):
tune.report(10.0)
tune.report(11.0)
tune.report(12.0)
for i in range(10):
tune.report(20.0)
# num_results = 4, no other constraints --> early stop after 7
stopper = TrialPlateauStopper(metric="_metric", num_results=4)
out = tune.run(train, stop=stopper)
self.assertEqual(out.trials[0].last_result[TRAINING_ITERATION], 7)
# num_results = 4, grace period 9 --> early stop after 9
stopper = TrialPlateauStopper(
metric="_metric", num_results=4, grace_period=9)
out = tune.run(train, stop=stopper)
self.assertEqual(out.trials[0].last_result[TRAINING_ITERATION], 9)
# num_results = 4, min_metric = 22 --> full 13 iterations
stopper = TrialPlateauStopper(
metric="_metric", num_results=4, metric_threshold=22.0, mode="max")
out = tune.run(train, stop=stopper)
self.assertEqual(out.trials[0].last_result[TRAINING_ITERATION], 13)
def testCustomTrialDir(self):
def train(config):
for i in range(10):