mirror of
https://github.com/wassname/ray.git
synced 2026-07-04 06:10:23 +08:00
[tune] auto infer metrics (#10663)
Co-authored-by: Kai Fricke <krfricke@users.noreply.github.com> Co-authored-by: Kai Fricke <kai@anyscale.com>
This commit is contained in:
@@ -66,8 +66,7 @@ def tune_mnist_mxnet(num_samples=10, num_epochs=10):
|
||||
reduction_factor=2)
|
||||
|
||||
reporter = CLIReporter(
|
||||
parameter_columns=["layer_1_size", "layer_2_size", "lr", "batch_size"],
|
||||
metric_columns=["loss", "mean_accuracy", "training_iteration"])
|
||||
parameter_columns=["layer_1_size", "layer_2_size", "lr", "batch_size"])
|
||||
|
||||
tune.run(
|
||||
partial(train_mnist_mxnet, mnist=mnist_data, num_epochs=num_epochs),
|
||||
|
||||
@@ -1,10 +1,12 @@
|
||||
from __future__ import print_function
|
||||
|
||||
import collections
|
||||
import numpy as np
|
||||
import time
|
||||
|
||||
from ray.tune.result import (EPISODE_REWARD_MEAN, MEAN_ACCURACY, MEAN_LOSS,
|
||||
TRAINING_ITERATION, TIME_TOTAL_S, TIMESTEPS_TOTAL)
|
||||
TRAINING_ITERATION, TIME_TOTAL_S, TIMESTEPS_TOTAL,
|
||||
AUTO_RESULT_KEYS)
|
||||
from ray.tune.utils import unflattened_lookup
|
||||
|
||||
try:
|
||||
@@ -51,6 +53,10 @@ class ProgressReporter:
|
||||
class TuneReporterBase(ProgressReporter):
|
||||
"""Abstract base class for the default Tune reporters.
|
||||
|
||||
If metric_columns is not overriden, Tune will attempt to automatically
|
||||
infer the metrics being outputted, up to 'infer_limit' number of
|
||||
metrics.
|
||||
|
||||
Args:
|
||||
metric_columns (dict[str, str]|list[str]): Names of metrics to
|
||||
include in progress table. If this is a dict, the keys should
|
||||
@@ -80,17 +86,25 @@ class TuneReporterBase(ProgressReporter):
|
||||
TIMESTEPS_TOTAL: "ts",
|
||||
EPISODE_REWARD_MEAN: "reward",
|
||||
})
|
||||
VALID_SUMMARY_TYPES = {
|
||||
int, float, np.float32, np.float64, np.int32, np.int64,
|
||||
type(None)
|
||||
}
|
||||
|
||||
def __init__(self,
|
||||
metric_columns=None,
|
||||
parameter_columns=None,
|
||||
max_progress_rows=20,
|
||||
max_error_rows=20,
|
||||
max_report_frequency=5):
|
||||
max_report_frequency=5,
|
||||
infer_limit=3):
|
||||
self._metrics_override = metric_columns is not None
|
||||
self._inferred_metrics = {}
|
||||
self._metric_columns = metric_columns or self.DEFAULT_COLUMNS.copy()
|
||||
self._parameter_columns = parameter_columns or []
|
||||
self._max_progress_rows = max_progress_rows
|
||||
self._max_error_rows = max_error_rows
|
||||
self._infer_limit = infer_limit
|
||||
|
||||
self._max_report_freqency = max_report_frequency
|
||||
self._last_report_time = 0
|
||||
@@ -110,6 +124,7 @@ class TuneReporterBase(ProgressReporter):
|
||||
representation (str): Representation to use in table. Defaults to
|
||||
`metric`.
|
||||
"""
|
||||
self._metrics_override = True
|
||||
if metric in self._metric_columns:
|
||||
raise ValueError("Column {} already exists.".format(metric))
|
||||
|
||||
@@ -161,6 +176,9 @@ class TuneReporterBase(ProgressReporter):
|
||||
fmt (str): Table format. See `tablefmt` in tabulate API.
|
||||
delim (str): Delimiter between messages.
|
||||
"""
|
||||
if not self._metrics_override:
|
||||
user_metrics = self._infer_user_metrics(trials, self._infer_limit)
|
||||
self._metric_columns.update(user_metrics)
|
||||
messages = ["== Status ==", memory_debug_str(), *sys_info]
|
||||
if done:
|
||||
max_progress = None
|
||||
@@ -178,6 +196,24 @@ class TuneReporterBase(ProgressReporter):
|
||||
messages.append(trial_errors_str(trials, fmt=fmt, max_rows=max_error))
|
||||
return delim.join(messages) + delim
|
||||
|
||||
def _infer_user_metrics(self, trials, limit=4):
|
||||
"""Try to infer the metrics to print out."""
|
||||
if len(self._inferred_metrics) >= limit:
|
||||
return self._inferred_metrics
|
||||
self._inferred_metrics = {}
|
||||
for t in trials:
|
||||
if not t.last_result:
|
||||
continue
|
||||
for metric, value in t.last_result.items():
|
||||
if metric not in self.DEFAULT_COLUMNS:
|
||||
if metric not in AUTO_RESULT_KEYS:
|
||||
if type(value) in self.VALID_SUMMARY_TYPES:
|
||||
self._inferred_metrics[metric] = metric
|
||||
|
||||
if len(self._inferred_metrics) >= limit:
|
||||
return self._inferred_metrics
|
||||
return self._inferred_metrics
|
||||
|
||||
|
||||
class JupyterNotebookReporter(TuneReporterBase):
|
||||
"""Jupyter notebook-friendly Reporter that can update display in-place.
|
||||
|
||||
@@ -29,6 +29,9 @@ EPISODE_REWARD_MEAN = "episode_reward_mean"
|
||||
# (Optional) Mean loss for training iteration
|
||||
MEAN_LOSS = "mean_loss"
|
||||
|
||||
# (Optional) Mean loss for training iteration
|
||||
NEG_MEAN_LOSS = "neg_mean_loss"
|
||||
|
||||
# (Optional) Mean accuracy for training iteration
|
||||
MEAN_ACCURACY = "mean_accuracy"
|
||||
|
||||
@@ -61,6 +64,26 @@ DEFAULT_EXPERIMENT_INFO_KEYS = ("trainable_name", EXPERIMENT_TAG, TRIAL_ID)
|
||||
DEFAULT_RESULT_KEYS = (TRAINING_ITERATION, TIME_TOTAL_S, TIMESTEPS_TOTAL,
|
||||
MEAN_ACCURACY, MEAN_LOSS)
|
||||
|
||||
# Make sure this doesn't regress
|
||||
AUTO_RESULT_KEYS = (
|
||||
TRAINING_ITERATION,
|
||||
TIME_TOTAL_S,
|
||||
EPISODES_TOTAL,
|
||||
TIMESTEPS_TOTAL,
|
||||
NODE_IP,
|
||||
HOSTNAME,
|
||||
PID,
|
||||
TIME_TOTAL_S,
|
||||
TIME_THIS_ITER_S,
|
||||
"timestamp",
|
||||
"experiment_id",
|
||||
"date",
|
||||
"time_since_restore",
|
||||
"iterations_since_restore",
|
||||
"timesteps_since_restore",
|
||||
"config",
|
||||
)
|
||||
|
||||
# __duplicate__ is a magic keyword used internally to
|
||||
# avoid double-logging results when using the Function API.
|
||||
RESULT_DUPLICATE = "__duplicate__"
|
||||
|
||||
@@ -3,9 +3,10 @@ import collections
|
||||
import os
|
||||
import unittest
|
||||
from unittest.mock import MagicMock, Mock
|
||||
|
||||
from ray import tune
|
||||
from ray.test_utils import run_string_as_driver
|
||||
from ray.tune.trial import Trial
|
||||
from ray.tune.result import AUTO_RESULT_KEYS
|
||||
from ray.tune.progress_reporter import (CLIReporter, _fair_filter_trials,
|
||||
trial_progress_str)
|
||||
|
||||
@@ -233,6 +234,43 @@ class ProgressReporterTest(unittest.TestCase):
|
||||
reporter.add_metric_column("foo", "bar")
|
||||
self.assertIn("foo", reporter._metric_columns)
|
||||
|
||||
def testInfer(self):
|
||||
reporter = CLIReporter()
|
||||
test_result = dict(foo_result=1, baz_result=4123, bar_result="testme")
|
||||
|
||||
def test(config):
|
||||
for i in range(3):
|
||||
tune.report(**test_result)
|
||||
|
||||
analysis = tune.run(test, num_samples=3)
|
||||
all_trials = analysis.trials
|
||||
inferred_results = reporter._infer_user_metrics(all_trials)
|
||||
for metric in inferred_results:
|
||||
self.assertNotIn(metric, AUTO_RESULT_KEYS)
|
||||
self.assertTrue(metric in test_result)
|
||||
|
||||
class TestReporter(CLIReporter):
|
||||
_output = []
|
||||
|
||||
def __init__(self, *args, **kwargs):
|
||||
super().__init__(*args, **kwargs)
|
||||
self._max_report_freqency = 0
|
||||
|
||||
def report(self, *args, **kwargs):
|
||||
progress_str = self._progress_str(*args, **kwargs)
|
||||
self._output.append(progress_str)
|
||||
|
||||
reporter = TestReporter()
|
||||
analysis = tune.run(test, num_samples=3, progress_reporter=reporter)
|
||||
found = {k: False for k in test_result}
|
||||
for output in reporter._output:
|
||||
for key in test_result:
|
||||
if key in output:
|
||||
found[key] = True
|
||||
assert found["foo_result"]
|
||||
assert found["baz_result"]
|
||||
assert not found["bar_result"]
|
||||
|
||||
def testProgressStr(self):
|
||||
trials = []
|
||||
for i in range(5):
|
||||
@@ -285,7 +323,6 @@ class ProgressReporterTest(unittest.TestCase):
|
||||
}, {"a": "A"},
|
||||
fmt="psql",
|
||||
max_rows=3)
|
||||
print(prog3)
|
||||
assert prog3 == EXPECTED_RESULT_3
|
||||
|
||||
def testEndToEndReporting(self):
|
||||
|
||||
Reference in New Issue
Block a user