mirror of
https://github.com/wassname/ray.git
synced 2026-06-29 19:00:58 +08:00
[tune] Prevent MEMORY checkpoints from breaking trial FT (#6691)
* Prevent MEMORY checkpoints from breaking FT * Add save/pause/resume/restore test * change checkpoint return value based on status * Fix test_checkpoint_manager_tests. * Fix test + checkpoint manager bug * lint * Add docstring * Add docstring to checkpoint_manager constructor * Change variable name for clarity * Revert on_checkpoint docstring wording * Break after success * nit: more informative warning * Quarantine test
This commit is contained in:
committed by
Edward Oakes
parent
11c6b32c2e
commit
2fca550096
@@ -45,6 +45,9 @@ class CheckpointManager:
|
||||
def __init__(self, keep_checkpoints_num, checkpoint_score_attr, delete_fn):
|
||||
"""Initializes a new CheckpointManager.
|
||||
|
||||
`newest_persistent_checkpoint` and `newest_memory_checkpoint` are
|
||||
initialized to Checkpoint objects with values of None.
|
||||
|
||||
Args:
|
||||
keep_checkpoints_num (int): Keep at least this many checkpoints.
|
||||
checkpoint_score_attr (str): Attribute to use to determine which
|
||||
@@ -60,28 +63,38 @@ class CheckpointManager:
|
||||
self._checkpoint_score_attr = checkpoint_score_attr[4:]
|
||||
else:
|
||||
self._checkpoint_score_attr = checkpoint_score_attr
|
||||
|
||||
self.delete = delete_fn
|
||||
self.newest_checkpoint = Checkpoint(Checkpoint.MEMORY, None)
|
||||
self.newest_persistent_checkpoint = Checkpoint(Checkpoint.PERSISTENT,
|
||||
None)
|
||||
self.newest_memory_checkpoint = Checkpoint(Checkpoint.MEMORY, None)
|
||||
self._best_checkpoints = []
|
||||
self._membership = set()
|
||||
|
||||
def on_checkpoint(self, checkpoint):
|
||||
"""Starts tracking checkpoint metadata on checkpoint.
|
||||
|
||||
Sets newest checkpoint. Deletes previous checkpoint as long as it isn't
|
||||
one of the best ones. Also deletes the worst checkpoint if at capacity.
|
||||
Sets the newest checkpoint. For PERSISTENT checkpoints: Deletes
|
||||
previous checkpoint as long as it isn't one of the best ones. Also
|
||||
deletes the worst checkpoint if at capacity.
|
||||
|
||||
Args:
|
||||
checkpoint (Checkpoint): Trial state checkpoint.
|
||||
"""
|
||||
old_checkpoint = self.newest_checkpoint
|
||||
self.newest_checkpoint = checkpoint
|
||||
if checkpoint.storage == Checkpoint.MEMORY:
|
||||
self.newest_memory_checkpoint = checkpoint
|
||||
return
|
||||
|
||||
old_checkpoint = self.newest_persistent_checkpoint
|
||||
self.newest_persistent_checkpoint = checkpoint
|
||||
|
||||
# Remove the old checkpoint if it isn't one of the best ones.
|
||||
if old_checkpoint.value and old_checkpoint not in self._membership:
|
||||
self.delete(old_checkpoint)
|
||||
|
||||
try:
|
||||
queue_item = QueueItem(self._priority(checkpoint), checkpoint)
|
||||
except KeyError:
|
||||
if old_checkpoint not in self._membership:
|
||||
self.delete(old_checkpoint)
|
||||
logger.error("Result dict has no key: {}. "
|
||||
"checkpoint_score_attr must be set to a key in the "
|
||||
"result dict.".format(self._checkpoint_score_attr))
|
||||
@@ -95,11 +108,10 @@ class CheckpointManager:
|
||||
self._membership.add(checkpoint)
|
||||
if worst in self._membership:
|
||||
self._membership.remove(worst)
|
||||
self.delete(worst)
|
||||
|
||||
# Remove the old checkpoint if it isn't one of the best ones.
|
||||
if old_checkpoint.value and old_checkpoint not in self._membership:
|
||||
self.delete(old_checkpoint)
|
||||
# Don't delete the newest checkpoint. It will be deleted on the
|
||||
# next on_checkpoint() call since it isn't in self._membership.
|
||||
if worst != checkpoint:
|
||||
self.delete(worst)
|
||||
|
||||
def best_checkpoints(self):
|
||||
"""Returns best checkpoints, sorted by score."""
|
||||
|
||||
@@ -172,13 +172,12 @@ class RayTrialExecutor(TrialExecutor):
|
||||
See `RayTrialExecutor.restore` for possible errors raised.
|
||||
"""
|
||||
prior_status = trial.status
|
||||
self.set_status(trial, Trial.RUNNING)
|
||||
trial.set_runner(
|
||||
runner or self._setup_remote_runner(
|
||||
trial,
|
||||
reuse_allowed=checkpoint is not None
|
||||
or trial.has_checkpoint()))
|
||||
if runner is None:
|
||||
reuse_allowed = checkpoint is not None or trial.has_checkpoint()
|
||||
runner = self._setup_remote_runner(trial, reuse_allowed)
|
||||
trial.set_runner(runner)
|
||||
self.restore(trial, checkpoint)
|
||||
self.set_status(trial, Trial.RUNNING)
|
||||
|
||||
previous_run = self._find_item(self._paused, trial)
|
||||
if prior_status == Trial.PAUSED and previous_run:
|
||||
@@ -421,6 +420,11 @@ class RayTrialExecutor(TrialExecutor):
|
||||
def _update_avail_resources(self, num_retries=5):
|
||||
resources = None
|
||||
for i in range(num_retries):
|
||||
if i > 0:
|
||||
logger.warning(
|
||||
"Cluster resources not detected or are 0. Attempt #"
|
||||
"%s...", i + 1)
|
||||
time.sleep(0.5)
|
||||
try:
|
||||
resources = ray.cluster_resources()
|
||||
except Exception:
|
||||
@@ -428,10 +432,8 @@ class RayTrialExecutor(TrialExecutor):
|
||||
# https://github.com/ray-project/ray/issues/4147
|
||||
logger.debug("Using resources for local machine.")
|
||||
resources = ResourceSpec().resolve(True).to_resource_dict()
|
||||
if not resources:
|
||||
logger.warning(
|
||||
"Cluster resources not detected or are 0. Retrying...")
|
||||
time.sleep(0.5)
|
||||
if resources:
|
||||
break
|
||||
|
||||
if not resources:
|
||||
# NOTE: This hides the possibility that Ray may be waiting for
|
||||
@@ -555,14 +557,14 @@ class RayTrialExecutor(TrialExecutor):
|
||||
"""Saves the trial's state to a checkpoint.
|
||||
|
||||
Args:
|
||||
trial (Trial): The state of this trial to be saved.
|
||||
trial (Trial): The trial to be saved.
|
||||
storage (str): Where to store the checkpoint. Defaults to
|
||||
PERSISTENT.
|
||||
result (dict): The state of this trial as a dictionary to be saved.
|
||||
If result is None, the trial's last result will be used.
|
||||
|
||||
Returns:
|
||||
Checkpoint future, or None if an Exception occurs.
|
||||
Checkpoint object, or None if an Exception occurs.
|
||||
"""
|
||||
result = result or trial.last_result
|
||||
|
||||
@@ -588,11 +590,17 @@ class RayTrialExecutor(TrialExecutor):
|
||||
"syncs by setting sync_on_checkpoint=False. Note that this "
|
||||
"might result in faulty trial restoration for some worker "
|
||||
"failure modes.")
|
||||
return checkpoint.value
|
||||
return checkpoint
|
||||
|
||||
def restore(self, trial, checkpoint=None):
|
||||
"""Restores training state from a given model checkpoint.
|
||||
|
||||
Args:
|
||||
trial (Trial): The trial to be restored.
|
||||
checkpoint (Checkpoint): The checkpoint to restore from. If None,
|
||||
the most recent PERSISTENT checkpoint is used. Defaults to
|
||||
None.
|
||||
|
||||
Raises:
|
||||
RuntimeError: This error is raised if no runner is found.
|
||||
AbortTrialExecution: This error is raised if the trial is
|
||||
|
||||
@@ -324,14 +324,12 @@ class PopulationBasedTraining(FIFOScheduler):
|
||||
reset_successful = trial_executor.reset_trial(trial, new_config,
|
||||
new_tag)
|
||||
if reset_successful:
|
||||
trial_executor.restore(
|
||||
trial, Checkpoint.from_object(new_state.last_checkpoint))
|
||||
trial_executor.restore(trial, new_state.last_checkpoint)
|
||||
else:
|
||||
trial_executor.stop_trial(trial, stop_logger=False)
|
||||
trial.config = new_config
|
||||
trial.experiment_tag = new_tag
|
||||
trial_executor.start_trial(
|
||||
trial, Checkpoint.from_object(new_state.last_checkpoint))
|
||||
trial_executor.start_trial(trial, new_state.last_checkpoint)
|
||||
|
||||
self._num_perturbations += 1
|
||||
# Transfer over the last perturbation time as well
|
||||
|
||||
@@ -28,14 +28,14 @@ class CheckpointManagerTest(unittest.TestCase):
|
||||
for i in range(3)
|
||||
]
|
||||
|
||||
with patch.object(checkpoint_manager, "delete") as \
|
||||
delete_mock:
|
||||
with patch.object(checkpoint_manager, "delete") as delete_mock:
|
||||
for j in range(3):
|
||||
checkpoint_manager.on_checkpoint(checkpoints[j])
|
||||
expected_deletes = 0 if j != 2 else 1
|
||||
self.assertEqual(delete_mock.call_count, expected_deletes, j)
|
||||
self.assertEqual(checkpoint_manager.newest_checkpoint,
|
||||
checkpoints[j])
|
||||
self.assertEqual(
|
||||
checkpoint_manager.newest_persistent_checkpoint,
|
||||
checkpoints[j])
|
||||
|
||||
best_checkpoints = checkpoint_manager.best_checkpoints()
|
||||
self.assertEqual(len(best_checkpoints), keep_checkpoints_num)
|
||||
@@ -59,8 +59,9 @@ class CheckpointManagerTest(unittest.TestCase):
|
||||
checkpoint_manager.on_checkpoint(checkpoints[j])
|
||||
expected_deletes = 0 if j != 3 else 1
|
||||
self.assertEqual(delete_mock.call_count, expected_deletes)
|
||||
self.assertEqual(checkpoint_manager.newest_checkpoint,
|
||||
checkpoints[j])
|
||||
self.assertEqual(
|
||||
checkpoint_manager.newest_persistent_checkpoint,
|
||||
checkpoints[j])
|
||||
|
||||
best_checkpoints = checkpoint_manager.best_checkpoints()
|
||||
self.assertEqual(len(best_checkpoints), keep_checkpoints_num)
|
||||
@@ -74,7 +75,7 @@ class CheckpointManagerTest(unittest.TestCase):
|
||||
keep_checkpoints_num = 4
|
||||
checkpoint_manager = self.checkpoint_manager(keep_checkpoints_num)
|
||||
checkpoints = [
|
||||
Checkpoint(Checkpoint.MEMORY, i, self.mock_result(i))
|
||||
Checkpoint(Checkpoint.PERSISTENT, i, self.mock_result(i))
|
||||
for i in range(16)
|
||||
]
|
||||
random.shuffle(checkpoints)
|
||||
@@ -92,15 +93,28 @@ class CheckpointManagerTest(unittest.TestCase):
|
||||
Tests that an error is logged when the associated result of the
|
||||
checkpoint has no checkpoint score attribute.
|
||||
"""
|
||||
keep_checkpoints_num = 1
|
||||
checkpoint_manager = self.checkpoint_manager(keep_checkpoints_num)
|
||||
checkpoint_manager = self.checkpoint_manager(keep_checkpoints_num=1)
|
||||
|
||||
no_attr_checkpoint = Checkpoint(Checkpoint.MEMORY, 0, {})
|
||||
no_attr_checkpoint = Checkpoint(Checkpoint.PERSISTENT, 0, {})
|
||||
with patch.object(logger, "error") as log_error_mock:
|
||||
checkpoint_manager.on_checkpoint(no_attr_checkpoint)
|
||||
log_error_mock.assert_called_once()
|
||||
# The newest checkpoint should still be set despite this error.
|
||||
assert checkpoint_manager.newest_checkpoint == no_attr_checkpoint
|
||||
self.assertEqual(checkpoint_manager.newest_persistent_checkpoint,
|
||||
no_attr_checkpoint)
|
||||
|
||||
def testOnMemoryCheckpoint(self):
|
||||
checkpoints = [
|
||||
Checkpoint(Checkpoint.MEMORY, 0, self.mock_result(0)),
|
||||
Checkpoint(Checkpoint.MEMORY, 0, self.mock_result(0))
|
||||
]
|
||||
checkpoint_manager = self.checkpoint_manager(keep_checkpoints_num=1)
|
||||
checkpoint_manager.on_checkpoint(checkpoints[0])
|
||||
checkpoint_manager.on_checkpoint(checkpoints[1])
|
||||
newest = checkpoint_manager.newest_memory_checkpoint
|
||||
|
||||
self.assertEqual(newest, checkpoints[1])
|
||||
self.assertEqual(checkpoint_manager.best_checkpoints(), [])
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
|
||||
@@ -555,6 +555,8 @@ tune.run(
|
||||
cluster.shutdown()
|
||||
|
||||
|
||||
# TODO(ujvl): Fix test.
|
||||
@pytest.mark.skip(reason="Not very consistent.")
|
||||
def test_cluster_interrupt(start_connected_cluster, tmpdir):
|
||||
"""Tests run_experiment on cluster shutdown with actual interrupt.
|
||||
|
||||
|
||||
@@ -51,6 +51,27 @@ class RayTrialExecutorTest(unittest.TestCase):
|
||||
self.trial_executor.stop_trial(trial)
|
||||
self.assertEqual(Trial.TERMINATED, trial.status)
|
||||
|
||||
def testSavePauseResumeRestore(self):
|
||||
"""Tests that pause checkpoint does not replace restore checkpoint."""
|
||||
trial = Trial("__fake")
|
||||
self.trial_executor.start_trial(trial)
|
||||
# Save
|
||||
checkpoint = self.trial_executor.save(trial, Checkpoint.PERSISTENT)
|
||||
self.assertEqual(Trial.RUNNING, trial.status)
|
||||
self.assertEqual(checkpoint.storage, Checkpoint.PERSISTENT)
|
||||
# Pause
|
||||
self.trial_executor.pause_trial(trial)
|
||||
self.assertEqual(Trial.PAUSED, trial.status)
|
||||
self.assertEqual(trial.checkpoint.storage, Checkpoint.MEMORY)
|
||||
# Resume
|
||||
self.trial_executor.start_trial(trial)
|
||||
self.assertEqual(Trial.RUNNING, trial.status)
|
||||
self.assertEqual(trial.checkpoint, checkpoint)
|
||||
# Restore
|
||||
self.trial_executor.restore(trial)
|
||||
self.trial_executor.stop_trial(trial)
|
||||
self.assertEqual(Trial.TERMINATED, trial.status)
|
||||
|
||||
def testStartFailure(self):
|
||||
_global_registry.register(TRAINABLE_CLASS, "asdf", None)
|
||||
trial = Trial("asdf", resources=Resources(1, 0))
|
||||
@@ -63,9 +84,9 @@ class RayTrialExecutorTest(unittest.TestCase):
|
||||
self.trial_executor.start_trial(trial)
|
||||
self.assertEqual(Trial.RUNNING, trial.status)
|
||||
self.trial_executor.fetch_result(trial)
|
||||
self.trial_executor.pause_trial(trial)
|
||||
checkpoint = self.trial_executor.pause_trial(trial)
|
||||
self.assertEqual(Trial.PAUSED, trial.status)
|
||||
self.trial_executor.start_trial(trial)
|
||||
self.trial_executor.start_trial(trial, checkpoint)
|
||||
self.assertEqual(Trial.RUNNING, trial.status)
|
||||
self.trial_executor.stop_trial(trial)
|
||||
self.assertEqual(Trial.TERMINATED, trial.status)
|
||||
|
||||
@@ -202,8 +202,8 @@ class TrialRunnerTest2(unittest.TestCase):
|
||||
runner.step()
|
||||
self.assertEqual(trials[0].status, Trial.RUNNING)
|
||||
self.assertEqual(ray.get(trials[0].runner.set_info.remote(1)), 1)
|
||||
path = runner.trial_executor.save(trials[0])
|
||||
kwargs["restore_path"] = path
|
||||
checkpoint = runner.trial_executor.save(trials[0])
|
||||
kwargs["restore_path"] = checkpoint.value
|
||||
|
||||
runner.add_trial(Trial("__fake", **kwargs))
|
||||
trials = runner.get_trials()
|
||||
@@ -216,7 +216,7 @@ class TrialRunnerTest2(unittest.TestCase):
|
||||
self.assertEqual(trials[0].status, Trial.TERMINATED)
|
||||
self.assertEqual(trials[1].status, Trial.RUNNING)
|
||||
self.assertEqual(ray.get(trials[1].runner.get_info.remote()), 1)
|
||||
self.addCleanup(os.remove, path)
|
||||
self.addCleanup(os.remove, checkpoint.value)
|
||||
|
||||
def testRestoreMetricsAfterCheckpointing(self):
|
||||
ray.init(num_cpus=1, num_gpus=1)
|
||||
@@ -230,9 +230,9 @@ class TrialRunnerTest2(unittest.TestCase):
|
||||
runner.step()
|
||||
self.assertEqual(trials[0].status, Trial.RUNNING)
|
||||
self.assertEqual(ray.get(trials[0].runner.set_info.remote(1)), 1)
|
||||
path = runner.trial_executor.save(trials[0])
|
||||
checkpoint = runner.trial_executor.save(trials[0])
|
||||
runner.trial_executor.stop_trial(trials[0])
|
||||
kwargs["restore_path"] = path
|
||||
kwargs["restore_path"] = checkpoint.value
|
||||
|
||||
runner.add_trial(Trial("__fake", **kwargs))
|
||||
trials = runner.get_trials()
|
||||
@@ -249,7 +249,7 @@ class TrialRunnerTest2(unittest.TestCase):
|
||||
self.assertEqual(trials[1].last_result["timesteps_since_restore"], 20)
|
||||
self.assertEqual(trials[1].last_result["iterations_since_restore"], 2)
|
||||
self.assertGreater(trials[1].last_result["time_since_restore"], 0)
|
||||
self.addCleanup(os.remove, path)
|
||||
self.addCleanup(os.remove, checkpoint.value)
|
||||
|
||||
def testCheckpointingAtEnd(self):
|
||||
ray.init(num_cpus=1, num_gpus=1)
|
||||
|
||||
@@ -200,7 +200,7 @@ class _MockTrialExecutor(TrialExecutor):
|
||||
pass
|
||||
|
||||
def save(self, trial, type=Checkpoint.PERSISTENT, result=None):
|
||||
return trial.trainable_name
|
||||
return Checkpoint(Checkpoint.PERSISTENT, trial.trainable_name, result)
|
||||
|
||||
def reset_trial(self, trial, new_config, new_experiment_tag):
|
||||
return False
|
||||
|
||||
@@ -194,11 +194,11 @@ class Trial:
|
||||
self.checkpoint_freq = checkpoint_freq
|
||||
self.checkpoint_at_end = checkpoint_at_end
|
||||
self.sync_on_checkpoint = sync_on_checkpoint
|
||||
newest_checkpoint = Checkpoint(Checkpoint.PERSISTENT, restore_path)
|
||||
self.checkpoint_manager = CheckpointManager(
|
||||
keep_checkpoints_num, checkpoint_score_attr,
|
||||
checkpoint_deleter(str(self), self.runner))
|
||||
self.checkpoint_manager.newest_checkpoint = newest_checkpoint
|
||||
checkpoint = Checkpoint(Checkpoint.PERSISTENT, restore_path)
|
||||
self.checkpoint_manager.newest_persistent_checkpoint = checkpoint
|
||||
|
||||
# Restoration fields
|
||||
self.restoring_from = None
|
||||
@@ -228,7 +228,16 @@ class Trial:
|
||||
|
||||
@property
|
||||
def checkpoint(self):
|
||||
return self.checkpoint_manager.newest_checkpoint
|
||||
"""Returns the most recent checkpoint.
|
||||
|
||||
If the trial is PAUSED, this is the most recent MEMORY checkpoint.
|
||||
Otherwise, it is the most recent PERSISTENT checkpoint.
|
||||
"""
|
||||
if self.status == Trial.PAUSED:
|
||||
assert self.checkpoint_manager.newest_memory_checkpoint.value
|
||||
return self.checkpoint_manager.newest_memory_checkpoint
|
||||
else:
|
||||
return self.checkpoint_manager.newest_persistent_checkpoint
|
||||
|
||||
@classmethod
|
||||
def generate_id(cls):
|
||||
@@ -351,7 +360,6 @@ class Trial:
|
||||
checkpoint (Checkpoint): Checkpoint taken.
|
||||
"""
|
||||
if checkpoint.storage == Checkpoint.MEMORY:
|
||||
# TODO(ujvl): Handle this separately to avoid restoration failure.
|
||||
self.checkpoint_manager.on_checkpoint(checkpoint)
|
||||
return
|
||||
if self.sync_on_checkpoint:
|
||||
@@ -364,7 +372,7 @@ class Trial:
|
||||
# checkpoint, so it should just be logged.
|
||||
logger.error(
|
||||
"Trial %s: An error occurred during the "
|
||||
"checkpoint pre-sync wait.", str(e))
|
||||
"checkpoint pre-sync wait - %s", self, str(e))
|
||||
# Force sync down and wait before tracking the new checkpoint.
|
||||
try:
|
||||
if self.result_logger.sync_down():
|
||||
|
||||
@@ -196,7 +196,7 @@ class TrialExecutor:
|
||||
|
||||
Assumes the trial is running.
|
||||
|
||||
Return:
|
||||
Returns:
|
||||
Result object for the trial.
|
||||
"""
|
||||
raise NotImplementedError
|
||||
@@ -219,7 +219,7 @@ class TrialExecutor:
|
||||
trial (Trial): Trial to be restored.
|
||||
checkpoint (Checkpoint): Checkpoint to restore from.
|
||||
|
||||
Return:
|
||||
Returns:
|
||||
False if error occurred, otherwise return True.
|
||||
"""
|
||||
raise NotImplementedError("Subclasses of TrialExecutor must provide "
|
||||
@@ -236,9 +236,8 @@ class TrialExecutor:
|
||||
PERSISTENT.
|
||||
result (dict): The state of this trial as a dictionary to be saved.
|
||||
|
||||
Return:
|
||||
A Python object if storage==Checkpoint.MEMORY otherwise
|
||||
a path to the checkpoint.
|
||||
Returns:
|
||||
A Checkpoint object.
|
||||
"""
|
||||
raise NotImplementedError("Subclasses of TrialExecutor must provide "
|
||||
"save() method")
|
||||
@@ -249,7 +248,7 @@ class TrialExecutor:
|
||||
Args:
|
||||
trial (Trial): The state of this trial to be saved.
|
||||
|
||||
Return:
|
||||
Returns:
|
||||
A dict that maps ExportFormats to successfully exported models.
|
||||
"""
|
||||
raise NotImplementedError("Subclasses of TrialExecutor must provide "
|
||||
|
||||
Reference in New Issue
Block a user