From 87c4f36f02604aa08864fb55c917e61724f8b858 Mon Sep 17 00:00:00 2001 From: Kai Fricke Date: Tue, 8 Sep 2020 23:38:28 +0100 Subject: [PATCH] [tune] Added `timeout` parameter to `tune.run()`, (#10642) --- python/ray/tune/experiment.py | 11 ++++++- python/ray/tune/stopper.py | 54 +++++++++++++++++++++++++++++++ python/ray/tune/tests/test_api.py | 37 +++++++++++++++++++++ python/ray/tune/tune.py | 5 +++ 4 files changed, 106 insertions(+), 1 deletion(-) diff --git a/python/ray/tune/experiment.py b/python/ray/tune/experiment.py index c67646e40..5e1440f86 100644 --- a/python/ray/tune/experiment.py +++ b/python/ray/tune/experiment.py @@ -8,7 +8,8 @@ from ray.tune.error import TuneError from ray.tune.registry import register_trainable, get_trainable_cls from ray.tune.result import DEFAULT_RESULTS_DIR from ray.tune.sample import Domain -from ray.tune.stopper import FunctionStopper, Stopper +from ray.tune.stopper import CombinedStopper, FunctionStopper, Stopper, \ + TimeoutStopper from ray.tune.utils import detect_checkpoint_function logger = logging.getLogger(__name__) @@ -102,6 +103,7 @@ class Experiment: name, run, stop=None, + time_budget_s=None, config=None, resources_per_trial=None, num_samples=1, @@ -159,6 +161,13 @@ class Experiment: raise ValueError("Invalid stop criteria: {}. Must be a " "callable or dict".format(stop)) + if time_budget_s: + if self._stopper: + self._stopper = CombinedStopper(self._stopper, + TimeoutStopper(time_budget_s)) + else: + self._stopper = TimeoutStopper(time_budget_s) + _raise_on_durable(self._run_identifier, sync_to_driver, upload_dir) stdout_file, stderr_file = _validate_log_to_file(log_to_file) diff --git a/python/ray/tune/stopper.py b/python/ray/tune/stopper.py index 92be7c424..9dc20773d 100644 --- a/python/ray/tune/stopper.py +++ b/python/ray/tune/stopper.py @@ -1,5 +1,9 @@ +import time + import numpy as np +from ray import logger + class Stopper: """Base class for implementing a Tune experiment stopper. @@ -38,6 +42,17 @@ class Stopper: raise NotImplementedError +class CombinedStopper(Stopper): + def __init__(self, *stoppers: Stopper): + self._stoppers = stoppers + + def __call__(self, trial_id, result): + return any(s(trial_id, result) for s in self._stoppers) + + def stop_all(self): + return any(s.stop_all() for s in self._stoppers) + + class NoopStopper(Stopper): def __call__(self, trial_id, result): return False @@ -140,3 +155,42 @@ class EarlyStopping(Stopper): def stop_all(self): """Return whether to stop and prevent trials from starting.""" return self.has_plateaued() and self._iterations >= self._patience + + +class TimeoutStopper(Stopper): + """Stops all trials after a certain timeout. + + Args: + timeout (int|float|datetime.timedelta): Either a number specifying + the timeout in seconds, or a `datetime.timedelta` object. + """ + + def __init__(self, timeout): + from datetime import timedelta + if isinstance(timeout, timedelta): + self._timeout_seconds = timeout.total_seconds() + elif isinstance(timeout, (int, float)): + self._timeout_seconds = timeout + else: + raise ValueError( + "`timeout` parameter has to be either a number or a " + "`datetime.timedelta` object. Found: {}".format(type(timeout))) + + # To account for setup overhead, set the start time only after + # the first call to `stop_all()`. + self._start = None + + def __call__(self, trial_id, result): + return False + + def stop_all(self): + if not self._start: + self._start = time.time() + return False + + now = time.time() + if now - self._start >= self._timeout_seconds: + logger.info(f"Reached timeout of {self._timeout_seconds} seconds. " + f"Stopping all trials.") + return True + return False diff --git a/python/ray/tune/tests/test_api.py b/python/ray/tune/tests/test_api.py index 1331ea414..fa0213dd8 100644 --- a/python/ray/tune/tests/test_api.py +++ b/python/ray/tune/tests/test_api.py @@ -1108,6 +1108,43 @@ class TrainableFunctionApiTest(unittest.TestCase): self.assertIn("PRINT_STDERR", content) self.assertIn("LOG_STDERR", content) + def testTimeout(self): + from ray.tune.stopper import TimeoutStopper + import datetime + + def train(config): + for i in range(20): + tune.report(metric=i) + time.sleep(1) + + register_trainable("f1", train) + + start = time.time() + tune.run("f1", time_budget_s=5) + diff = time.time() - start + self.assertLess(diff, 10) + + # Metric should fire first + start = time.time() + tune.run("f1", stop={"metric": 3}, time_budget_s=7) + diff = time.time() - start + self.assertLess(diff, 7) + + # Timeout should fire first + start = time.time() + tune.run("f1", stop={"metric": 10}, time_budget_s=5) + diff = time.time() - start + self.assertLess(diff, 10) + + # Combined stopper. Shorter timeout should win. + start = time.time() + tune.run( + "f1", + stop=TimeoutStopper(10), + time_budget_s=datetime.timedelta(seconds=3)) + diff = time.time() - start + self.assertLess(diff, 9) + class ShimCreationTest(unittest.TestCase): def testCreateScheduler(self): diff --git a/python/ray/tune/tune.py b/python/ray/tune/tune.py index aca5f234b..f331bebec 100644 --- a/python/ray/tune/tune.py +++ b/python/ray/tune/tune.py @@ -69,6 +69,7 @@ def run( run_or_experiment, name=None, stop=None, + time_budget_s=None, config=None, resources_per_trial=None, num_samples=1, @@ -155,6 +156,9 @@ def run( ``ray.tune.Stopper``, which allows users to implement custom experiment-wide stopping (i.e., stopping an entire Tune run based on some time constraint). + time_budget_s (int|float|datetime.timedelta): Global time budget in + seconds after which all trials are stopped. Can also be a + ``datetime.timedelta`` object. config (dict): Algorithm-specific configuration for Tune variant generation (e.g. env, hyperparams). Defaults to empty dict. Custom search algorithms may ignore this. @@ -289,6 +293,7 @@ def run( name=name, run=exp, stop=stop, + time_budget_s=time_budget_s, config=config, resources_per_trial=resources_per_trial, num_samples=num_samples,