[tune] ExperimentalAnalysis in-memory cache (#5962)

This commit is contained in:
Hersh Godse
2019-11-15 12:47:50 -08:00
committed by Richard Liaw
parent 7d33e9949b
commit 7aa06fb25c
3 changed files with 161 additions and 2 deletions
@@ -183,6 +183,89 @@ class ExperimentAnalysis(Analysis):
super(ExperimentAnalysis, self).__init__(
os.path.dirname(experiment_checkpoint_path))
def get_best_trial(self, metric, mode="max", scope="all"):
"""Retrieve the best trial object.
Compares all trials' scores on `metric`.
Args:
metric (str): Key for trial info to order on.
mode (str): One of [min, max].
scope (str): One of [all, last]. If `scope=last`, only look at
each trial's final step for `metric`, and compare across
trials based on `mode=[min,max]`. If `scope=all`, find each
trial's min/max score for `metric` based on `mode`, and
compare trials based on `mode=[min,max]`.
"""
if mode not in ["max", "min"]:
raise ValueError(
"ExperimentAnalysis: attempting to get best trial for "
"metric {} for mode {} not in [\"max\", \"min\"]".format(
metric, mode))
if scope not in ["all", "last"]:
raise ValueError(
"ExperimentAnalysis: attempting to get best trial for "
"metric {} for scope {} not in [\"all\", \"last\"]".format(
metric, scope))
best_trial = None
best_metric_score = None
for trial in self.trials:
if metric not in trial.metric_analysis:
continue
if scope == "last":
metric_score = trial.metric_analysis[metric]["last"]
else:
metric_score = trial.metric_analysis[metric][mode]
if best_metric_score is None:
best_metric_score = metric_score
best_trial = trial
continue
if (mode == "max") and (best_metric_score < metric_score):
best_metric_score = metric_score
best_trial = trial
elif (mode == "min") and (best_metric_score > metric_score):
best_metric_score = metric_score
best_trial = trial
return best_trial
def get_best_config(self, metric, mode="max", scope="all"):
"""Retrieve the best config corresponding to the trial.
Compares all trials' scores on `metric`.
Args:
metric (str): Key for trial info to order on.
mode (str): One of [min, max].
scope (str): One of [all, last]. If `scope=last`, only look at
each trial's final step for `metric`, and compare across
trials based on `mode=[min,max]`. If `scope=all`, find each
trial's min/max score for `metric` based on `mode`, and
compare trials based on `mode=[min,max]`.
"""
best_trial = self.get_best_trial(metric, mode, scope)
return best_trial.config if best_trial else None
def get_best_logdir(self, metric, mode="max", scope="all"):
"""Retrieve the logdir corresponding to the best trial.
Compares all trials' scores on `metric`.
Args:
metric (str): Key for trial info to order on.
mode (str): One of [min, max].
scope (str): One of [all, last]. If `scope=last`, only look at
each trial's final step for `metric`, and compare across
trials based on `mode=[min,max]`. If `scope=all`, find each
trial's min/max score for `metric` based on `mode`, and
compare trials based on `mode=[min,max]`.
"""
best_trial = self.get_best_trial(metric, mode, scope)
return best_trial.logdir if best_trial else None
def stats(self):
"""Returns a dictionary of the statistics of the experiment."""
return self._experiment_state.get("stats")
@@ -10,13 +10,70 @@ import os
import pandas as pd
import ray
from ray.tune import run, sample_from, Analysis
from ray.tune import run, Trainable, sample_from, Analysis, grid_search
from ray.tune.examples.async_hyperband_example import MyTrainableClass
class ExperimentAnalysisInMemorySuite(unittest.TestCase):
def setUp(self):
class MockTrainable(Trainable):
def _setup(self, config):
self.id = config["id"]
self.idx = 0
self.scores_dict = {
0: [5, 0],
1: [4, 1],
2: [2, 8],
3: [9, 6],
4: [7, 3]
}
def _train(self):
val = self.scores_dict[self.id][self.idx]
self.idx += 1
return {"score": val}
def _save(self, checkpoint_dir):
pass
def _restore(self, checkpoint_path):
pass
self.MockTrainable = MockTrainable
ray.init(local_mode=False, num_cpus=1)
def tearDown(self):
shutil.rmtree(self.test_dir, ignore_errors=True)
ray.shutdown()
def testCompareTrials(self):
self.test_dir = tempfile.mkdtemp()
scores_all = [5, 4, 2, 9, 7, 0, 1, 8, 6, 3]
scores_last = scores_all[5:]
ea = run(
self.MockTrainable,
name="analysis_exp",
local_dir=self.test_dir,
stop={"training_iteration": 2},
num_samples=1,
config={"id": grid_search(list(range(5)))})
max_all = ea.get_best_trial("score",
"max").metric_analysis["score"]["max"]
min_all = ea.get_best_trial("score",
"min").metric_analysis["score"]["min"]
max_last = ea.get_best_trial("score", "max",
"last").metric_analysis["score"]["last"]
self.assertEqual(max_all, max(scores_all))
self.assertEqual(min_all, min(scores_all))
self.assertEqual(max_last, max(scores_last))
self.assertNotEqual(max_last, max(scores_all))
class ExperimentAnalysisSuite(unittest.TestCase):
def setUp(self):
ray.init(local_mode=True)
ray.init(local_mode=False)
self.test_dir = tempfile.mkdtemp()
self.test_name = "analysis_exp"
self.num_samples = 10
+19
View File
@@ -10,8 +10,10 @@ import uuid
import time
import tempfile
import os
from numbers import Number
from ray.tune import TuneError
from ray.tune.logger import pretty_print, UnifiedLogger
from ray.tune.util import flatten_dict
# NOTE(rkn): We import ray.tune.registry here instead of importing the names we
# need because there are cyclic imports that may cause specific names to not
# have been defined yet. See https://github.com/ray-project/ray/issues/1716.
@@ -156,6 +158,9 @@ class Trial(object):
self.checkpoint_freq = checkpoint_freq
self.checkpoint_at_end = checkpoint_at_end
# stores in memory max/min/last result for each metric by trial
self.metric_analysis = {}
self.history = []
self.keep_checkpoints_num = keep_checkpoints_num
self._cmp_greater = not checkpoint_score_attr.startswith("min-")
@@ -325,6 +330,20 @@ class Trial(object):
self.last_result = result
self.last_update_time = time.time()
self.result_logger.on_result(self.last_result)
for metric, value in flatten_dict(result).items():
if isinstance(value, Number):
if metric not in self.metric_analysis:
self.metric_analysis[metric] = {
"max": value,
"min": value,
"last": value
}
else:
self.metric_analysis[metric]["max"] = max(
value, self.metric_analysis[metric]["max"])
self.metric_analysis[metric]["min"] = min(
value, self.metric_analysis[metric]["min"])
self.metric_analysis[metric]["last"] = value
def compare_checkpoints(self, attr_mean):
"""Compares two checkpoints based on the attribute attr_mean param.