[tune] Tune CLI Fixes (#4659)

What do these changes do?
  Add --limit flag for ls
  Add ordering functionality to --sort flag
  Remove last_result from the names of columns for ls
  Fix weird double quote error messages (\")
This commit is contained in:
Andrew Tan
2019-04-30 18:21:33 -07:00
committed by Richard Liaw
parent 448a7bd08d
commit 23ae73135e
3 changed files with 79 additions and 57 deletions
+44 -30
View File
@@ -27,15 +27,9 @@ EDITOR = os.getenv("EDITOR", "vim")
TIMESTAMP_FORMAT = "%Y-%m-%d %H:%M:%S (%A)"
DEFAULT_EXPERIMENT_INFO_KEYS = (
"trainable_name",
"experiment_tag",
"trial_id",
"status",
"last_update_time",
)
DEFAULT_RESULT_KEYS = (TRAINING_ITERATION, MEAN_ACCURACY, MEAN_LOSS)
DEFAULT_EXPERIMENT_INFO_KEYS = ("trainable_name", "experiment_tag", "trial_id",
"status", "last_update_time",
TRAINING_ITERATION, MEAN_ACCURACY, MEAN_LOSS)
DEFAULT_PROJECT_INFO_KEYS = (
"name",
@@ -46,6 +40,8 @@ DEFAULT_PROJECT_INFO_KEYS = (
"last_updated",
)
UNNEST_KEYS = ("config", "last_result")
try:
TERM_HEIGHT, TERM_WIDTH = subprocess.check_output(["stty", "size"]).split()
TERM_HEIGHT, TERM_WIDTH = int(TERM_HEIGHT), int(TERM_WIDTH)
@@ -130,35 +126,42 @@ def list_trials(experiment_path,
output=None,
filter_op=None,
info_keys=None,
result_keys=None):
limit=None,
desc=False):
"""Lists trials in the directory subtree starting at the given path.
Args:
experiment_path (str): Directory where trials are located.
Corresponds to Experiment.local_dir/Experiment.name.
sort (str): Key to sort by.
sort (list): Keys to sort by.
output (str): Name of file where output is saved.
filter_op (str): Filter operation in the format
"<column> <operator> <value>".
info_keys (list): Keys that are displayed.
result_keys (list): Keys of last result that are displayed.
limit (int): Number of rows to display.
desc (bool): Sort ascending vs. descending.
"""
_check_tabulate()
experiment_state = _get_experiment_state(
experiment_path, exit_on_fail=True)
checkpoint_dicts = experiment_state["checkpoints"]
checkpoint_dicts = [flatten_dict(g) for g in checkpoint_dicts]
checkpoints = experiment_state["checkpoints"]
checkpoint_dicts = []
for g in checkpoints:
for key in UNNEST_KEYS:
if key not in g:
continue
unnest_dict = flatten_dict(g.pop(key))
g.update(unnest_dict)
g = flatten_dict(g)
checkpoint_dicts.append(g)
checkpoints_df = pd.DataFrame(checkpoint_dicts)
if not info_keys:
info_keys = DEFAULT_EXPERIMENT_INFO_KEYS
if not result_keys:
result_keys = DEFAULT_RESULT_KEYS
result_keys = ["last_result:{}".format(k) for k in result_keys]
col_keys = [
k for k in list(info_keys) + result_keys if k in checkpoints_df
]
col_keys = [k for k in list(info_keys) if k in checkpoints_df]
checkpoints_df = checkpoints_df[col_keys]
if "last_update_time" in checkpoints_df:
@@ -183,7 +186,7 @@ def list_trials(experiment_path,
val = str(val)
# TODO(Andrew): add support for datetime and boolean
else:
raise ValueError("Unsupported dtype for \"{}\": {}".format(
raise ValueError("Unsupported dtype for {}: {}".format(
val, col_type))
op = OPERATORS[op]
filtered_index = op(checkpoints_df[col], val)
@@ -191,9 +194,13 @@ def list_trials(experiment_path,
if sort:
if sort not in checkpoints_df:
raise KeyError("Sort Index \"{}\" not in: {}".format(
sort, list(checkpoints_df)))
checkpoints_df = checkpoints_df.sort_values(by=sort)
raise KeyError("{} not in: {}".format(sort, list(checkpoints_df)))
ascending = not desc
checkpoints_df = checkpoints_df.sort_values(
by=sort, ascending=ascending)
if limit:
checkpoints_df = checkpoints_df[:limit]
print_format_output(checkpoints_df)
@@ -212,17 +219,21 @@ def list_experiments(project_path,
sort=None,
output=None,
filter_op=None,
info_keys=None):
info_keys=None,
limit=None,
desc=False):
"""Lists experiments in the directory subtree.
Args:
project_path (str): Directory where experiments are located.
Corresponds to Experiment.local_dir.
sort (str): Key to sort by.
sort (list): Keys to sort by.
output (str): Name of file where output is saved.
filter_op (str): Filter operation in the format
"<column> <operator> <value>".
info_keys (list): Keys that are displayed.
limit (int): Number of rows to display.
desc (bool): Sort ascending vs. descending.
"""
_check_tabulate()
base, experiment_folders, _ = next(os.walk(project_path))
@@ -284,7 +295,7 @@ def list_experiments(project_path,
val = str(val)
# TODO(Andrew): add support for datetime and boolean
else:
raise ValueError("Unsupported dtype for \"{}\": {}".format(
raise ValueError("Unsupported dtype for {}: {}".format(
val, col_type))
op = OPERATORS[op]
filtered_index = op(info_df[col], val)
@@ -292,9 +303,12 @@ def list_experiments(project_path,
if sort:
if sort not in info_df:
raise KeyError("Sort Index \"{}\" not in: {}".format(
sort, list(info_df)))
info_df = info_df.sort_values(by=sort)
raise KeyError("{} not in: {}".format(sort, list(info_df)))
ascending = not desc
info_df = info_df.sort_values(by=sort, ascending=ascending)
if limit:
info_df = info_df[:limit]
print_format_output(info_df)
+23 -13
View File
@@ -24,7 +24,6 @@ def cli():
@click.option(
"--filter",
"filter_op",
nargs=1,
default=None,
type=str,
help="Select filter in the format '<column> <operator> <value>'.")
@@ -34,20 +33,21 @@ def cli():
type=str,
help="Select columns to be displayed.")
@click.option(
"--result-columns",
"result_columns",
"--limit",
default=None,
type=str,
help="Select columns of last result to be displayed.")
def list_trials(experiment_path, sort, output, filter_op, columns,
result_columns):
type=int,
help="Select number of rows to display.")
@click.option(
"--desc", default=False, type=bool, help="Sort ascending vs. descending.")
def list_trials(experiment_path, sort, output, filter_op, columns, limit,
desc):
"""Lists trials in the directory subtree starting at the given path."""
if sort:
sort = sort.split(",")
if columns:
columns = columns.split(",")
if result_columns:
result_columns = result_columns.split(",")
commands.list_trials(experiment_path, sort, output, filter_op, columns,
result_columns)
limit, desc)
@cli.command()
@@ -63,7 +63,6 @@ def list_trials(experiment_path, sort, output, filter_op, columns,
@click.option(
"--filter",
"filter_op",
nargs=1,
default=None,
type=str,
help="Select filter in the format '<column> <operator> <value>'.")
@@ -72,11 +71,22 @@ def list_trials(experiment_path, sort, output, filter_op, columns,
default=None,
type=str,
help="Select columns to be displayed.")
def list_experiments(project_path, sort, output, filter_op, columns):
@click.option(
"--limit",
default=None,
type=int,
help="Select number of rows to display.")
@click.option(
"--desc", default=False, type=bool, help="Sort ascending vs. descending.")
def list_experiments(project_path, sort, output, filter_op, columns, limit,
desc):
"""Lists experiments in the directory subtree."""
if sort:
sort = sort.split(",")
if columns:
columns = columns.split(",")
commands.list_experiments(project_path, sort, output, filter_op, columns)
commands.list_experiments(project_path, sort, output, filter_op, columns,
limit, desc)
@cli.command()
+12 -14
View File
@@ -66,7 +66,7 @@ def test_ls(start_ray, tmpdir):
"""This test captures output of list_trials."""
experiment_name = "test_ls"
experiment_path = os.path.join(str(tmpdir), experiment_name)
num_samples = 2
num_samples = 3
tune.run_experiments({
experiment_name: {
"run": "__fake",
@@ -78,19 +78,14 @@ def test_ls(start_ray, tmpdir):
}
})
with Capturing() as output:
commands.list_trials(
experiment_path,
info_keys=("status", ),
result_keys=(
"episode_reward_mean",
"training_iteration",
))
lines = output.captured
assert sum("TERMINATED" in line for line in lines) == num_samples
columns = ["status", "episode_reward_mean", "training_iteration"]
limit = 2
with Capturing() as output:
commands.list_trials(experiment_path, info_keys=columns, limit=limit)
lines = output.captured
assert all(col in lines[1] for col in columns)
assert lines[1].count("|") == 4
assert lines[1].count("|") == len(columns) + 1
assert len(lines) == 3 + limit + 1
with Capturing() as output:
commands.list_trials(
@@ -99,6 +94,7 @@ def test_ls(start_ray, tmpdir):
filter_op="status == TERMINATED")
lines = output.captured
assert sum("TERMINATED" in line for line in lines) == num_samples
assert len(lines) == 3 + num_samples + 1
def test_lsx(start_ray, tmpdir):
@@ -118,12 +114,14 @@ def test_lsx(start_ray, tmpdir):
}
})
limit = 2
with Capturing() as output:
commands.list_experiments(project_path, info_keys=("total_trials", ))
commands.list_experiments(
project_path, info_keys=("total_trials", ), limit=limit)
lines = output.captured
assert sum("1" in line for line in lines) >= num_experiments
assert "total_trials" in lines[1]
assert lines[1].count("|") == 2
assert len(lines) == 3 + limit + 1
with Capturing() as output:
commands.list_experiments(