mirror of
https://github.com/wassname/ray.git
synced 2026-06-29 06:33:06 +08:00
[tune] add filter flag for Tune CLI (#4337)
## What do these changes do? Adds filter flag (--filter) to ls / lsx commands for Tune CLI. Usage: `tune ls [path] --filter [column] [operator] [value]` e.g. `tune lsx ~/ray_results/my_project --filter total_trials == 1`
This commit is contained in:
@@ -8,9 +8,11 @@ import logging
|
||||
import os
|
||||
import sys
|
||||
import subprocess
|
||||
import operator
|
||||
from datetime import datetime
|
||||
|
||||
import pandas as pd
|
||||
from pandas.api.types import is_string_dtype, is_numeric_dtype
|
||||
from ray.tune.util import flatten_dict
|
||||
from ray.tune.result import TRAINING_ITERATION, MEAN_ACCURACY, MEAN_LOSS
|
||||
from ray.tune.trial import Trial
|
||||
@@ -21,6 +23,8 @@ except ImportError:
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
EDITOR = os.getenv("EDITOR", "vim")
|
||||
|
||||
TIMESTAMP_FORMAT = "%Y-%m-%d %H:%M:%S (%A)"
|
||||
|
||||
DEFAULT_EXPERIMENT_INFO_KEYS = (
|
||||
@@ -48,7 +52,14 @@ try:
|
||||
except subprocess.CalledProcessError:
|
||||
TERM_HEIGHT, TERM_WIDTH = 100, 100
|
||||
|
||||
EDITOR = os.getenv("EDITOR", "vim")
|
||||
OPERATORS = {
|
||||
'<': operator.lt,
|
||||
'<=': operator.le,
|
||||
'==': operator.eq,
|
||||
'!=': operator.ne,
|
||||
'>=': operator.ge,
|
||||
'>': operator.gt,
|
||||
}
|
||||
|
||||
|
||||
def _check_tabulate():
|
||||
@@ -117,6 +128,7 @@ def _get_experiment_state(experiment_path, exit_on_fail=False):
|
||||
def list_trials(experiment_path,
|
||||
sort=None,
|
||||
output=None,
|
||||
filter_op=None,
|
||||
info_keys=DEFAULT_EXPERIMENT_INFO_KEYS,
|
||||
result_keys=DEFAULT_RESULT_KEYS):
|
||||
"""Lists trials in the directory subtree starting at the given path.
|
||||
@@ -126,6 +138,8 @@ def list_trials(experiment_path,
|
||||
Corresponds to Experiment.local_dir/Experiment.name.
|
||||
sort (str): Key to sort by.
|
||||
output (str): Name of file where output is saved.
|
||||
filter_op (str): Filter operation in the format
|
||||
"<column> <operator> <value>".
|
||||
info_keys (list): Keys that are displayed.
|
||||
result_keys (list): Keys of last result that are displayed.
|
||||
"""
|
||||
@@ -156,6 +170,21 @@ def list_trials(experiment_path,
|
||||
checkpoints_df["logdir"] = checkpoints_df["logdir"].str.replace(
|
||||
experiment_path, '')
|
||||
|
||||
if filter_op:
|
||||
col, op, val = filter_op.split(' ')
|
||||
col_type = checkpoints_df[col].dtype
|
||||
if is_numeric_dtype(col_type):
|
||||
val = float(val)
|
||||
elif is_string_dtype(col_type):
|
||||
val = str(val)
|
||||
# TODO(Andrew): add support for datetime and boolean
|
||||
else:
|
||||
raise ValueError("Unsupported dtype for '{}': {}".format(
|
||||
val, col_type))
|
||||
op = OPERATORS[op]
|
||||
filtered_index = op(checkpoints_df[col], val)
|
||||
checkpoints_df = checkpoints_df[filtered_index]
|
||||
|
||||
if sort:
|
||||
if sort not in checkpoints_df:
|
||||
raise KeyError("Sort Index '{}' not in: {}".format(
|
||||
@@ -180,6 +209,7 @@ def list_trials(experiment_path,
|
||||
def list_experiments(project_path,
|
||||
sort=None,
|
||||
output=None,
|
||||
filter_op=None,
|
||||
info_keys=DEFAULT_PROJECT_INFO_KEYS):
|
||||
"""Lists experiments in the directory subtree.
|
||||
|
||||
@@ -188,6 +218,8 @@ def list_experiments(project_path,
|
||||
Corresponds to Experiment.local_dir.
|
||||
sort (str): Key to sort by.
|
||||
output (str): Name of file where output is saved.
|
||||
filter_op (str): Filter operation in the format
|
||||
"<column> <operator> <value>".
|
||||
info_keys (list): Keys that are displayed.
|
||||
"""
|
||||
_check_tabulate()
|
||||
@@ -234,13 +266,26 @@ def list_experiments(project_path,
|
||||
|
||||
info_df = pd.DataFrame(experiment_data_collection)
|
||||
col_keys = [k for k in list(info_keys) if k in info_df]
|
||||
|
||||
if not col_keys:
|
||||
print("None of keys {} in experiment data!".format(info_keys))
|
||||
sys.exit(0)
|
||||
|
||||
info_df = info_df[col_keys]
|
||||
|
||||
if filter_op:
|
||||
col, op, val = filter_op.split(' ')
|
||||
col_type = info_df[col].dtype
|
||||
if is_numeric_dtype(col_type):
|
||||
val = float(val)
|
||||
elif is_string_dtype(col_type):
|
||||
val = str(val)
|
||||
# TODO(Andrew): add support for datetime and boolean
|
||||
else:
|
||||
raise ValueError("Unsupported dtype for '{}': {}".format(
|
||||
val, col_type))
|
||||
op = OPERATORS[op]
|
||||
filtered_index = op(info_df[col], val)
|
||||
info_df = info_df[filtered_index]
|
||||
|
||||
if sort:
|
||||
if sort not in info_df:
|
||||
raise KeyError("Sort Index '{}' not in: {}".format(
|
||||
|
||||
@@ -20,10 +20,17 @@ def cli():
|
||||
"-o",
|
||||
default=None,
|
||||
type=str,
|
||||
help="Output information to a pickle file.")
|
||||
def list_trials(experiment_path, sort, output):
|
||||
help="Select file to output information to.")
|
||||
@click.option(
|
||||
"--filter",
|
||||
"filter_op",
|
||||
nargs=1,
|
||||
default=None,
|
||||
type=str,
|
||||
help="Select filter in the format '<column> <operator> <value>'.")
|
||||
def list_trials(experiment_path, sort, output, filter_op):
|
||||
"""Lists trials in the directory subtree starting at the given path."""
|
||||
commands.list_trials(experiment_path, sort, output)
|
||||
commands.list_trials(experiment_path, sort, output, filter_op)
|
||||
|
||||
|
||||
@cli.command()
|
||||
@@ -35,10 +42,17 @@ def list_trials(experiment_path, sort, output):
|
||||
"-o",
|
||||
default=None,
|
||||
type=str,
|
||||
help="Select filename to output information to.")
|
||||
def list_experiments(project_path, sort, output):
|
||||
help="Select file to output information to.")
|
||||
@click.option(
|
||||
"--filter",
|
||||
"filter_op",
|
||||
nargs=1,
|
||||
default=None,
|
||||
type=str,
|
||||
help="Select filter in the format '<column> <operator> <value>'.")
|
||||
def list_experiments(project_path, sort, output, filter_op):
|
||||
"""Lists experiments in the directory subtree."""
|
||||
commands.list_experiments(project_path, sort, output)
|
||||
commands.list_experiments(project_path, sort, output, filter_op)
|
||||
|
||||
|
||||
@cli.command()
|
||||
|
||||
@@ -58,6 +58,14 @@ def test_ls(start_ray, tmpdir):
|
||||
lines = output.captured
|
||||
assert sum("TERMINATED" in line for line in lines) == num_samples
|
||||
|
||||
with Capturing() as output:
|
||||
commands.list_trials(
|
||||
experiment_path,
|
||||
info_keys=("status", ),
|
||||
filter_op="status == TERMINATED")
|
||||
lines = output.captured
|
||||
assert sum("TERMINATED" in line for line in lines) == num_samples
|
||||
|
||||
|
||||
def test_lsx(start_ray, tmpdir):
|
||||
"""This test captures output of list_experiments."""
|
||||
@@ -79,4 +87,13 @@ def test_lsx(start_ray, tmpdir):
|
||||
with Capturing() as output:
|
||||
commands.list_experiments(project_path, info_keys=("total_trials", ))
|
||||
lines = output.captured
|
||||
assert sum("1" in line for line in lines) >= 3
|
||||
assert sum("1" in line for line in lines) >= num_experiments
|
||||
|
||||
with Capturing() as output:
|
||||
commands.list_experiments(
|
||||
project_path,
|
||||
info_keys=("total_trials", ),
|
||||
filter_op="total_trials == 1")
|
||||
lines = output.captured
|
||||
assert sum("1" in line for line in lines) >= num_experiments
|
||||
assert len(lines) == 3 + num_experiments + 1
|
||||
|
||||
Reference in New Issue
Block a user