mirror of
https://github.com/wassname/ray.git
synced 2026-06-29 10:45:02 +08:00
[tune] Added EarlyStopping and relative test suite (#8459)
This commit is contained in:
@@ -2,7 +2,7 @@ from ray.tune.error import TuneError
|
||||
from ray.tune.tune import run_experiments, run
|
||||
from ray.tune.experiment import Experiment
|
||||
from ray.tune.analysis import ExperimentAnalysis, Analysis
|
||||
from ray.tune.stopper import Stopper
|
||||
from ray.tune.stopper import Stopper, EarlyStopping
|
||||
from ray.tune.registry import register_env, register_trainable
|
||||
from ray.tune.trainable import Trainable
|
||||
from ray.tune.durable_trainable import DurableTrainable
|
||||
@@ -17,8 +17,9 @@ from ray.tune.sample import (function, sample_from, uniform, choice, randint,
|
||||
__all__ = [
|
||||
"Trainable", "DurableTrainable", "TuneError", "grid_search",
|
||||
"register_env", "register_trainable", "run", "run_experiments", "Stopper",
|
||||
"Experiment", "function", "sample_from", "track", "uniform", "choice",
|
||||
"randint", "randn", "loguniform", "ExperimentAnalysis", "Analysis",
|
||||
"CLIReporter", "JupyterNotebookReporter", "ProgressReporter", "report",
|
||||
"get_trial_dir", "get_trial_name", "get_trial_id"
|
||||
"EarlyStopping", "Experiment", "function", "sample_from", "track",
|
||||
"uniform", "choice", "randint", "randn", "loguniform",
|
||||
"ExperimentAnalysis", "Analysis", "CLIReporter", "JupyterNotebookReporter",
|
||||
"ProgressReporter", "report", "get_trial_dir", "get_trial_name",
|
||||
"get_trial_id"
|
||||
]
|
||||
|
||||
@@ -1,3 +1,6 @@
|
||||
import numpy as np
|
||||
|
||||
|
||||
class Stopper:
|
||||
"""Base class for implementing a Tune experiment stopper.
|
||||
|
||||
@@ -61,3 +64,52 @@ class FunctionStopper(Stopper):
|
||||
"Stop object must be ray.tune.Stopper subclass to be detected "
|
||||
"correctly.")
|
||||
return is_function
|
||||
|
||||
|
||||
class EarlyStopping(Stopper):
|
||||
def __init__(self, metric, std=0.001, top=10, mode="min"):
|
||||
"""Create the EarlyStopping object.
|
||||
|
||||
Args:
|
||||
metric (str): The metric to be monitored.
|
||||
std (float): The minimal standard deviation after which
|
||||
the tuning process has to stop.
|
||||
top (int): The number of best model to consider.
|
||||
mode (str): The mode to select the top results.
|
||||
Can either be "min" or "max".
|
||||
|
||||
Raises:
|
||||
ValueError: If the mode parameter is not "min" nor "max".
|
||||
ValueError: If the top parameter is not an integer
|
||||
greater than 1.
|
||||
ValueError: If the standard deviation parameter is not
|
||||
a strictly positive float.
|
||||
"""
|
||||
if mode not in ("min", "max"):
|
||||
raise ValueError("The mode parameter can only be"
|
||||
" either min or max.")
|
||||
if not isinstance(top, int) or top <= 1:
|
||||
raise ValueError("Top results to consider must be"
|
||||
" a positive integer greater than one.")
|
||||
if not isinstance(std, float) or std <= 0:
|
||||
raise ValueError("The standard deviation must be"
|
||||
" a strictly positive float number.")
|
||||
self._mode = mode
|
||||
self._metric = metric
|
||||
self._std = std
|
||||
self._top = top
|
||||
self._top_values = []
|
||||
|
||||
def __call__(self, trial_id, result):
|
||||
"""Return a boolean representing if the tuning has to stop."""
|
||||
self._top_values.append(result[self._metric])
|
||||
if self._mode == "min":
|
||||
self._top_values = sorted(self._top_values)[:self._top]
|
||||
else:
|
||||
self._top_values = sorted(self._top_values)[-self._top:]
|
||||
return self.stop_all()
|
||||
|
||||
def stop_all(self):
|
||||
"""Return whether to stop and prevent trials from starting."""
|
||||
return (len(self._top_values) == self._top
|
||||
and np.std(self._top_values) <= self._std)
|
||||
|
||||
@@ -10,7 +10,8 @@ import ray
|
||||
from ray.rllib import _register_all
|
||||
|
||||
from ray import tune
|
||||
from ray.tune import DurableTrainable, Trainable, TuneError, Stopper
|
||||
from ray.tune import (DurableTrainable, Trainable, TuneError, Stopper,
|
||||
EarlyStopping)
|
||||
from ray.tune import register_env, register_trainable, run_experiments
|
||||
from ray.tune.schedulers import TrialScheduler, FIFOScheduler
|
||||
from ray.tune.trial import Trial
|
||||
@@ -487,6 +488,37 @@ class TrainableFunctionApiTest(unittest.TestCase):
|
||||
t.last_result.get("training_iteration") is None
|
||||
for t in trials))
|
||||
|
||||
def testEarlyStopping(self):
|
||||
def train(config, reporter):
|
||||
reporter(test=0)
|
||||
|
||||
top = 3
|
||||
|
||||
with self.assertRaises(ValueError):
|
||||
EarlyStopping("test", top=0)
|
||||
with self.assertRaises(ValueError):
|
||||
EarlyStopping("test", top="0")
|
||||
with self.assertRaises(ValueError):
|
||||
EarlyStopping("test", std=0)
|
||||
with self.assertRaises(ValueError):
|
||||
EarlyStopping("test", std="0")
|
||||
with self.assertRaises(ValueError):
|
||||
EarlyStopping("test", mode="0")
|
||||
|
||||
stopper = EarlyStopping("test", top=top, mode="min")
|
||||
|
||||
analysis = tune.run(train, num_samples=10, stop=stopper)
|
||||
self.assertTrue(
|
||||
all(t.status == Trial.TERMINATED for t in analysis.trials))
|
||||
self.assertTrue(len(analysis.dataframe()) <= top)
|
||||
|
||||
stopper = EarlyStopping("test", top=top, mode="min")
|
||||
|
||||
analysis = tune.run(train, num_samples=10, stop=stopper)
|
||||
self.assertTrue(
|
||||
all(t.status == Trial.TERMINATED for t in analysis.trials))
|
||||
self.assertTrue(len(analysis.dataframe()) <= top)
|
||||
|
||||
def testBadStoppingFunction(self):
|
||||
def train(config, reporter):
|
||||
for i in range(10):
|
||||
|
||||
Reference in New Issue
Block a user