mirror of
https://github.com/wassname/ray.git
synced 2026-07-04 13:40:14 +08:00
[tune] Added timeout parameter to tune.run(), (#10642)
This commit is contained in:
@@ -8,7 +8,8 @@ from ray.tune.error import TuneError
|
||||
from ray.tune.registry import register_trainable, get_trainable_cls
|
||||
from ray.tune.result import DEFAULT_RESULTS_DIR
|
||||
from ray.tune.sample import Domain
|
||||
from ray.tune.stopper import FunctionStopper, Stopper
|
||||
from ray.tune.stopper import CombinedStopper, FunctionStopper, Stopper, \
|
||||
TimeoutStopper
|
||||
from ray.tune.utils import detect_checkpoint_function
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
@@ -102,6 +103,7 @@ class Experiment:
|
||||
name,
|
||||
run,
|
||||
stop=None,
|
||||
time_budget_s=None,
|
||||
config=None,
|
||||
resources_per_trial=None,
|
||||
num_samples=1,
|
||||
@@ -159,6 +161,13 @@ class Experiment:
|
||||
raise ValueError("Invalid stop criteria: {}. Must be a "
|
||||
"callable or dict".format(stop))
|
||||
|
||||
if time_budget_s:
|
||||
if self._stopper:
|
||||
self._stopper = CombinedStopper(self._stopper,
|
||||
TimeoutStopper(time_budget_s))
|
||||
else:
|
||||
self._stopper = TimeoutStopper(time_budget_s)
|
||||
|
||||
_raise_on_durable(self._run_identifier, sync_to_driver, upload_dir)
|
||||
|
||||
stdout_file, stderr_file = _validate_log_to_file(log_to_file)
|
||||
|
||||
@@ -1,5 +1,9 @@
|
||||
import time
|
||||
|
||||
import numpy as np
|
||||
|
||||
from ray import logger
|
||||
|
||||
|
||||
class Stopper:
|
||||
"""Base class for implementing a Tune experiment stopper.
|
||||
@@ -38,6 +42,17 @@ class Stopper:
|
||||
raise NotImplementedError
|
||||
|
||||
|
||||
class CombinedStopper(Stopper):
|
||||
def __init__(self, *stoppers: Stopper):
|
||||
self._stoppers = stoppers
|
||||
|
||||
def __call__(self, trial_id, result):
|
||||
return any(s(trial_id, result) for s in self._stoppers)
|
||||
|
||||
def stop_all(self):
|
||||
return any(s.stop_all() for s in self._stoppers)
|
||||
|
||||
|
||||
class NoopStopper(Stopper):
|
||||
def __call__(self, trial_id, result):
|
||||
return False
|
||||
@@ -140,3 +155,42 @@ class EarlyStopping(Stopper):
|
||||
def stop_all(self):
|
||||
"""Return whether to stop and prevent trials from starting."""
|
||||
return self.has_plateaued() and self._iterations >= self._patience
|
||||
|
||||
|
||||
class TimeoutStopper(Stopper):
|
||||
"""Stops all trials after a certain timeout.
|
||||
|
||||
Args:
|
||||
timeout (int|float|datetime.timedelta): Either a number specifying
|
||||
the timeout in seconds, or a `datetime.timedelta` object.
|
||||
"""
|
||||
|
||||
def __init__(self, timeout):
|
||||
from datetime import timedelta
|
||||
if isinstance(timeout, timedelta):
|
||||
self._timeout_seconds = timeout.total_seconds()
|
||||
elif isinstance(timeout, (int, float)):
|
||||
self._timeout_seconds = timeout
|
||||
else:
|
||||
raise ValueError(
|
||||
"`timeout` parameter has to be either a number or a "
|
||||
"`datetime.timedelta` object. Found: {}".format(type(timeout)))
|
||||
|
||||
# To account for setup overhead, set the start time only after
|
||||
# the first call to `stop_all()`.
|
||||
self._start = None
|
||||
|
||||
def __call__(self, trial_id, result):
|
||||
return False
|
||||
|
||||
def stop_all(self):
|
||||
if not self._start:
|
||||
self._start = time.time()
|
||||
return False
|
||||
|
||||
now = time.time()
|
||||
if now - self._start >= self._timeout_seconds:
|
||||
logger.info(f"Reached timeout of {self._timeout_seconds} seconds. "
|
||||
f"Stopping all trials.")
|
||||
return True
|
||||
return False
|
||||
|
||||
@@ -1108,6 +1108,43 @@ class TrainableFunctionApiTest(unittest.TestCase):
|
||||
self.assertIn("PRINT_STDERR", content)
|
||||
self.assertIn("LOG_STDERR", content)
|
||||
|
||||
def testTimeout(self):
|
||||
from ray.tune.stopper import TimeoutStopper
|
||||
import datetime
|
||||
|
||||
def train(config):
|
||||
for i in range(20):
|
||||
tune.report(metric=i)
|
||||
time.sleep(1)
|
||||
|
||||
register_trainable("f1", train)
|
||||
|
||||
start = time.time()
|
||||
tune.run("f1", time_budget_s=5)
|
||||
diff = time.time() - start
|
||||
self.assertLess(diff, 10)
|
||||
|
||||
# Metric should fire first
|
||||
start = time.time()
|
||||
tune.run("f1", stop={"metric": 3}, time_budget_s=7)
|
||||
diff = time.time() - start
|
||||
self.assertLess(diff, 7)
|
||||
|
||||
# Timeout should fire first
|
||||
start = time.time()
|
||||
tune.run("f1", stop={"metric": 10}, time_budget_s=5)
|
||||
diff = time.time() - start
|
||||
self.assertLess(diff, 10)
|
||||
|
||||
# Combined stopper. Shorter timeout should win.
|
||||
start = time.time()
|
||||
tune.run(
|
||||
"f1",
|
||||
stop=TimeoutStopper(10),
|
||||
time_budget_s=datetime.timedelta(seconds=3))
|
||||
diff = time.time() - start
|
||||
self.assertLess(diff, 9)
|
||||
|
||||
|
||||
class ShimCreationTest(unittest.TestCase):
|
||||
def testCreateScheduler(self):
|
||||
|
||||
@@ -69,6 +69,7 @@ def run(
|
||||
run_or_experiment,
|
||||
name=None,
|
||||
stop=None,
|
||||
time_budget_s=None,
|
||||
config=None,
|
||||
resources_per_trial=None,
|
||||
num_samples=1,
|
||||
@@ -155,6 +156,9 @@ def run(
|
||||
``ray.tune.Stopper``, which allows users to implement
|
||||
custom experiment-wide stopping (i.e., stopping an entire Tune
|
||||
run based on some time constraint).
|
||||
time_budget_s (int|float|datetime.timedelta): Global time budget in
|
||||
seconds after which all trials are stopped. Can also be a
|
||||
``datetime.timedelta`` object.
|
||||
config (dict): Algorithm-specific configuration for Tune variant
|
||||
generation (e.g. env, hyperparams). Defaults to empty dict.
|
||||
Custom search algorithms may ignore this.
|
||||
@@ -289,6 +293,7 @@ def run(
|
||||
name=name,
|
||||
run=exp,
|
||||
stop=stop,
|
||||
time_budget_s=time_budget_s,
|
||||
config=config,
|
||||
resources_per_trial=resources_per_trial,
|
||||
num_samples=num_samples,
|
||||
|
||||
Reference in New Issue
Block a user