[tune] Prevent deletion of checkpoint from user-initiated resto… (#7501)

* Fix restore bug

* Add test

* Lint

* Indent
This commit is contained in:
Ujval Misra
2020-03-09 15:53:10 -07:00
committed by GitHub
parent 08d4cb3822
commit 023d4c02a9
4 changed files with 34 additions and 23 deletions
+13 -18
View File
@@ -7,7 +7,6 @@ try:
except ImportError:
pd = None
from ray.tune.checkpoint_manager import Checkpoint
from ray.tune.error import TuneError
from ray.tune.result import EXPR_PROGRESS_FILE, EXPR_PARAM_FILE,\
CONFIG_PREFIX, TRAINING_ITERATION
@@ -43,7 +42,6 @@ class Analysis:
metric (str): Key for trial info to order on.
If None, uses last result.
mode (str): One of [min, max].
"""
rows = self._retrieve_rows(metric=metric, mode=mode)
all_configs = self.get_all_configs(prefix=True)
@@ -59,7 +57,6 @@ class Analysis:
Args:
metric (str): Key for trial info to order on.
mode (str): One of [min, max].
"""
rows = self._retrieve_rows(metric=metric, mode=mode)
all_configs = self.get_all_configs()
@@ -73,7 +70,6 @@ class Analysis:
Args:
metric (str): Key for trial info to order on.
mode (str): One of [min, max].
"""
df = self.dataframe(metric=metric, mode=mode)
if mode == "max":
@@ -98,7 +94,7 @@ class Analysis:
def get_all_configs(self, prefix=False):
"""Returns a list of all configurations.
Parameters:
Args:
prefix (bool): If True, flattens the config dict
and prepends `config/`.
"""
@@ -120,32 +116,30 @@ class Analysis:
return self._configs
def get_trial_checkpoints_paths(self, trial, metric=TRAINING_ITERATION):
"""Returns a list of [path, metric] lists for all disk checkpoints of
a trial.
"""Gets paths and metrics of all persistent checkpoints of a trial.
Arguments:
trial(Trial): The log directory of a trial, or a trial instance.
Args:
trial (Trial): The log directory of a trial, or a trial instance.
metric (str): key for trial info to return, e.g. "mean_accuracy".
"training_iteration" is used by default.
"""
Returns:
A list of [path, metric] lists for all persistent checkpoints of
the trial.
"""
if isinstance(trial, str):
trial_dir = os.path.expanduser(trial)
# get checkpoints from logdir
# Get checkpoints from logdir.
chkpt_df = TrainableUtil.get_checkpoints_paths(trial_dir)
# join with trial dataframe to get metrics
# Join with trial dataframe to get metrics.
trial_df = self.trial_dataframes[trial_dir]
path_metric_df = chkpt_df.merge(
trial_df, on="training_iteration", how="inner")
return path_metric_df[["chkpt_path", metric]].values.tolist()
elif isinstance(trial, Trial):
checkpoints = trial.checkpoint_manager.best_checkpoints()
# TODO(ujvl): Remove condition once the checkpoint manager is
# modified to only track PERSISTENT checkpoints.
return [[c.value, c.result[metric]] for c in checkpoints
if c.storage == Checkpoint.PERSISTENT]
return [[c.value, c.result[metric]] for c in checkpoints]
else:
raise ValueError("trial should be a string or a Trial instance.")
@@ -198,7 +192,8 @@ class ExperimentAnalysis(Analysis):
"""Initializer.
Args:
experiment_path (str): Path to where experiment is located.
experiment_checkpoint_path (str): Path to where experiment is
located.
trials (list|None): List of trials that can be accessed via
`analysis.trials`.
"""
@@ -239,7 +239,6 @@ class TrialRunnerTest2(unittest.TestCase):
runner.step() # Start trial
self.assertEqual(trials[0].status, Trial.RUNNING)
self.assertEqual(ray.get(trials[0].runner.set_info.remote(1)), 1)
# checkpoint = runner.trial_executor.save(trials[0])
runner.step() # Process result, dispatch save
runner.step() # Process save
runner.trial_executor.stop_trial(trials[0])
@@ -58,6 +58,22 @@ class TuneRestoreTest(unittest.TestCase):
},
)
def testPostRestoreCheckpointExistence(self):
"""Tests that checkpoint restored from is not deleted post-restore."""
self.assertTrue(os.path.isfile(self.checkpoint_path))
tune.run(
"PG",
name="TuneRestoreTest",
stop={"training_iteration": 2},
checkpoint_freq=1,
keep_checkpoints_num=1,
restore=self.checkpoint_path,
config={
"env": "CartPole-v0",
},
)
self.assertTrue(os.path.isfile(self.checkpoint_path))
class TuneExampleTest(unittest.TestCase):
def setUp(self):
+5 -4
View File
@@ -205,10 +205,9 @@ class Trial:
self.checkpoint_manager = CheckpointManager(
keep_checkpoints_num, checkpoint_score_attr,
checkpoint_deleter(self._trainable_name(), self.runner))
checkpoint = Checkpoint(Checkpoint.PERSISTENT, restore_path)
self.checkpoint_manager.newest_persistent_checkpoint = checkpoint
# Restoration fields
self.restore_path = restore_path
self.restoring_from = None
self.num_failures = 0
@@ -243,8 +242,10 @@ class Trial:
if self.status == Trial.PAUSED:
assert self.checkpoint_manager.newest_memory_checkpoint.value
return self.checkpoint_manager.newest_memory_checkpoint
else:
return self.checkpoint_manager.newest_persistent_checkpoint
checkpoint = self.checkpoint_manager.newest_persistent_checkpoint
if checkpoint.value is None:
checkpoint = Checkpoint(Checkpoint.PERSISTENT, self.restore_path)
return checkpoint
@classmethod
def generate_id(cls):