mirror of
https://github.com/wassname/ray.git
synced 2026-07-01 18:22:26 +08:00
[tune] sort running trials to top in status table (#10926)
Co-authored-by: Richard Liaw <rliaw@berkeley.edu>
This commit is contained in:
@@ -9,6 +9,7 @@ import time
|
||||
from ray.tune.result import (EPISODE_REWARD_MEAN, MEAN_ACCURACY, MEAN_LOSS,
|
||||
TRAINING_ITERATION, TIME_TOTAL_S, TIMESTEPS_TOTAL,
|
||||
AUTO_RESULT_KEYS)
|
||||
from ray.tune.trial import Trial
|
||||
from ray.tune.utils import unflattened_lookup
|
||||
|
||||
try:
|
||||
@@ -77,6 +78,11 @@ class TuneReporterBase(ProgressReporter):
|
||||
corresponding to each trial. Defaults to 20.
|
||||
max_report_frequency (int): Maximum report frequency in seconds.
|
||||
Defaults to 5s.
|
||||
infer_limit (int): Maximum number of metrics to automatically infer
|
||||
from tune results.
|
||||
metric (str): Metric used to determine best current trial.
|
||||
mode (str): One of [min, max]. Determines whether objective is
|
||||
minimizing or maximizing the metric attribute.
|
||||
"""
|
||||
|
||||
# Truncated representations of column names (to accommodate small screens).
|
||||
@@ -100,7 +106,9 @@ class TuneReporterBase(ProgressReporter):
|
||||
max_progress_rows=20,
|
||||
max_error_rows=20,
|
||||
max_report_frequency=5,
|
||||
infer_limit=3):
|
||||
infer_limit=3,
|
||||
metric=None,
|
||||
mode=None):
|
||||
self._total_samples = total_samples
|
||||
self._metrics_override = metric_columns is not None
|
||||
self._inferred_metrics = {}
|
||||
@@ -113,6 +121,22 @@ class TuneReporterBase(ProgressReporter):
|
||||
self._max_report_freqency = max_report_frequency
|
||||
self._last_report_time = 0
|
||||
|
||||
self._metric = metric
|
||||
self._mode = mode
|
||||
|
||||
def set_search_properties(self, metric, mode):
|
||||
if self._metric and metric:
|
||||
return False
|
||||
if self._mode and mode:
|
||||
return False
|
||||
|
||||
if metric:
|
||||
self._metric = metric
|
||||
if mode:
|
||||
self._mode = mode
|
||||
|
||||
return True
|
||||
|
||||
def set_total_samples(self, total_samples):
|
||||
self._total_samples = total_samples
|
||||
|
||||
@@ -193,6 +217,13 @@ class TuneReporterBase(ProgressReporter):
|
||||
else:
|
||||
max_progress = self._max_progress_rows
|
||||
max_error = self._max_error_rows
|
||||
|
||||
current_best_trial, metric = self._current_best_trial(trials)
|
||||
if current_best_trial:
|
||||
messages.append(
|
||||
best_trial_str(current_best_trial, metric,
|
||||
self._parameter_columns))
|
||||
|
||||
messages.append(
|
||||
trial_progress_str(
|
||||
trials,
|
||||
@@ -202,6 +233,7 @@ class TuneReporterBase(ProgressReporter):
|
||||
fmt=fmt,
|
||||
max_rows=max_progress))
|
||||
messages.append(trial_errors_str(trials, fmt=fmt, max_rows=max_error))
|
||||
|
||||
return delim.join(messages) + delim
|
||||
|
||||
def _infer_user_metrics(self, trials, limit=4):
|
||||
@@ -222,6 +254,34 @@ class TuneReporterBase(ProgressReporter):
|
||||
return self._inferred_metrics
|
||||
return self._inferred_metrics
|
||||
|
||||
def _current_best_trial(self, trials):
|
||||
if not trials:
|
||||
return None, None
|
||||
|
||||
metric, mode = self._metric, self._mode
|
||||
# If no metric has been set, see if exactly one has been reported
|
||||
# and use that one. `mode` must still be set.
|
||||
if not metric:
|
||||
if len(self._inferred_metrics) == 1:
|
||||
metric = list(self._inferred_metrics.keys())[0]
|
||||
|
||||
if not metric or not mode:
|
||||
return None, metric
|
||||
|
||||
metric_op = 1. if mode == "max" else -1.
|
||||
best_metric = float("-inf")
|
||||
best_trial = None
|
||||
for t in trials:
|
||||
if not t.last_result:
|
||||
continue
|
||||
if metric not in t.last_result:
|
||||
continue
|
||||
if not best_metric or \
|
||||
t.last_result[metric] * metric_op > best_metric:
|
||||
best_metric = t.last_result[metric]
|
||||
best_trial = t
|
||||
return best_trial, metric
|
||||
|
||||
|
||||
class JupyterNotebookReporter(TuneReporterBase):
|
||||
"""Jupyter notebook-friendly Reporter that can update display in-place.
|
||||
@@ -254,10 +314,14 @@ class JupyterNotebookReporter(TuneReporterBase):
|
||||
total_samples=None,
|
||||
max_progress_rows=20,
|
||||
max_error_rows=20,
|
||||
max_report_frequency=5):
|
||||
super(JupyterNotebookReporter, self).__init__(
|
||||
metric_columns, parameter_columns, total_samples,
|
||||
max_progress_rows, max_error_rows, max_report_frequency)
|
||||
max_report_frequency=5,
|
||||
infer_limit=3,
|
||||
metric=None,
|
||||
mode=None):
|
||||
super(JupyterNotebookReporter,
|
||||
self).__init__(metric_columns, parameter_columns, total_samples,
|
||||
max_progress_rows, max_error_rows,
|
||||
max_report_frequency, infer_limit, metric, mode)
|
||||
self._overwrite = overwrite
|
||||
|
||||
def report(self, trials, done, *sys_info):
|
||||
@@ -299,11 +363,15 @@ class CLIReporter(TuneReporterBase):
|
||||
total_samples=None,
|
||||
max_progress_rows=20,
|
||||
max_error_rows=20,
|
||||
max_report_frequency=5):
|
||||
max_report_frequency=5,
|
||||
infer_limit=3,
|
||||
metric=None,
|
||||
mode=None):
|
||||
|
||||
super(CLIReporter, self).__init__(metric_columns, parameter_columns,
|
||||
total_samples, max_progress_rows,
|
||||
max_error_rows, max_report_frequency)
|
||||
super(CLIReporter,
|
||||
self).__init__(metric_columns, parameter_columns, total_samples,
|
||||
max_progress_rows, max_error_rows,
|
||||
max_report_frequency, infer_limit, metric, mode)
|
||||
|
||||
def report(self, trials, done, *sys_info):
|
||||
print(self._progress_str(trials, done, *sys_info))
|
||||
@@ -376,13 +444,20 @@ def trial_progress_str(trials,
|
||||
for state in sorted(trials_by_state)
|
||||
]
|
||||
|
||||
state_tbl_order = [
|
||||
Trial.RUNNING, Trial.PAUSED, Trial.PENDING, Trial.TERMINATED,
|
||||
Trial.ERROR
|
||||
]
|
||||
|
||||
max_rows = max_rows or float("inf")
|
||||
if num_trials > max_rows:
|
||||
# TODO(ujvl): suggestion for users to view more rows.
|
||||
trials_by_state_trunc = _fair_filter_trials(trials_by_state, max_rows)
|
||||
trials = []
|
||||
overflow_strs = []
|
||||
for state in sorted(trials_by_state):
|
||||
for state in state_tbl_order:
|
||||
if state not in trials_by_state:
|
||||
continue
|
||||
trials += trials_by_state_trunc[state]
|
||||
num = len(trials_by_state[state]) - len(
|
||||
trials_by_state_trunc[state])
|
||||
@@ -393,8 +468,13 @@ def trial_progress_str(trials,
|
||||
overflow_str = ", ".join(overflow_strs)
|
||||
else:
|
||||
overflow = False
|
||||
trials = []
|
||||
for state in state_tbl_order:
|
||||
if state not in trials_by_state:
|
||||
continue
|
||||
trials += trials_by_state[state]
|
||||
|
||||
if total_samples >= sys.maxsize:
|
||||
if total_samples and total_samples >= sys.maxsize:
|
||||
total_samples = "infinite"
|
||||
|
||||
messages.append("Number of trials: {}{} ({})".format(
|
||||
@@ -474,6 +554,18 @@ def trial_errors_str(trials, fmt="psql", max_rows=None):
|
||||
return delim.join(messages)
|
||||
|
||||
|
||||
def best_trial_str(trial, metric, parameter_columns=None):
|
||||
"""Returns a readable message stating the current best trial."""
|
||||
val = trial.last_result[metric]
|
||||
config = trial.last_result.get("config", {})
|
||||
parameter_columns = parameter_columns or list(config.keys())
|
||||
if isinstance(parameter_columns, Mapping):
|
||||
parameter_columns = parameter_columns.keys()
|
||||
params = {p: config.get(p) for p in parameter_columns}
|
||||
return f"Current best trial: {trial.trial_id} with {metric}={val} and " \
|
||||
f"parameters={params}"
|
||||
|
||||
|
||||
def _fair_filter_trials(trials_by_state, max_trials):
|
||||
"""Filters trials such that each state is represented fairly.
|
||||
|
||||
|
||||
@@ -8,15 +8,15 @@ from ray.test_utils import run_string_as_driver
|
||||
from ray.tune.trial import Trial
|
||||
from ray.tune.result import AUTO_RESULT_KEYS
|
||||
from ray.tune.progress_reporter import (CLIReporter, _fair_filter_trials,
|
||||
trial_progress_str)
|
||||
best_trial_str, trial_progress_str)
|
||||
|
||||
EXPECTED_RESULT_1 = """Result logdir: /foo
|
||||
Number of trials: 5 (1 PENDING, 3 RUNNING, 1 TERMINATED)
|
||||
+--------------+------------+-------+-----+-----+------------+
|
||||
| Trial name | status | loc | a | b | metric_1 |
|
||||
|--------------+------------+-------+-----+-----+------------|
|
||||
| 00001 | PENDING | here | 1 | 2 | 0.5 |
|
||||
| 00002 | RUNNING | here | 2 | 4 | 1 |
|
||||
| 00001 | PENDING | here | 1 | 2 | 0.5 |
|
||||
| 00000 | TERMINATED | here | 0 | 0 | 0 |
|
||||
+--------------+------------+-------+-----+-----+------------+
|
||||
... 2 more trials not shown (2 RUNNING)"""
|
||||
@@ -26,11 +26,11 @@ Number of trials: 5 (1 PENDING, 3 RUNNING, 1 TERMINATED)
|
||||
+--------------+------------+-------+-----+-----+---------+---------+
|
||||
| Trial name | status | loc | a | b | n/k/0 | n/k/1 |
|
||||
|--------------+------------+-------+-----+-----+---------+---------|
|
||||
| 00000 | TERMINATED | here | 0 | 0 | 0 | 0 |
|
||||
| 00001 | PENDING | here | 1 | 2 | 1 | 2 |
|
||||
| 00002 | RUNNING | here | 2 | 4 | 2 | 4 |
|
||||
| 00003 | RUNNING | here | 3 | 6 | 3 | 6 |
|
||||
| 00004 | RUNNING | here | 4 | 8 | 4 | 8 |
|
||||
| 00001 | PENDING | here | 1 | 2 | 1 | 2 |
|
||||
| 00000 | TERMINATED | here | 0 | 0 | 0 | 0 |
|
||||
+--------------+------------+-------+-----+-----+---------+---------+"""
|
||||
|
||||
EXPECTED_RESULT_3 = """Result logdir: /foo
|
||||
@@ -38,8 +38,8 @@ Number of trials: 5 (1 PENDING, 3 RUNNING, 1 TERMINATED)
|
||||
+--------------+------------+-------+-----+------------+------------+
|
||||
| Trial name | status | loc | A | Metric 1 | Metric 2 |
|
||||
|--------------+------------+-------+-----+------------+------------|
|
||||
| 00001 | PENDING | here | 1 | 0.5 | 0.25 |
|
||||
| 00002 | RUNNING | here | 2 | 1 | 0.5 |
|
||||
| 00001 | PENDING | here | 1 | 0.5 | 0.25 |
|
||||
| 00000 | TERMINATED | here | 0 | 0 | 0 |
|
||||
+--------------+------------+-------+-----+------------+------------+
|
||||
... 2 more trials not shown (2 RUNNING)"""
|
||||
@@ -154,6 +154,12 @@ EXPECTED_END_TO_END_AC = """Number of trials: 30/30 (30 TERMINATED)
|
||||
| f_xxxxx_00029 | TERMINATED | | | | 9 |
|
||||
+---------------+------------+-------+-----+-----+-----+"""
|
||||
|
||||
EXPECTED_BEST_1 = "Current best trial: 00001 with metric_1=0.5 and " \
|
||||
"parameters={'a': 1, 'b': 2, 'n': {'k': [1, 2]}}"
|
||||
|
||||
EXPECTED_BEST_2 = "Current best trial: 00004 with metric_1=2.0 and " \
|
||||
"parameters={'a': 4}"
|
||||
|
||||
|
||||
class ProgressReporterTest(unittest.TestCase):
|
||||
def mock_trial(self, status, i):
|
||||
@@ -305,8 +311,43 @@ class ProgressReporterTest(unittest.TestCase):
|
||||
}, {"a": "A"},
|
||||
fmt="psql",
|
||||
max_rows=3)
|
||||
print(prog3)
|
||||
assert prog3 == EXPECTED_RESULT_3
|
||||
|
||||
# Current best trial
|
||||
best1 = best_trial_str(trials[1], "metric_1")
|
||||
assert best1 == EXPECTED_BEST_1
|
||||
|
||||
def testCurrentBestTrial(self):
|
||||
trials = []
|
||||
for i in range(5):
|
||||
t = Mock()
|
||||
t.status = "RUNNING"
|
||||
t.trial_id = "%05d" % i
|
||||
t.local_dir = "/foo"
|
||||
t.location = "here"
|
||||
t.config = {"a": i, "b": i * 2, "n": {"k": [i, 2 * i]}}
|
||||
t.evaluated_params = {"a": i}
|
||||
t.last_result = {"config": {"a": i}, "metric_1": i / 2}
|
||||
t.__str__ = lambda self: self.trial_id
|
||||
trials.append(t)
|
||||
|
||||
class TestReporter(CLIReporter):
|
||||
_output = []
|
||||
|
||||
def __init__(self, *args, **kwargs):
|
||||
super().__init__(*args, **kwargs)
|
||||
self._max_report_freqency = 0
|
||||
|
||||
def report(self, *args, **kwargs):
|
||||
progress_str = self._progress_str(*args, **kwargs)
|
||||
self._output.append(progress_str)
|
||||
|
||||
reporter = TestReporter(mode="max")
|
||||
reporter.report(trials, done=False)
|
||||
|
||||
assert EXPECTED_BEST_2 in reporter._output[0]
|
||||
|
||||
def testEndToEndReporting(self):
|
||||
try:
|
||||
os.environ["_TEST_TUNE_TRIAL_UUID"] = "xxxxx"
|
||||
|
||||
@@ -388,6 +388,12 @@ def run(
|
||||
else:
|
||||
progress_reporter = CLIReporter()
|
||||
|
||||
if not progress_reporter.set_search_properties(metric, mode):
|
||||
raise ValueError(
|
||||
"You passed a `metric` or `mode` argument to `tune.run()`, but "
|
||||
"the reporter you are using was already instantiated with their "
|
||||
"own `metric` and `mode` parameters. Either remove the arguments "
|
||||
"from your reporter or from your call to `tune.run()`")
|
||||
progress_reporter.set_total_samples(search_alg.total_samples)
|
||||
|
||||
# User Warning for GPUs
|
||||
|
||||
Reference in New Issue
Block a user