mirror of
https://github.com/wassname/ray.git
synced 2026-06-29 07:41:09 +08:00
[tune] Prevent deletion of checkpoint from user-initiated resto… (#7501)
* Fix restore bug * Add test * Lint * Indent
This commit is contained in:
@@ -7,7 +7,6 @@ try:
|
||||
except ImportError:
|
||||
pd = None
|
||||
|
||||
from ray.tune.checkpoint_manager import Checkpoint
|
||||
from ray.tune.error import TuneError
|
||||
from ray.tune.result import EXPR_PROGRESS_FILE, EXPR_PARAM_FILE,\
|
||||
CONFIG_PREFIX, TRAINING_ITERATION
|
||||
@@ -43,7 +42,6 @@ class Analysis:
|
||||
metric (str): Key for trial info to order on.
|
||||
If None, uses last result.
|
||||
mode (str): One of [min, max].
|
||||
|
||||
"""
|
||||
rows = self._retrieve_rows(metric=metric, mode=mode)
|
||||
all_configs = self.get_all_configs(prefix=True)
|
||||
@@ -59,7 +57,6 @@ class Analysis:
|
||||
Args:
|
||||
metric (str): Key for trial info to order on.
|
||||
mode (str): One of [min, max].
|
||||
|
||||
"""
|
||||
rows = self._retrieve_rows(metric=metric, mode=mode)
|
||||
all_configs = self.get_all_configs()
|
||||
@@ -73,7 +70,6 @@ class Analysis:
|
||||
Args:
|
||||
metric (str): Key for trial info to order on.
|
||||
mode (str): One of [min, max].
|
||||
|
||||
"""
|
||||
df = self.dataframe(metric=metric, mode=mode)
|
||||
if mode == "max":
|
||||
@@ -98,7 +94,7 @@ class Analysis:
|
||||
def get_all_configs(self, prefix=False):
|
||||
"""Returns a list of all configurations.
|
||||
|
||||
Parameters:
|
||||
Args:
|
||||
prefix (bool): If True, flattens the config dict
|
||||
and prepends `config/`.
|
||||
"""
|
||||
@@ -120,32 +116,30 @@ class Analysis:
|
||||
return self._configs
|
||||
|
||||
def get_trial_checkpoints_paths(self, trial, metric=TRAINING_ITERATION):
|
||||
"""Returns a list of [path, metric] lists for all disk checkpoints of
|
||||
a trial.
|
||||
"""Gets paths and metrics of all persistent checkpoints of a trial.
|
||||
|
||||
Arguments:
|
||||
trial(Trial): The log directory of a trial, or a trial instance.
|
||||
Args:
|
||||
trial (Trial): The log directory of a trial, or a trial instance.
|
||||
metric (str): key for trial info to return, e.g. "mean_accuracy".
|
||||
"training_iteration" is used by default.
|
||||
"""
|
||||
|
||||
Returns:
|
||||
A list of [path, metric] lists for all persistent checkpoints of
|
||||
the trial.
|
||||
"""
|
||||
if isinstance(trial, str):
|
||||
trial_dir = os.path.expanduser(trial)
|
||||
|
||||
# get checkpoints from logdir
|
||||
# Get checkpoints from logdir.
|
||||
chkpt_df = TrainableUtil.get_checkpoints_paths(trial_dir)
|
||||
|
||||
# join with trial dataframe to get metrics
|
||||
# Join with trial dataframe to get metrics.
|
||||
trial_df = self.trial_dataframes[trial_dir]
|
||||
path_metric_df = chkpt_df.merge(
|
||||
trial_df, on="training_iteration", how="inner")
|
||||
return path_metric_df[["chkpt_path", metric]].values.tolist()
|
||||
elif isinstance(trial, Trial):
|
||||
checkpoints = trial.checkpoint_manager.best_checkpoints()
|
||||
# TODO(ujvl): Remove condition once the checkpoint manager is
|
||||
# modified to only track PERSISTENT checkpoints.
|
||||
return [[c.value, c.result[metric]] for c in checkpoints
|
||||
if c.storage == Checkpoint.PERSISTENT]
|
||||
return [[c.value, c.result[metric]] for c in checkpoints]
|
||||
else:
|
||||
raise ValueError("trial should be a string or a Trial instance.")
|
||||
|
||||
@@ -198,7 +192,8 @@ class ExperimentAnalysis(Analysis):
|
||||
"""Initializer.
|
||||
|
||||
Args:
|
||||
experiment_path (str): Path to where experiment is located.
|
||||
experiment_checkpoint_path (str): Path to where experiment is
|
||||
located.
|
||||
trials (list|None): List of trials that can be accessed via
|
||||
`analysis.trials`.
|
||||
"""
|
||||
|
||||
@@ -239,7 +239,6 @@ class TrialRunnerTest2(unittest.TestCase):
|
||||
runner.step() # Start trial
|
||||
self.assertEqual(trials[0].status, Trial.RUNNING)
|
||||
self.assertEqual(ray.get(trials[0].runner.set_info.remote(1)), 1)
|
||||
# checkpoint = runner.trial_executor.save(trials[0])
|
||||
runner.step() # Process result, dispatch save
|
||||
runner.step() # Process save
|
||||
runner.trial_executor.stop_trial(trials[0])
|
||||
|
||||
@@ -58,6 +58,22 @@ class TuneRestoreTest(unittest.TestCase):
|
||||
},
|
||||
)
|
||||
|
||||
def testPostRestoreCheckpointExistence(self):
|
||||
"""Tests that checkpoint restored from is not deleted post-restore."""
|
||||
self.assertTrue(os.path.isfile(self.checkpoint_path))
|
||||
tune.run(
|
||||
"PG",
|
||||
name="TuneRestoreTest",
|
||||
stop={"training_iteration": 2},
|
||||
checkpoint_freq=1,
|
||||
keep_checkpoints_num=1,
|
||||
restore=self.checkpoint_path,
|
||||
config={
|
||||
"env": "CartPole-v0",
|
||||
},
|
||||
)
|
||||
self.assertTrue(os.path.isfile(self.checkpoint_path))
|
||||
|
||||
|
||||
class TuneExampleTest(unittest.TestCase):
|
||||
def setUp(self):
|
||||
|
||||
@@ -205,10 +205,9 @@ class Trial:
|
||||
self.checkpoint_manager = CheckpointManager(
|
||||
keep_checkpoints_num, checkpoint_score_attr,
|
||||
checkpoint_deleter(self._trainable_name(), self.runner))
|
||||
checkpoint = Checkpoint(Checkpoint.PERSISTENT, restore_path)
|
||||
self.checkpoint_manager.newest_persistent_checkpoint = checkpoint
|
||||
|
||||
# Restoration fields
|
||||
self.restore_path = restore_path
|
||||
self.restoring_from = None
|
||||
self.num_failures = 0
|
||||
|
||||
@@ -243,8 +242,10 @@ class Trial:
|
||||
if self.status == Trial.PAUSED:
|
||||
assert self.checkpoint_manager.newest_memory_checkpoint.value
|
||||
return self.checkpoint_manager.newest_memory_checkpoint
|
||||
else:
|
||||
return self.checkpoint_manager.newest_persistent_checkpoint
|
||||
checkpoint = self.checkpoint_manager.newest_persistent_checkpoint
|
||||
if checkpoint.value is None:
|
||||
checkpoint = Checkpoint(Checkpoint.PERSISTENT, self.restore_path)
|
||||
return checkpoint
|
||||
|
||||
@classmethod
|
||||
def generate_id(cls):
|
||||
|
||||
Reference in New Issue
Block a user