[tune] auto infer metrics (#10663)

Co-authored-by: Kai Fricke <krfricke@users.noreply.github.com>
Co-authored-by: Kai Fricke <kai@anyscale.com>
This commit is contained in:
Richard Liaw
2020-09-09 09:53:47 -07:00
committed by GitHub
parent 3501ea396c
commit 153813936b
5 changed files with 107 additions and 6 deletions
+1 -2
View File
@@ -66,8 +66,7 @@ def tune_mnist_mxnet(num_samples=10, num_epochs=10):
reduction_factor=2)
reporter = CLIReporter(
parameter_columns=["layer_1_size", "layer_2_size", "lr", "batch_size"],
metric_columns=["loss", "mean_accuracy", "training_iteration"])
parameter_columns=["layer_1_size", "layer_2_size", "lr", "batch_size"])
tune.run(
partial(train_mnist_mxnet, mnist=mnist_data, num_epochs=num_epochs),
+38 -2
View File
@@ -1,10 +1,12 @@
from __future__ import print_function
import collections
import numpy as np
import time
from ray.tune.result import (EPISODE_REWARD_MEAN, MEAN_ACCURACY, MEAN_LOSS,
TRAINING_ITERATION, TIME_TOTAL_S, TIMESTEPS_TOTAL)
TRAINING_ITERATION, TIME_TOTAL_S, TIMESTEPS_TOTAL,
AUTO_RESULT_KEYS)
from ray.tune.utils import unflattened_lookup
try:
@@ -51,6 +53,10 @@ class ProgressReporter:
class TuneReporterBase(ProgressReporter):
"""Abstract base class for the default Tune reporters.
If metric_columns is not overriden, Tune will attempt to automatically
infer the metrics being outputted, up to 'infer_limit' number of
metrics.
Args:
metric_columns (dict[str, str]|list[str]): Names of metrics to
include in progress table. If this is a dict, the keys should
@@ -80,17 +86,25 @@ class TuneReporterBase(ProgressReporter):
TIMESTEPS_TOTAL: "ts",
EPISODE_REWARD_MEAN: "reward",
})
VALID_SUMMARY_TYPES = {
int, float, np.float32, np.float64, np.int32, np.int64,
type(None)
}
def __init__(self,
metric_columns=None,
parameter_columns=None,
max_progress_rows=20,
max_error_rows=20,
max_report_frequency=5):
max_report_frequency=5,
infer_limit=3):
self._metrics_override = metric_columns is not None
self._inferred_metrics = {}
self._metric_columns = metric_columns or self.DEFAULT_COLUMNS.copy()
self._parameter_columns = parameter_columns or []
self._max_progress_rows = max_progress_rows
self._max_error_rows = max_error_rows
self._infer_limit = infer_limit
self._max_report_freqency = max_report_frequency
self._last_report_time = 0
@@ -110,6 +124,7 @@ class TuneReporterBase(ProgressReporter):
representation (str): Representation to use in table. Defaults to
`metric`.
"""
self._metrics_override = True
if metric in self._metric_columns:
raise ValueError("Column {} already exists.".format(metric))
@@ -161,6 +176,9 @@ class TuneReporterBase(ProgressReporter):
fmt (str): Table format. See `tablefmt` in tabulate API.
delim (str): Delimiter between messages.
"""
if not self._metrics_override:
user_metrics = self._infer_user_metrics(trials, self._infer_limit)
self._metric_columns.update(user_metrics)
messages = ["== Status ==", memory_debug_str(), *sys_info]
if done:
max_progress = None
@@ -178,6 +196,24 @@ class TuneReporterBase(ProgressReporter):
messages.append(trial_errors_str(trials, fmt=fmt, max_rows=max_error))
return delim.join(messages) + delim
def _infer_user_metrics(self, trials, limit=4):
"""Try to infer the metrics to print out."""
if len(self._inferred_metrics) >= limit:
return self._inferred_metrics
self._inferred_metrics = {}
for t in trials:
if not t.last_result:
continue
for metric, value in t.last_result.items():
if metric not in self.DEFAULT_COLUMNS:
if metric not in AUTO_RESULT_KEYS:
if type(value) in self.VALID_SUMMARY_TYPES:
self._inferred_metrics[metric] = metric
if len(self._inferred_metrics) >= limit:
return self._inferred_metrics
return self._inferred_metrics
class JupyterNotebookReporter(TuneReporterBase):
"""Jupyter notebook-friendly Reporter that can update display in-place.
+23
View File
@@ -29,6 +29,9 @@ EPISODE_REWARD_MEAN = "episode_reward_mean"
# (Optional) Mean loss for training iteration
MEAN_LOSS = "mean_loss"
# (Optional) Mean loss for training iteration
NEG_MEAN_LOSS = "neg_mean_loss"
# (Optional) Mean accuracy for training iteration
MEAN_ACCURACY = "mean_accuracy"
@@ -61,6 +64,26 @@ DEFAULT_EXPERIMENT_INFO_KEYS = ("trainable_name", EXPERIMENT_TAG, TRIAL_ID)
DEFAULT_RESULT_KEYS = (TRAINING_ITERATION, TIME_TOTAL_S, TIMESTEPS_TOTAL,
MEAN_ACCURACY, MEAN_LOSS)
# Make sure this doesn't regress
AUTO_RESULT_KEYS = (
TRAINING_ITERATION,
TIME_TOTAL_S,
EPISODES_TOTAL,
TIMESTEPS_TOTAL,
NODE_IP,
HOSTNAME,
PID,
TIME_TOTAL_S,
TIME_THIS_ITER_S,
"timestamp",
"experiment_id",
"date",
"time_since_restore",
"iterations_since_restore",
"timesteps_since_restore",
"config",
)
# __duplicate__ is a magic keyword used internally to
# avoid double-logging results when using the Function API.
RESULT_DUPLICATE = "__duplicate__"
@@ -3,9 +3,10 @@ import collections
import os
import unittest
from unittest.mock import MagicMock, Mock
from ray import tune
from ray.test_utils import run_string_as_driver
from ray.tune.trial import Trial
from ray.tune.result import AUTO_RESULT_KEYS
from ray.tune.progress_reporter import (CLIReporter, _fair_filter_trials,
trial_progress_str)
@@ -233,6 +234,43 @@ class ProgressReporterTest(unittest.TestCase):
reporter.add_metric_column("foo", "bar")
self.assertIn("foo", reporter._metric_columns)
def testInfer(self):
reporter = CLIReporter()
test_result = dict(foo_result=1, baz_result=4123, bar_result="testme")
def test(config):
for i in range(3):
tune.report(**test_result)
analysis = tune.run(test, num_samples=3)
all_trials = analysis.trials
inferred_results = reporter._infer_user_metrics(all_trials)
for metric in inferred_results:
self.assertNotIn(metric, AUTO_RESULT_KEYS)
self.assertTrue(metric in test_result)
class TestReporter(CLIReporter):
_output = []
def __init__(self, *args, **kwargs):
super().__init__(*args, **kwargs)
self._max_report_freqency = 0
def report(self, *args, **kwargs):
progress_str = self._progress_str(*args, **kwargs)
self._output.append(progress_str)
reporter = TestReporter()
analysis = tune.run(test, num_samples=3, progress_reporter=reporter)
found = {k: False for k in test_result}
for output in reporter._output:
for key in test_result:
if key in output:
found[key] = True
assert found["foo_result"]
assert found["baz_result"]
assert not found["bar_result"]
def testProgressStr(self):
trials = []
for i in range(5):
@@ -285,7 +323,6 @@ class ProgressReporterTest(unittest.TestCase):
}, {"a": "A"},
fmt="psql",
max_rows=3)
print(prog3)
assert prog3 == EXPECTED_RESULT_3
def testEndToEndReporting(self):