mirror of
https://github.com/wassname/ray.git
synced 2026-06-28 09:45:24 +08:00
[tune] ExperimentalAnalysis in-memory cache (#5962)
This commit is contained in:
committed by
Richard Liaw
parent
7d33e9949b
commit
7aa06fb25c
@@ -183,6 +183,89 @@ class ExperimentAnalysis(Analysis):
|
||||
super(ExperimentAnalysis, self).__init__(
|
||||
os.path.dirname(experiment_checkpoint_path))
|
||||
|
||||
def get_best_trial(self, metric, mode="max", scope="all"):
|
||||
"""Retrieve the best trial object.
|
||||
|
||||
Compares all trials' scores on `metric`.
|
||||
|
||||
Args:
|
||||
metric (str): Key for trial info to order on.
|
||||
mode (str): One of [min, max].
|
||||
scope (str): One of [all, last]. If `scope=last`, only look at
|
||||
each trial's final step for `metric`, and compare across
|
||||
trials based on `mode=[min,max]`. If `scope=all`, find each
|
||||
trial's min/max score for `metric` based on `mode`, and
|
||||
compare trials based on `mode=[min,max]`.
|
||||
"""
|
||||
if mode not in ["max", "min"]:
|
||||
raise ValueError(
|
||||
"ExperimentAnalysis: attempting to get best trial for "
|
||||
"metric {} for mode {} not in [\"max\", \"min\"]".format(
|
||||
metric, mode))
|
||||
if scope not in ["all", "last"]:
|
||||
raise ValueError(
|
||||
"ExperimentAnalysis: attempting to get best trial for "
|
||||
"metric {} for scope {} not in [\"all\", \"last\"]".format(
|
||||
metric, scope))
|
||||
best_trial = None
|
||||
best_metric_score = None
|
||||
for trial in self.trials:
|
||||
if metric not in trial.metric_analysis:
|
||||
continue
|
||||
|
||||
if scope == "last":
|
||||
metric_score = trial.metric_analysis[metric]["last"]
|
||||
else:
|
||||
metric_score = trial.metric_analysis[metric][mode]
|
||||
|
||||
if best_metric_score is None:
|
||||
best_metric_score = metric_score
|
||||
best_trial = trial
|
||||
continue
|
||||
|
||||
if (mode == "max") and (best_metric_score < metric_score):
|
||||
best_metric_score = metric_score
|
||||
best_trial = trial
|
||||
elif (mode == "min") and (best_metric_score > metric_score):
|
||||
best_metric_score = metric_score
|
||||
best_trial = trial
|
||||
|
||||
return best_trial
|
||||
|
||||
def get_best_config(self, metric, mode="max", scope="all"):
|
||||
"""Retrieve the best config corresponding to the trial.
|
||||
|
||||
Compares all trials' scores on `metric`.
|
||||
|
||||
Args:
|
||||
metric (str): Key for trial info to order on.
|
||||
mode (str): One of [min, max].
|
||||
scope (str): One of [all, last]. If `scope=last`, only look at
|
||||
each trial's final step for `metric`, and compare across
|
||||
trials based on `mode=[min,max]`. If `scope=all`, find each
|
||||
trial's min/max score for `metric` based on `mode`, and
|
||||
compare trials based on `mode=[min,max]`.
|
||||
"""
|
||||
best_trial = self.get_best_trial(metric, mode, scope)
|
||||
return best_trial.config if best_trial else None
|
||||
|
||||
def get_best_logdir(self, metric, mode="max", scope="all"):
|
||||
"""Retrieve the logdir corresponding to the best trial.
|
||||
|
||||
Compares all trials' scores on `metric`.
|
||||
|
||||
Args:
|
||||
metric (str): Key for trial info to order on.
|
||||
mode (str): One of [min, max].
|
||||
scope (str): One of [all, last]. If `scope=last`, only look at
|
||||
each trial's final step for `metric`, and compare across
|
||||
trials based on `mode=[min,max]`. If `scope=all`, find each
|
||||
trial's min/max score for `metric` based on `mode`, and
|
||||
compare trials based on `mode=[min,max]`.
|
||||
"""
|
||||
best_trial = self.get_best_trial(metric, mode, scope)
|
||||
return best_trial.logdir if best_trial else None
|
||||
|
||||
def stats(self):
|
||||
"""Returns a dictionary of the statistics of the experiment."""
|
||||
return self._experiment_state.get("stats")
|
||||
|
||||
@@ -10,13 +10,70 @@ import os
|
||||
import pandas as pd
|
||||
|
||||
import ray
|
||||
from ray.tune import run, sample_from, Analysis
|
||||
from ray.tune import run, Trainable, sample_from, Analysis, grid_search
|
||||
from ray.tune.examples.async_hyperband_example import MyTrainableClass
|
||||
|
||||
|
||||
class ExperimentAnalysisInMemorySuite(unittest.TestCase):
|
||||
def setUp(self):
|
||||
class MockTrainable(Trainable):
|
||||
def _setup(self, config):
|
||||
self.id = config["id"]
|
||||
self.idx = 0
|
||||
self.scores_dict = {
|
||||
0: [5, 0],
|
||||
1: [4, 1],
|
||||
2: [2, 8],
|
||||
3: [9, 6],
|
||||
4: [7, 3]
|
||||
}
|
||||
|
||||
def _train(self):
|
||||
val = self.scores_dict[self.id][self.idx]
|
||||
self.idx += 1
|
||||
return {"score": val}
|
||||
|
||||
def _save(self, checkpoint_dir):
|
||||
pass
|
||||
|
||||
def _restore(self, checkpoint_path):
|
||||
pass
|
||||
|
||||
self.MockTrainable = MockTrainable
|
||||
ray.init(local_mode=False, num_cpus=1)
|
||||
|
||||
def tearDown(self):
|
||||
shutil.rmtree(self.test_dir, ignore_errors=True)
|
||||
ray.shutdown()
|
||||
|
||||
def testCompareTrials(self):
|
||||
self.test_dir = tempfile.mkdtemp()
|
||||
scores_all = [5, 4, 2, 9, 7, 0, 1, 8, 6, 3]
|
||||
scores_last = scores_all[5:]
|
||||
|
||||
ea = run(
|
||||
self.MockTrainable,
|
||||
name="analysis_exp",
|
||||
local_dir=self.test_dir,
|
||||
stop={"training_iteration": 2},
|
||||
num_samples=1,
|
||||
config={"id": grid_search(list(range(5)))})
|
||||
|
||||
max_all = ea.get_best_trial("score",
|
||||
"max").metric_analysis["score"]["max"]
|
||||
min_all = ea.get_best_trial("score",
|
||||
"min").metric_analysis["score"]["min"]
|
||||
max_last = ea.get_best_trial("score", "max",
|
||||
"last").metric_analysis["score"]["last"]
|
||||
self.assertEqual(max_all, max(scores_all))
|
||||
self.assertEqual(min_all, min(scores_all))
|
||||
self.assertEqual(max_last, max(scores_last))
|
||||
self.assertNotEqual(max_last, max(scores_all))
|
||||
|
||||
|
||||
class ExperimentAnalysisSuite(unittest.TestCase):
|
||||
def setUp(self):
|
||||
ray.init(local_mode=True)
|
||||
ray.init(local_mode=False)
|
||||
self.test_dir = tempfile.mkdtemp()
|
||||
self.test_name = "analysis_exp"
|
||||
self.num_samples = 10
|
||||
|
||||
@@ -10,8 +10,10 @@ import uuid
|
||||
import time
|
||||
import tempfile
|
||||
import os
|
||||
from numbers import Number
|
||||
from ray.tune import TuneError
|
||||
from ray.tune.logger import pretty_print, UnifiedLogger
|
||||
from ray.tune.util import flatten_dict
|
||||
# NOTE(rkn): We import ray.tune.registry here instead of importing the names we
|
||||
# need because there are cyclic imports that may cause specific names to not
|
||||
# have been defined yet. See https://github.com/ray-project/ray/issues/1716.
|
||||
@@ -156,6 +158,9 @@ class Trial(object):
|
||||
self.checkpoint_freq = checkpoint_freq
|
||||
self.checkpoint_at_end = checkpoint_at_end
|
||||
|
||||
# stores in memory max/min/last result for each metric by trial
|
||||
self.metric_analysis = {}
|
||||
|
||||
self.history = []
|
||||
self.keep_checkpoints_num = keep_checkpoints_num
|
||||
self._cmp_greater = not checkpoint_score_attr.startswith("min-")
|
||||
@@ -325,6 +330,20 @@ class Trial(object):
|
||||
self.last_result = result
|
||||
self.last_update_time = time.time()
|
||||
self.result_logger.on_result(self.last_result)
|
||||
for metric, value in flatten_dict(result).items():
|
||||
if isinstance(value, Number):
|
||||
if metric not in self.metric_analysis:
|
||||
self.metric_analysis[metric] = {
|
||||
"max": value,
|
||||
"min": value,
|
||||
"last": value
|
||||
}
|
||||
else:
|
||||
self.metric_analysis[metric]["max"] = max(
|
||||
value, self.metric_analysis[metric]["max"])
|
||||
self.metric_analysis[metric]["min"] = min(
|
||||
value, self.metric_analysis[metric]["min"])
|
||||
self.metric_analysis[metric]["last"] = value
|
||||
|
||||
def compare_checkpoints(self, attr_mean):
|
||||
"""Compares two checkpoints based on the attribute attr_mean param.
|
||||
|
||||
Reference in New Issue
Block a user