[tune] Expose progress reporter to users (#6915)

* Pluggable progress reporter

* Fix types

* Fix bug, address comments

* lint

* Add convenience function and test

* lint

* Use trials instead of trial_runner

* Add docs

* Update docs

* Fix doc examples

* More doc updates

* Address comments, add configurable frequency

* use reward
This commit is contained in:
Ujval Misra
2020-01-25 12:28:05 -08:00
committed by Richard Liaw
parent 2e88e2e773
commit ed9de8b2fa
5 changed files with 339 additions and 73 deletions
+5
View File
@@ -6,6 +6,8 @@ from ray.tune.registry import register_env, register_trainable
from ray.tune.trainable import Trainable
from ray.tune.durable_trainable import DurableTrainable
from ray.tune.suggest import grid_search
from ray.tune.progress_reporter import (ProgressReporter, CLIReporter,
JupyterNotebookReporter)
from ray.tune.sample import (function, sample_from, uniform, choice, randint,
randn, loguniform)
@@ -29,4 +31,7 @@ __all__ = [
"loguniform",
"ExperimentAnalysis",
"Analysis",
"CLIReporter",
"JupyterNotebookReporter",
"ProgressReporter",
]
+208 -60
View File
@@ -1,10 +1,11 @@
from __future__ import print_function
import collections
import time
from ray.tune.result import (DEFAULT_RESULT_KEYS, CONFIG_PREFIX,
EPISODE_REWARD_MEAN, MEAN_ACCURACY, MEAN_LOSS,
TRAINING_ITERATION, TIME_TOTAL_S, TIMESTEPS_TOTAL)
from ray.tune.result import (CONFIG_PREFIX, EPISODE_REWARD_MEAN, MEAN_ACCURACY,
MEAN_LOSS, TRAINING_ITERATION, TIME_TOTAL_S,
TIMESTEPS_TOTAL)
from ray.tune.utils import flatten_dict
try:
@@ -14,67 +15,202 @@ except ImportError:
"Please re-run 'pip install ray[tune]' or "
"'pip install ray[rllib]'.")
DEFAULT_PROGRESS_KEYS = DEFAULT_RESULT_KEYS + (EPISODE_REWARD_MEAN, )
# Truncated representations of column names (to accommodate small screens).
REPORTED_REPRESENTATIONS = {
EPISODE_REWARD_MEAN: "reward",
MEAN_ACCURACY: "acc",
MEAN_LOSS: "loss",
TIME_TOTAL_S: "total time (s)",
TIMESTEPS_TOTAL: "timesteps",
TRAINING_ITERATION: "iter",
}
class ProgressReporter:
# TODO(ujvl): Expose ProgressReporter in tune.run for custom reporting.
"""Abstract class for experiment progress reporting.
def report(self, trial_runner):
"""Reports progress across all trials of the trial runner.
`should_report()` is called to determine whether or not `report()` should
be called. Tune will call these functions after trial state transitions,
receiving training results, and so on.
"""
def should_report(self, trials, done=False):
"""Returns whether or not progress should be reported.
Args:
trial_runner: Trial runner to report on.
trials (list[Trial]): Trials to report on.
done (bool): Whether this is the last progress report attempt.
"""
raise NotImplementedError
def report(self, trials, *sys_info):
"""Reports progress across trials.
Args:
trials (list[Trial]): Trials to report on.
sys_info: System info.
"""
raise NotImplementedError
class JupyterNotebookReporter(ProgressReporter):
def __init__(self, overwrite):
class TuneReporterBase(ProgressReporter):
"""Abstract base class for the default Tune reporters."""
# Truncated representations of column names (to accommodate small screens).
DEFAULT_COLUMNS = {
EPISODE_REWARD_MEAN: "reward",
MEAN_ACCURACY: "acc",
MEAN_LOSS: "loss",
TIME_TOTAL_S: "total time (s)",
TIMESTEPS_TOTAL: "ts",
TRAINING_ITERATION: "iter",
}
def __init__(self,
metric_columns=None,
max_progress_rows=20,
max_error_rows=20,
max_report_frequency=5):
"""Initializes a new TuneReporterBase.
Args:
metric_columns (dict[str, str]|list[str]): Names of metrics to
include in progress table. If this is a dict, the keys should
be metric names and the values should be the displayed names.
If this is a list, the metric name is used directly.
max_progress_rows (int): Maximum number of rows to print
in the progress table. The progress table describes the
progress of each trial. Defaults to 20.
max_error_rows (int): Maximum number of rows to print in the
error table. The error table lists the error file, if any,
corresponding to each trial. Defaults to 20.
max_report_frequency (int): Maximum report frequency in seconds.
Defaults to 5s.
"""
self._metric_columns = metric_columns or self.DEFAULT_COLUMNS
self._max_progress_rows = max_progress_rows
self._max_error_rows = max_error_rows
self._max_report_freqency = max_report_frequency
self._last_report_time = 0
def should_report(self, trials, done=False):
if time.time() - self._last_report_time > self._max_report_freqency:
self._last_report_time = time.time()
return True
return done
def add_metric_column(self, metric, representation=None):
"""Adds a metric to the existing columns.
Args:
metric (str): Metric to add. This must be a metric being returned
in training step results.
representation (str): Representation to use in table. Defaults to
`metric`.
"""
if metric in self._metric_columns:
raise ValueError("Column {} already exists.".format(metric))
if isinstance(self._metric_columns, collections.Mapping):
representation = representation or metric
self._metric_columns[metric] = representation
else:
if representation is not None and representation != metric:
raise ValueError(
"`representation` cannot differ from `metric` "
"if this reporter was initialized with a list "
"of metric columns.")
self._metric_columns.append(metric)
def _progress_str(self, trials, *sys_info, fmt="psql", delim="\n"):
"""Returns full progress string.
This string contains a progress table and error table. The progress
table describes the progress of each trial. The error table lists
the error file, if any, corresponding to each trial. The latter only
exists if errors have occurred.
Args:
trials (list[Trial]): Trials to report on.
fmt (str): Table format. See `tablefmt` in tabulate API.
delim (str): Delimiter between messages.
"""
messages = ["== Status ==", memory_debug_str(), *sys_info]
if self._max_progress_rows > 0:
messages.append(
trial_progress_str(
trials,
metric_columns=self._metric_columns,
fmt=fmt,
max_rows=self._max_progress_rows))
if self._max_error_rows > 0:
messages.append(
trial_errors_str(
trials, fmt=fmt, max_rows=self._max_error_rows))
return delim.join(messages) + delim
class JupyterNotebookReporter(TuneReporterBase):
"""Jupyter notebook-friendly Reporter that can update display in-place."""
def __init__(self,
overwrite,
metric_columns=None,
max_progress_rows=20,
max_error_rows=20,
max_report_frequency=5):
"""Initializes a new JupyterNotebookReporter.
Args:
overwrite (bool): Flag for overwriting the last reported progress.
metric_columns (dict[str, str]|list[str]): Names of metrics to
include in progress table. If this is a dict, the keys should
be metric names and the values should be the displayed names.
If this is a list, the metric name is used directly.
max_progress_rows (int): Maximum number of rows to print
in the progress table. The progress table describes the
progress of each trial. Defaults to 20.
max_error_rows (int): Maximum number of rows to print in the
error table. The error table lists the error file, if any,
corresponding to each trial. Defaults to 20.
max_report_frequency (int): Maximum report frequency in seconds.
Defaults to 5s.
"""
self.overwrite = overwrite
super(JupyterNotebookReporter,
self).__init__(metric_columns, max_progress_rows, max_error_rows,
max_report_frequency)
self._overwrite = overwrite
def report(self, trial_runner):
delim = "<br>"
messages = [
"== Status ==",
memory_debug_str(),
trial_runner.scheduler_alg.debug_string(),
trial_runner.trial_executor.debug_string(),
trial_progress_str(trial_runner.get_trials(), fmt="html"),
trial_errors_str(trial_runner.get_trials(), fmt="html"),
]
def report(self, trials, *sys_info):
from IPython.display import clear_output
from IPython.core.display import display, HTML
if self.overwrite:
if self._overwrite:
clear_output(wait=True)
display(HTML(delim.join(messages) + delim))
progress_str = self._progress_str(
trials, *sys_info, fmt="html", delim="<br>")
display(HTML(progress_str))
class CLIReporter(ProgressReporter):
def report(self, trial_runner):
messages = [
"== Status ==",
memory_debug_str(),
trial_runner.scheduler_alg.debug_string(),
trial_runner.trial_executor.debug_string(),
trial_progress_str(trial_runner.get_trials()),
trial_errors_str(trial_runner.get_trials()),
]
print("\n".join(messages) + "\n")
class CLIReporter(TuneReporterBase):
"""Command-line reporter"""
def __init__(self,
metric_columns=None,
max_progress_rows=20,
max_error_rows=20,
max_report_frequency=5):
"""Initializes a CLIReporter.
Args:
metric_columns (dict[str, str]|list[str]): Names of metrics to
include in progress table. If this is a dict, the keys should
be metric names and the values should be the displayed names.
If this is a list, the metric name is used directly.
max_progress_rows (int): Maximum number of rows to print
in the progress table. The progress table describes the
progress of each trial. Defaults to 20.
max_error_rows (int): Maximum number of rows to print in the
error table. The error table lists the error file, if any,
corresponding to each trial. Defaults to 20.
max_report_frequency (int): Maximum report frequency in seconds.
Defaults to 5s.
"""
super(CLIReporter, self).__init__(metric_columns, max_progress_rows,
max_error_rows, max_report_frequency)
def report(self, trials, *sys_info):
print(self._progress_str(trials, *sys_info))
def memory_debug_str():
@@ -98,18 +234,21 @@ def memory_debug_str():
"(or ray[debug]) to resolve)")
def trial_progress_str(trials, metrics=None, fmt="psql", max_rows=20):
def trial_progress_str(trials, metric_columns, fmt="psql", max_rows=None):
"""Returns a human readable message for printing to the console.
This contains a table where each row represents a trial, its parameters
and the current values of its metrics.
Args:
trials (List[Trial]): List of trials to get progress string for.
metrics (List[str]): Names of metrics to include. Defaults to
metrics defined in DEFAULT_RESULT_KEYS.
trials (list[Trial]): List of trials to get progress string for.
metric_columns (dict[str, str]|list[str]): Names of metrics to include.
If this is a dict, the keys are metric names and the values are
the names to use in the message. If this is a list, the metric
name is used in the message directly.
fmt (str): Output format (see tablefmt in tabulate API).
max_rows (int): Maximum number of rows in the trial table.
max_rows (int): Maximum number of rows in the trial table. Defaults to
unlimited.
"""
messages = []
delim = "<br>" if fmt == "html" else "\n"
@@ -131,6 +270,7 @@ def trial_progress_str(trials, metrics=None, fmt="psql", max_rows=20):
messages.append("Number of trials: {} ({})".format(
num_trials, ", ".join(num_trials_strs)))
max_rows = max_rows or float("inf")
if num_trials > max_rows:
# TODO(ujvl): suggestion for users to view more rows.
trials_by_state_trunc = _fair_filter_trials(trials_by_state, max_rows)
@@ -148,33 +288,41 @@ def trial_progress_str(trials, metrics=None, fmt="psql", max_rows=20):
"shown.".format(max_rows, overflow, overflow_str))
# Pre-process trials to figure out what columns to show.
keys = list(metrics or DEFAULT_PROGRESS_KEYS)
if isinstance(metric_columns, collections.Mapping):
keys = list(metric_columns.keys())
else:
keys = metric_columns
keys = [k for k in keys if any(t.last_result.get(k) for t in trials)]
# Build trial rows.
params = list(set().union(*[t.evaluated_params for t in trials]))
trial_table = [_get_trial_info(trial, params, keys) for trial in trials]
# Parse columns.
parsed_columns = [REPORTED_REPRESENTATIONS.get(k, k) for k in keys]
columns = ["Trial name", "status", "loc"]
columns += params + parsed_columns
# Format column headings
if isinstance(metric_columns, collections.Mapping):
formatted_columns = [metric_columns[k] for k in keys]
else:
formatted_columns = keys
columns = ["Trial name", "status", "loc"] + params + formatted_columns
# Tabulate.
messages.append(
tabulate(trial_table, headers=columns, tablefmt=fmt, showindex=False))
return delim.join(messages)
def trial_errors_str(trials, fmt="psql", max_rows=20):
def trial_errors_str(trials, fmt="psql", max_rows=None):
"""Returns a readable message regarding trial errors.
Args:
trials (List[Trial]): List of trials to get progress string for.
trials (list[Trial]): List of trials to get progress string for.
fmt (str): Output format (see tablefmt in tabulate API).
max_rows (int): Maximum number of rows in the error table.
max_rows (int): Maximum number of rows in the error table. Defaults to
unlimited.
"""
messages = []
failed = [t for t in trials if t.error_file]
num_failed = len(failed)
if num_failed > 0:
messages.append("Number of errored trials: {}".format(num_failed))
max_rows = max_rows or float("inf")
if num_failed > max_rows:
messages.append("Table truncated to {} rows ({} overflow)".format(
max_rows, num_failed - max_rows))
@@ -196,7 +344,7 @@ def _fair_filter_trials(trials_by_state, max_trials):
The oldest trials are truncated if necessary.
Args:
trials_by_state (Dict[str, List[Trial]]: Trials by state.
trials_by_state (dict[str, list[Trial]]: Trials by state.
max_trials (int): Maximum number of trials to return.
Returns:
Dict mapping state to List of fairly represented trials.
@@ -234,8 +382,8 @@ def _get_trial_info(trial, parameters, metrics):
Args:
trial (Trial): Trial to get information for.
parameters (List[str]): Names of trial parameters to include.
metrics (List[str]): Names of metrics to include.
parameters (list[str]): Names of trial parameters to include.
metrics (list[str]): Names of metrics to include.
"""
result = flatten_dict(trial.last_result)
trial_info = [str(trial), trial.status, str(trial.location)]
@@ -4,7 +4,7 @@ import unittest
from unittest.mock import MagicMock
from ray.tune.trial import Trial
from ray.tune.progress_reporter import _fair_filter_trials
from ray.tune.progress_reporter import CLIReporter, _fair_filter_trials
class ProgressReporterTest(unittest.TestCase):
@@ -48,3 +48,22 @@ class ProgressReporterTest(unittest.TestCase):
for i in range(len(state_trials) - 1):
self.assertGreaterEqual(state_trials[i].start_time,
state_trials[i + 1].start_time)
def testAddMetricColumn(self):
"""Tests edge cases of add_metric_column."""
# Test list-initialized metric columns.
reporter = CLIReporter(metric_columns=["foo", "bar"])
with self.assertRaises(ValueError):
reporter.add_metric_column("bar")
with self.assertRaises(ValueError):
reporter.add_metric_column("baz", "qux")
reporter.add_metric_column("baz")
self.assertIn("baz", reporter._metric_columns)
# Test default-initialized (dict) metric columns.
reporter = CLIReporter()
reporter.add_metric_column("foo", "bar")
self.assertIn("foo", reporter._metric_columns)
+32 -12
View File
@@ -1,12 +1,11 @@
import logging
import time
import six
from ray.tune.error import TuneError
from ray.tune.experiment import convert_to_experiment_list, Experiment
from ray.tune.analysis import ExperimentAnalysis
from ray.tune.suggest import BasicVariantGenerator
from ray.tune.trial import Trial, DEBUG_PRINT_INTERVAL
from ray.tune.trial import Trial
from ray.tune.trainable import Trainable
from ray.tune.ray_trial_executor import RayTrialExecutor
from ray.tune.registry import get_trainable_cls
@@ -51,6 +50,21 @@ def _check_default_resources_override(run_identifier):
Trainable.default_resource_request.__code__)
def _report_progress(runner, reporter, done=False):
"""Reports experiment progress.
Args:
runner (TrialRunner): Trial runner to report on.
reporter (ProgressReporter): Progress reporter.
done (bool): Whether this is the last progress report attempt.
"""
trials = runner.get_trials()
if reporter.should_report(trials, done=done):
sched_debug_str = runner.scheduler_alg.debug_string()
executor_debug_str = runner.trial_executor.debug_string()
reporter.report(trials, sched_debug_str, executor_debug_str)
def run(run_or_experiment,
name=None,
stop=None,
@@ -77,6 +91,7 @@ def run(run_or_experiment,
with_server=False,
server_port=TuneServer.DEFAULT_PORT,
verbose=2,
progress_reporter=None,
resume=False,
queue_trials=False,
reuse_actors=False,
@@ -169,6 +184,10 @@ def run(run_or_experiment,
server_port (int): Port number for launching TuneServer.
verbose (int): 0, 1, or 2. Verbosity mode. 0 = silent,
1 = only status updates, 2 = status and trial results.
progress_reporter (ProgressReporter): Progress reporter for reporting
intermediate experiment progress. Defaults to CLIReporter if
running in command-line, or JupyterNotebookReporter if running in
a Jupyter notebook.
resume (str|bool): One of "LOCAL", "REMOTE", "PROMPT", or bool.
LOCAL/True restores the checkpoint from the local_checkpoint_dir.
REMOTE restores the checkpoint from remote_checkpoint_dir.
@@ -272,10 +291,11 @@ def run(run_or_experiment,
for exp in experiments:
runner.add_experiment(exp)
if IS_NOTEBOOK:
reporter = JupyterNotebookReporter(overwrite=verbose < 2)
else:
reporter = CLIReporter()
if progress_reporter is None:
if IS_NOTEBOOK:
progress_reporter = JupyterNotebookReporter(overwrite=verbose < 2)
else:
progress_reporter = CLIReporter()
# User Warning for GPUs
if trial_executor.has_gpus():
@@ -295,13 +315,10 @@ def run(run_or_experiment,
"`Trainable.default_resource_request` if using the "
"Trainable API.")
last_debug = 0
while not runner.is_finished():
runner.step()
if time.time() - last_debug > DEBUG_PRINT_INTERVAL:
if verbose:
reporter.report(runner)
last_debug = time.time()
if verbose:
_report_progress(runner, progress_reporter)
try:
runner.checkpoint(force=True)
@@ -309,7 +326,7 @@ def run(run_or_experiment,
logger.exception("Trial Runner checkpointing failed.")
if verbose:
reporter.report(runner)
_report_progress(runner, progress_reporter, done=True)
wait_for_sync()
@@ -339,6 +356,7 @@ def run_experiments(experiments,
with_server=False,
server_port=TuneServer.DEFAULT_PORT,
verbose=2,
progress_reporter=None,
resume=False,
queue_trials=False,
reuse_actors=False,
@@ -380,6 +398,7 @@ def run_experiments(experiments,
with_server=with_server,
server_port=server_port,
verbose=verbose,
progress_reporter=progress_reporter,
resume=resume,
queue_trials=queue_trials,
reuse_actors=reuse_actors,
@@ -396,6 +415,7 @@ def run_experiments(experiments,
with_server=with_server,
server_port=server_port,
verbose=verbose,
progress_reporter=progress_reporter,
resume=resume,
queue_trials=queue_trials,
reuse_actors=reuse_actors,