mirror of
https://github.com/wassname/ray.git
synced 2026-06-28 23:57:45 +08:00
[Tune] Introduced patience to early stopping (#8484)
This commit is contained in:
@@ -67,9 +67,13 @@ class FunctionStopper(Stopper):
|
||||
|
||||
|
||||
class EarlyStopping(Stopper):
|
||||
def __init__(self, metric, std=0.001, top=10, mode="min"):
|
||||
def __init__(self, metric, std=0.001, top=10, mode="min", patience=0):
|
||||
"""Create the EarlyStopping object.
|
||||
|
||||
Stops the entire experiment when the metric has plateaued
|
||||
for more than the given amount of iterations specified in
|
||||
the patience parameter.
|
||||
|
||||
Args:
|
||||
metric (str): The metric to be monitored.
|
||||
std (float): The minimal standard deviation after which
|
||||
@@ -77,6 +81,8 @@ class EarlyStopping(Stopper):
|
||||
top (int): The number of best model to consider.
|
||||
mode (str): The mode to select the top results.
|
||||
Can either be "min" or "max".
|
||||
patience (int): Number of epochs to wait for
|
||||
a change in the top models.
|
||||
|
||||
Raises:
|
||||
ValueError: If the mode parameter is not "min" nor "max".
|
||||
@@ -84,6 +90,8 @@ class EarlyStopping(Stopper):
|
||||
greater than 1.
|
||||
ValueError: If the standard deviation parameter is not
|
||||
a strictly positive float.
|
||||
ValueError: If the patience parameter is not
|
||||
a strictly positive integer.
|
||||
"""
|
||||
if mode not in ("min", "max"):
|
||||
raise ValueError("The mode parameter can only be"
|
||||
@@ -91,11 +99,16 @@ class EarlyStopping(Stopper):
|
||||
if not isinstance(top, int) or top <= 1:
|
||||
raise ValueError("Top results to consider must be"
|
||||
" a positive integer greater than one.")
|
||||
if not isinstance(patience, int) or patience < 0:
|
||||
raise ValueError("Patience must be"
|
||||
" a strictly positive integer.")
|
||||
if not isinstance(std, float) or std <= 0:
|
||||
raise ValueError("The standard deviation must be"
|
||||
" a strictly positive float number.")
|
||||
self._mode = mode
|
||||
self._metric = metric
|
||||
self._patience = patience
|
||||
self._iterations = 0
|
||||
self._std = std
|
||||
self._top = top
|
||||
self._top_values = []
|
||||
@@ -107,9 +120,23 @@ class EarlyStopping(Stopper):
|
||||
self._top_values = sorted(self._top_values)[:self._top]
|
||||
else:
|
||||
self._top_values = sorted(self._top_values)[-self._top:]
|
||||
|
||||
# If the current iteration has to stop
|
||||
if self.has_plateaued():
|
||||
# we increment the total counter of iterations
|
||||
self._iterations += 1
|
||||
else:
|
||||
# otherwise we reset the counter
|
||||
self._iterations = 0
|
||||
|
||||
# and then call the method that re-executes
|
||||
# the checks, including the iterations.
|
||||
return self.stop_all()
|
||||
|
||||
def has_plateaued(self):
|
||||
return (len(self._top_values) == self._top
|
||||
and np.std(self._top_values) <= self._std)
|
||||
|
||||
def stop_all(self):
|
||||
"""Return whether to stop and prevent trials from starting."""
|
||||
return (len(self._top_values) == self._top
|
||||
and np.std(self._top_values) <= self._std)
|
||||
return self.has_plateaued() and self._iterations >= self._patience
|
||||
|
||||
@@ -500,6 +500,8 @@ class TrainableFunctionApiTest(unittest.TestCase):
|
||||
EarlyStopping("test", top="0")
|
||||
with self.assertRaises(ValueError):
|
||||
EarlyStopping("test", std=0)
|
||||
with self.assertRaises(ValueError):
|
||||
EarlyStopping("test", patience=-1)
|
||||
with self.assertRaises(ValueError):
|
||||
EarlyStopping("test", std="0")
|
||||
with self.assertRaises(ValueError):
|
||||
@@ -512,6 +514,14 @@ class TrainableFunctionApiTest(unittest.TestCase):
|
||||
all(t.status == Trial.TERMINATED for t in analysis.trials))
|
||||
self.assertTrue(len(analysis.dataframe()) <= top)
|
||||
|
||||
patience = 10
|
||||
stopper = EarlyStopping("test", top=top, mode="min", patience=patience)
|
||||
|
||||
analysis = tune.run(train, num_samples=100, stop=stopper)
|
||||
self.assertTrue(
|
||||
all(t.status == Trial.TERMINATED for t in analysis.trials))
|
||||
self.assertTrue(len(analysis.dataframe()) <= patience)
|
||||
|
||||
stopper = EarlyStopping("test", top=top, mode="min")
|
||||
|
||||
analysis = tune.run(train, num_samples=10, stop=stopper)
|
||||
|
||||
Reference in New Issue
Block a user