[tune] Added timeout parameter to tune.run(), (#10642)

This commit is contained in:
Kai Fricke
2020-09-08 23:38:28 +01:00
committed by GitHub
parent 415be78cc0
commit 87c4f36f02
4 changed files with 106 additions and 1 deletions
+10 -1
View File
@@ -8,7 +8,8 @@ from ray.tune.error import TuneError
from ray.tune.registry import register_trainable, get_trainable_cls
from ray.tune.result import DEFAULT_RESULTS_DIR
from ray.tune.sample import Domain
from ray.tune.stopper import FunctionStopper, Stopper
from ray.tune.stopper import CombinedStopper, FunctionStopper, Stopper, \
TimeoutStopper
from ray.tune.utils import detect_checkpoint_function
logger = logging.getLogger(__name__)
@@ -102,6 +103,7 @@ class Experiment:
name,
run,
stop=None,
time_budget_s=None,
config=None,
resources_per_trial=None,
num_samples=1,
@@ -159,6 +161,13 @@ class Experiment:
raise ValueError("Invalid stop criteria: {}. Must be a "
"callable or dict".format(stop))
if time_budget_s:
if self._stopper:
self._stopper = CombinedStopper(self._stopper,
TimeoutStopper(time_budget_s))
else:
self._stopper = TimeoutStopper(time_budget_s)
_raise_on_durable(self._run_identifier, sync_to_driver, upload_dir)
stdout_file, stderr_file = _validate_log_to_file(log_to_file)
+54
View File
@@ -1,5 +1,9 @@
import time
import numpy as np
from ray import logger
class Stopper:
"""Base class for implementing a Tune experiment stopper.
@@ -38,6 +42,17 @@ class Stopper:
raise NotImplementedError
class CombinedStopper(Stopper):
def __init__(self, *stoppers: Stopper):
self._stoppers = stoppers
def __call__(self, trial_id, result):
return any(s(trial_id, result) for s in self._stoppers)
def stop_all(self):
return any(s.stop_all() for s in self._stoppers)
class NoopStopper(Stopper):
def __call__(self, trial_id, result):
return False
@@ -140,3 +155,42 @@ class EarlyStopping(Stopper):
def stop_all(self):
"""Return whether to stop and prevent trials from starting."""
return self.has_plateaued() and self._iterations >= self._patience
class TimeoutStopper(Stopper):
"""Stops all trials after a certain timeout.
Args:
timeout (int|float|datetime.timedelta): Either a number specifying
the timeout in seconds, or a `datetime.timedelta` object.
"""
def __init__(self, timeout):
from datetime import timedelta
if isinstance(timeout, timedelta):
self._timeout_seconds = timeout.total_seconds()
elif isinstance(timeout, (int, float)):
self._timeout_seconds = timeout
else:
raise ValueError(
"`timeout` parameter has to be either a number or a "
"`datetime.timedelta` object. Found: {}".format(type(timeout)))
# To account for setup overhead, set the start time only after
# the first call to `stop_all()`.
self._start = None
def __call__(self, trial_id, result):
return False
def stop_all(self):
if not self._start:
self._start = time.time()
return False
now = time.time()
if now - self._start >= self._timeout_seconds:
logger.info(f"Reached timeout of {self._timeout_seconds} seconds. "
f"Stopping all trials.")
return True
return False
+37
View File
@@ -1108,6 +1108,43 @@ class TrainableFunctionApiTest(unittest.TestCase):
self.assertIn("PRINT_STDERR", content)
self.assertIn("LOG_STDERR", content)
def testTimeout(self):
from ray.tune.stopper import TimeoutStopper
import datetime
def train(config):
for i in range(20):
tune.report(metric=i)
time.sleep(1)
register_trainable("f1", train)
start = time.time()
tune.run("f1", time_budget_s=5)
diff = time.time() - start
self.assertLess(diff, 10)
# Metric should fire first
start = time.time()
tune.run("f1", stop={"metric": 3}, time_budget_s=7)
diff = time.time() - start
self.assertLess(diff, 7)
# Timeout should fire first
start = time.time()
tune.run("f1", stop={"metric": 10}, time_budget_s=5)
diff = time.time() - start
self.assertLess(diff, 10)
# Combined stopper. Shorter timeout should win.
start = time.time()
tune.run(
"f1",
stop=TimeoutStopper(10),
time_budget_s=datetime.timedelta(seconds=3))
diff = time.time() - start
self.assertLess(diff, 9)
class ShimCreationTest(unittest.TestCase):
def testCreateScheduler(self):
+5
View File
@@ -69,6 +69,7 @@ def run(
run_or_experiment,
name=None,
stop=None,
time_budget_s=None,
config=None,
resources_per_trial=None,
num_samples=1,
@@ -155,6 +156,9 @@ def run(
``ray.tune.Stopper``, which allows users to implement
custom experiment-wide stopping (i.e., stopping an entire Tune
run based on some time constraint).
time_budget_s (int|float|datetime.timedelta): Global time budget in
seconds after which all trials are stopped. Can also be a
``datetime.timedelta`` object.
config (dict): Algorithm-specific configuration for Tune variant
generation (e.g. env, hyperparams). Defaults to empty dict.
Custom search algorithms may ignore this.
@@ -289,6 +293,7 @@ def run(
name=name,
run=exp,
stop=stop,
time_budget_s=time_budget_s,
config=config,
resources_per_trial=resources_per_trial,
num_samples=num_samples,