[tune] Prevent MEMORY checkpoints from breaking trial FT (#6691)

* Prevent MEMORY checkpoints from breaking FT

* Add save/pause/resume/restore test

* change checkpoint return value based on status

* Fix test_checkpoint_manager_tests.

* Fix test + checkpoint manager bug

* lint

* Add docstring

* Add docstring to checkpoint_manager constructor

* Change variable name for clarity

* Revert on_checkpoint docstring wording

* Break after success

* nit: more informative warning

* Quarantine test
This commit is contained in:
Ujval Misra
2020-01-22 23:17:09 -08:00
committed by Edward Oakes
parent 11c6b32c2e
commit 2fca550096
10 changed files with 122 additions and 60 deletions
+24 -12
View File
@@ -45,6 +45,9 @@ class CheckpointManager:
def __init__(self, keep_checkpoints_num, checkpoint_score_attr, delete_fn):
"""Initializes a new CheckpointManager.
`newest_persistent_checkpoint` and `newest_memory_checkpoint` are
initialized to Checkpoint objects with values of None.
Args:
keep_checkpoints_num (int): Keep at least this many checkpoints.
checkpoint_score_attr (str): Attribute to use to determine which
@@ -60,28 +63,38 @@ class CheckpointManager:
self._checkpoint_score_attr = checkpoint_score_attr[4:]
else:
self._checkpoint_score_attr = checkpoint_score_attr
self.delete = delete_fn
self.newest_checkpoint = Checkpoint(Checkpoint.MEMORY, None)
self.newest_persistent_checkpoint = Checkpoint(Checkpoint.PERSISTENT,
None)
self.newest_memory_checkpoint = Checkpoint(Checkpoint.MEMORY, None)
self._best_checkpoints = []
self._membership = set()
def on_checkpoint(self, checkpoint):
"""Starts tracking checkpoint metadata on checkpoint.
Sets newest checkpoint. Deletes previous checkpoint as long as it isn't
one of the best ones. Also deletes the worst checkpoint if at capacity.
Sets the newest checkpoint. For PERSISTENT checkpoints: Deletes
previous checkpoint as long as it isn't one of the best ones. Also
deletes the worst checkpoint if at capacity.
Args:
checkpoint (Checkpoint): Trial state checkpoint.
"""
old_checkpoint = self.newest_checkpoint
self.newest_checkpoint = checkpoint
if checkpoint.storage == Checkpoint.MEMORY:
self.newest_memory_checkpoint = checkpoint
return
old_checkpoint = self.newest_persistent_checkpoint
self.newest_persistent_checkpoint = checkpoint
# Remove the old checkpoint if it isn't one of the best ones.
if old_checkpoint.value and old_checkpoint not in self._membership:
self.delete(old_checkpoint)
try:
queue_item = QueueItem(self._priority(checkpoint), checkpoint)
except KeyError:
if old_checkpoint not in self._membership:
self.delete(old_checkpoint)
logger.error("Result dict has no key: {}. "
"checkpoint_score_attr must be set to a key in the "
"result dict.".format(self._checkpoint_score_attr))
@@ -95,11 +108,10 @@ class CheckpointManager:
self._membership.add(checkpoint)
if worst in self._membership:
self._membership.remove(worst)
self.delete(worst)
# Remove the old checkpoint if it isn't one of the best ones.
if old_checkpoint.value and old_checkpoint not in self._membership:
self.delete(old_checkpoint)
# Don't delete the newest checkpoint. It will be deleted on the
# next on_checkpoint() call since it isn't in self._membership.
if worst != checkpoint:
self.delete(worst)
def best_checkpoints(self):
"""Returns best checkpoints, sorted by score."""
+21 -13
View File
@@ -172,13 +172,12 @@ class RayTrialExecutor(TrialExecutor):
See `RayTrialExecutor.restore` for possible errors raised.
"""
prior_status = trial.status
self.set_status(trial, Trial.RUNNING)
trial.set_runner(
runner or self._setup_remote_runner(
trial,
reuse_allowed=checkpoint is not None
or trial.has_checkpoint()))
if runner is None:
reuse_allowed = checkpoint is not None or trial.has_checkpoint()
runner = self._setup_remote_runner(trial, reuse_allowed)
trial.set_runner(runner)
self.restore(trial, checkpoint)
self.set_status(trial, Trial.RUNNING)
previous_run = self._find_item(self._paused, trial)
if prior_status == Trial.PAUSED and previous_run:
@@ -421,6 +420,11 @@ class RayTrialExecutor(TrialExecutor):
def _update_avail_resources(self, num_retries=5):
resources = None
for i in range(num_retries):
if i > 0:
logger.warning(
"Cluster resources not detected or are 0. Attempt #"
"%s...", i + 1)
time.sleep(0.5)
try:
resources = ray.cluster_resources()
except Exception:
@@ -428,10 +432,8 @@ class RayTrialExecutor(TrialExecutor):
# https://github.com/ray-project/ray/issues/4147
logger.debug("Using resources for local machine.")
resources = ResourceSpec().resolve(True).to_resource_dict()
if not resources:
logger.warning(
"Cluster resources not detected or are 0. Retrying...")
time.sleep(0.5)
if resources:
break
if not resources:
# NOTE: This hides the possibility that Ray may be waiting for
@@ -555,14 +557,14 @@ class RayTrialExecutor(TrialExecutor):
"""Saves the trial's state to a checkpoint.
Args:
trial (Trial): The state of this trial to be saved.
trial (Trial): The trial to be saved.
storage (str): Where to store the checkpoint. Defaults to
PERSISTENT.
result (dict): The state of this trial as a dictionary to be saved.
If result is None, the trial's last result will be used.
Returns:
Checkpoint future, or None if an Exception occurs.
Checkpoint object, or None if an Exception occurs.
"""
result = result or trial.last_result
@@ -588,11 +590,17 @@ class RayTrialExecutor(TrialExecutor):
"syncs by setting sync_on_checkpoint=False. Note that this "
"might result in faulty trial restoration for some worker "
"failure modes.")
return checkpoint.value
return checkpoint
def restore(self, trial, checkpoint=None):
"""Restores training state from a given model checkpoint.
Args:
trial (Trial): The trial to be restored.
checkpoint (Checkpoint): The checkpoint to restore from. If None,
the most recent PERSISTENT checkpoint is used. Defaults to
None.
Raises:
RuntimeError: This error is raised if no runner is found.
AbortTrialExecution: This error is raised if the trial is
+2 -4
View File
@@ -324,14 +324,12 @@ class PopulationBasedTraining(FIFOScheduler):
reset_successful = trial_executor.reset_trial(trial, new_config,
new_tag)
if reset_successful:
trial_executor.restore(
trial, Checkpoint.from_object(new_state.last_checkpoint))
trial_executor.restore(trial, new_state.last_checkpoint)
else:
trial_executor.stop_trial(trial, stop_logger=False)
trial.config = new_config
trial.experiment_tag = new_tag
trial_executor.start_trial(
trial, Checkpoint.from_object(new_state.last_checkpoint))
trial_executor.start_trial(trial, new_state.last_checkpoint)
self._num_perturbations += 1
# Transfer over the last perturbation time as well
@@ -28,14 +28,14 @@ class CheckpointManagerTest(unittest.TestCase):
for i in range(3)
]
with patch.object(checkpoint_manager, "delete") as \
delete_mock:
with patch.object(checkpoint_manager, "delete") as delete_mock:
for j in range(3):
checkpoint_manager.on_checkpoint(checkpoints[j])
expected_deletes = 0 if j != 2 else 1
self.assertEqual(delete_mock.call_count, expected_deletes, j)
self.assertEqual(checkpoint_manager.newest_checkpoint,
checkpoints[j])
self.assertEqual(
checkpoint_manager.newest_persistent_checkpoint,
checkpoints[j])
best_checkpoints = checkpoint_manager.best_checkpoints()
self.assertEqual(len(best_checkpoints), keep_checkpoints_num)
@@ -59,8 +59,9 @@ class CheckpointManagerTest(unittest.TestCase):
checkpoint_manager.on_checkpoint(checkpoints[j])
expected_deletes = 0 if j != 3 else 1
self.assertEqual(delete_mock.call_count, expected_deletes)
self.assertEqual(checkpoint_manager.newest_checkpoint,
checkpoints[j])
self.assertEqual(
checkpoint_manager.newest_persistent_checkpoint,
checkpoints[j])
best_checkpoints = checkpoint_manager.best_checkpoints()
self.assertEqual(len(best_checkpoints), keep_checkpoints_num)
@@ -74,7 +75,7 @@ class CheckpointManagerTest(unittest.TestCase):
keep_checkpoints_num = 4
checkpoint_manager = self.checkpoint_manager(keep_checkpoints_num)
checkpoints = [
Checkpoint(Checkpoint.MEMORY, i, self.mock_result(i))
Checkpoint(Checkpoint.PERSISTENT, i, self.mock_result(i))
for i in range(16)
]
random.shuffle(checkpoints)
@@ -92,15 +93,28 @@ class CheckpointManagerTest(unittest.TestCase):
Tests that an error is logged when the associated result of the
checkpoint has no checkpoint score attribute.
"""
keep_checkpoints_num = 1
checkpoint_manager = self.checkpoint_manager(keep_checkpoints_num)
checkpoint_manager = self.checkpoint_manager(keep_checkpoints_num=1)
no_attr_checkpoint = Checkpoint(Checkpoint.MEMORY, 0, {})
no_attr_checkpoint = Checkpoint(Checkpoint.PERSISTENT, 0, {})
with patch.object(logger, "error") as log_error_mock:
checkpoint_manager.on_checkpoint(no_attr_checkpoint)
log_error_mock.assert_called_once()
# The newest checkpoint should still be set despite this error.
assert checkpoint_manager.newest_checkpoint == no_attr_checkpoint
self.assertEqual(checkpoint_manager.newest_persistent_checkpoint,
no_attr_checkpoint)
def testOnMemoryCheckpoint(self):
checkpoints = [
Checkpoint(Checkpoint.MEMORY, 0, self.mock_result(0)),
Checkpoint(Checkpoint.MEMORY, 0, self.mock_result(0))
]
checkpoint_manager = self.checkpoint_manager(keep_checkpoints_num=1)
checkpoint_manager.on_checkpoint(checkpoints[0])
checkpoint_manager.on_checkpoint(checkpoints[1])
newest = checkpoint_manager.newest_memory_checkpoint
self.assertEqual(newest, checkpoints[1])
self.assertEqual(checkpoint_manager.best_checkpoints(), [])
if __name__ == "__main__":
+2
View File
@@ -555,6 +555,8 @@ tune.run(
cluster.shutdown()
# TODO(ujvl): Fix test.
@pytest.mark.skip(reason="Not very consistent.")
def test_cluster_interrupt(start_connected_cluster, tmpdir):
"""Tests run_experiment on cluster shutdown with actual interrupt.
@@ -51,6 +51,27 @@ class RayTrialExecutorTest(unittest.TestCase):
self.trial_executor.stop_trial(trial)
self.assertEqual(Trial.TERMINATED, trial.status)
def testSavePauseResumeRestore(self):
"""Tests that pause checkpoint does not replace restore checkpoint."""
trial = Trial("__fake")
self.trial_executor.start_trial(trial)
# Save
checkpoint = self.trial_executor.save(trial, Checkpoint.PERSISTENT)
self.assertEqual(Trial.RUNNING, trial.status)
self.assertEqual(checkpoint.storage, Checkpoint.PERSISTENT)
# Pause
self.trial_executor.pause_trial(trial)
self.assertEqual(Trial.PAUSED, trial.status)
self.assertEqual(trial.checkpoint.storage, Checkpoint.MEMORY)
# Resume
self.trial_executor.start_trial(trial)
self.assertEqual(Trial.RUNNING, trial.status)
self.assertEqual(trial.checkpoint, checkpoint)
# Restore
self.trial_executor.restore(trial)
self.trial_executor.stop_trial(trial)
self.assertEqual(Trial.TERMINATED, trial.status)
def testStartFailure(self):
_global_registry.register(TRAINABLE_CLASS, "asdf", None)
trial = Trial("asdf", resources=Resources(1, 0))
@@ -63,9 +84,9 @@ class RayTrialExecutorTest(unittest.TestCase):
self.trial_executor.start_trial(trial)
self.assertEqual(Trial.RUNNING, trial.status)
self.trial_executor.fetch_result(trial)
self.trial_executor.pause_trial(trial)
checkpoint = self.trial_executor.pause_trial(trial)
self.assertEqual(Trial.PAUSED, trial.status)
self.trial_executor.start_trial(trial)
self.trial_executor.start_trial(trial, checkpoint)
self.assertEqual(Trial.RUNNING, trial.status)
self.trial_executor.stop_trial(trial)
self.assertEqual(Trial.TERMINATED, trial.status)
+6 -6
View File
@@ -202,8 +202,8 @@ class TrialRunnerTest2(unittest.TestCase):
runner.step()
self.assertEqual(trials[0].status, Trial.RUNNING)
self.assertEqual(ray.get(trials[0].runner.set_info.remote(1)), 1)
path = runner.trial_executor.save(trials[0])
kwargs["restore_path"] = path
checkpoint = runner.trial_executor.save(trials[0])
kwargs["restore_path"] = checkpoint.value
runner.add_trial(Trial("__fake", **kwargs))
trials = runner.get_trials()
@@ -216,7 +216,7 @@ class TrialRunnerTest2(unittest.TestCase):
self.assertEqual(trials[0].status, Trial.TERMINATED)
self.assertEqual(trials[1].status, Trial.RUNNING)
self.assertEqual(ray.get(trials[1].runner.get_info.remote()), 1)
self.addCleanup(os.remove, path)
self.addCleanup(os.remove, checkpoint.value)
def testRestoreMetricsAfterCheckpointing(self):
ray.init(num_cpus=1, num_gpus=1)
@@ -230,9 +230,9 @@ class TrialRunnerTest2(unittest.TestCase):
runner.step()
self.assertEqual(trials[0].status, Trial.RUNNING)
self.assertEqual(ray.get(trials[0].runner.set_info.remote(1)), 1)
path = runner.trial_executor.save(trials[0])
checkpoint = runner.trial_executor.save(trials[0])
runner.trial_executor.stop_trial(trials[0])
kwargs["restore_path"] = path
kwargs["restore_path"] = checkpoint.value
runner.add_trial(Trial("__fake", **kwargs))
trials = runner.get_trials()
@@ -249,7 +249,7 @@ class TrialRunnerTest2(unittest.TestCase):
self.assertEqual(trials[1].last_result["timesteps_since_restore"], 20)
self.assertEqual(trials[1].last_result["iterations_since_restore"], 2)
self.assertGreater(trials[1].last_result["time_since_restore"], 0)
self.addCleanup(os.remove, path)
self.addCleanup(os.remove, checkpoint.value)
def testCheckpointingAtEnd(self):
ray.init(num_cpus=1, num_gpus=1)
@@ -200,7 +200,7 @@ class _MockTrialExecutor(TrialExecutor):
pass
def save(self, trial, type=Checkpoint.PERSISTENT, result=None):
return trial.trainable_name
return Checkpoint(Checkpoint.PERSISTENT, trial.trainable_name, result)
def reset_trial(self, trial, new_config, new_experiment_tag):
return False
+13 -5
View File
@@ -194,11 +194,11 @@ class Trial:
self.checkpoint_freq = checkpoint_freq
self.checkpoint_at_end = checkpoint_at_end
self.sync_on_checkpoint = sync_on_checkpoint
newest_checkpoint = Checkpoint(Checkpoint.PERSISTENT, restore_path)
self.checkpoint_manager = CheckpointManager(
keep_checkpoints_num, checkpoint_score_attr,
checkpoint_deleter(str(self), self.runner))
self.checkpoint_manager.newest_checkpoint = newest_checkpoint
checkpoint = Checkpoint(Checkpoint.PERSISTENT, restore_path)
self.checkpoint_manager.newest_persistent_checkpoint = checkpoint
# Restoration fields
self.restoring_from = None
@@ -228,7 +228,16 @@ class Trial:
@property
def checkpoint(self):
return self.checkpoint_manager.newest_checkpoint
"""Returns the most recent checkpoint.
If the trial is PAUSED, this is the most recent MEMORY checkpoint.
Otherwise, it is the most recent PERSISTENT checkpoint.
"""
if self.status == Trial.PAUSED:
assert self.checkpoint_manager.newest_memory_checkpoint.value
return self.checkpoint_manager.newest_memory_checkpoint
else:
return self.checkpoint_manager.newest_persistent_checkpoint
@classmethod
def generate_id(cls):
@@ -351,7 +360,6 @@ class Trial:
checkpoint (Checkpoint): Checkpoint taken.
"""
if checkpoint.storage == Checkpoint.MEMORY:
# TODO(ujvl): Handle this separately to avoid restoration failure.
self.checkpoint_manager.on_checkpoint(checkpoint)
return
if self.sync_on_checkpoint:
@@ -364,7 +372,7 @@ class Trial:
# checkpoint, so it should just be logged.
logger.error(
"Trial %s: An error occurred during the "
"checkpoint pre-sync wait.", str(e))
"checkpoint pre-sync wait - %s", self, str(e))
# Force sync down and wait before tracking the new checkpoint.
try:
if self.result_logger.sync_down():
+5 -6
View File
@@ -196,7 +196,7 @@ class TrialExecutor:
Assumes the trial is running.
Return:
Returns:
Result object for the trial.
"""
raise NotImplementedError
@@ -219,7 +219,7 @@ class TrialExecutor:
trial (Trial): Trial to be restored.
checkpoint (Checkpoint): Checkpoint to restore from.
Return:
Returns:
False if error occurred, otherwise return True.
"""
raise NotImplementedError("Subclasses of TrialExecutor must provide "
@@ -236,9 +236,8 @@ class TrialExecutor:
PERSISTENT.
result (dict): The state of this trial as a dictionary to be saved.
Return:
A Python object if storage==Checkpoint.MEMORY otherwise
a path to the checkpoint.
Returns:
A Checkpoint object.
"""
raise NotImplementedError("Subclasses of TrialExecutor must provide "
"save() method")
@@ -249,7 +248,7 @@ class TrialExecutor:
Args:
trial (Trial): The state of this trial to be saved.
Return:
Returns:
A dict that maps ExportFormats to successfully exported models.
"""
raise NotImplementedError("Subclasses of TrialExecutor must provide "