From ed9de8b2fa269d2866c83e1e04bcdb0019d7d625 Mon Sep 17 00:00:00 2001 From: Ujval Misra Date: Sat, 25 Jan 2020 12:28:05 -0800 Subject: [PATCH] [tune] Expose progress reporter to users (#6915) * Pluggable progress reporter * Fix types * Fix bug, address comments * lint * Add convenience function and test * lint * Use trials instead of trial_runner * Add docs * Update docs * Fix doc examples * More doc updates * Address comments, add configurable frequency * use reward --- doc/source/tune-usage.rst | 74 +++++ python/ray/tune/__init__.py | 5 + python/ray/tune/progress_reporter.py | 268 ++++++++++++++---- .../ray/tune/tests/test_progress_reporter.py | 21 +- python/ray/tune/tune.py | 44 ++- 5 files changed, 339 insertions(+), 73 deletions(-) diff --git a/doc/source/tune-usage.rst b/doc/source/tune-usage.rst index 03c06a398..516435749 100644 --- a/doc/source/tune-usage.rst +++ b/doc/source/tune-usage.rst @@ -708,6 +708,80 @@ By default, Tune will run hyperparameter evaluations on multiple processes. Howe Note that some behavior such as writing to files by depending on the current working directory in a Trainable and setting global process variables may not work as expected. Local mode with multiple configuration evaluations will interleave computation, so it is most naturally used when running a single configuration evaluation. +CLI Progress Reporting +---------------------- + +By default, Tune reports experiment progress periodically to the command-line as follows. + +.. code-block:: bash + + == Status == + Memory usage on this node: 11.4/16.0 GiB + Using FIFO scheduling algorithm. + Resources requested: 4/12 CPUs, 0/0 GPUs, 0.0/3.17 GiB heap, 0.0/1.07 GiB objects + Result logdir: /Users/foo/ray_results/myexp + Number of trials: 4 (4 RUNNING) + +----------------------+----------+---------------------+-----------+--------+--------+--------+--------+------------------+-------+ + | Trial name | status | loc | param1 | param2 | param3 | acc | loss | total time (s) | iter | + |----------------------+----------+---------------------+-----------+--------+--------+--------+--------+------------------+-------| + | MyTrainable_a826033a | RUNNING | 10.234.98.164:31115 | 0.303706 | 0.0761 | 0.4328 | 0.1289 | 1.8572 | 7.54952 | 15 | + | MyTrainable_a8263fc6 | RUNNING | 10.234.98.164:31117 | 0.929276 | 0.158 | 0.3417 | 0.4865 | 1.6307 | 7.0501 | 14 | + | MyTrainable_a8267914 | RUNNING | 10.234.98.164:31111 | 0.068426 | 0.0319 | 0.1147 | 0.9585 | 1.9603 | 7.0477 | 14 | + | MyTrainable_a826b7bc | RUNNING | 10.234.98.164:31112 | 0.729127 | 0.0748 | 0.1784 | 0.1797 | 1.7161 | 7.05715 | 14 | + +----------------------+----------+---------------------+-----------+--------+--------+--------+--------+------------------+-------+ + +Note that columns will be hidden if they are completely empty. The output can be configured in various ways by instantiating a ``CLIReporter`` instance (or ``JupyterNotebookReporter`` if you're using jupyter notebook). Here's an example: + +.. code-block:: python + + from ray.tune import CLIReporter + + # Limit the number of rows. + reporter = CLIReporter(max_progress_rows=10) + # Add a custom metric column, in addition to the default metrics. + # Note that this must be a metric that is returned in your training results. + reporter.add_metric_column("custom_metric") + tune.run(my_trainable, progress_reporter=reporter) + +Extending ``CLIReporter`` lets you control reporting frequency. For example: + +.. code-block:: python + + class ExperimentTerminationReporter(CLIReporter): + def should_report(self, trials, done=False): + """Reports only on experiment termination.""" + return done + + tune.run(my_trainable, progress_reporter=ExperimentTerminationReporter()) + + class TrialTerminationReporter(CLIReporter): + def __init__(self): + self.num_terminated = 0 + + def should_report(self, trials, done=False): + """Reports only on trial termination events.""" + old_num_terminated = self.num_terminated + self.num_terminated = len([t for t in trials if t.status == Trial.TERMINATED]) + return self.num_terminated > old_num_terminated + + tune.run(my_trainable, progress_reporter=TrialTerminationReporter()) + +The default reporting style can also be overriden more broadly by extending the ``ProgressReporter`` interface directly. Note that you can print to any output stream, file etc. + +.. code-block:: python + + from ray.tune import ProgressReporter + + class CustomReporter(ProgressReporter): + + def should_report(self, trials, done=False): + return True + + def report(self, trials, *sys_info): + print(*sys_info) + print("\n".join([str(trial) for trial in trials])) + + tune.run(my_trainable, progress_reporter=CustomReporter()) Tune CLI (Experimental) ----------------------- diff --git a/python/ray/tune/__init__.py b/python/ray/tune/__init__.py index 4c6bfd36c..132a6edd7 100644 --- a/python/ray/tune/__init__.py +++ b/python/ray/tune/__init__.py @@ -6,6 +6,8 @@ from ray.tune.registry import register_env, register_trainable from ray.tune.trainable import Trainable from ray.tune.durable_trainable import DurableTrainable from ray.tune.suggest import grid_search +from ray.tune.progress_reporter import (ProgressReporter, CLIReporter, + JupyterNotebookReporter) from ray.tune.sample import (function, sample_from, uniform, choice, randint, randn, loguniform) @@ -29,4 +31,7 @@ __all__ = [ "loguniform", "ExperimentAnalysis", "Analysis", + "CLIReporter", + "JupyterNotebookReporter", + "ProgressReporter", ] diff --git a/python/ray/tune/progress_reporter.py b/python/ray/tune/progress_reporter.py index 2a5e2d04c..c88676b56 100644 --- a/python/ray/tune/progress_reporter.py +++ b/python/ray/tune/progress_reporter.py @@ -1,10 +1,11 @@ from __future__ import print_function import collections +import time -from ray.tune.result import (DEFAULT_RESULT_KEYS, CONFIG_PREFIX, - EPISODE_REWARD_MEAN, MEAN_ACCURACY, MEAN_LOSS, - TRAINING_ITERATION, TIME_TOTAL_S, TIMESTEPS_TOTAL) +from ray.tune.result import (CONFIG_PREFIX, EPISODE_REWARD_MEAN, MEAN_ACCURACY, + MEAN_LOSS, TRAINING_ITERATION, TIME_TOTAL_S, + TIMESTEPS_TOTAL) from ray.tune.utils import flatten_dict try: @@ -14,67 +15,202 @@ except ImportError: "Please re-run 'pip install ray[tune]' or " "'pip install ray[rllib]'.") -DEFAULT_PROGRESS_KEYS = DEFAULT_RESULT_KEYS + (EPISODE_REWARD_MEAN, ) -# Truncated representations of column names (to accommodate small screens). -REPORTED_REPRESENTATIONS = { - EPISODE_REWARD_MEAN: "reward", - MEAN_ACCURACY: "acc", - MEAN_LOSS: "loss", - TIME_TOTAL_S: "total time (s)", - TIMESTEPS_TOTAL: "timesteps", - TRAINING_ITERATION: "iter", -} - class ProgressReporter: - # TODO(ujvl): Expose ProgressReporter in tune.run for custom reporting. + """Abstract class for experiment progress reporting. - def report(self, trial_runner): - """Reports progress across all trials of the trial runner. + `should_report()` is called to determine whether or not `report()` should + be called. Tune will call these functions after trial state transitions, + receiving training results, and so on. + """ + + def should_report(self, trials, done=False): + """Returns whether or not progress should be reported. Args: - trial_runner: Trial runner to report on. + trials (list[Trial]): Trials to report on. + done (bool): Whether this is the last progress report attempt. + """ + raise NotImplementedError + + def report(self, trials, *sys_info): + """Reports progress across trials. + + Args: + trials (list[Trial]): Trials to report on. + sys_info: System info. """ raise NotImplementedError -class JupyterNotebookReporter(ProgressReporter): - def __init__(self, overwrite): +class TuneReporterBase(ProgressReporter): + """Abstract base class for the default Tune reporters.""" + + # Truncated representations of column names (to accommodate small screens). + DEFAULT_COLUMNS = { + EPISODE_REWARD_MEAN: "reward", + MEAN_ACCURACY: "acc", + MEAN_LOSS: "loss", + TIME_TOTAL_S: "total time (s)", + TIMESTEPS_TOTAL: "ts", + TRAINING_ITERATION: "iter", + } + + def __init__(self, + metric_columns=None, + max_progress_rows=20, + max_error_rows=20, + max_report_frequency=5): + """Initializes a new TuneReporterBase. + + Args: + metric_columns (dict[str, str]|list[str]): Names of metrics to + include in progress table. If this is a dict, the keys should + be metric names and the values should be the displayed names. + If this is a list, the metric name is used directly. + max_progress_rows (int): Maximum number of rows to print + in the progress table. The progress table describes the + progress of each trial. Defaults to 20. + max_error_rows (int): Maximum number of rows to print in the + error table. The error table lists the error file, if any, + corresponding to each trial. Defaults to 20. + max_report_frequency (int): Maximum report frequency in seconds. + Defaults to 5s. + """ + self._metric_columns = metric_columns or self.DEFAULT_COLUMNS + self._max_progress_rows = max_progress_rows + self._max_error_rows = max_error_rows + + self._max_report_freqency = max_report_frequency + self._last_report_time = 0 + + def should_report(self, trials, done=False): + if time.time() - self._last_report_time > self._max_report_freqency: + self._last_report_time = time.time() + return True + return done + + def add_metric_column(self, metric, representation=None): + """Adds a metric to the existing columns. + + Args: + metric (str): Metric to add. This must be a metric being returned + in training step results. + representation (str): Representation to use in table. Defaults to + `metric`. + """ + if metric in self._metric_columns: + raise ValueError("Column {} already exists.".format(metric)) + + if isinstance(self._metric_columns, collections.Mapping): + representation = representation or metric + self._metric_columns[metric] = representation + else: + if representation is not None and representation != metric: + raise ValueError( + "`representation` cannot differ from `metric` " + "if this reporter was initialized with a list " + "of metric columns.") + self._metric_columns.append(metric) + + def _progress_str(self, trials, *sys_info, fmt="psql", delim="\n"): + """Returns full progress string. + + This string contains a progress table and error table. The progress + table describes the progress of each trial. The error table lists + the error file, if any, corresponding to each trial. The latter only + exists if errors have occurred. + + Args: + trials (list[Trial]): Trials to report on. + fmt (str): Table format. See `tablefmt` in tabulate API. + delim (str): Delimiter between messages. + """ + messages = ["== Status ==", memory_debug_str(), *sys_info] + if self._max_progress_rows > 0: + messages.append( + trial_progress_str( + trials, + metric_columns=self._metric_columns, + fmt=fmt, + max_rows=self._max_progress_rows)) + if self._max_error_rows > 0: + messages.append( + trial_errors_str( + trials, fmt=fmt, max_rows=self._max_error_rows)) + return delim.join(messages) + delim + + +class JupyterNotebookReporter(TuneReporterBase): + """Jupyter notebook-friendly Reporter that can update display in-place.""" + + def __init__(self, + overwrite, + metric_columns=None, + max_progress_rows=20, + max_error_rows=20, + max_report_frequency=5): """Initializes a new JupyterNotebookReporter. Args: overwrite (bool): Flag for overwriting the last reported progress. + metric_columns (dict[str, str]|list[str]): Names of metrics to + include in progress table. If this is a dict, the keys should + be metric names and the values should be the displayed names. + If this is a list, the metric name is used directly. + max_progress_rows (int): Maximum number of rows to print + in the progress table. The progress table describes the + progress of each trial. Defaults to 20. + max_error_rows (int): Maximum number of rows to print in the + error table. The error table lists the error file, if any, + corresponding to each trial. Defaults to 20. + max_report_frequency (int): Maximum report frequency in seconds. + Defaults to 5s. """ - self.overwrite = overwrite + super(JupyterNotebookReporter, + self).__init__(metric_columns, max_progress_rows, max_error_rows, + max_report_frequency) + self._overwrite = overwrite - def report(self, trial_runner): - delim = "
" - messages = [ - "== Status ==", - memory_debug_str(), - trial_runner.scheduler_alg.debug_string(), - trial_runner.trial_executor.debug_string(), - trial_progress_str(trial_runner.get_trials(), fmt="html"), - trial_errors_str(trial_runner.get_trials(), fmt="html"), - ] + def report(self, trials, *sys_info): from IPython.display import clear_output from IPython.core.display import display, HTML - if self.overwrite: + if self._overwrite: clear_output(wait=True) - display(HTML(delim.join(messages) + delim)) + progress_str = self._progress_str( + trials, *sys_info, fmt="html", delim="
") + display(HTML(progress_str)) -class CLIReporter(ProgressReporter): - def report(self, trial_runner): - messages = [ - "== Status ==", - memory_debug_str(), - trial_runner.scheduler_alg.debug_string(), - trial_runner.trial_executor.debug_string(), - trial_progress_str(trial_runner.get_trials()), - trial_errors_str(trial_runner.get_trials()), - ] - print("\n".join(messages) + "\n") +class CLIReporter(TuneReporterBase): + """Command-line reporter""" + + def __init__(self, + metric_columns=None, + max_progress_rows=20, + max_error_rows=20, + max_report_frequency=5): + """Initializes a CLIReporter. + + Args: + metric_columns (dict[str, str]|list[str]): Names of metrics to + include in progress table. If this is a dict, the keys should + be metric names and the values should be the displayed names. + If this is a list, the metric name is used directly. + max_progress_rows (int): Maximum number of rows to print + in the progress table. The progress table describes the + progress of each trial. Defaults to 20. + max_error_rows (int): Maximum number of rows to print in the + error table. The error table lists the error file, if any, + corresponding to each trial. Defaults to 20. + max_report_frequency (int): Maximum report frequency in seconds. + Defaults to 5s. + """ + super(CLIReporter, self).__init__(metric_columns, max_progress_rows, + max_error_rows, max_report_frequency) + + def report(self, trials, *sys_info): + print(self._progress_str(trials, *sys_info)) def memory_debug_str(): @@ -98,18 +234,21 @@ def memory_debug_str(): "(or ray[debug]) to resolve)") -def trial_progress_str(trials, metrics=None, fmt="psql", max_rows=20): +def trial_progress_str(trials, metric_columns, fmt="psql", max_rows=None): """Returns a human readable message for printing to the console. This contains a table where each row represents a trial, its parameters and the current values of its metrics. Args: - trials (List[Trial]): List of trials to get progress string for. - metrics (List[str]): Names of metrics to include. Defaults to - metrics defined in DEFAULT_RESULT_KEYS. + trials (list[Trial]): List of trials to get progress string for. + metric_columns (dict[str, str]|list[str]): Names of metrics to include. + If this is a dict, the keys are metric names and the values are + the names to use in the message. If this is a list, the metric + name is used in the message directly. fmt (str): Output format (see tablefmt in tabulate API). - max_rows (int): Maximum number of rows in the trial table. + max_rows (int): Maximum number of rows in the trial table. Defaults to + unlimited. """ messages = [] delim = "
" if fmt == "html" else "\n" @@ -131,6 +270,7 @@ def trial_progress_str(trials, metrics=None, fmt="psql", max_rows=20): messages.append("Number of trials: {} ({})".format( num_trials, ", ".join(num_trials_strs))) + max_rows = max_rows or float("inf") if num_trials > max_rows: # TODO(ujvl): suggestion for users to view more rows. trials_by_state_trunc = _fair_filter_trials(trials_by_state, max_rows) @@ -148,33 +288,41 @@ def trial_progress_str(trials, metrics=None, fmt="psql", max_rows=20): "shown.".format(max_rows, overflow, overflow_str)) # Pre-process trials to figure out what columns to show. - keys = list(metrics or DEFAULT_PROGRESS_KEYS) + if isinstance(metric_columns, collections.Mapping): + keys = list(metric_columns.keys()) + else: + keys = metric_columns keys = [k for k in keys if any(t.last_result.get(k) for t in trials)] # Build trial rows. params = list(set().union(*[t.evaluated_params for t in trials])) trial_table = [_get_trial_info(trial, params, keys) for trial in trials] - # Parse columns. - parsed_columns = [REPORTED_REPRESENTATIONS.get(k, k) for k in keys] - columns = ["Trial name", "status", "loc"] - columns += params + parsed_columns + # Format column headings + if isinstance(metric_columns, collections.Mapping): + formatted_columns = [metric_columns[k] for k in keys] + else: + formatted_columns = keys + columns = ["Trial name", "status", "loc"] + params + formatted_columns + # Tabulate. messages.append( tabulate(trial_table, headers=columns, tablefmt=fmt, showindex=False)) return delim.join(messages) -def trial_errors_str(trials, fmt="psql", max_rows=20): +def trial_errors_str(trials, fmt="psql", max_rows=None): """Returns a readable message regarding trial errors. Args: - trials (List[Trial]): List of trials to get progress string for. + trials (list[Trial]): List of trials to get progress string for. fmt (str): Output format (see tablefmt in tabulate API). - max_rows (int): Maximum number of rows in the error table. + max_rows (int): Maximum number of rows in the error table. Defaults to + unlimited. """ messages = [] failed = [t for t in trials if t.error_file] num_failed = len(failed) if num_failed > 0: messages.append("Number of errored trials: {}".format(num_failed)) + max_rows = max_rows or float("inf") if num_failed > max_rows: messages.append("Table truncated to {} rows ({} overflow)".format( max_rows, num_failed - max_rows)) @@ -196,7 +344,7 @@ def _fair_filter_trials(trials_by_state, max_trials): The oldest trials are truncated if necessary. Args: - trials_by_state (Dict[str, List[Trial]]: Trials by state. + trials_by_state (dict[str, list[Trial]]: Trials by state. max_trials (int): Maximum number of trials to return. Returns: Dict mapping state to List of fairly represented trials. @@ -234,8 +382,8 @@ def _get_trial_info(trial, parameters, metrics): Args: trial (Trial): Trial to get information for. - parameters (List[str]): Names of trial parameters to include. - metrics (List[str]): Names of metrics to include. + parameters (list[str]): Names of trial parameters to include. + metrics (list[str]): Names of metrics to include. """ result = flatten_dict(trial.last_result) trial_info = [str(trial), trial.status, str(trial.location)] diff --git a/python/ray/tune/tests/test_progress_reporter.py b/python/ray/tune/tests/test_progress_reporter.py index fd3031599..95ff33cd5 100644 --- a/python/ray/tune/tests/test_progress_reporter.py +++ b/python/ray/tune/tests/test_progress_reporter.py @@ -4,7 +4,7 @@ import unittest from unittest.mock import MagicMock from ray.tune.trial import Trial -from ray.tune.progress_reporter import _fair_filter_trials +from ray.tune.progress_reporter import CLIReporter, _fair_filter_trials class ProgressReporterTest(unittest.TestCase): @@ -48,3 +48,22 @@ class ProgressReporterTest(unittest.TestCase): for i in range(len(state_trials) - 1): self.assertGreaterEqual(state_trials[i].start_time, state_trials[i + 1].start_time) + + def testAddMetricColumn(self): + """Tests edge cases of add_metric_column.""" + + # Test list-initialized metric columns. + reporter = CLIReporter(metric_columns=["foo", "bar"]) + with self.assertRaises(ValueError): + reporter.add_metric_column("bar") + + with self.assertRaises(ValueError): + reporter.add_metric_column("baz", "qux") + + reporter.add_metric_column("baz") + self.assertIn("baz", reporter._metric_columns) + + # Test default-initialized (dict) metric columns. + reporter = CLIReporter() + reporter.add_metric_column("foo", "bar") + self.assertIn("foo", reporter._metric_columns) diff --git a/python/ray/tune/tune.py b/python/ray/tune/tune.py index 89a5451cd..202c6eb41 100644 --- a/python/ray/tune/tune.py +++ b/python/ray/tune/tune.py @@ -1,12 +1,11 @@ import logging -import time import six from ray.tune.error import TuneError from ray.tune.experiment import convert_to_experiment_list, Experiment from ray.tune.analysis import ExperimentAnalysis from ray.tune.suggest import BasicVariantGenerator -from ray.tune.trial import Trial, DEBUG_PRINT_INTERVAL +from ray.tune.trial import Trial from ray.tune.trainable import Trainable from ray.tune.ray_trial_executor import RayTrialExecutor from ray.tune.registry import get_trainable_cls @@ -51,6 +50,21 @@ def _check_default_resources_override(run_identifier): Trainable.default_resource_request.__code__) +def _report_progress(runner, reporter, done=False): + """Reports experiment progress. + + Args: + runner (TrialRunner): Trial runner to report on. + reporter (ProgressReporter): Progress reporter. + done (bool): Whether this is the last progress report attempt. + """ + trials = runner.get_trials() + if reporter.should_report(trials, done=done): + sched_debug_str = runner.scheduler_alg.debug_string() + executor_debug_str = runner.trial_executor.debug_string() + reporter.report(trials, sched_debug_str, executor_debug_str) + + def run(run_or_experiment, name=None, stop=None, @@ -77,6 +91,7 @@ def run(run_or_experiment, with_server=False, server_port=TuneServer.DEFAULT_PORT, verbose=2, + progress_reporter=None, resume=False, queue_trials=False, reuse_actors=False, @@ -169,6 +184,10 @@ def run(run_or_experiment, server_port (int): Port number for launching TuneServer. verbose (int): 0, 1, or 2. Verbosity mode. 0 = silent, 1 = only status updates, 2 = status and trial results. + progress_reporter (ProgressReporter): Progress reporter for reporting + intermediate experiment progress. Defaults to CLIReporter if + running in command-line, or JupyterNotebookReporter if running in + a Jupyter notebook. resume (str|bool): One of "LOCAL", "REMOTE", "PROMPT", or bool. LOCAL/True restores the checkpoint from the local_checkpoint_dir. REMOTE restores the checkpoint from remote_checkpoint_dir. @@ -272,10 +291,11 @@ def run(run_or_experiment, for exp in experiments: runner.add_experiment(exp) - if IS_NOTEBOOK: - reporter = JupyterNotebookReporter(overwrite=verbose < 2) - else: - reporter = CLIReporter() + if progress_reporter is None: + if IS_NOTEBOOK: + progress_reporter = JupyterNotebookReporter(overwrite=verbose < 2) + else: + progress_reporter = CLIReporter() # User Warning for GPUs if trial_executor.has_gpus(): @@ -295,13 +315,10 @@ def run(run_or_experiment, "`Trainable.default_resource_request` if using the " "Trainable API.") - last_debug = 0 while not runner.is_finished(): runner.step() - if time.time() - last_debug > DEBUG_PRINT_INTERVAL: - if verbose: - reporter.report(runner) - last_debug = time.time() + if verbose: + _report_progress(runner, progress_reporter) try: runner.checkpoint(force=True) @@ -309,7 +326,7 @@ def run(run_or_experiment, logger.exception("Trial Runner checkpointing failed.") if verbose: - reporter.report(runner) + _report_progress(runner, progress_reporter, done=True) wait_for_sync() @@ -339,6 +356,7 @@ def run_experiments(experiments, with_server=False, server_port=TuneServer.DEFAULT_PORT, verbose=2, + progress_reporter=None, resume=False, queue_trials=False, reuse_actors=False, @@ -380,6 +398,7 @@ def run_experiments(experiments, with_server=with_server, server_port=server_port, verbose=verbose, + progress_reporter=progress_reporter, resume=resume, queue_trials=queue_trials, reuse_actors=reuse_actors, @@ -396,6 +415,7 @@ def run_experiments(experiments, with_server=with_server, server_port=server_port, verbose=verbose, + progress_reporter=progress_reporter, resume=resume, queue_trials=queue_trials, reuse_actors=reuse_actors,