[tune] dynamic global checkpointing interval (#13736)

* Add scalability tests

* Move experiment checkpointing into a manager class

* Dynamic global checkpointing

* Actually write checkpoints

* Remove debug message

* Pass `force`

* Pre-review

* Revert scalability commits

* Revert scalability commits

* Apply suggestions from code review
This commit is contained in:
Kai Fricke
2021-01-29 17:14:46 +01:00
committed by GitHub
parent 0f3a3e14aa
commit 9a413144b1
3 changed files with 168 additions and 50 deletions
@@ -695,6 +695,25 @@ class TrialRunnerTest3(unittest.TestCase):
self.assertTrue(trials[0].has_checkpoint())
self.assertEqual(num_checkpoints(trials[0]), 2)
@patch("ray.tune.syncer.CLOUD_SYNC_PERIOD", 0)
def testCheckpointAutoPeriod(self):
# This makes checkpointing take 2 seconds.
def sync_up(source, target):
time.sleep(2)
return True
runner = TrialRunner(
local_checkpoint_dir=self.tmpdir,
checkpoint_period="auto",
sync_to_cloud=sync_up,
remote_checkpoint_dir="fake")
runner.add_trial(Trial("__fake", config={"user_checkpoint_freq": 1}))
runner.step() # Run one step, this will trigger checkpointing
self.assertGreaterEqual(runner._checkpoint_manager._checkpoint_period,
38.)
class SearchAlgorithmTest(unittest.TestCase):
@classmethod
+144 -49
View File
@@ -1,4 +1,4 @@
from typing import Optional
from typing import Optional, Union
import click
from datetime import datetime
@@ -16,11 +16,11 @@ from ray.tune.stopper import NoopStopper
from ray.tune.ray_trial_executor import RayTrialExecutor
from ray.tune.result import (DEFAULT_METRIC, TIME_THIS_ITER_S,
RESULT_DUPLICATE, SHOULD_CHECKPOINT)
from ray.tune.syncer import get_cloud_syncer
from ray.tune.syncer import CloudSyncer, get_cloud_syncer
from ray.tune.trial import Checkpoint, Trial
from ray.tune.schedulers import FIFOScheduler, TrialScheduler
from ray.tune.suggest import BasicVariantGenerator
from ray.tune.utils import warn_if_slow, flatten_dict, env_integer
from ray.tune.suggest import BasicVariantGenerator, SearchAlgorithm
from ray.tune.utils import warn_if_slow, flatten_dict
from ray.tune.utils.log import Verbosity, has_verbosity
from ray.tune.utils.placement_groups import TUNE_MAX_PENDING_TRIALS_PG
from ray.tune.utils.serialization import TuneFunctionDecoder, \
@@ -42,6 +42,106 @@ def _find_newest_ckpt(ckpt_dir):
return max(full_paths)
class _ExperimentCheckpointManager:
"""Helper class for managing experiment-level checkpoints.
This class implements the ``checkpoint()`` method used to checkpoint
experiment state. When called, this will serialize and write to disk
the state of the trial runner, trial executor, and search algorithm, to
a specified checkpoint file.
The checkpoint period is automatically adjusted to
``max(10, time_per_checkpoint * 19)``. This means that at most 5% of the
time (1/20) will be used for writing checkpoints, while 95% of the time
(19/20) will be used to handle the rest of the training loop.
"""
def __init__(self, checkpoint_dir: str,
checkpoint_period: Union[int, float, str], start_time: float,
session_str: str, syncer: CloudSyncer):
self._checkpoint_dir = checkpoint_dir
self._auto_checkpoint_enabled = checkpoint_period == "auto"
if self._auto_checkpoint_enabled:
self._checkpoint_period = 10. # Initial value
else:
self._checkpoint_period = float(checkpoint_period)
self._start_time = start_time
self._session_str = session_str
self._syncer = syncer
self._last_checkpoint_time = 0.
@property
def auto_checkpoint_enabled(self):
return self._auto_checkpoint_enabled
def checkpoint(self,
checkpoint_file: str,
trial_runner: "TrialRunner",
trial_executor: RayTrialExecutor,
search_alg: SearchAlgorithm,
force=False):
"""Saves execution state to `self._local_checkpoint_dir`.
Overwrites the current session checkpoint, which starts when self
is instantiated. Throttle depends on self._checkpoint_period.
Also automatically saves the search algorithm to the local
checkpoint dir.
Args:
force (bool): Forces a checkpoint despite checkpoint_period.
"""
if not self._checkpoint_dir:
return
now = time.time()
if now - self._last_checkpoint_time < self._checkpoint_period and (
not force):
return
def _serialize_and_write():
runner_state = {
"checkpoints": list(trial_executor.get_checkpoints().values()),
"runner_data": trial_runner.__getstate__(),
"stats": {
"start_time": self._start_time,
"timestamp": self._last_checkpoint_time
}
}
tmp_file_name = os.path.join(self._checkpoint_dir,
".tmp_checkpoint")
with open(tmp_file_name, "w") as f:
json.dump(runner_state, f, indent=2, cls=TuneFunctionEncoder)
os.replace(tmp_file_name, checkpoint_file)
search_alg.save_to_dir(
self._checkpoint_dir, session_str=self._session_str)
checkpoint_time_start = time.monotonic()
_serialize_and_write()
if force:
self._syncer.sync_up()
else:
self._syncer.sync_up_if_needed()
checkpoint_time_taken = time.monotonic() - checkpoint_time_start
if self._auto_checkpoint_enabled:
# Multiplying this time by 19 means we spend ~5% of the time
# writing global checkpoints and 95% of the time processing trials
self._checkpoint_period = max(10., checkpoint_time_taken * 19)
logger.debug(f"Global experiment checkpointing took "
f"{checkpoint_time_taken:.2f} seconds. "
f"Adjusting checkpoint period to "
f"{self._checkpoint_period:.2f} seconds.")
self._last_checkpoint_time = time.time()
return self._checkpoint_dir
class TrialRunner:
"""A TrialRunner implements the event loop for scheduling trials on Ray.
@@ -82,8 +182,10 @@ class TrialRunner:
If fail_fast='raise' provided, Tune will automatically
raise the exception received by the Trainable. fail_fast='raise'
can easily leak resources and should be used with caution.
checkpoint_period (int): Trial runner checkpoint periodicity in
seconds. Defaults to 10.
checkpoint_period (int|str): Trial runner checkpoint periodicity in
seconds. Defaults to ``"auto"``, which adjusts checkpointing
time so that at most 5% of the time is spent on writing
checkpoints.
trial_executor (TrialExecutor): Defaults to RayTrialExecutor.
callbacks (list): List of callbacks that will be called at different
times in the training loop. Must be instances of the
@@ -183,9 +285,7 @@ class TrialRunner:
self._start_time = time.time()
self._last_checkpoint_time = -float("inf")
if checkpoint_period is None:
checkpoint_period = env_integer("TUNE_GLOBAL_CHECKPOINT_S", 10)
self._checkpoint_period = checkpoint_period
self._session_str = datetime.fromtimestamp(
self._start_time).strftime("%Y-%m-%d_%H-%M-%S")
self.checkpoint_file = None
@@ -196,6 +296,20 @@ class TrialRunner:
self._callbacks = CallbackList(callbacks or [])
if checkpoint_period is None:
checkpoint_period = os.getenv("TUNE_GLOBAL_CHECKPOINT_S", "auto")
self._checkpoint_period = checkpoint_period
self._checkpoint_manager = self._create_checkpoint_manager()
def _create_checkpoint_manager(self):
return _ExperimentCheckpointManager(
checkpoint_dir=self._local_checkpoint_dir,
checkpoint_period=self._checkpoint_period,
start_time=self._start_time,
session_str=self._session_str,
syncer=self._syncer)
@property
def resumed(self):
return self._resumed
@@ -269,36 +383,23 @@ class TrialRunner:
Args:
force (bool): Forces a checkpoint despite checkpoint_period.
"""
if not self._local_checkpoint_dir:
return
now = time.time()
if now - self._last_checkpoint_time < self._checkpoint_period and (
not force):
return
self._last_checkpoint_time = now
runner_state = {
"checkpoints": list(
self.trial_executor.get_checkpoints().values()),
"runner_data": self.__getstate__(),
"stats": {
"start_time": self._start_time,
"timestamp": self._last_checkpoint_time
}
}
tmp_file_name = os.path.join(self._local_checkpoint_dir,
".tmp_checkpoint")
with open(tmp_file_name, "w") as f:
json.dump(runner_state, f, indent=2, cls=TuneFunctionEncoder)
with warn_if_slow(
"experiment_checkpoint",
message="Checkpointing the experiment state took "
"{duration:.3f} s, which may be a performance "
"bottleneck. Please ensure the "
"`TUNE_GLOBAL_CHECKPOINT_S` environment variable is "
"something significantly higher than this duration "
"to ensure compute time is mostly spent on the main "
"training loop.",
disable=self._checkpoint_manager.auto_checkpoint_enabled):
os.replace(tmp_file_name, self.checkpoint_file)
self._search_alg.save_to_dir(
self._local_checkpoint_dir, session_str=self._session_str)
if force:
self._syncer.sync_up()
else:
self._syncer.sync_up_if_needed()
return self._local_checkpoint_dir
self._checkpoint_manager.checkpoint(
checkpoint_file=self.checkpoint_file,
trial_runner=self,
trial_executor=self.trial_executor,
search_alg=self._search_alg,
force=force)
def resume(self, run_errored_only=False):
"""Resumes all checkpointed trials from previous run.
@@ -406,16 +507,7 @@ class TrialRunner:
self._stop_experiment_if_needed()
try:
with warn_if_slow(
"experiment_checkpoint",
message="Checkpointing the experiment state took "
"{duration:.3f} s, which may be a performance "
"bottleneck. Please ensure the "
"`TUNE_GLOBAL_CHECKPOINT_S` environment variable is "
"something significantly higher than this duration "
"to ensure compute time is mostly spent on the main "
"training loop."):
self.checkpoint()
self.checkpoint()
except Exception as e:
logger.warning(f"Trial Runner checkpointing failed: {str(e)}")
self._iteration += 1
@@ -1028,7 +1120,8 @@ class TrialRunner:
for k in [
"_trials", "_stop_queue", "_server", "_search_alg",
"_scheduler_alg", "_pending_trial_queue_times",
"trial_executor", "_syncer", "_callbacks"
"trial_executor", "_syncer", "_callbacks",
"_checkpoint_manager"
]:
del state[k]
state["launch_web_server"] = bool(self._server)
@@ -1045,5 +1138,7 @@ class TrialRunner:
self.__dict__.setdefault("_start_time", start_time)
self.__dict__.update(state)
self._checkpoint_manager = self._create_checkpoint_manager()
if launch_web_server:
self._server = TuneServer(self, self._server_port)
+5 -1
View File
@@ -133,11 +133,13 @@ class warn_if_slow:
def __init__(self,
name: str,
threshold: Optional[float] = None,
message: Optional[str] = None):
message: Optional[str] = None,
disable: bool = False):
self.name = name
self.threshold = threshold or self.DEFAULT_THRESHOLD
self.message = message or self.DEFAULT_MESSAGE
self.too_slow = False
self.disable = disable
def __enter__(self):
self.start = time.time()
@@ -145,6 +147,8 @@ class warn_if_slow:
def __exit__(self, type, value, traceback):
now = time.time()
if self.disable:
return
if now - self.start > self.threshold and now - START_OF_TIME > 60.0:
self.too_slow = True
duration = now - self.start