diff --git a/doc/source/tune-usage.rst b/doc/source/tune-usage.rst index cbac45287..aa19dd103 100644 --- a/doc/source/tune-usage.rst +++ b/doc/source/tune-usage.rst @@ -465,7 +465,7 @@ Tune CLI (Experimental) Here are a few examples of command line calls. -- ``tune list-trials``: List tabular information about trials within an experiment. Add the ``--sort`` flag to sort the output by specific columns. +- ``tune list-trials``: List tabular information about trials within an experiment. Add the ``--sort`` flag to sort the output by specific columns. Add the ``--filter`` flag to filter the output in the format ``" "``. .. code-block:: bash @@ -482,7 +482,16 @@ Here are a few examples of command line calls. +------------------+-----------------------+------------+ Dropped columns: ['status', 'last_update_time'] -- ``tune list-experiments``: List tabular information about experiments within a project. Add the ``--sort`` flag to sort the output by specific columns. + $ tune list-trials [EXPERIMENT_DIR] --filter "trial_id == 7b99a28a" + + +------------------+-----------------------+------------+ + | trainable_name | experiment_tag | trial_id | + |------------------+-----------------------+------------| + | MyTrainableClass | 3_height=54,width=21 | 7b99a28a | + +------------------+-----------------------+------------+ + Dropped columns: ['status', 'last_update_time'] + +- ``tune list-experiments``: List tabular information about experiments within a project. Add the ``--sort`` flag to sort the output by specific columns. Add the ``--filter`` flag to filter the output in the format ``" "``. .. code-block:: bash @@ -497,6 +506,16 @@ Here are a few examples of command line calls. +----------------------+----------------+------------------+---------------------+ Dropped columns: ['error_trials', 'last_updated'] + $ tune list-experiments [PROJECT_DIR] --filter "total_trials <= 1" --sort name + + +----------------------+----------------+------------------+---------------------+ + | name | total_trials | running_trials | terminated_trials | + |----------------------+----------------+------------------+---------------------| + | hyperband_test | 1 | 0 | 1 | + | test | 1 | 0 | 0 | + +----------------------+----------------+------------------+---------------------+ + Dropped columns: ['error_trials', 'last_updated'] + Further Questions or Issues? ---------------------------- diff --git a/python/ray/tune/commands.py b/python/ray/tune/commands.py index add8cf25a..590a00adc 100644 --- a/python/ray/tune/commands.py +++ b/python/ray/tune/commands.py @@ -8,9 +8,11 @@ import logging import os import sys import subprocess +import operator from datetime import datetime import pandas as pd +from pandas.api.types import is_string_dtype, is_numeric_dtype from ray.tune.util import flatten_dict from ray.tune.result import TRAINING_ITERATION, MEAN_ACCURACY, MEAN_LOSS from ray.tune.trial import Trial @@ -21,6 +23,8 @@ except ImportError: logger = logging.getLogger(__name__) +EDITOR = os.getenv("EDITOR", "vim") + TIMESTAMP_FORMAT = "%Y-%m-%d %H:%M:%S (%A)" DEFAULT_EXPERIMENT_INFO_KEYS = ( @@ -48,7 +52,14 @@ try: except subprocess.CalledProcessError: TERM_HEIGHT, TERM_WIDTH = 100, 100 -EDITOR = os.getenv("EDITOR", "vim") +OPERATORS = { + '<': operator.lt, + '<=': operator.le, + '==': operator.eq, + '!=': operator.ne, + '>=': operator.ge, + '>': operator.gt, +} def _check_tabulate(): @@ -117,6 +128,7 @@ def _get_experiment_state(experiment_path, exit_on_fail=False): def list_trials(experiment_path, sort=None, output=None, + filter_op=None, info_keys=DEFAULT_EXPERIMENT_INFO_KEYS, result_keys=DEFAULT_RESULT_KEYS): """Lists trials in the directory subtree starting at the given path. @@ -126,6 +138,8 @@ def list_trials(experiment_path, Corresponds to Experiment.local_dir/Experiment.name. sort (str): Key to sort by. output (str): Name of file where output is saved. + filter_op (str): Filter operation in the format + " ". info_keys (list): Keys that are displayed. result_keys (list): Keys of last result that are displayed. """ @@ -156,6 +170,21 @@ def list_trials(experiment_path, checkpoints_df["logdir"] = checkpoints_df["logdir"].str.replace( experiment_path, '') + if filter_op: + col, op, val = filter_op.split(' ') + col_type = checkpoints_df[col].dtype + if is_numeric_dtype(col_type): + val = float(val) + elif is_string_dtype(col_type): + val = str(val) + # TODO(Andrew): add support for datetime and boolean + else: + raise ValueError("Unsupported dtype for '{}': {}".format( + val, col_type)) + op = OPERATORS[op] + filtered_index = op(checkpoints_df[col], val) + checkpoints_df = checkpoints_df[filtered_index] + if sort: if sort not in checkpoints_df: raise KeyError("Sort Index '{}' not in: {}".format( @@ -180,6 +209,7 @@ def list_trials(experiment_path, def list_experiments(project_path, sort=None, output=None, + filter_op=None, info_keys=DEFAULT_PROJECT_INFO_KEYS): """Lists experiments in the directory subtree. @@ -188,6 +218,8 @@ def list_experiments(project_path, Corresponds to Experiment.local_dir. sort (str): Key to sort by. output (str): Name of file where output is saved. + filter_op (str): Filter operation in the format + " ". info_keys (list): Keys that are displayed. """ _check_tabulate() @@ -234,13 +266,26 @@ def list_experiments(project_path, info_df = pd.DataFrame(experiment_data_collection) col_keys = [k for k in list(info_keys) if k in info_df] - if not col_keys: print("None of keys {} in experiment data!".format(info_keys)) sys.exit(0) - info_df = info_df[col_keys] + if filter_op: + col, op, val = filter_op.split(' ') + col_type = info_df[col].dtype + if is_numeric_dtype(col_type): + val = float(val) + elif is_string_dtype(col_type): + val = str(val) + # TODO(Andrew): add support for datetime and boolean + else: + raise ValueError("Unsupported dtype for '{}': {}".format( + val, col_type)) + op = OPERATORS[op] + filtered_index = op(info_df[col], val) + info_df = info_df[filtered_index] + if sort: if sort not in info_df: raise KeyError("Sort Index '{}' not in: {}".format( diff --git a/python/ray/tune/scripts.py b/python/ray/tune/scripts.py index 97527e5ea..cdf9f7b2a 100644 --- a/python/ray/tune/scripts.py +++ b/python/ray/tune/scripts.py @@ -20,10 +20,17 @@ def cli(): "-o", default=None, type=str, - help="Output information to a pickle file.") -def list_trials(experiment_path, sort, output): + help="Select file to output information to.") +@click.option( + "--filter", + "filter_op", + nargs=1, + default=None, + type=str, + help="Select filter in the format ' '.") +def list_trials(experiment_path, sort, output, filter_op): """Lists trials in the directory subtree starting at the given path.""" - commands.list_trials(experiment_path, sort, output) + commands.list_trials(experiment_path, sort, output, filter_op) @cli.command() @@ -35,10 +42,17 @@ def list_trials(experiment_path, sort, output): "-o", default=None, type=str, - help="Select filename to output information to.") -def list_experiments(project_path, sort, output): + help="Select file to output information to.") +@click.option( + "--filter", + "filter_op", + nargs=1, + default=None, + type=str, + help="Select filter in the format ' '.") +def list_experiments(project_path, sort, output, filter_op): """Lists experiments in the directory subtree.""" - commands.list_experiments(project_path, sort, output) + commands.list_experiments(project_path, sort, output, filter_op) @cli.command() diff --git a/python/ray/tune/tests/test_commands.py b/python/ray/tune/tests/test_commands.py index 9b58d3090..a8d76ee22 100644 --- a/python/ray/tune/tests/test_commands.py +++ b/python/ray/tune/tests/test_commands.py @@ -58,6 +58,14 @@ def test_ls(start_ray, tmpdir): lines = output.captured assert sum("TERMINATED" in line for line in lines) == num_samples + with Capturing() as output: + commands.list_trials( + experiment_path, + info_keys=("status", ), + filter_op="status == TERMINATED") + lines = output.captured + assert sum("TERMINATED" in line for line in lines) == num_samples + def test_lsx(start_ray, tmpdir): """This test captures output of list_experiments.""" @@ -79,4 +87,13 @@ def test_lsx(start_ray, tmpdir): with Capturing() as output: commands.list_experiments(project_path, info_keys=("total_trials", )) lines = output.captured - assert sum("1" in line for line in lines) >= 3 + assert sum("1" in line for line in lines) >= num_experiments + + with Capturing() as output: + commands.list_experiments( + project_path, + info_keys=("total_trials", ), + filter_op="total_trials == 1") + lines = output.captured + assert sum("1" in line for line in lines) >= num_experiments + assert len(lines) == 3 + num_experiments + 1