[tune] Added EarlyStopping and relative test suite (#8459)

This commit is contained in:
Luca Cappelletti
2020-05-17 21:18:59 +02:00
committed by GitHub
parent 42c9fa19d1
commit 2ff26f13d2
3 changed files with 91 additions and 6 deletions
+6 -5
View File
@@ -2,7 +2,7 @@ from ray.tune.error import TuneError
from ray.tune.tune import run_experiments, run
from ray.tune.experiment import Experiment
from ray.tune.analysis import ExperimentAnalysis, Analysis
from ray.tune.stopper import Stopper
from ray.tune.stopper import Stopper, EarlyStopping
from ray.tune.registry import register_env, register_trainable
from ray.tune.trainable import Trainable
from ray.tune.durable_trainable import DurableTrainable
@@ -17,8 +17,9 @@ from ray.tune.sample import (function, sample_from, uniform, choice, randint,
__all__ = [
"Trainable", "DurableTrainable", "TuneError", "grid_search",
"register_env", "register_trainable", "run", "run_experiments", "Stopper",
"Experiment", "function", "sample_from", "track", "uniform", "choice",
"randint", "randn", "loguniform", "ExperimentAnalysis", "Analysis",
"CLIReporter", "JupyterNotebookReporter", "ProgressReporter", "report",
"get_trial_dir", "get_trial_name", "get_trial_id"
"EarlyStopping", "Experiment", "function", "sample_from", "track",
"uniform", "choice", "randint", "randn", "loguniform",
"ExperimentAnalysis", "Analysis", "CLIReporter", "JupyterNotebookReporter",
"ProgressReporter", "report", "get_trial_dir", "get_trial_name",
"get_trial_id"
]
+52
View File
@@ -1,3 +1,6 @@
import numpy as np
class Stopper:
"""Base class for implementing a Tune experiment stopper.
@@ -61,3 +64,52 @@ class FunctionStopper(Stopper):
"Stop object must be ray.tune.Stopper subclass to be detected "
"correctly.")
return is_function
class EarlyStopping(Stopper):
def __init__(self, metric, std=0.001, top=10, mode="min"):
"""Create the EarlyStopping object.
Args:
metric (str): The metric to be monitored.
std (float): The minimal standard deviation after which
the tuning process has to stop.
top (int): The number of best model to consider.
mode (str): The mode to select the top results.
Can either be "min" or "max".
Raises:
ValueError: If the mode parameter is not "min" nor "max".
ValueError: If the top parameter is not an integer
greater than 1.
ValueError: If the standard deviation parameter is not
a strictly positive float.
"""
if mode not in ("min", "max"):
raise ValueError("The mode parameter can only be"
" either min or max.")
if not isinstance(top, int) or top <= 1:
raise ValueError("Top results to consider must be"
" a positive integer greater than one.")
if not isinstance(std, float) or std <= 0:
raise ValueError("The standard deviation must be"
" a strictly positive float number.")
self._mode = mode
self._metric = metric
self._std = std
self._top = top
self._top_values = []
def __call__(self, trial_id, result):
"""Return a boolean representing if the tuning has to stop."""
self._top_values.append(result[self._metric])
if self._mode == "min":
self._top_values = sorted(self._top_values)[:self._top]
else:
self._top_values = sorted(self._top_values)[-self._top:]
return self.stop_all()
def stop_all(self):
"""Return whether to stop and prevent trials from starting."""
return (len(self._top_values) == self._top
and np.std(self._top_values) <= self._std)
+33 -1
View File
@@ -10,7 +10,8 @@ import ray
from ray.rllib import _register_all
from ray import tune
from ray.tune import DurableTrainable, Trainable, TuneError, Stopper
from ray.tune import (DurableTrainable, Trainable, TuneError, Stopper,
EarlyStopping)
from ray.tune import register_env, register_trainable, run_experiments
from ray.tune.schedulers import TrialScheduler, FIFOScheduler
from ray.tune.trial import Trial
@@ -487,6 +488,37 @@ class TrainableFunctionApiTest(unittest.TestCase):
t.last_result.get("training_iteration") is None
for t in trials))
def testEarlyStopping(self):
def train(config, reporter):
reporter(test=0)
top = 3
with self.assertRaises(ValueError):
EarlyStopping("test", top=0)
with self.assertRaises(ValueError):
EarlyStopping("test", top="0")
with self.assertRaises(ValueError):
EarlyStopping("test", std=0)
with self.assertRaises(ValueError):
EarlyStopping("test", std="0")
with self.assertRaises(ValueError):
EarlyStopping("test", mode="0")
stopper = EarlyStopping("test", top=top, mode="min")
analysis = tune.run(train, num_samples=10, stop=stopper)
self.assertTrue(
all(t.status == Trial.TERMINATED for t in analysis.trials))
self.assertTrue(len(analysis.dataframe()) <= top)
stopper = EarlyStopping("test", top=top, mode="min")
analysis = tune.run(train, num_samples=10, stop=stopper)
self.assertTrue(
all(t.status == Trial.TERMINATED for t in analysis.trials))
self.assertTrue(len(analysis.dataframe()) <= top)
def testBadStoppingFunction(self):
def train(config, reporter):
for i in range(10):