diff --git a/python/ray/tune/analysis/experiment_analysis.py b/python/ray/tune/analysis/experiment_analysis.py index 2621c1415..854ec5307 100644 --- a/python/ray/tune/analysis/experiment_analysis.py +++ b/python/ray/tune/analysis/experiment_analysis.py @@ -7,7 +7,6 @@ try: except ImportError: pd = None -from ray.tune.checkpoint_manager import Checkpoint from ray.tune.error import TuneError from ray.tune.result import EXPR_PROGRESS_FILE, EXPR_PARAM_FILE,\ CONFIG_PREFIX, TRAINING_ITERATION @@ -43,7 +42,6 @@ class Analysis: metric (str): Key for trial info to order on. If None, uses last result. mode (str): One of [min, max]. - """ rows = self._retrieve_rows(metric=metric, mode=mode) all_configs = self.get_all_configs(prefix=True) @@ -59,7 +57,6 @@ class Analysis: Args: metric (str): Key for trial info to order on. mode (str): One of [min, max]. - """ rows = self._retrieve_rows(metric=metric, mode=mode) all_configs = self.get_all_configs() @@ -73,7 +70,6 @@ class Analysis: Args: metric (str): Key for trial info to order on. mode (str): One of [min, max]. - """ df = self.dataframe(metric=metric, mode=mode) if mode == "max": @@ -98,7 +94,7 @@ class Analysis: def get_all_configs(self, prefix=False): """Returns a list of all configurations. - Parameters: + Args: prefix (bool): If True, flattens the config dict and prepends `config/`. """ @@ -120,32 +116,30 @@ class Analysis: return self._configs def get_trial_checkpoints_paths(self, trial, metric=TRAINING_ITERATION): - """Returns a list of [path, metric] lists for all disk checkpoints of - a trial. + """Gets paths and metrics of all persistent checkpoints of a trial. - Arguments: - trial(Trial): The log directory of a trial, or a trial instance. + Args: + trial (Trial): The log directory of a trial, or a trial instance. metric (str): key for trial info to return, e.g. "mean_accuracy". "training_iteration" is used by default. - """ + Returns: + A list of [path, metric] lists for all persistent checkpoints of + the trial. + """ if isinstance(trial, str): trial_dir = os.path.expanduser(trial) - - # get checkpoints from logdir + # Get checkpoints from logdir. chkpt_df = TrainableUtil.get_checkpoints_paths(trial_dir) - # join with trial dataframe to get metrics + # Join with trial dataframe to get metrics. trial_df = self.trial_dataframes[trial_dir] path_metric_df = chkpt_df.merge( trial_df, on="training_iteration", how="inner") return path_metric_df[["chkpt_path", metric]].values.tolist() elif isinstance(trial, Trial): checkpoints = trial.checkpoint_manager.best_checkpoints() - # TODO(ujvl): Remove condition once the checkpoint manager is - # modified to only track PERSISTENT checkpoints. - return [[c.value, c.result[metric]] for c in checkpoints - if c.storage == Checkpoint.PERSISTENT] + return [[c.value, c.result[metric]] for c in checkpoints] else: raise ValueError("trial should be a string or a Trial instance.") @@ -198,7 +192,8 @@ class ExperimentAnalysis(Analysis): """Initializer. Args: - experiment_path (str): Path to where experiment is located. + experiment_checkpoint_path (str): Path to where experiment is + located. trials (list|None): List of trials that can be accessed via `analysis.trials`. """ diff --git a/python/ray/tune/tests/test_trial_runner_2.py b/python/ray/tune/tests/test_trial_runner_2.py index e02b20eea..11e04457d 100644 --- a/python/ray/tune/tests/test_trial_runner_2.py +++ b/python/ray/tune/tests/test_trial_runner_2.py @@ -239,7 +239,6 @@ class TrialRunnerTest2(unittest.TestCase): runner.step() # Start trial self.assertEqual(trials[0].status, Trial.RUNNING) self.assertEqual(ray.get(trials[0].runner.set_info.remote(1)), 1) - # checkpoint = runner.trial_executor.save(trials[0]) runner.step() # Process result, dispatch save runner.step() # Process save runner.trial_executor.stop_trial(trials[0]) diff --git a/python/ray/tune/tests/test_tune_restore.py b/python/ray/tune/tests/test_tune_restore.py index 726a82e98..257a20897 100644 --- a/python/ray/tune/tests/test_tune_restore.py +++ b/python/ray/tune/tests/test_tune_restore.py @@ -58,6 +58,22 @@ class TuneRestoreTest(unittest.TestCase): }, ) + def testPostRestoreCheckpointExistence(self): + """Tests that checkpoint restored from is not deleted post-restore.""" + self.assertTrue(os.path.isfile(self.checkpoint_path)) + tune.run( + "PG", + name="TuneRestoreTest", + stop={"training_iteration": 2}, + checkpoint_freq=1, + keep_checkpoints_num=1, + restore=self.checkpoint_path, + config={ + "env": "CartPole-v0", + }, + ) + self.assertTrue(os.path.isfile(self.checkpoint_path)) + class TuneExampleTest(unittest.TestCase): def setUp(self): diff --git a/python/ray/tune/trial.py b/python/ray/tune/trial.py index 575e49cc0..63a11001d 100644 --- a/python/ray/tune/trial.py +++ b/python/ray/tune/trial.py @@ -205,10 +205,9 @@ class Trial: self.checkpoint_manager = CheckpointManager( keep_checkpoints_num, checkpoint_score_attr, checkpoint_deleter(self._trainable_name(), self.runner)) - checkpoint = Checkpoint(Checkpoint.PERSISTENT, restore_path) - self.checkpoint_manager.newest_persistent_checkpoint = checkpoint # Restoration fields + self.restore_path = restore_path self.restoring_from = None self.num_failures = 0 @@ -243,8 +242,10 @@ class Trial: if self.status == Trial.PAUSED: assert self.checkpoint_manager.newest_memory_checkpoint.value return self.checkpoint_manager.newest_memory_checkpoint - else: - return self.checkpoint_manager.newest_persistent_checkpoint + checkpoint = self.checkpoint_manager.newest_persistent_checkpoint + if checkpoint.value is None: + checkpoint = Checkpoint(Checkpoint.PERSISTENT, self.restore_path) + return checkpoint @classmethod def generate_id(cls):