mirror of
https://github.com/wassname/ray.git
synced 2026-07-01 02:00:46 +08:00
[tune] get checkpoints paths for a trial after tuning (#6643)
This commit is contained in:
@@ -7,8 +7,12 @@ try:
|
||||
except ImportError:
|
||||
pd = None
|
||||
|
||||
from ray.tune.checkpoint_manager import Checkpoint
|
||||
from ray.tune.error import TuneError
|
||||
from ray.tune.result import EXPR_PROGRESS_FILE, EXPR_PARAM_FILE, CONFIG_PREFIX
|
||||
from ray.tune.result import EXPR_PROGRESS_FILE, EXPR_PARAM_FILE,\
|
||||
CONFIG_PREFIX, TRAINING_ITERATION
|
||||
from ray.tune.trial import Trial
|
||||
from ray.tune.trainable import TrainableUtil
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
@@ -115,6 +119,36 @@ class Analysis:
|
||||
"Couldn't read config from {} paths".format(fail_count))
|
||||
return self._configs
|
||||
|
||||
def get_trial_checkpoints_paths(self, trial, metric=TRAINING_ITERATION):
|
||||
"""Returns a list of [path, metric] lists for all disk checkpoints of
|
||||
a trial.
|
||||
|
||||
Arguments:
|
||||
trial(Trial): The log directory of a trial, or a trial instance.
|
||||
metric (str): key for trial info to return, e.g. "mean_accuracy".
|
||||
"training_iteration" is used by default.
|
||||
"""
|
||||
|
||||
if isinstance(trial, str):
|
||||
trial_dir = os.path.expanduser(trial)
|
||||
|
||||
# get checkpoints from logdir
|
||||
chkpt_df = TrainableUtil.get_checkpoints_paths(trial_dir)
|
||||
|
||||
# join with trial dataframe to get metrics
|
||||
trial_df = self.trial_dataframes[trial_dir]
|
||||
path_metric_df = chkpt_df.merge(
|
||||
trial_df, on="training_iteration", how="inner")
|
||||
return path_metric_df[["chkpt_path", metric]].values.tolist()
|
||||
elif isinstance(trial, Trial):
|
||||
checkpoints = trial.checkpoint_manager.best_checkpoints()
|
||||
# TODO(ujvl): Remove condition once the checkpoint manager is
|
||||
# modified to only track PERSISTENT checkpoints.
|
||||
return [[c.value, c.result[metric]] for c in checkpoints
|
||||
if c.storage == Checkpoint.PERSISTENT]
|
||||
else:
|
||||
raise ValueError("trial should be a string or a Trial instance.")
|
||||
|
||||
def _retrieve_rows(self, metric=None, mode=None):
|
||||
assert mode is None or mode in ["max", "min"]
|
||||
rows = {}
|
||||
|
||||
@@ -118,9 +118,23 @@ if __name__ == "__main__":
|
||||
verbose=1,
|
||||
stop=stopper.stop,
|
||||
export_formats=[ExportFormat.MODEL],
|
||||
checkpoint_score_attr="mean_accuracy",
|
||||
checkpoint_freq=5,
|
||||
keep_checkpoints_num=4,
|
||||
num_samples=4,
|
||||
config={
|
||||
"lr": tune.uniform(0.001, 1),
|
||||
"momentum": tune.uniform(0.001, 1),
|
||||
})
|
||||
# __tune_end__
|
||||
|
||||
best_trial = analysis.get_best_trial("mean_accuracy")
|
||||
best_checkpoint = max(
|
||||
analysis.get_trial_checkpoints_paths(best_trial, "mean_accuracy"))
|
||||
restored_trainable = PytorchTrainble()
|
||||
restored_trainable.restore(best_checkpoint[0])
|
||||
best_model = restored_trainable.model
|
||||
# Note that test only runs on a small random set of the test data, thus the
|
||||
# accuracy may be different from metrics shown in tuning process.
|
||||
test_acc = test(best_model, get_data_loaders()[1])
|
||||
print("best model accuracy: ", test_acc)
|
||||
|
||||
@@ -30,6 +30,7 @@ class ExperimentAnalysisSuite(unittest.TestCase):
|
||||
name=self.test_name,
|
||||
local_dir=self.test_dir,
|
||||
stop={"training_iteration": 1},
|
||||
checkpoint_freq=1,
|
||||
num_samples=self.num_samples,
|
||||
config={
|
||||
"width": sample_from(
|
||||
@@ -69,6 +70,37 @@ class ExperimentAnalysisSuite(unittest.TestCase):
|
||||
self.assertTrue(logdir2.startswith(self.test_path))
|
||||
self.assertNotEquals(logdir, logdir2)
|
||||
|
||||
def testGetTrialCheckpointsPathsByTrial(self):
|
||||
best_trial = self.ea.get_best_trial(self.metric)
|
||||
checkpoints_metrics = self.ea.get_trial_checkpoints_paths(best_trial)
|
||||
logdir = self.ea.get_best_logdir(self.metric)
|
||||
expected_path = os.path.join(logdir, "checkpoint_1", "checkpoint")
|
||||
assert checkpoints_metrics[0][0] == expected_path
|
||||
assert checkpoints_metrics[0][1] == 1
|
||||
|
||||
def testGetTrialCheckpointsPathsByPath(self):
|
||||
logdir = self.ea.get_best_logdir(self.metric)
|
||||
checkpoints_metrics = self.ea.get_trial_checkpoints_paths(logdir)
|
||||
expected_path = os.path.join(logdir, "checkpoint_1/", "checkpoint")
|
||||
assert checkpoints_metrics[0][0] == expected_path
|
||||
assert checkpoints_metrics[0][1] == 1
|
||||
|
||||
def testGetTrialCheckpointsPathsWithMetricByTrial(self):
|
||||
best_trial = self.ea.get_best_trial(self.metric)
|
||||
paths = self.ea.get_trial_checkpoints_paths(best_trial, self.metric)
|
||||
logdir = self.ea.get_best_logdir(self.metric)
|
||||
expected_path = os.path.join(logdir, "checkpoint_1", "checkpoint")
|
||||
assert paths[0][0] == expected_path
|
||||
assert paths[0][1] == best_trial.metric_analysis[self.metric]["last"]
|
||||
|
||||
def testGetTrialCheckpointsPathsWithMetricByPath(self):
|
||||
best_trial = self.ea.get_best_trial(self.metric)
|
||||
logdir = self.ea.get_best_logdir(self.metric)
|
||||
paths = self.ea.get_trial_checkpoints_paths(best_trial, self.metric)
|
||||
expected_path = os.path.join(logdir, "checkpoint_1", "checkpoint")
|
||||
assert paths[0][0] == expected_path
|
||||
assert paths[0][1] == best_trial.metric_analysis[self.metric]["last"]
|
||||
|
||||
def testAllDataframes(self):
|
||||
dataframes = self.ea.trial_dataframes
|
||||
self.assertTrue(len(dataframes) == self.num_samples)
|
||||
|
||||
@@ -3,8 +3,10 @@ from datetime import datetime
|
||||
import copy
|
||||
import io
|
||||
import logging
|
||||
import glob
|
||||
import os
|
||||
import pickle
|
||||
import pandas as pd
|
||||
from six import string_types
|
||||
import shutil
|
||||
import tempfile
|
||||
@@ -73,6 +75,36 @@ class TrainableUtil:
|
||||
# Drop marker in directory to identify it as a checkpoint dir.
|
||||
open(os.path.join(checkpoint_dir, ".is_checkpoint"), "a").close()
|
||||
|
||||
@staticmethod
|
||||
def get_checkpoints_paths(logdir):
|
||||
""" Finds the checkpoints within a specific folder.
|
||||
|
||||
Returns a pandas DataFrame of training iterations and checkpoint
|
||||
paths within a specific folder.
|
||||
|
||||
Raises:
|
||||
FileNotFoundError if the directory is not found.
|
||||
"""
|
||||
marker_paths = glob.glob(
|
||||
os.path.join(logdir, "checkpoint_*/.is_checkpoint"))
|
||||
iter_chkpt_pairs = []
|
||||
for marker_path in marker_paths:
|
||||
chkpt_dir = os.path.dirname(marker_path)
|
||||
metadata_file = glob.glob(
|
||||
os.path.join(chkpt_dir, "*.tune_metadata"))
|
||||
if len(metadata_file) != 1:
|
||||
raise ValueError(
|
||||
"{} has zero or more than one tune_metadata.".format(
|
||||
chkpt_dir))
|
||||
|
||||
chkpt_path = metadata_file[0][:-len(".tune_metadata")]
|
||||
chkpt_iter = int(chkpt_dir[chkpt_dir.rfind("_") + 1:])
|
||||
iter_chkpt_pairs.append([chkpt_iter, chkpt_path])
|
||||
|
||||
chkpt_df = pd.DataFrame(
|
||||
iter_chkpt_pairs, columns=["training_iteration", "chkpt_path"])
|
||||
return chkpt_df
|
||||
|
||||
|
||||
class Trainable:
|
||||
"""Abstract class for trainable models, functions, etc.
|
||||
|
||||
Reference in New Issue
Block a user