mirror of
https://github.com/wassname/ray.git
synced 2026-06-27 22:23:17 +08:00
[tune] Tune CLI Fixes (#4659)
What do these changes do? Add --limit flag for ls Add ordering functionality to --sort flag Remove last_result from the names of columns for ls Fix weird double quote error messages (\")
This commit is contained in:
+44
-30
@@ -27,15 +27,9 @@ EDITOR = os.getenv("EDITOR", "vim")
|
||||
|
||||
TIMESTAMP_FORMAT = "%Y-%m-%d %H:%M:%S (%A)"
|
||||
|
||||
DEFAULT_EXPERIMENT_INFO_KEYS = (
|
||||
"trainable_name",
|
||||
"experiment_tag",
|
||||
"trial_id",
|
||||
"status",
|
||||
"last_update_time",
|
||||
)
|
||||
|
||||
DEFAULT_RESULT_KEYS = (TRAINING_ITERATION, MEAN_ACCURACY, MEAN_LOSS)
|
||||
DEFAULT_EXPERIMENT_INFO_KEYS = ("trainable_name", "experiment_tag", "trial_id",
|
||||
"status", "last_update_time",
|
||||
TRAINING_ITERATION, MEAN_ACCURACY, MEAN_LOSS)
|
||||
|
||||
DEFAULT_PROJECT_INFO_KEYS = (
|
||||
"name",
|
||||
@@ -46,6 +40,8 @@ DEFAULT_PROJECT_INFO_KEYS = (
|
||||
"last_updated",
|
||||
)
|
||||
|
||||
UNNEST_KEYS = ("config", "last_result")
|
||||
|
||||
try:
|
||||
TERM_HEIGHT, TERM_WIDTH = subprocess.check_output(["stty", "size"]).split()
|
||||
TERM_HEIGHT, TERM_WIDTH = int(TERM_HEIGHT), int(TERM_WIDTH)
|
||||
@@ -130,35 +126,42 @@ def list_trials(experiment_path,
|
||||
output=None,
|
||||
filter_op=None,
|
||||
info_keys=None,
|
||||
result_keys=None):
|
||||
limit=None,
|
||||
desc=False):
|
||||
"""Lists trials in the directory subtree starting at the given path.
|
||||
|
||||
Args:
|
||||
experiment_path (str): Directory where trials are located.
|
||||
Corresponds to Experiment.local_dir/Experiment.name.
|
||||
sort (str): Key to sort by.
|
||||
sort (list): Keys to sort by.
|
||||
output (str): Name of file where output is saved.
|
||||
filter_op (str): Filter operation in the format
|
||||
"<column> <operator> <value>".
|
||||
info_keys (list): Keys that are displayed.
|
||||
result_keys (list): Keys of last result that are displayed.
|
||||
limit (int): Number of rows to display.
|
||||
desc (bool): Sort ascending vs. descending.
|
||||
"""
|
||||
_check_tabulate()
|
||||
experiment_state = _get_experiment_state(
|
||||
experiment_path, exit_on_fail=True)
|
||||
|
||||
checkpoint_dicts = experiment_state["checkpoints"]
|
||||
checkpoint_dicts = [flatten_dict(g) for g in checkpoint_dicts]
|
||||
checkpoints = experiment_state["checkpoints"]
|
||||
|
||||
checkpoint_dicts = []
|
||||
for g in checkpoints:
|
||||
for key in UNNEST_KEYS:
|
||||
if key not in g:
|
||||
continue
|
||||
unnest_dict = flatten_dict(g.pop(key))
|
||||
g.update(unnest_dict)
|
||||
g = flatten_dict(g)
|
||||
checkpoint_dicts.append(g)
|
||||
|
||||
checkpoints_df = pd.DataFrame(checkpoint_dicts)
|
||||
|
||||
if not info_keys:
|
||||
info_keys = DEFAULT_EXPERIMENT_INFO_KEYS
|
||||
if not result_keys:
|
||||
result_keys = DEFAULT_RESULT_KEYS
|
||||
result_keys = ["last_result:{}".format(k) for k in result_keys]
|
||||
col_keys = [
|
||||
k for k in list(info_keys) + result_keys if k in checkpoints_df
|
||||
]
|
||||
col_keys = [k for k in list(info_keys) if k in checkpoints_df]
|
||||
checkpoints_df = checkpoints_df[col_keys]
|
||||
|
||||
if "last_update_time" in checkpoints_df:
|
||||
@@ -183,7 +186,7 @@ def list_trials(experiment_path,
|
||||
val = str(val)
|
||||
# TODO(Andrew): add support for datetime and boolean
|
||||
else:
|
||||
raise ValueError("Unsupported dtype for \"{}\": {}".format(
|
||||
raise ValueError("Unsupported dtype for {}: {}".format(
|
||||
val, col_type))
|
||||
op = OPERATORS[op]
|
||||
filtered_index = op(checkpoints_df[col], val)
|
||||
@@ -191,9 +194,13 @@ def list_trials(experiment_path,
|
||||
|
||||
if sort:
|
||||
if sort not in checkpoints_df:
|
||||
raise KeyError("Sort Index \"{}\" not in: {}".format(
|
||||
sort, list(checkpoints_df)))
|
||||
checkpoints_df = checkpoints_df.sort_values(by=sort)
|
||||
raise KeyError("{} not in: {}".format(sort, list(checkpoints_df)))
|
||||
ascending = not desc
|
||||
checkpoints_df = checkpoints_df.sort_values(
|
||||
by=sort, ascending=ascending)
|
||||
|
||||
if limit:
|
||||
checkpoints_df = checkpoints_df[:limit]
|
||||
|
||||
print_format_output(checkpoints_df)
|
||||
|
||||
@@ -212,17 +219,21 @@ def list_experiments(project_path,
|
||||
sort=None,
|
||||
output=None,
|
||||
filter_op=None,
|
||||
info_keys=None):
|
||||
info_keys=None,
|
||||
limit=None,
|
||||
desc=False):
|
||||
"""Lists experiments in the directory subtree.
|
||||
|
||||
Args:
|
||||
project_path (str): Directory where experiments are located.
|
||||
Corresponds to Experiment.local_dir.
|
||||
sort (str): Key to sort by.
|
||||
sort (list): Keys to sort by.
|
||||
output (str): Name of file where output is saved.
|
||||
filter_op (str): Filter operation in the format
|
||||
"<column> <operator> <value>".
|
||||
info_keys (list): Keys that are displayed.
|
||||
limit (int): Number of rows to display.
|
||||
desc (bool): Sort ascending vs. descending.
|
||||
"""
|
||||
_check_tabulate()
|
||||
base, experiment_folders, _ = next(os.walk(project_path))
|
||||
@@ -284,7 +295,7 @@ def list_experiments(project_path,
|
||||
val = str(val)
|
||||
# TODO(Andrew): add support for datetime and boolean
|
||||
else:
|
||||
raise ValueError("Unsupported dtype for \"{}\": {}".format(
|
||||
raise ValueError("Unsupported dtype for {}: {}".format(
|
||||
val, col_type))
|
||||
op = OPERATORS[op]
|
||||
filtered_index = op(info_df[col], val)
|
||||
@@ -292,9 +303,12 @@ def list_experiments(project_path,
|
||||
|
||||
if sort:
|
||||
if sort not in info_df:
|
||||
raise KeyError("Sort Index \"{}\" not in: {}".format(
|
||||
sort, list(info_df)))
|
||||
info_df = info_df.sort_values(by=sort)
|
||||
raise KeyError("{} not in: {}".format(sort, list(info_df)))
|
||||
ascending = not desc
|
||||
info_df = info_df.sort_values(by=sort, ascending=ascending)
|
||||
|
||||
if limit:
|
||||
info_df = info_df[:limit]
|
||||
|
||||
print_format_output(info_df)
|
||||
|
||||
|
||||
+23
-13
@@ -24,7 +24,6 @@ def cli():
|
||||
@click.option(
|
||||
"--filter",
|
||||
"filter_op",
|
||||
nargs=1,
|
||||
default=None,
|
||||
type=str,
|
||||
help="Select filter in the format '<column> <operator> <value>'.")
|
||||
@@ -34,20 +33,21 @@ def cli():
|
||||
type=str,
|
||||
help="Select columns to be displayed.")
|
||||
@click.option(
|
||||
"--result-columns",
|
||||
"result_columns",
|
||||
"--limit",
|
||||
default=None,
|
||||
type=str,
|
||||
help="Select columns of last result to be displayed.")
|
||||
def list_trials(experiment_path, sort, output, filter_op, columns,
|
||||
result_columns):
|
||||
type=int,
|
||||
help="Select number of rows to display.")
|
||||
@click.option(
|
||||
"--desc", default=False, type=bool, help="Sort ascending vs. descending.")
|
||||
def list_trials(experiment_path, sort, output, filter_op, columns, limit,
|
||||
desc):
|
||||
"""Lists trials in the directory subtree starting at the given path."""
|
||||
if sort:
|
||||
sort = sort.split(",")
|
||||
if columns:
|
||||
columns = columns.split(",")
|
||||
if result_columns:
|
||||
result_columns = result_columns.split(",")
|
||||
commands.list_trials(experiment_path, sort, output, filter_op, columns,
|
||||
result_columns)
|
||||
limit, desc)
|
||||
|
||||
|
||||
@cli.command()
|
||||
@@ -63,7 +63,6 @@ def list_trials(experiment_path, sort, output, filter_op, columns,
|
||||
@click.option(
|
||||
"--filter",
|
||||
"filter_op",
|
||||
nargs=1,
|
||||
default=None,
|
||||
type=str,
|
||||
help="Select filter in the format '<column> <operator> <value>'.")
|
||||
@@ -72,11 +71,22 @@ def list_trials(experiment_path, sort, output, filter_op, columns,
|
||||
default=None,
|
||||
type=str,
|
||||
help="Select columns to be displayed.")
|
||||
def list_experiments(project_path, sort, output, filter_op, columns):
|
||||
@click.option(
|
||||
"--limit",
|
||||
default=None,
|
||||
type=int,
|
||||
help="Select number of rows to display.")
|
||||
@click.option(
|
||||
"--desc", default=False, type=bool, help="Sort ascending vs. descending.")
|
||||
def list_experiments(project_path, sort, output, filter_op, columns, limit,
|
||||
desc):
|
||||
"""Lists experiments in the directory subtree."""
|
||||
if sort:
|
||||
sort = sort.split(",")
|
||||
if columns:
|
||||
columns = columns.split(",")
|
||||
commands.list_experiments(project_path, sort, output, filter_op, columns)
|
||||
commands.list_experiments(project_path, sort, output, filter_op, columns,
|
||||
limit, desc)
|
||||
|
||||
|
||||
@cli.command()
|
||||
|
||||
@@ -66,7 +66,7 @@ def test_ls(start_ray, tmpdir):
|
||||
"""This test captures output of list_trials."""
|
||||
experiment_name = "test_ls"
|
||||
experiment_path = os.path.join(str(tmpdir), experiment_name)
|
||||
num_samples = 2
|
||||
num_samples = 3
|
||||
tune.run_experiments({
|
||||
experiment_name: {
|
||||
"run": "__fake",
|
||||
@@ -78,19 +78,14 @@ def test_ls(start_ray, tmpdir):
|
||||
}
|
||||
})
|
||||
|
||||
with Capturing() as output:
|
||||
commands.list_trials(
|
||||
experiment_path,
|
||||
info_keys=("status", ),
|
||||
result_keys=(
|
||||
"episode_reward_mean",
|
||||
"training_iteration",
|
||||
))
|
||||
lines = output.captured
|
||||
assert sum("TERMINATED" in line for line in lines) == num_samples
|
||||
columns = ["status", "episode_reward_mean", "training_iteration"]
|
||||
limit = 2
|
||||
with Capturing() as output:
|
||||
commands.list_trials(experiment_path, info_keys=columns, limit=limit)
|
||||
lines = output.captured
|
||||
assert all(col in lines[1] for col in columns)
|
||||
assert lines[1].count("|") == 4
|
||||
assert lines[1].count("|") == len(columns) + 1
|
||||
assert len(lines) == 3 + limit + 1
|
||||
|
||||
with Capturing() as output:
|
||||
commands.list_trials(
|
||||
@@ -99,6 +94,7 @@ def test_ls(start_ray, tmpdir):
|
||||
filter_op="status == TERMINATED")
|
||||
lines = output.captured
|
||||
assert sum("TERMINATED" in line for line in lines) == num_samples
|
||||
assert len(lines) == 3 + num_samples + 1
|
||||
|
||||
|
||||
def test_lsx(start_ray, tmpdir):
|
||||
@@ -118,12 +114,14 @@ def test_lsx(start_ray, tmpdir):
|
||||
}
|
||||
})
|
||||
|
||||
limit = 2
|
||||
with Capturing() as output:
|
||||
commands.list_experiments(project_path, info_keys=("total_trials", ))
|
||||
commands.list_experiments(
|
||||
project_path, info_keys=("total_trials", ), limit=limit)
|
||||
lines = output.captured
|
||||
assert sum("1" in line for line in lines) >= num_experiments
|
||||
assert "total_trials" in lines[1]
|
||||
assert lines[1].count("|") == 2
|
||||
assert len(lines) == 3 + limit + 1
|
||||
|
||||
with Capturing() as output:
|
||||
commands.list_experiments(
|
||||
|
||||
Reference in New Issue
Block a user