diff --git a/python/ray/tune/BUILD b/python/ray/tune/BUILD index 16942a73f..c5fb221a0 100644 --- a/python/ray/tune/BUILD +++ b/python/ray/tune/BUILD @@ -72,6 +72,13 @@ py_test( tags = ["jenkins_only"], ) +py_test( + name = "test_progress_reporter", + size = "small", + srcs = ["tests/test_progress_reporter.py"], + deps = [":tune_lib"], +) + py_test( name = "test_ray_trial_executor", size = "medium", diff --git a/python/ray/tune/progress_reporter.py b/python/ray/tune/progress_reporter.py index 6496847a4..832473993 100644 --- a/python/ray/tune/progress_reporter.py +++ b/python/ray/tune/progress_reporter.py @@ -1,5 +1,7 @@ from __future__ import print_function +import collections + from ray.tune.result import (DEFAULT_RESULT_KEYS, CONFIG_PREFIX, EPISODE_REWARD_MEAN, MEAN_ACCURACY, MEAN_LOSS, TRAINING_ITERATION, TIME_TOTAL_S, TIMESTEPS_TOTAL) @@ -25,6 +27,8 @@ REPORTED_REPRESENTATIONS = { class ProgressReporter(object): + # TODO(ujvl): Expose ProgressReporter in tune.run for custom reporting. + def report(self, trial_runner): """Reports progress across all trials of the trial runner. @@ -49,7 +53,8 @@ class JupyterNotebookReporter(ProgressReporter): "== Status ==", memory_debug_str(), trial_runner.debug_string(delim=delim), - trial_progress_str(trial_runner.get_trials(), fmt="html") + trial_progress_str(trial_runner.get_trials(), fmt="html"), + trial_errors_str(trial_runner.get_trials(), fmt="html"), ] from IPython.display import clear_output from IPython.core.display import display, HTML @@ -64,7 +69,8 @@ class CLIReporter(ProgressReporter): "== Status ==", memory_debug_str(), trial_runner.debug_string(), - trial_progress_str(trial_runner.get_trials()) + trial_progress_str(trial_runner.get_trials()), + trial_errors_str(trial_runner.get_trials()), ] print("\n".join(messages) + "\n") @@ -90,7 +96,7 @@ def memory_debug_str(): "(or ray[debug]) to resolve)") -def trial_progress_str(trials, metrics=None, fmt="psql", max_rows=100): +def trial_progress_str(trials, metrics=None, fmt="psql", max_rows=20): """Returns a human readable message for printing to the console. This contains a table where each row represents a trial, its parameters @@ -109,52 +115,116 @@ def trial_progress_str(trials, metrics=None, fmt="psql", max_rows=100): return delim.join(messages) num_trials = len(trials) - trials_per_state = {} + trials_by_state = collections.defaultdict(list) for t in trials: - trials_per_state[t.status] = trials_per_state.get(t.status, 0) + 1 - messages.append("Number of trials: {} ({})".format(num_trials, - trials_per_state)) + trials_by_state[t.status].append(t) + for local_dir in sorted({t.local_dir for t in trials}): messages.append("Result logdir: {}".format(local_dir)) + num_trials_strs = [ + "{} {}".format(len(trials_by_state[state]), state) + for state in trials_by_state + ] + messages.append("Number of trials: {} ({})".format( + num_trials, ", ".join(num_trials_strs))) + if num_trials > max_rows: - overflow = num_trials - max_rows # TODO(ujvl): suggestion for users to view more rows. - messages.append("Table truncated to {} rows ({} overflow).".format( - max_rows, overflow)) + trials_by_state_trunc = _fair_filter_trials(trials_by_state, max_rows) + trials = [] + overflow_strs = [] + for state in trials_by_state: + trials += trials_by_state_trunc[state] + overflow = len(trials_by_state[state]) - len( + trials_by_state_trunc[state]) + overflow_strs.append("{} {}".format(overflow, state)) + # Build overflow string. + overflow = num_trials - max_rows + overflow_str = ", ".join(overflow_strs) + messages.append("Table truncated to {} rows. {} trials ({}) not " + "shown.".format(max_rows, overflow, overflow_str)) # Pre-process trials to figure out what columns to show. keys = list(metrics or DEFAULT_PROGRESS_KEYS) keys = [k for k in keys if any(t.last_result.get(k) for t in trials)] - # Build trial rows. - trial_table = [] params = list(set().union(*[t.evaluated_params for t in trials])) - for trial in trials[:min(num_trials, max_rows)]: - trial_table.append(_get_trial_info(trial, params, keys)) + trial_table = [_get_trial_info(trial, params, keys) for trial in trials] # Parse columns. parsed_columns = [REPORTED_REPRESENTATIONS.get(k, k) for k in keys] columns = ["Trial name", "status", "loc"] columns += params + parsed_columns messages.append( tabulate(trial_table, headers=columns, tablefmt=fmt, showindex=False)) + return delim.join(messages) - # Build trial error rows. + +def trial_errors_str(trials, fmt="psql", max_rows=20): + """Returns a readable message regarding trial errors. + + Args: + trials (List[Trial]): List of trials to get progress string for. + fmt (str): Output format (see tablefmt in tabulate API). + max_rows (int): Maximum number of rows in the error table. + """ + messages = [] failed = [t for t in trials if t.error_file] - if len(failed) > 0: - messages.append("Number of errored trials: {}".format(len(failed))) + num_failed = len(failed) + if num_failed > 0: + messages.append("Number of errored trials: {}".format(num_failed)) + if num_failed > max_rows: + messages.append("Table truncated to {} rows ({} overflow)".format( + max_rows, num_failed - max_rows)) error_table = [] - for trial in failed: + for trial in failed[:max_rows]: row = [str(trial), trial.num_failures, trial.error_file] error_table.append(row) columns = ["Trial name", "# failures", "error file"] messages.append( tabulate( error_table, headers=columns, tablefmt=fmt, showindex=False)) - + delim = "
" if fmt == "html" else "\n" return delim.join(messages) +def _fair_filter_trials(trials_by_state, max_trials): + """Filters trials such that each state is represented fairly. + + The oldest trials are truncated if necessary. + + Args: + trials_by_state (Dict[str, List[Trial]]: Trials by state. + max_trials (int): Maximum number of trials to return. + Returns: + Dict mapping state to List of fairly represented trials. + """ + num_trials_by_state = collections.defaultdict(int) + no_change = False + # Determine number of trials to keep per state. + while max_trials > 0 and not no_change: + no_change = True + for state in trials_by_state: + if num_trials_by_state[state] < len(trials_by_state[state]): + no_change = False + max_trials -= 1 + num_trials_by_state[state] += 1 + # Sort by start time, descending. + sorted_trials_by_state = { + state: sorted( + trials_by_state[state], + reverse=True, + key=lambda t: t.start_time if t.start_time else float("-inf")) + for state in trials_by_state + } + # Truncate oldest trials. + filtered_trials = { + state: sorted_trials_by_state[state][:num_trials_by_state[state]] + for state in trials_by_state + } + return filtered_trials + + def _get_trial_info(trial, parameters, metrics): """Returns the following information about a trial: diff --git a/python/ray/tune/tests/test_progress_reporter.py b/python/ray/tune/tests/test_progress_reporter.py new file mode 100644 index 000000000..c8407ae5f --- /dev/null +++ b/python/ray/tune/tests/test_progress_reporter.py @@ -0,0 +1,59 @@ +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +import collections +import sys +import time +import unittest + +from ray.tune.trial import Trial +from ray.tune.progress_reporter import _fair_filter_trials + +if sys.version_info >= (3, 3): + from unittest.mock import MagicMock +else: + from mock import MagicMock + + +class ProgressReporterTest(unittest.TestCase): + def mock_trial(self, status, start_time): + mock = MagicMock() + mock.status = status + mock.start_time = start_time + return mock + + def testFairFilterTrials(self): + """Tests that trials are represented fairly.""" + trials_by_state = collections.defaultdict(list) + # States for which trials are under and overrepresented + states_under = (Trial.PAUSED, Trial.ERROR) + states_over = (Trial.PENDING, Trial.RUNNING, Trial.TERMINATED) + + max_trials = 13 + num_trials_under = 2 # num of trials for each underrepresented state + num_trials_over = 10 # num of trials for each overrepresented state + + for state in states_under: + for _ in range(num_trials_under): + trials_by_state[state].append( + self.mock_trial(state, time.time())) + for state in states_over: + for _ in range(num_trials_over): + trials_by_state[state].append( + self.mock_trial(state, time.time())) + + filtered_trials_by_state = _fair_filter_trials( + trials_by_state, max_trials=max_trials) + for state in trials_by_state: + if state in states_under: + expected_num_trials = num_trials_under + else: + expected_num_trials = (max_trials - num_trials_under * + len(states_under)) / len(states_over) + state_trials = filtered_trials_by_state[state] + self.assertEqual(len(state_trials), expected_num_trials) + # Make sure trials are sorted newest-first within state. + for i in range(len(state_trials) - 1): + self.assertGreaterEqual(state_trials[i].start_time, + state_trials[i + 1].start_time) diff --git a/python/ray/tune/trial.py b/python/ray/tune/trial.py index 1798caa31..7ee547921 100644 --- a/python/ray/tune/trial.py +++ b/python/ray/tune/trial.py @@ -162,6 +162,7 @@ class Trial(object): self.export_formats = export_formats self.status = Trial.PENDING + self.start_time = None self.logdir = None self.runner = None self.result_logger = None @@ -251,6 +252,12 @@ class Trial(object): """Sets the location of the trial.""" self.address = location + def set_status(self, status): + """Sets the status of the trial.""" + if status == Trial.RUNNING and self.start_time is None: + self.start_time = time.time() + self.status = status + def close_logger(self): """Closes logger.""" if self.result_logger: diff --git a/python/ray/tune/trial_executor.py b/python/ray/tune/trial_executor.py index d8d35e023..859cb23eb 100644 --- a/python/ray/tune/trial_executor.py +++ b/python/ray/tune/trial_executor.py @@ -41,7 +41,7 @@ class TrialExecutor(object): """ logger.debug("Trial %s: Changing status from %s to %s.", trial, trial.status, status) - trial.status = status + trial.set_status(status) if status in [Trial.TERMINATED, Trial.ERROR]: self.try_checkpoint_metadata(trial)