diff --git a/python/ray/tune/tests/test_trial_runner_3.py b/python/ray/tune/tests/test_trial_runner_3.py index ab10112d4..b0c4a7063 100644 --- a/python/ray/tune/tests/test_trial_runner_3.py +++ b/python/ray/tune/tests/test_trial_runner_3.py @@ -695,6 +695,25 @@ class TrialRunnerTest3(unittest.TestCase): self.assertTrue(trials[0].has_checkpoint()) self.assertEqual(num_checkpoints(trials[0]), 2) + @patch("ray.tune.syncer.CLOUD_SYNC_PERIOD", 0) + def testCheckpointAutoPeriod(self): + # This makes checkpointing take 2 seconds. + def sync_up(source, target): + time.sleep(2) + return True + + runner = TrialRunner( + local_checkpoint_dir=self.tmpdir, + checkpoint_period="auto", + sync_to_cloud=sync_up, + remote_checkpoint_dir="fake") + runner.add_trial(Trial("__fake", config={"user_checkpoint_freq": 1})) + + runner.step() # Run one step, this will trigger checkpointing + + self.assertGreaterEqual(runner._checkpoint_manager._checkpoint_period, + 38.) + class SearchAlgorithmTest(unittest.TestCase): @classmethod diff --git a/python/ray/tune/trial_runner.py b/python/ray/tune/trial_runner.py index c487190f7..d8b45b19b 100644 --- a/python/ray/tune/trial_runner.py +++ b/python/ray/tune/trial_runner.py @@ -1,4 +1,4 @@ -from typing import Optional +from typing import Optional, Union import click from datetime import datetime @@ -16,11 +16,11 @@ from ray.tune.stopper import NoopStopper from ray.tune.ray_trial_executor import RayTrialExecutor from ray.tune.result import (DEFAULT_METRIC, TIME_THIS_ITER_S, RESULT_DUPLICATE, SHOULD_CHECKPOINT) -from ray.tune.syncer import get_cloud_syncer +from ray.tune.syncer import CloudSyncer, get_cloud_syncer from ray.tune.trial import Checkpoint, Trial from ray.tune.schedulers import FIFOScheduler, TrialScheduler -from ray.tune.suggest import BasicVariantGenerator -from ray.tune.utils import warn_if_slow, flatten_dict, env_integer +from ray.tune.suggest import BasicVariantGenerator, SearchAlgorithm +from ray.tune.utils import warn_if_slow, flatten_dict from ray.tune.utils.log import Verbosity, has_verbosity from ray.tune.utils.placement_groups import TUNE_MAX_PENDING_TRIALS_PG from ray.tune.utils.serialization import TuneFunctionDecoder, \ @@ -42,6 +42,106 @@ def _find_newest_ckpt(ckpt_dir): return max(full_paths) +class _ExperimentCheckpointManager: + """Helper class for managing experiment-level checkpoints. + + This class implements the ``checkpoint()`` method used to checkpoint + experiment state. When called, this will serialize and write to disk + the state of the trial runner, trial executor, and search algorithm, to + a specified checkpoint file. + + The checkpoint period is automatically adjusted to + ``max(10, time_per_checkpoint * 19)``. This means that at most 5% of the + time (1/20) will be used for writing checkpoints, while 95% of the time + (19/20) will be used to handle the rest of the training loop. + + """ + + def __init__(self, checkpoint_dir: str, + checkpoint_period: Union[int, float, str], start_time: float, + session_str: str, syncer: CloudSyncer): + self._checkpoint_dir = checkpoint_dir + self._auto_checkpoint_enabled = checkpoint_period == "auto" + if self._auto_checkpoint_enabled: + self._checkpoint_period = 10. # Initial value + else: + self._checkpoint_period = float(checkpoint_period) + + self._start_time = start_time + self._session_str = session_str + + self._syncer = syncer + + self._last_checkpoint_time = 0. + + @property + def auto_checkpoint_enabled(self): + return self._auto_checkpoint_enabled + + def checkpoint(self, + checkpoint_file: str, + trial_runner: "TrialRunner", + trial_executor: RayTrialExecutor, + search_alg: SearchAlgorithm, + force=False): + """Saves execution state to `self._local_checkpoint_dir`. + + Overwrites the current session checkpoint, which starts when self + is instantiated. Throttle depends on self._checkpoint_period. + + Also automatically saves the search algorithm to the local + checkpoint dir. + + Args: + force (bool): Forces a checkpoint despite checkpoint_period. + """ + if not self._checkpoint_dir: + return + + now = time.time() + if now - self._last_checkpoint_time < self._checkpoint_period and ( + not force): + return + + def _serialize_and_write(): + runner_state = { + "checkpoints": list(trial_executor.get_checkpoints().values()), + "runner_data": trial_runner.__getstate__(), + "stats": { + "start_time": self._start_time, + "timestamp": self._last_checkpoint_time + } + } + tmp_file_name = os.path.join(self._checkpoint_dir, + ".tmp_checkpoint") + with open(tmp_file_name, "w") as f: + json.dump(runner_state, f, indent=2, cls=TuneFunctionEncoder) + + os.replace(tmp_file_name, checkpoint_file) + search_alg.save_to_dir( + self._checkpoint_dir, session_str=self._session_str) + + checkpoint_time_start = time.monotonic() + _serialize_and_write() + if force: + self._syncer.sync_up() + else: + self._syncer.sync_up_if_needed() + checkpoint_time_taken = time.monotonic() - checkpoint_time_start + + if self._auto_checkpoint_enabled: + # Multiplying this time by 19 means we spend ~5% of the time + # writing global checkpoints and 95% of the time processing trials + self._checkpoint_period = max(10., checkpoint_time_taken * 19) + logger.debug(f"Global experiment checkpointing took " + f"{checkpoint_time_taken:.2f} seconds. " + f"Adjusting checkpoint period to " + f"{self._checkpoint_period:.2f} seconds.") + + self._last_checkpoint_time = time.time() + return self._checkpoint_dir + + class TrialRunner: """A TrialRunner implements the event loop for scheduling trials on Ray. @@ -82,8 +182,10 @@ class TrialRunner: If fail_fast='raise' provided, Tune will automatically raise the exception received by the Trainable. fail_fast='raise' can easily leak resources and should be used with caution. - checkpoint_period (int): Trial runner checkpoint periodicity in - seconds. Defaults to 10. + checkpoint_period (int|str): Trial runner checkpoint periodicity in + seconds. Defaults to ``"auto"``, which adjusts checkpointing + time so that at most 5% of the time is spent on writing + checkpoints. trial_executor (TrialExecutor): Defaults to RayTrialExecutor. callbacks (list): List of callbacks that will be called at different times in the training loop. Must be instances of the @@ -183,9 +285,7 @@ class TrialRunner: self._start_time = time.time() self._last_checkpoint_time = -float("inf") - if checkpoint_period is None: - checkpoint_period = env_integer("TUNE_GLOBAL_CHECKPOINT_S", 10) - self._checkpoint_period = checkpoint_period + self._session_str = datetime.fromtimestamp( self._start_time).strftime("%Y-%m-%d_%H-%M-%S") self.checkpoint_file = None @@ -196,6 +296,20 @@ class TrialRunner: self._callbacks = CallbackList(callbacks or []) + if checkpoint_period is None: + checkpoint_period = os.getenv("TUNE_GLOBAL_CHECKPOINT_S", "auto") + + self._checkpoint_period = checkpoint_period + self._checkpoint_manager = self._create_checkpoint_manager() + + def _create_checkpoint_manager(self): + return _ExperimentCheckpointManager( + checkpoint_dir=self._local_checkpoint_dir, + checkpoint_period=self._checkpoint_period, + start_time=self._start_time, + session_str=self._session_str, + syncer=self._syncer) + @property def resumed(self): return self._resumed @@ -269,36 +383,23 @@ class TrialRunner: Args: force (bool): Forces a checkpoint despite checkpoint_period. """ - if not self._local_checkpoint_dir: - return - now = time.time() - if now - self._last_checkpoint_time < self._checkpoint_period and ( - not force): - return - self._last_checkpoint_time = now - runner_state = { - "checkpoints": list( - self.trial_executor.get_checkpoints().values()), - "runner_data": self.__getstate__(), - "stats": { - "start_time": self._start_time, - "timestamp": self._last_checkpoint_time - } - } - tmp_file_name = os.path.join(self._local_checkpoint_dir, - ".tmp_checkpoint") - with open(tmp_file_name, "w") as f: - json.dump(runner_state, f, indent=2, cls=TuneFunctionEncoder) + with warn_if_slow( + "experiment_checkpoint", + message="Checkpointing the experiment state took " + "{duration:.3f} s, which may be a performance " + "bottleneck. Please ensure the " + "`TUNE_GLOBAL_CHECKPOINT_S` environment variable is " + "something significantly higher than this duration " + "to ensure compute time is mostly spent on the main " + "training loop.", + disable=self._checkpoint_manager.auto_checkpoint_enabled): - os.replace(tmp_file_name, self.checkpoint_file) - self._search_alg.save_to_dir( - self._local_checkpoint_dir, session_str=self._session_str) - - if force: - self._syncer.sync_up() - else: - self._syncer.sync_up_if_needed() - return self._local_checkpoint_dir + self._checkpoint_manager.checkpoint( + checkpoint_file=self.checkpoint_file, + trial_runner=self, + trial_executor=self.trial_executor, + search_alg=self._search_alg, + force=force) def resume(self, run_errored_only=False): """Resumes all checkpointed trials from previous run. @@ -406,16 +507,7 @@ class TrialRunner: self._stop_experiment_if_needed() try: - with warn_if_slow( - "experiment_checkpoint", - message="Checkpointing the experiment state took " - "{duration:.3f} s, which may be a performance " - "bottleneck. Please ensure the " - "`TUNE_GLOBAL_CHECKPOINT_S` environment variable is " - "something significantly higher than this duration " - "to ensure compute time is mostly spent on the main " - "training loop."): - self.checkpoint() + self.checkpoint() except Exception as e: logger.warning(f"Trial Runner checkpointing failed: {str(e)}") self._iteration += 1 @@ -1028,7 +1120,8 @@ class TrialRunner: for k in [ "_trials", "_stop_queue", "_server", "_search_alg", "_scheduler_alg", "_pending_trial_queue_times", - "trial_executor", "_syncer", "_callbacks" + "trial_executor", "_syncer", "_callbacks", + "_checkpoint_manager" ]: del state[k] state["launch_web_server"] = bool(self._server) @@ -1045,5 +1138,7 @@ class TrialRunner: self.__dict__.setdefault("_start_time", start_time) self.__dict__.update(state) + self._checkpoint_manager = self._create_checkpoint_manager() + if launch_web_server: self._server = TuneServer(self, self._server_port) diff --git a/python/ray/tune/utils/util.py b/python/ray/tune/utils/util.py index 47a6b648e..688261fdb 100644 --- a/python/ray/tune/utils/util.py +++ b/python/ray/tune/utils/util.py @@ -133,11 +133,13 @@ class warn_if_slow: def __init__(self, name: str, threshold: Optional[float] = None, - message: Optional[str] = None): + message: Optional[str] = None, + disable: bool = False): self.name = name self.threshold = threshold or self.DEFAULT_THRESHOLD self.message = message or self.DEFAULT_MESSAGE self.too_slow = False + self.disable = disable def __enter__(self): self.start = time.time() @@ -145,6 +147,8 @@ class warn_if_slow: def __exit__(self, type, value, traceback): now = time.time() + if self.disable: + return if now - self.start > self.threshold and now - START_OF_TIME > 60.0: self.too_slow = True duration = now - self.start