From 2ff26f13d210cc4c2d9d1088b9f97e9886603166 Mon Sep 17 00:00:00 2001 From: Luca Cappelletti Date: Sun, 17 May 2020 21:18:59 +0200 Subject: [PATCH] [tune] Added EarlyStopping and relative test suite (#8459) --- python/ray/tune/__init__.py | 11 ++++--- python/ray/tune/stopper.py | 52 +++++++++++++++++++++++++++++++ python/ray/tune/tests/test_api.py | 34 +++++++++++++++++++- 3 files changed, 91 insertions(+), 6 deletions(-) diff --git a/python/ray/tune/__init__.py b/python/ray/tune/__init__.py index 16f2b9c24..58753f98f 100644 --- a/python/ray/tune/__init__.py +++ b/python/ray/tune/__init__.py @@ -2,7 +2,7 @@ from ray.tune.error import TuneError from ray.tune.tune import run_experiments, run from ray.tune.experiment import Experiment from ray.tune.analysis import ExperimentAnalysis, Analysis -from ray.tune.stopper import Stopper +from ray.tune.stopper import Stopper, EarlyStopping from ray.tune.registry import register_env, register_trainable from ray.tune.trainable import Trainable from ray.tune.durable_trainable import DurableTrainable @@ -17,8 +17,9 @@ from ray.tune.sample import (function, sample_from, uniform, choice, randint, __all__ = [ "Trainable", "DurableTrainable", "TuneError", "grid_search", "register_env", "register_trainable", "run", "run_experiments", "Stopper", - "Experiment", "function", "sample_from", "track", "uniform", "choice", - "randint", "randn", "loguniform", "ExperimentAnalysis", "Analysis", - "CLIReporter", "JupyterNotebookReporter", "ProgressReporter", "report", - "get_trial_dir", "get_trial_name", "get_trial_id" + "EarlyStopping", "Experiment", "function", "sample_from", "track", + "uniform", "choice", "randint", "randn", "loguniform", + "ExperimentAnalysis", "Analysis", "CLIReporter", "JupyterNotebookReporter", + "ProgressReporter", "report", "get_trial_dir", "get_trial_name", + "get_trial_id" ] diff --git a/python/ray/tune/stopper.py b/python/ray/tune/stopper.py index 984239105..e72cb2284 100644 --- a/python/ray/tune/stopper.py +++ b/python/ray/tune/stopper.py @@ -1,3 +1,6 @@ +import numpy as np + + class Stopper: """Base class for implementing a Tune experiment stopper. @@ -61,3 +64,52 @@ class FunctionStopper(Stopper): "Stop object must be ray.tune.Stopper subclass to be detected " "correctly.") return is_function + + +class EarlyStopping(Stopper): + def __init__(self, metric, std=0.001, top=10, mode="min"): + """Create the EarlyStopping object. + + Args: + metric (str): The metric to be monitored. + std (float): The minimal standard deviation after which + the tuning process has to stop. + top (int): The number of best model to consider. + mode (str): The mode to select the top results. + Can either be "min" or "max". + + Raises: + ValueError: If the mode parameter is not "min" nor "max". + ValueError: If the top parameter is not an integer + greater than 1. + ValueError: If the standard deviation parameter is not + a strictly positive float. + """ + if mode not in ("min", "max"): + raise ValueError("The mode parameter can only be" + " either min or max.") + if not isinstance(top, int) or top <= 1: + raise ValueError("Top results to consider must be" + " a positive integer greater than one.") + if not isinstance(std, float) or std <= 0: + raise ValueError("The standard deviation must be" + " a strictly positive float number.") + self._mode = mode + self._metric = metric + self._std = std + self._top = top + self._top_values = [] + + def __call__(self, trial_id, result): + """Return a boolean representing if the tuning has to stop.""" + self._top_values.append(result[self._metric]) + if self._mode == "min": + self._top_values = sorted(self._top_values)[:self._top] + else: + self._top_values = sorted(self._top_values)[-self._top:] + return self.stop_all() + + def stop_all(self): + """Return whether to stop and prevent trials from starting.""" + return (len(self._top_values) == self._top + and np.std(self._top_values) <= self._std) diff --git a/python/ray/tune/tests/test_api.py b/python/ray/tune/tests/test_api.py index 60bfbe23b..edac66c3b 100644 --- a/python/ray/tune/tests/test_api.py +++ b/python/ray/tune/tests/test_api.py @@ -10,7 +10,8 @@ import ray from ray.rllib import _register_all from ray import tune -from ray.tune import DurableTrainable, Trainable, TuneError, Stopper +from ray.tune import (DurableTrainable, Trainable, TuneError, Stopper, + EarlyStopping) from ray.tune import register_env, register_trainable, run_experiments from ray.tune.schedulers import TrialScheduler, FIFOScheduler from ray.tune.trial import Trial @@ -487,6 +488,37 @@ class TrainableFunctionApiTest(unittest.TestCase): t.last_result.get("training_iteration") is None for t in trials)) + def testEarlyStopping(self): + def train(config, reporter): + reporter(test=0) + + top = 3 + + with self.assertRaises(ValueError): + EarlyStopping("test", top=0) + with self.assertRaises(ValueError): + EarlyStopping("test", top="0") + with self.assertRaises(ValueError): + EarlyStopping("test", std=0) + with self.assertRaises(ValueError): + EarlyStopping("test", std="0") + with self.assertRaises(ValueError): + EarlyStopping("test", mode="0") + + stopper = EarlyStopping("test", top=top, mode="min") + + analysis = tune.run(train, num_samples=10, stop=stopper) + self.assertTrue( + all(t.status == Trial.TERMINATED for t in analysis.trials)) + self.assertTrue(len(analysis.dataframe()) <= top) + + stopper = EarlyStopping("test", top=top, mode="min") + + analysis = tune.run(train, num_samples=10, stop=stopper) + self.assertTrue( + all(t.status == Trial.TERMINATED for t in analysis.trials)) + self.assertTrue(len(analysis.dataframe()) <= top) + def testBadStoppingFunction(self): def train(config, reporter): for i in range(10):