mirror of
https://github.com/wassname/ray.git
synced 2026-07-01 05:41:51 +08:00
[tune] dynamic global checkpointing interval (#13736)
* Add scalability tests * Move experiment checkpointing into a manager class * Dynamic global checkpointing * Actually write checkpoints * Remove debug message * Pass `force` * Pre-review * Revert scalability commits * Revert scalability commits * Apply suggestions from code review
This commit is contained in:
@@ -695,6 +695,25 @@ class TrialRunnerTest3(unittest.TestCase):
|
||||
self.assertTrue(trials[0].has_checkpoint())
|
||||
self.assertEqual(num_checkpoints(trials[0]), 2)
|
||||
|
||||
@patch("ray.tune.syncer.CLOUD_SYNC_PERIOD", 0)
|
||||
def testCheckpointAutoPeriod(self):
|
||||
# This makes checkpointing take 2 seconds.
|
||||
def sync_up(source, target):
|
||||
time.sleep(2)
|
||||
return True
|
||||
|
||||
runner = TrialRunner(
|
||||
local_checkpoint_dir=self.tmpdir,
|
||||
checkpoint_period="auto",
|
||||
sync_to_cloud=sync_up,
|
||||
remote_checkpoint_dir="fake")
|
||||
runner.add_trial(Trial("__fake", config={"user_checkpoint_freq": 1}))
|
||||
|
||||
runner.step() # Run one step, this will trigger checkpointing
|
||||
|
||||
self.assertGreaterEqual(runner._checkpoint_manager._checkpoint_period,
|
||||
38.)
|
||||
|
||||
|
||||
class SearchAlgorithmTest(unittest.TestCase):
|
||||
@classmethod
|
||||
|
||||
+144
-49
@@ -1,4 +1,4 @@
|
||||
from typing import Optional
|
||||
from typing import Optional, Union
|
||||
|
||||
import click
|
||||
from datetime import datetime
|
||||
@@ -16,11 +16,11 @@ from ray.tune.stopper import NoopStopper
|
||||
from ray.tune.ray_trial_executor import RayTrialExecutor
|
||||
from ray.tune.result import (DEFAULT_METRIC, TIME_THIS_ITER_S,
|
||||
RESULT_DUPLICATE, SHOULD_CHECKPOINT)
|
||||
from ray.tune.syncer import get_cloud_syncer
|
||||
from ray.tune.syncer import CloudSyncer, get_cloud_syncer
|
||||
from ray.tune.trial import Checkpoint, Trial
|
||||
from ray.tune.schedulers import FIFOScheduler, TrialScheduler
|
||||
from ray.tune.suggest import BasicVariantGenerator
|
||||
from ray.tune.utils import warn_if_slow, flatten_dict, env_integer
|
||||
from ray.tune.suggest import BasicVariantGenerator, SearchAlgorithm
|
||||
from ray.tune.utils import warn_if_slow, flatten_dict
|
||||
from ray.tune.utils.log import Verbosity, has_verbosity
|
||||
from ray.tune.utils.placement_groups import TUNE_MAX_PENDING_TRIALS_PG
|
||||
from ray.tune.utils.serialization import TuneFunctionDecoder, \
|
||||
@@ -42,6 +42,106 @@ def _find_newest_ckpt(ckpt_dir):
|
||||
return max(full_paths)
|
||||
|
||||
|
||||
class _ExperimentCheckpointManager:
|
||||
"""Helper class for managing experiment-level checkpoints.
|
||||
|
||||
This class implements the ``checkpoint()`` method used to checkpoint
|
||||
experiment state. When called, this will serialize and write to disk
|
||||
the state of the trial runner, trial executor, and search algorithm, to
|
||||
a specified checkpoint file.
|
||||
|
||||
The checkpoint period is automatically adjusted to
|
||||
``max(10, time_per_checkpoint * 19)``. This means that at most 5% of the
|
||||
time (1/20) will be used for writing checkpoints, while 95% of the time
|
||||
(19/20) will be used to handle the rest of the training loop.
|
||||
|
||||
"""
|
||||
|
||||
def __init__(self, checkpoint_dir: str,
|
||||
checkpoint_period: Union[int, float, str], start_time: float,
|
||||
session_str: str, syncer: CloudSyncer):
|
||||
self._checkpoint_dir = checkpoint_dir
|
||||
self._auto_checkpoint_enabled = checkpoint_period == "auto"
|
||||
if self._auto_checkpoint_enabled:
|
||||
self._checkpoint_period = 10. # Initial value
|
||||
else:
|
||||
self._checkpoint_period = float(checkpoint_period)
|
||||
|
||||
self._start_time = start_time
|
||||
self._session_str = session_str
|
||||
|
||||
self._syncer = syncer
|
||||
|
||||
self._last_checkpoint_time = 0.
|
||||
|
||||
@property
|
||||
def auto_checkpoint_enabled(self):
|
||||
return self._auto_checkpoint_enabled
|
||||
|
||||
def checkpoint(self,
|
||||
checkpoint_file: str,
|
||||
trial_runner: "TrialRunner",
|
||||
trial_executor: RayTrialExecutor,
|
||||
search_alg: SearchAlgorithm,
|
||||
force=False):
|
||||
"""Saves execution state to `self._local_checkpoint_dir`.
|
||||
|
||||
Overwrites the current session checkpoint, which starts when self
|
||||
is instantiated. Throttle depends on self._checkpoint_period.
|
||||
|
||||
Also automatically saves the search algorithm to the local
|
||||
checkpoint dir.
|
||||
|
||||
Args:
|
||||
force (bool): Forces a checkpoint despite checkpoint_period.
|
||||
"""
|
||||
if not self._checkpoint_dir:
|
||||
return
|
||||
|
||||
now = time.time()
|
||||
if now - self._last_checkpoint_time < self._checkpoint_period and (
|
||||
not force):
|
||||
return
|
||||
|
||||
def _serialize_and_write():
|
||||
runner_state = {
|
||||
"checkpoints": list(trial_executor.get_checkpoints().values()),
|
||||
"runner_data": trial_runner.__getstate__(),
|
||||
"stats": {
|
||||
"start_time": self._start_time,
|
||||
"timestamp": self._last_checkpoint_time
|
||||
}
|
||||
}
|
||||
tmp_file_name = os.path.join(self._checkpoint_dir,
|
||||
".tmp_checkpoint")
|
||||
with open(tmp_file_name, "w") as f:
|
||||
json.dump(runner_state, f, indent=2, cls=TuneFunctionEncoder)
|
||||
|
||||
os.replace(tmp_file_name, checkpoint_file)
|
||||
search_alg.save_to_dir(
|
||||
self._checkpoint_dir, session_str=self._session_str)
|
||||
|
||||
checkpoint_time_start = time.monotonic()
|
||||
_serialize_and_write()
|
||||
if force:
|
||||
self._syncer.sync_up()
|
||||
else:
|
||||
self._syncer.sync_up_if_needed()
|
||||
checkpoint_time_taken = time.monotonic() - checkpoint_time_start
|
||||
|
||||
if self._auto_checkpoint_enabled:
|
||||
# Multiplying this time by 19 means we spend ~5% of the time
|
||||
# writing global checkpoints and 95% of the time processing trials
|
||||
self._checkpoint_period = max(10., checkpoint_time_taken * 19)
|
||||
logger.debug(f"Global experiment checkpointing took "
|
||||
f"{checkpoint_time_taken:.2f} seconds. "
|
||||
f"Adjusting checkpoint period to "
|
||||
f"{self._checkpoint_period:.2f} seconds.")
|
||||
|
||||
self._last_checkpoint_time = time.time()
|
||||
return self._checkpoint_dir
|
||||
|
||||
|
||||
class TrialRunner:
|
||||
"""A TrialRunner implements the event loop for scheduling trials on Ray.
|
||||
|
||||
@@ -82,8 +182,10 @@ class TrialRunner:
|
||||
If fail_fast='raise' provided, Tune will automatically
|
||||
raise the exception received by the Trainable. fail_fast='raise'
|
||||
can easily leak resources and should be used with caution.
|
||||
checkpoint_period (int): Trial runner checkpoint periodicity in
|
||||
seconds. Defaults to 10.
|
||||
checkpoint_period (int|str): Trial runner checkpoint periodicity in
|
||||
seconds. Defaults to ``"auto"``, which adjusts checkpointing
|
||||
time so that at most 5% of the time is spent on writing
|
||||
checkpoints.
|
||||
trial_executor (TrialExecutor): Defaults to RayTrialExecutor.
|
||||
callbacks (list): List of callbacks that will be called at different
|
||||
times in the training loop. Must be instances of the
|
||||
@@ -183,9 +285,7 @@ class TrialRunner:
|
||||
|
||||
self._start_time = time.time()
|
||||
self._last_checkpoint_time = -float("inf")
|
||||
if checkpoint_period is None:
|
||||
checkpoint_period = env_integer("TUNE_GLOBAL_CHECKPOINT_S", 10)
|
||||
self._checkpoint_period = checkpoint_period
|
||||
|
||||
self._session_str = datetime.fromtimestamp(
|
||||
self._start_time).strftime("%Y-%m-%d_%H-%M-%S")
|
||||
self.checkpoint_file = None
|
||||
@@ -196,6 +296,20 @@ class TrialRunner:
|
||||
|
||||
self._callbacks = CallbackList(callbacks or [])
|
||||
|
||||
if checkpoint_period is None:
|
||||
checkpoint_period = os.getenv("TUNE_GLOBAL_CHECKPOINT_S", "auto")
|
||||
|
||||
self._checkpoint_period = checkpoint_period
|
||||
self._checkpoint_manager = self._create_checkpoint_manager()
|
||||
|
||||
def _create_checkpoint_manager(self):
|
||||
return _ExperimentCheckpointManager(
|
||||
checkpoint_dir=self._local_checkpoint_dir,
|
||||
checkpoint_period=self._checkpoint_period,
|
||||
start_time=self._start_time,
|
||||
session_str=self._session_str,
|
||||
syncer=self._syncer)
|
||||
|
||||
@property
|
||||
def resumed(self):
|
||||
return self._resumed
|
||||
@@ -269,36 +383,23 @@ class TrialRunner:
|
||||
Args:
|
||||
force (bool): Forces a checkpoint despite checkpoint_period.
|
||||
"""
|
||||
if not self._local_checkpoint_dir:
|
||||
return
|
||||
now = time.time()
|
||||
if now - self._last_checkpoint_time < self._checkpoint_period and (
|
||||
not force):
|
||||
return
|
||||
self._last_checkpoint_time = now
|
||||
runner_state = {
|
||||
"checkpoints": list(
|
||||
self.trial_executor.get_checkpoints().values()),
|
||||
"runner_data": self.__getstate__(),
|
||||
"stats": {
|
||||
"start_time": self._start_time,
|
||||
"timestamp": self._last_checkpoint_time
|
||||
}
|
||||
}
|
||||
tmp_file_name = os.path.join(self._local_checkpoint_dir,
|
||||
".tmp_checkpoint")
|
||||
with open(tmp_file_name, "w") as f:
|
||||
json.dump(runner_state, f, indent=2, cls=TuneFunctionEncoder)
|
||||
with warn_if_slow(
|
||||
"experiment_checkpoint",
|
||||
message="Checkpointing the experiment state took "
|
||||
"{duration:.3f} s, which may be a performance "
|
||||
"bottleneck. Please ensure the "
|
||||
"`TUNE_GLOBAL_CHECKPOINT_S` environment variable is "
|
||||
"something significantly higher than this duration "
|
||||
"to ensure compute time is mostly spent on the main "
|
||||
"training loop.",
|
||||
disable=self._checkpoint_manager.auto_checkpoint_enabled):
|
||||
|
||||
os.replace(tmp_file_name, self.checkpoint_file)
|
||||
self._search_alg.save_to_dir(
|
||||
self._local_checkpoint_dir, session_str=self._session_str)
|
||||
|
||||
if force:
|
||||
self._syncer.sync_up()
|
||||
else:
|
||||
self._syncer.sync_up_if_needed()
|
||||
return self._local_checkpoint_dir
|
||||
self._checkpoint_manager.checkpoint(
|
||||
checkpoint_file=self.checkpoint_file,
|
||||
trial_runner=self,
|
||||
trial_executor=self.trial_executor,
|
||||
search_alg=self._search_alg,
|
||||
force=force)
|
||||
|
||||
def resume(self, run_errored_only=False):
|
||||
"""Resumes all checkpointed trials from previous run.
|
||||
@@ -406,16 +507,7 @@ class TrialRunner:
|
||||
self._stop_experiment_if_needed()
|
||||
|
||||
try:
|
||||
with warn_if_slow(
|
||||
"experiment_checkpoint",
|
||||
message="Checkpointing the experiment state took "
|
||||
"{duration:.3f} s, which may be a performance "
|
||||
"bottleneck. Please ensure the "
|
||||
"`TUNE_GLOBAL_CHECKPOINT_S` environment variable is "
|
||||
"something significantly higher than this duration "
|
||||
"to ensure compute time is mostly spent on the main "
|
||||
"training loop."):
|
||||
self.checkpoint()
|
||||
self.checkpoint()
|
||||
except Exception as e:
|
||||
logger.warning(f"Trial Runner checkpointing failed: {str(e)}")
|
||||
self._iteration += 1
|
||||
@@ -1028,7 +1120,8 @@ class TrialRunner:
|
||||
for k in [
|
||||
"_trials", "_stop_queue", "_server", "_search_alg",
|
||||
"_scheduler_alg", "_pending_trial_queue_times",
|
||||
"trial_executor", "_syncer", "_callbacks"
|
||||
"trial_executor", "_syncer", "_callbacks",
|
||||
"_checkpoint_manager"
|
||||
]:
|
||||
del state[k]
|
||||
state["launch_web_server"] = bool(self._server)
|
||||
@@ -1045,5 +1138,7 @@ class TrialRunner:
|
||||
self.__dict__.setdefault("_start_time", start_time)
|
||||
|
||||
self.__dict__.update(state)
|
||||
self._checkpoint_manager = self._create_checkpoint_manager()
|
||||
|
||||
if launch_web_server:
|
||||
self._server = TuneServer(self, self._server_port)
|
||||
|
||||
@@ -133,11 +133,13 @@ class warn_if_slow:
|
||||
def __init__(self,
|
||||
name: str,
|
||||
threshold: Optional[float] = None,
|
||||
message: Optional[str] = None):
|
||||
message: Optional[str] = None,
|
||||
disable: bool = False):
|
||||
self.name = name
|
||||
self.threshold = threshold or self.DEFAULT_THRESHOLD
|
||||
self.message = message or self.DEFAULT_MESSAGE
|
||||
self.too_slow = False
|
||||
self.disable = disable
|
||||
|
||||
def __enter__(self):
|
||||
self.start = time.time()
|
||||
@@ -145,6 +147,8 @@ class warn_if_slow:
|
||||
|
||||
def __exit__(self, type, value, traceback):
|
||||
now = time.time()
|
||||
if self.disable:
|
||||
return
|
||||
if now - self.start > self.threshold and now - START_OF_TIME > 60.0:
|
||||
self.too_slow = True
|
||||
duration = now - self.start
|
||||
|
||||
Reference in New Issue
Block a user