[tune] add filter flag for Tune CLI (#4337)

## What do these changes do?

Adds filter flag (--filter) to ls / lsx commands for Tune CLI.

Usage: `tune ls [path] --filter [column] [operator] [value]`
e.g. `tune lsx ~/ray_results/my_project --filter total_trials == 1`
This commit is contained in:
Andrew Tan
2019-03-27 11:19:25 -07:00
committed by Richard Liaw
parent c6f12e5219
commit 12db684f72
4 changed files with 107 additions and 12 deletions
+48 -3
View File
@@ -8,9 +8,11 @@ import logging
import os
import sys
import subprocess
import operator
from datetime import datetime
import pandas as pd
from pandas.api.types import is_string_dtype, is_numeric_dtype
from ray.tune.util import flatten_dict
from ray.tune.result import TRAINING_ITERATION, MEAN_ACCURACY, MEAN_LOSS
from ray.tune.trial import Trial
@@ -21,6 +23,8 @@ except ImportError:
logger = logging.getLogger(__name__)
EDITOR = os.getenv("EDITOR", "vim")
TIMESTAMP_FORMAT = "%Y-%m-%d %H:%M:%S (%A)"
DEFAULT_EXPERIMENT_INFO_KEYS = (
@@ -48,7 +52,14 @@ try:
except subprocess.CalledProcessError:
TERM_HEIGHT, TERM_WIDTH = 100, 100
EDITOR = os.getenv("EDITOR", "vim")
OPERATORS = {
'<': operator.lt,
'<=': operator.le,
'==': operator.eq,
'!=': operator.ne,
'>=': operator.ge,
'>': operator.gt,
}
def _check_tabulate():
@@ -117,6 +128,7 @@ def _get_experiment_state(experiment_path, exit_on_fail=False):
def list_trials(experiment_path,
sort=None,
output=None,
filter_op=None,
info_keys=DEFAULT_EXPERIMENT_INFO_KEYS,
result_keys=DEFAULT_RESULT_KEYS):
"""Lists trials in the directory subtree starting at the given path.
@@ -126,6 +138,8 @@ def list_trials(experiment_path,
Corresponds to Experiment.local_dir/Experiment.name.
sort (str): Key to sort by.
output (str): Name of file where output is saved.
filter_op (str): Filter operation in the format
"<column> <operator> <value>".
info_keys (list): Keys that are displayed.
result_keys (list): Keys of last result that are displayed.
"""
@@ -156,6 +170,21 @@ def list_trials(experiment_path,
checkpoints_df["logdir"] = checkpoints_df["logdir"].str.replace(
experiment_path, '')
if filter_op:
col, op, val = filter_op.split(' ')
col_type = checkpoints_df[col].dtype
if is_numeric_dtype(col_type):
val = float(val)
elif is_string_dtype(col_type):
val = str(val)
# TODO(Andrew): add support for datetime and boolean
else:
raise ValueError("Unsupported dtype for '{}': {}".format(
val, col_type))
op = OPERATORS[op]
filtered_index = op(checkpoints_df[col], val)
checkpoints_df = checkpoints_df[filtered_index]
if sort:
if sort not in checkpoints_df:
raise KeyError("Sort Index '{}' not in: {}".format(
@@ -180,6 +209,7 @@ def list_trials(experiment_path,
def list_experiments(project_path,
sort=None,
output=None,
filter_op=None,
info_keys=DEFAULT_PROJECT_INFO_KEYS):
"""Lists experiments in the directory subtree.
@@ -188,6 +218,8 @@ def list_experiments(project_path,
Corresponds to Experiment.local_dir.
sort (str): Key to sort by.
output (str): Name of file where output is saved.
filter_op (str): Filter operation in the format
"<column> <operator> <value>".
info_keys (list): Keys that are displayed.
"""
_check_tabulate()
@@ -234,13 +266,26 @@ def list_experiments(project_path,
info_df = pd.DataFrame(experiment_data_collection)
col_keys = [k for k in list(info_keys) if k in info_df]
if not col_keys:
print("None of keys {} in experiment data!".format(info_keys))
sys.exit(0)
info_df = info_df[col_keys]
if filter_op:
col, op, val = filter_op.split(' ')
col_type = info_df[col].dtype
if is_numeric_dtype(col_type):
val = float(val)
elif is_string_dtype(col_type):
val = str(val)
# TODO(Andrew): add support for datetime and boolean
else:
raise ValueError("Unsupported dtype for '{}': {}".format(
val, col_type))
op = OPERATORS[op]
filtered_index = op(info_df[col], val)
info_df = info_df[filtered_index]
if sort:
if sort not in info_df:
raise KeyError("Sort Index '{}' not in: {}".format(
+20 -6
View File
@@ -20,10 +20,17 @@ def cli():
"-o",
default=None,
type=str,
help="Output information to a pickle file.")
def list_trials(experiment_path, sort, output):
help="Select file to output information to.")
@click.option(
"--filter",
"filter_op",
nargs=1,
default=None,
type=str,
help="Select filter in the format '<column> <operator> <value>'.")
def list_trials(experiment_path, sort, output, filter_op):
"""Lists trials in the directory subtree starting at the given path."""
commands.list_trials(experiment_path, sort, output)
commands.list_trials(experiment_path, sort, output, filter_op)
@cli.command()
@@ -35,10 +42,17 @@ def list_trials(experiment_path, sort, output):
"-o",
default=None,
type=str,
help="Select filename to output information to.")
def list_experiments(project_path, sort, output):
help="Select file to output information to.")
@click.option(
"--filter",
"filter_op",
nargs=1,
default=None,
type=str,
help="Select filter in the format '<column> <operator> <value>'.")
def list_experiments(project_path, sort, output, filter_op):
"""Lists experiments in the directory subtree."""
commands.list_experiments(project_path, sort, output)
commands.list_experiments(project_path, sort, output, filter_op)
@cli.command()
+18 -1
View File
@@ -58,6 +58,14 @@ def test_ls(start_ray, tmpdir):
lines = output.captured
assert sum("TERMINATED" in line for line in lines) == num_samples
with Capturing() as output:
commands.list_trials(
experiment_path,
info_keys=("status", ),
filter_op="status == TERMINATED")
lines = output.captured
assert sum("TERMINATED" in line for line in lines) == num_samples
def test_lsx(start_ray, tmpdir):
"""This test captures output of list_experiments."""
@@ -79,4 +87,13 @@ def test_lsx(start_ray, tmpdir):
with Capturing() as output:
commands.list_experiments(project_path, info_keys=("total_trials", ))
lines = output.captured
assert sum("1" in line for line in lines) >= 3
assert sum("1" in line for line in lines) >= num_experiments
with Capturing() as output:
commands.list_experiments(
project_path,
info_keys=("total_trials", ),
filter_op="total_trials == 1")
lines = output.captured
assert sum("1" in line for line in lines) >= num_experiments
assert len(lines) == 3 + num_experiments + 1