mirror of
https://github.com/wassname/ray.git
synced 2026-06-28 02:01:24 +08:00
[tune] add more stoppers and stopper documentation (#12750)
* Add new stoppers & docs * Add tests for maximum iteration stopper and trial plateau stopper * Update python/ray/tune/stopper.py Co-authored-by: Richard Liaw <rliaw@berkeley.edu> * Update doc/source/tune/api_docs/stoppers.rst Co-authored-by: Richard Liaw <rliaw@berkeley.edu> * Update doc/source/tune/api_docs/stoppers.rst Co-authored-by: Richard Liaw <rliaw@berkeley.edu> * Apply suggestions from code review * Apply suggestions from code review * Update python/ray/tune/stopper.py Co-authored-by: Richard Liaw <rliaw@berkeley.edu>
This commit is contained in:
@@ -167,6 +167,13 @@ class Experiment:
|
||||
stopping_criteria = {}
|
||||
if not stop:
|
||||
pass
|
||||
elif isinstance(stop, list):
|
||||
if any(not isinstance(s, Stopper) for s in stop):
|
||||
raise ValueError(
|
||||
"If you pass a list as the `stop` argument to "
|
||||
"`tune.run()`, each element must be an instance of "
|
||||
"`tune.stopper.Stopper`.")
|
||||
self._stopper = CombinedStopper(*stop)
|
||||
elif isinstance(stop, dict):
|
||||
stopping_criteria = stop
|
||||
elif callable(stop):
|
||||
|
||||
+180
-26
@@ -1,4 +1,7 @@
|
||||
import warnings
|
||||
from typing import Dict, Optional
|
||||
import time
|
||||
from collections import defaultdict, deque
|
||||
|
||||
import numpy as np
|
||||
|
||||
@@ -43,6 +46,27 @@ class Stopper:
|
||||
|
||||
|
||||
class CombinedStopper(Stopper):
|
||||
"""Combine several stoppers via 'OR'.
|
||||
|
||||
Args:
|
||||
*stoppers (Stopper): Stoppers to be combined.
|
||||
|
||||
Example:
|
||||
|
||||
.. code-block:: python
|
||||
|
||||
from ray.tune.stopper import CombinedStopper, \
|
||||
MaximumIterationStopper, TrialPlateauStopper
|
||||
|
||||
stopper = CombinedStopper(
|
||||
MaximumIterationStopper(max_iter=20),
|
||||
TrialPlateauStopper(metric="my_metric")
|
||||
)
|
||||
|
||||
tune.run(train, stop=stopper)
|
||||
|
||||
"""
|
||||
|
||||
def __init__(self, *stoppers: Stopper):
|
||||
self._stoppers = stoppers
|
||||
|
||||
@@ -62,6 +86,18 @@ class NoopStopper(Stopper):
|
||||
|
||||
|
||||
class FunctionStopper(Stopper):
|
||||
"""Provide a custom function to check if trial should be stopped.
|
||||
|
||||
The passed function will be called after each iteration. If it returns
|
||||
True, the trial will be stopped.
|
||||
|
||||
Args:
|
||||
function (Callable[[str, Dict], bool): Function that checks if a trial
|
||||
should be stopped. Must accept the `trial_id` string and `result`
|
||||
dictionary as arguments. Must return a boolean.
|
||||
|
||||
"""
|
||||
|
||||
def __init__(self, function):
|
||||
self._fn = function
|
||||
|
||||
@@ -81,33 +117,53 @@ class FunctionStopper(Stopper):
|
||||
return is_function
|
||||
|
||||
|
||||
class EarlyStopping(Stopper):
|
||||
class MaximumIterationStopper(Stopper):
|
||||
"""Stop trials after reaching a maximum number of iterations
|
||||
|
||||
Args:
|
||||
max_iter (int): Number of iterations before stopping a trial.
|
||||
"""
|
||||
|
||||
def __init__(self, max_iter: int):
|
||||
self._max_iter = max_iter
|
||||
self._iter = defaultdict(lambda: 0)
|
||||
|
||||
def __call__(self, trial_id: str, result: Dict):
|
||||
self._iter[trial_id] += 1
|
||||
return self._iter[trial_id] >= self._max_iter
|
||||
|
||||
def stop_all(self):
|
||||
return False
|
||||
|
||||
|
||||
class ExperimentPlateauStopper(Stopper):
|
||||
"""Early stop the experiment when a metric plateaued across trials.
|
||||
|
||||
Stops the entire experiment when the metric has plateaued
|
||||
for more than the given amount of iterations specified in
|
||||
the patience parameter.
|
||||
|
||||
Args:
|
||||
metric (str): The metric to be monitored.
|
||||
std (float): The minimal standard deviation after which
|
||||
the tuning process has to stop.
|
||||
top (int): The number of best models to consider.
|
||||
mode (str): The mode to select the top results.
|
||||
Can either be "min" or "max".
|
||||
patience (int): Number of epochs to wait for
|
||||
a change in the top models.
|
||||
|
||||
Raises:
|
||||
ValueError: If the mode parameter is not "min" nor "max".
|
||||
ValueError: If the top parameter is not an integer
|
||||
greater than 1.
|
||||
ValueError: If the standard deviation parameter is not
|
||||
a strictly positive float.
|
||||
ValueError: If the patience parameter is not
|
||||
a strictly positive integer.
|
||||
"""
|
||||
|
||||
def __init__(self, metric, std=0.001, top=10, mode="min", patience=0):
|
||||
"""Create the EarlyStopping object.
|
||||
|
||||
Stops the entire experiment when the metric has plateaued
|
||||
for more than the given amount of iterations specified in
|
||||
the patience parameter.
|
||||
|
||||
Args:
|
||||
metric (str): The metric to be monitored.
|
||||
std (float): The minimal standard deviation after which
|
||||
the tuning process has to stop.
|
||||
top (int): The number of best model to consider.
|
||||
mode (str): The mode to select the top results.
|
||||
Can either be "min" or "max".
|
||||
patience (int): Number of epochs to wait for
|
||||
a change in the top models.
|
||||
|
||||
Raises:
|
||||
ValueError: If the mode parameter is not "min" nor "max".
|
||||
ValueError: If the top parameter is not an integer
|
||||
greater than 1.
|
||||
ValueError: If the standard deviation parameter is not
|
||||
a strictly positive float.
|
||||
ValueError: If the patience parameter is not
|
||||
a strictly positive integer.
|
||||
"""
|
||||
if mode not in ("min", "max"):
|
||||
raise ValueError("The mode parameter can only be"
|
||||
" either min or max.")
|
||||
@@ -157,9 +213,107 @@ class EarlyStopping(Stopper):
|
||||
return self.has_plateaued() and self._iterations >= self._patience
|
||||
|
||||
|
||||
class EarlyStopping(ExperimentPlateauStopper):
|
||||
def __init__(self, *args, **kwargs):
|
||||
warnings.warn(
|
||||
"The `EarlyStopping` stopper has been renamed to "
|
||||
"`ExperimentPlateauStopper`. The reference will be removed "
|
||||
"in a future version of Ray. Please use ExperimentPlateauStopper"
|
||||
"instead.", DeprecationWarning)
|
||||
super(EarlyStopping, self).__init__(*args, **kwargs)
|
||||
|
||||
|
||||
class TrialPlateauStopper(Stopper):
|
||||
"""Early stop single trials when they reached a plateau.
|
||||
|
||||
When the standard deviation of the `metric` result of a trial is
|
||||
below a threshold `std`, the trial plateaued and will be stopped
|
||||
early.
|
||||
|
||||
Args:
|
||||
metric (str): Metric to check for convergence.
|
||||
std (float): Maximum metric standard deviation to decide if a
|
||||
trial plateaued. Defaults to 0.01.
|
||||
num_results (int): Number of results to consider for stdev
|
||||
calculation.
|
||||
grace_period (int): Minimum number of timesteps before a trial
|
||||
can be early stopped
|
||||
metric_threshold (Optional[float]):
|
||||
Minimum or maximum value the result has to exceed before it can
|
||||
be stopped early.
|
||||
mode (Optional[str]): If a `metric_threshold` argument has been
|
||||
passed, this must be one of [min, max]. Specifies if we optimize
|
||||
for a large metric (max) or a small metric (min). If max, the
|
||||
`metric_threshold` has to be exceeded, if min the value has to
|
||||
be lower than `metric_threshold` in order to early stop.
|
||||
"""
|
||||
|
||||
def __init__(self,
|
||||
metric: str,
|
||||
std: float = 0.01,
|
||||
num_results: int = 4,
|
||||
grace_period: int = 4,
|
||||
metric_threshold: Optional[float] = None,
|
||||
mode: Optional[str] = None):
|
||||
self._metric = metric
|
||||
self._mode = mode
|
||||
|
||||
self._std = std
|
||||
self._num_results = num_results
|
||||
self._grace_period = grace_period
|
||||
self._metric_threshold = metric_threshold
|
||||
|
||||
if self._metric_threshold:
|
||||
if mode not in ["min", "max"]:
|
||||
raise ValueError(
|
||||
f"When specifying a `metric_threshold`, the `mode` "
|
||||
f"argument has to be one of [min, max]. "
|
||||
f"Got: {mode}")
|
||||
|
||||
self._iter = defaultdict(lambda: 0)
|
||||
self._trial_results = defaultdict(
|
||||
lambda: deque(maxlen=self._num_results))
|
||||
|
||||
def __call__(self, trial_id: str, result: Dict):
|
||||
metric_result = result.get(self._metric)
|
||||
self._trial_results[trial_id].append(metric_result)
|
||||
self._iter[trial_id] += 1
|
||||
|
||||
# If still in grace period, do not stop yet
|
||||
if self._iter[trial_id] < self._grace_period:
|
||||
return False
|
||||
|
||||
# If not enough results yet, do not stop yet
|
||||
if len(self._trial_results[trial_id]) < self._num_results:
|
||||
return False
|
||||
|
||||
# If metric threshold value not reached, do not stop yet
|
||||
if self._metric_threshold is not None:
|
||||
if self._mode == "min" and metric_result > self._metric_threshold:
|
||||
return False
|
||||
elif self._mode == "max" and \
|
||||
metric_result < self._metric_threshold:
|
||||
return False
|
||||
|
||||
# Calculate stdev of last `num_results` results
|
||||
try:
|
||||
current_std = np.std(self._trial_results[trial_id])
|
||||
except Exception:
|
||||
current_std = float("inf")
|
||||
|
||||
# If stdev is lower than threshold, stop early.
|
||||
return current_std < self._std
|
||||
|
||||
def stop_all(self):
|
||||
return False
|
||||
|
||||
|
||||
class TimeoutStopper(Stopper):
|
||||
"""Stops all trials after a certain timeout.
|
||||
|
||||
This stopper is automatically created when the `time_budget_s`
|
||||
argument is passed to `tune.run()`.
|
||||
|
||||
Args:
|
||||
timeout (int|float|datetime.timedelta): Either a number specifying
|
||||
the timeout in seconds, or a `datetime.timedelta` object.
|
||||
|
||||
@@ -17,6 +17,7 @@ from ray.tune import (DurableTrainable, Trainable, TuneError, Stopper,
|
||||
from ray.tune import register_env, register_trainable, run_experiments
|
||||
from ray.tune.schedulers import (TrialScheduler, FIFOScheduler,
|
||||
AsyncHyperBandScheduler)
|
||||
from ray.tune.stopper import MaximumIterationStopper, TrialPlateauStopper
|
||||
from ray.tune.trial import Trial
|
||||
from ray.tune.result import (TIMESTEPS_TOTAL, DONE, HOSTNAME, NODE_IP, PID,
|
||||
EPISODES_TOTAL, TRAINING_ITERATION,
|
||||
@@ -556,6 +557,44 @@ class TrainableFunctionApiTest(unittest.TestCase):
|
||||
with self.assertRaises(TuneError):
|
||||
tune.run(train, stop=stop)
|
||||
|
||||
def testMaximumIterationStopper(self):
|
||||
def train(config):
|
||||
for i in range(10):
|
||||
tune.report(it=i)
|
||||
|
||||
stopper = MaximumIterationStopper(max_iter=6)
|
||||
|
||||
out = tune.run(train, stop=stopper)
|
||||
self.assertEqual(out.trials[0].last_result[TRAINING_ITERATION], 6)
|
||||
|
||||
def testTrialPlateauStopper(self):
|
||||
def train(config):
|
||||
tune.report(10.0)
|
||||
tune.report(11.0)
|
||||
tune.report(12.0)
|
||||
for i in range(10):
|
||||
tune.report(20.0)
|
||||
|
||||
# num_results = 4, no other constraints --> early stop after 7
|
||||
stopper = TrialPlateauStopper(metric="_metric", num_results=4)
|
||||
|
||||
out = tune.run(train, stop=stopper)
|
||||
self.assertEqual(out.trials[0].last_result[TRAINING_ITERATION], 7)
|
||||
|
||||
# num_results = 4, grace period 9 --> early stop after 9
|
||||
stopper = TrialPlateauStopper(
|
||||
metric="_metric", num_results=4, grace_period=9)
|
||||
|
||||
out = tune.run(train, stop=stopper)
|
||||
self.assertEqual(out.trials[0].last_result[TRAINING_ITERATION], 9)
|
||||
|
||||
# num_results = 4, min_metric = 22 --> full 13 iterations
|
||||
stopper = TrialPlateauStopper(
|
||||
metric="_metric", num_results=4, metric_threshold=22.0, mode="max")
|
||||
|
||||
out = tune.run(train, stop=stopper)
|
||||
self.assertEqual(out.trials[0].last_result[TRAINING_ITERATION], 13)
|
||||
|
||||
def testCustomTrialDir(self):
|
||||
def train(config):
|
||||
for i in range(10):
|
||||
|
||||
Reference in New Issue
Block a user