diff --git a/python/ray/tune/commands.py b/python/ray/tune/commands.py index bf3651c44..a610424da 100644 --- a/python/ray/tune/commands.py +++ b/python/ray/tune/commands.py @@ -27,15 +27,9 @@ EDITOR = os.getenv("EDITOR", "vim") TIMESTAMP_FORMAT = "%Y-%m-%d %H:%M:%S (%A)" -DEFAULT_EXPERIMENT_INFO_KEYS = ( - "trainable_name", - "experiment_tag", - "trial_id", - "status", - "last_update_time", -) - -DEFAULT_RESULT_KEYS = (TRAINING_ITERATION, MEAN_ACCURACY, MEAN_LOSS) +DEFAULT_EXPERIMENT_INFO_KEYS = ("trainable_name", "experiment_tag", "trial_id", + "status", "last_update_time", + TRAINING_ITERATION, MEAN_ACCURACY, MEAN_LOSS) DEFAULT_PROJECT_INFO_KEYS = ( "name", @@ -46,6 +40,8 @@ DEFAULT_PROJECT_INFO_KEYS = ( "last_updated", ) +UNNEST_KEYS = ("config", "last_result") + try: TERM_HEIGHT, TERM_WIDTH = subprocess.check_output(["stty", "size"]).split() TERM_HEIGHT, TERM_WIDTH = int(TERM_HEIGHT), int(TERM_WIDTH) @@ -130,35 +126,42 @@ def list_trials(experiment_path, output=None, filter_op=None, info_keys=None, - result_keys=None): + limit=None, + desc=False): """Lists trials in the directory subtree starting at the given path. Args: experiment_path (str): Directory where trials are located. Corresponds to Experiment.local_dir/Experiment.name. - sort (str): Key to sort by. + sort (list): Keys to sort by. output (str): Name of file where output is saved. filter_op (str): Filter operation in the format " ". info_keys (list): Keys that are displayed. - result_keys (list): Keys of last result that are displayed. + limit (int): Number of rows to display. + desc (bool): Sort ascending vs. descending. """ _check_tabulate() experiment_state = _get_experiment_state( experiment_path, exit_on_fail=True) - checkpoint_dicts = experiment_state["checkpoints"] - checkpoint_dicts = [flatten_dict(g) for g in checkpoint_dicts] + checkpoints = experiment_state["checkpoints"] + + checkpoint_dicts = [] + for g in checkpoints: + for key in UNNEST_KEYS: + if key not in g: + continue + unnest_dict = flatten_dict(g.pop(key)) + g.update(unnest_dict) + g = flatten_dict(g) + checkpoint_dicts.append(g) + checkpoints_df = pd.DataFrame(checkpoint_dicts) if not info_keys: info_keys = DEFAULT_EXPERIMENT_INFO_KEYS - if not result_keys: - result_keys = DEFAULT_RESULT_KEYS - result_keys = ["last_result:{}".format(k) for k in result_keys] - col_keys = [ - k for k in list(info_keys) + result_keys if k in checkpoints_df - ] + col_keys = [k for k in list(info_keys) if k in checkpoints_df] checkpoints_df = checkpoints_df[col_keys] if "last_update_time" in checkpoints_df: @@ -183,7 +186,7 @@ def list_trials(experiment_path, val = str(val) # TODO(Andrew): add support for datetime and boolean else: - raise ValueError("Unsupported dtype for \"{}\": {}".format( + raise ValueError("Unsupported dtype for {}: {}".format( val, col_type)) op = OPERATORS[op] filtered_index = op(checkpoints_df[col], val) @@ -191,9 +194,13 @@ def list_trials(experiment_path, if sort: if sort not in checkpoints_df: - raise KeyError("Sort Index \"{}\" not in: {}".format( - sort, list(checkpoints_df))) - checkpoints_df = checkpoints_df.sort_values(by=sort) + raise KeyError("{} not in: {}".format(sort, list(checkpoints_df))) + ascending = not desc + checkpoints_df = checkpoints_df.sort_values( + by=sort, ascending=ascending) + + if limit: + checkpoints_df = checkpoints_df[:limit] print_format_output(checkpoints_df) @@ -212,17 +219,21 @@ def list_experiments(project_path, sort=None, output=None, filter_op=None, - info_keys=None): + info_keys=None, + limit=None, + desc=False): """Lists experiments in the directory subtree. Args: project_path (str): Directory where experiments are located. Corresponds to Experiment.local_dir. - sort (str): Key to sort by. + sort (list): Keys to sort by. output (str): Name of file where output is saved. filter_op (str): Filter operation in the format " ". info_keys (list): Keys that are displayed. + limit (int): Number of rows to display. + desc (bool): Sort ascending vs. descending. """ _check_tabulate() base, experiment_folders, _ = next(os.walk(project_path)) @@ -284,7 +295,7 @@ def list_experiments(project_path, val = str(val) # TODO(Andrew): add support for datetime and boolean else: - raise ValueError("Unsupported dtype for \"{}\": {}".format( + raise ValueError("Unsupported dtype for {}: {}".format( val, col_type)) op = OPERATORS[op] filtered_index = op(info_df[col], val) @@ -292,9 +303,12 @@ def list_experiments(project_path, if sort: if sort not in info_df: - raise KeyError("Sort Index \"{}\" not in: {}".format( - sort, list(info_df))) - info_df = info_df.sort_values(by=sort) + raise KeyError("{} not in: {}".format(sort, list(info_df))) + ascending = not desc + info_df = info_df.sort_values(by=sort, ascending=ascending) + + if limit: + info_df = info_df[:limit] print_format_output(info_df) diff --git a/python/ray/tune/scripts.py b/python/ray/tune/scripts.py index 8eb166c84..344087c22 100644 --- a/python/ray/tune/scripts.py +++ b/python/ray/tune/scripts.py @@ -24,7 +24,6 @@ def cli(): @click.option( "--filter", "filter_op", - nargs=1, default=None, type=str, help="Select filter in the format ' '.") @@ -34,20 +33,21 @@ def cli(): type=str, help="Select columns to be displayed.") @click.option( - "--result-columns", - "result_columns", + "--limit", default=None, - type=str, - help="Select columns of last result to be displayed.") -def list_trials(experiment_path, sort, output, filter_op, columns, - result_columns): + type=int, + help="Select number of rows to display.") +@click.option( + "--desc", default=False, type=bool, help="Sort ascending vs. descending.") +def list_trials(experiment_path, sort, output, filter_op, columns, limit, + desc): """Lists trials in the directory subtree starting at the given path.""" + if sort: + sort = sort.split(",") if columns: columns = columns.split(",") - if result_columns: - result_columns = result_columns.split(",") commands.list_trials(experiment_path, sort, output, filter_op, columns, - result_columns) + limit, desc) @cli.command() @@ -63,7 +63,6 @@ def list_trials(experiment_path, sort, output, filter_op, columns, @click.option( "--filter", "filter_op", - nargs=1, default=None, type=str, help="Select filter in the format ' '.") @@ -72,11 +71,22 @@ def list_trials(experiment_path, sort, output, filter_op, columns, default=None, type=str, help="Select columns to be displayed.") -def list_experiments(project_path, sort, output, filter_op, columns): +@click.option( + "--limit", + default=None, + type=int, + help="Select number of rows to display.") +@click.option( + "--desc", default=False, type=bool, help="Sort ascending vs. descending.") +def list_experiments(project_path, sort, output, filter_op, columns, limit, + desc): """Lists experiments in the directory subtree.""" + if sort: + sort = sort.split(",") if columns: columns = columns.split(",") - commands.list_experiments(project_path, sort, output, filter_op, columns) + commands.list_experiments(project_path, sort, output, filter_op, columns, + limit, desc) @cli.command() diff --git a/python/ray/tune/tests/test_commands.py b/python/ray/tune/tests/test_commands.py index 2f3e1e214..ab27cc65d 100644 --- a/python/ray/tune/tests/test_commands.py +++ b/python/ray/tune/tests/test_commands.py @@ -66,7 +66,7 @@ def test_ls(start_ray, tmpdir): """This test captures output of list_trials.""" experiment_name = "test_ls" experiment_path = os.path.join(str(tmpdir), experiment_name) - num_samples = 2 + num_samples = 3 tune.run_experiments({ experiment_name: { "run": "__fake", @@ -78,19 +78,14 @@ def test_ls(start_ray, tmpdir): } }) - with Capturing() as output: - commands.list_trials( - experiment_path, - info_keys=("status", ), - result_keys=( - "episode_reward_mean", - "training_iteration", - )) - lines = output.captured - assert sum("TERMINATED" in line for line in lines) == num_samples columns = ["status", "episode_reward_mean", "training_iteration"] + limit = 2 + with Capturing() as output: + commands.list_trials(experiment_path, info_keys=columns, limit=limit) + lines = output.captured assert all(col in lines[1] for col in columns) - assert lines[1].count("|") == 4 + assert lines[1].count("|") == len(columns) + 1 + assert len(lines) == 3 + limit + 1 with Capturing() as output: commands.list_trials( @@ -99,6 +94,7 @@ def test_ls(start_ray, tmpdir): filter_op="status == TERMINATED") lines = output.captured assert sum("TERMINATED" in line for line in lines) == num_samples + assert len(lines) == 3 + num_samples + 1 def test_lsx(start_ray, tmpdir): @@ -118,12 +114,14 @@ def test_lsx(start_ray, tmpdir): } }) + limit = 2 with Capturing() as output: - commands.list_experiments(project_path, info_keys=("total_trials", )) + commands.list_experiments( + project_path, info_keys=("total_trials", ), limit=limit) lines = output.captured - assert sum("1" in line for line in lines) >= num_experiments assert "total_trials" in lines[1] assert lines[1].count("|") == 2 + assert len(lines) == 3 + limit + 1 with Capturing() as output: commands.list_experiments(