mirror of
https://github.com/wassname/ray.git
synced 2026-07-02 17:23:09 +08:00
[tune] Add --columns flag for CLI (#4564)
This commit is contained in:
@@ -129,8 +129,8 @@ def list_trials(experiment_path,
|
||||
sort=None,
|
||||
output=None,
|
||||
filter_op=None,
|
||||
info_keys=DEFAULT_EXPERIMENT_INFO_KEYS,
|
||||
result_keys=DEFAULT_RESULT_KEYS):
|
||||
info_keys=None,
|
||||
result_keys=None):
|
||||
"""Lists trials in the directory subtree starting at the given path.
|
||||
|
||||
Args:
|
||||
@@ -151,6 +151,10 @@ def list_trials(experiment_path,
|
||||
checkpoint_dicts = [flatten_dict(g) for g in checkpoint_dicts]
|
||||
checkpoints_df = pd.DataFrame(checkpoint_dicts)
|
||||
|
||||
if not info_keys:
|
||||
info_keys = DEFAULT_EXPERIMENT_INFO_KEYS
|
||||
if not result_keys:
|
||||
result_keys = DEFAULT_RESULT_KEYS
|
||||
result_keys = ["last_result:{}".format(k) for k in result_keys]
|
||||
col_keys = [
|
||||
k for k in list(info_keys) + result_keys if k in checkpoints_df
|
||||
@@ -208,7 +212,7 @@ def list_experiments(project_path,
|
||||
sort=None,
|
||||
output=None,
|
||||
filter_op=None,
|
||||
info_keys=DEFAULT_PROJECT_INFO_KEYS):
|
||||
info_keys=None):
|
||||
"""Lists experiments in the directory subtree.
|
||||
|
||||
Args:
|
||||
@@ -263,6 +267,8 @@ def list_experiments(project_path,
|
||||
sys.exit(0)
|
||||
|
||||
info_df = pd.DataFrame(experiment_data_collection)
|
||||
if not info_keys:
|
||||
info_keys = DEFAULT_PROJECT_INFO_KEYS
|
||||
col_keys = [k for k in list(info_keys) if k in info_df]
|
||||
if not col_keys:
|
||||
print("None of keys {} in experiment data!".format(info_keys))
|
||||
|
||||
@@ -28,9 +28,26 @@ def cli():
|
||||
default=None,
|
||||
type=str,
|
||||
help="Select filter in the format '<column> <operator> <value>'.")
|
||||
def list_trials(experiment_path, sort, output, filter_op):
|
||||
@click.option(
|
||||
"--columns",
|
||||
default=None,
|
||||
type=str,
|
||||
help="Select columns to be displayed.")
|
||||
@click.option(
|
||||
"--result-columns",
|
||||
"result_columns",
|
||||
default=None,
|
||||
type=str,
|
||||
help="Select columns of last result to be displayed.")
|
||||
def list_trials(experiment_path, sort, output, filter_op, columns,
|
||||
result_columns):
|
||||
"""Lists trials in the directory subtree starting at the given path."""
|
||||
commands.list_trials(experiment_path, sort, output, filter_op)
|
||||
if columns:
|
||||
columns = columns.split(',')
|
||||
if result_columns:
|
||||
result_columns = result_columns.split(',')
|
||||
commands.list_trials(experiment_path, sort, output, filter_op, columns,
|
||||
result_columns)
|
||||
|
||||
|
||||
@cli.command()
|
||||
@@ -50,9 +67,16 @@ def list_trials(experiment_path, sort, output, filter_op):
|
||||
default=None,
|
||||
type=str,
|
||||
help="Select filter in the format '<column> <operator> <value>'.")
|
||||
def list_experiments(project_path, sort, output, filter_op):
|
||||
@click.option(
|
||||
"--columns",
|
||||
default=None,
|
||||
type=str,
|
||||
help="Select columns to be displayed.")
|
||||
def list_experiments(project_path, sort, output, filter_op, columns):
|
||||
"""Lists experiments in the directory subtree."""
|
||||
commands.list_experiments(project_path, sort, output, filter_op)
|
||||
if columns:
|
||||
columns = columns.split(',')
|
||||
commands.list_experiments(project_path, sort, output, filter_op, columns)
|
||||
|
||||
|
||||
@cli.command()
|
||||
|
||||
@@ -79,9 +79,18 @@ def test_ls(start_ray, tmpdir):
|
||||
})
|
||||
|
||||
with Capturing() as output:
|
||||
commands.list_trials(experiment_path, info_keys=("status", ))
|
||||
commands.list_trials(
|
||||
experiment_path,
|
||||
info_keys=("status", ),
|
||||
result_keys=(
|
||||
"episode_reward_mean",
|
||||
"training_iteration",
|
||||
))
|
||||
lines = output.captured
|
||||
assert sum("TERMINATED" in line for line in lines) == num_samples
|
||||
columns = ["status", "episode_reward_mean", "training_iteration"]
|
||||
assert all(col in lines[1] for col in columns)
|
||||
assert lines[1].count('|') == 4
|
||||
|
||||
with Capturing() as output:
|
||||
commands.list_trials(
|
||||
@@ -113,6 +122,8 @@ def test_lsx(start_ray, tmpdir):
|
||||
commands.list_experiments(project_path, info_keys=("total_trials", ))
|
||||
lines = output.captured
|
||||
assert sum("1" in line for line in lines) >= num_experiments
|
||||
assert "total_trials" in lines[1]
|
||||
assert lines[1].count('|') == 2
|
||||
|
||||
with Capturing() as output:
|
||||
commands.list_experiments(
|
||||
|
||||
Reference in New Issue
Block a user