From 5b330de1820b6b0115833dff5b890fd7bbf56080 Mon Sep 17 00:00:00 2001 From: Luca Cappelletti Date: Mon, 18 May 2020 22:12:16 +0200 Subject: [PATCH] [Tune] Introduced patience to early stopping (#8484) --- python/ray/tune/stopper.py | 33 ++++++++++++++++++++++++++++--- python/ray/tune/tests/test_api.py | 10 ++++++++++ 2 files changed, 40 insertions(+), 3 deletions(-) diff --git a/python/ray/tune/stopper.py b/python/ray/tune/stopper.py index e72cb2284..92be7c424 100644 --- a/python/ray/tune/stopper.py +++ b/python/ray/tune/stopper.py @@ -67,9 +67,13 @@ class FunctionStopper(Stopper): class EarlyStopping(Stopper): - def __init__(self, metric, std=0.001, top=10, mode="min"): + def __init__(self, metric, std=0.001, top=10, mode="min", patience=0): """Create the EarlyStopping object. + Stops the entire experiment when the metric has plateaued + for more than the given amount of iterations specified in + the patience parameter. + Args: metric (str): The metric to be monitored. std (float): The minimal standard deviation after which @@ -77,6 +81,8 @@ class EarlyStopping(Stopper): top (int): The number of best model to consider. mode (str): The mode to select the top results. Can either be "min" or "max". + patience (int): Number of epochs to wait for + a change in the top models. Raises: ValueError: If the mode parameter is not "min" nor "max". @@ -84,6 +90,8 @@ class EarlyStopping(Stopper): greater than 1. ValueError: If the standard deviation parameter is not a strictly positive float. + ValueError: If the patience parameter is not + a strictly positive integer. """ if mode not in ("min", "max"): raise ValueError("The mode parameter can only be" @@ -91,11 +99,16 @@ class EarlyStopping(Stopper): if not isinstance(top, int) or top <= 1: raise ValueError("Top results to consider must be" " a positive integer greater than one.") + if not isinstance(patience, int) or patience < 0: + raise ValueError("Patience must be" + " a strictly positive integer.") if not isinstance(std, float) or std <= 0: raise ValueError("The standard deviation must be" " a strictly positive float number.") self._mode = mode self._metric = metric + self._patience = patience + self._iterations = 0 self._std = std self._top = top self._top_values = [] @@ -107,9 +120,23 @@ class EarlyStopping(Stopper): self._top_values = sorted(self._top_values)[:self._top] else: self._top_values = sorted(self._top_values)[-self._top:] + + # If the current iteration has to stop + if self.has_plateaued(): + # we increment the total counter of iterations + self._iterations += 1 + else: + # otherwise we reset the counter + self._iterations = 0 + + # and then call the method that re-executes + # the checks, including the iterations. return self.stop_all() + def has_plateaued(self): + return (len(self._top_values) == self._top + and np.std(self._top_values) <= self._std) + def stop_all(self): """Return whether to stop and prevent trials from starting.""" - return (len(self._top_values) == self._top - and np.std(self._top_values) <= self._std) + return self.has_plateaued() and self._iterations >= self._patience diff --git a/python/ray/tune/tests/test_api.py b/python/ray/tune/tests/test_api.py index edac66c3b..e578c5359 100644 --- a/python/ray/tune/tests/test_api.py +++ b/python/ray/tune/tests/test_api.py @@ -500,6 +500,8 @@ class TrainableFunctionApiTest(unittest.TestCase): EarlyStopping("test", top="0") with self.assertRaises(ValueError): EarlyStopping("test", std=0) + with self.assertRaises(ValueError): + EarlyStopping("test", patience=-1) with self.assertRaises(ValueError): EarlyStopping("test", std="0") with self.assertRaises(ValueError): @@ -512,6 +514,14 @@ class TrainableFunctionApiTest(unittest.TestCase): all(t.status == Trial.TERMINATED for t in analysis.trials)) self.assertTrue(len(analysis.dataframe()) <= top) + patience = 10 + stopper = EarlyStopping("test", top=top, mode="min", patience=patience) + + analysis = tune.run(train, num_samples=100, stop=stopper) + self.assertTrue( + all(t.status == Trial.TERMINATED for t in analysis.trials)) + self.assertTrue(len(analysis.dataframe()) <= patience) + stopper = EarlyStopping("test", top=top, mode="min") analysis = tune.run(train, num_samples=10, stop=stopper)