mirror of
https://github.com/wassname/ray.git
synced 2026-07-01 12:37:14 +08:00
d7c7aba99c
Co-authored-by: Richard Liaw <rliaw@berkeley.edu>
301 lines
10 KiB
Python
301 lines
10 KiB
Python
import click
|
|
import logging
|
|
import os
|
|
import subprocess
|
|
import operator
|
|
from datetime import datetime
|
|
|
|
import pandas as pd
|
|
from pandas.api.types import is_string_dtype, is_numeric_dtype
|
|
from ray.tune.result import (DEFAULT_EXPERIMENT_INFO_KEYS, DEFAULT_RESULT_KEYS,
|
|
CONFIG_PREFIX)
|
|
from ray.tune.analysis import Analysis
|
|
from ray.tune import TuneError
|
|
try:
|
|
from tabulate import tabulate
|
|
except ImportError:
|
|
tabulate = None
|
|
|
|
logger = logging.getLogger(__name__)
|
|
|
|
EDITOR = os.getenv("EDITOR", "vim")
|
|
|
|
TIMESTAMP_FORMAT = "%Y-%m-%d %H:%M:%S (%A)"
|
|
|
|
DEFAULT_CLI_KEYS = DEFAULT_EXPERIMENT_INFO_KEYS + DEFAULT_RESULT_KEYS
|
|
|
|
DEFAULT_PROJECT_INFO_KEYS = (
|
|
"name",
|
|
"total_trials",
|
|
"last_updated",
|
|
)
|
|
|
|
try:
|
|
TERM_HEIGHT, TERM_WIDTH = subprocess.check_output(["stty", "size"]).split()
|
|
TERM_HEIGHT, TERM_WIDTH = int(TERM_HEIGHT), int(TERM_WIDTH)
|
|
except subprocess.CalledProcessError:
|
|
TERM_HEIGHT, TERM_WIDTH = 100, 100
|
|
|
|
OPERATORS = {
|
|
"<": operator.lt,
|
|
"<=": operator.le,
|
|
"==": operator.eq,
|
|
"!=": operator.ne,
|
|
">=": operator.ge,
|
|
">": operator.gt,
|
|
}
|
|
|
|
|
|
def _check_tabulate():
|
|
"""Checks whether tabulate is installed."""
|
|
if tabulate is None:
|
|
raise ImportError(
|
|
"Tabulate not installed. Please run `pip install tabulate`.")
|
|
|
|
|
|
def print_format_output(dataframe):
|
|
"""Prints output of given dataframe to fit into terminal.
|
|
|
|
Returns:
|
|
table (pd.DataFrame): Final outputted dataframe.
|
|
dropped_cols (list): Columns dropped due to terminal size.
|
|
empty_cols (list): Empty columns (dropped on default).
|
|
"""
|
|
print_df = pd.DataFrame()
|
|
dropped_cols = []
|
|
empty_cols = []
|
|
# column display priority is based on the info_keys passed in
|
|
for i, col in enumerate(dataframe):
|
|
if dataframe[col].isnull().all():
|
|
# Don't add col to print_df if is fully empty
|
|
empty_cols += [col]
|
|
continue
|
|
|
|
print_df[col] = dataframe[col]
|
|
test_table = tabulate(print_df, headers="keys", tablefmt="psql")
|
|
if str(test_table).index("\n") > TERM_WIDTH:
|
|
# Drop all columns beyond terminal width
|
|
print_df.drop(col, axis=1, inplace=True)
|
|
dropped_cols += list(dataframe.columns)[i:]
|
|
break
|
|
|
|
table = tabulate(
|
|
print_df, headers="keys", tablefmt="psql", showindex="never")
|
|
|
|
print(table)
|
|
if dropped_cols:
|
|
click.secho("Dropped columns: {}".format(dropped_cols), fg="yellow")
|
|
click.secho("Please increase your terminal size "
|
|
"to view remaining columns.")
|
|
if empty_cols:
|
|
click.secho("Empty columns: {}".format(empty_cols), fg="yellow")
|
|
|
|
return table, dropped_cols, empty_cols
|
|
|
|
|
|
def list_trials(experiment_path,
|
|
sort=None,
|
|
output=None,
|
|
filter_op=None,
|
|
info_keys=None,
|
|
limit=None,
|
|
desc=False):
|
|
"""Lists trials in the directory subtree starting at the given path.
|
|
|
|
Args:
|
|
experiment_path (str): Directory where trials are located.
|
|
Like Experiment.local_dir/Experiment.name/experiment*.json.
|
|
sort (list): Keys to sort by.
|
|
output (str): Name of file where output is saved.
|
|
filter_op (str): Filter operation in the format
|
|
"<column> <operator> <value>".
|
|
info_keys (list): Keys that are displayed.
|
|
limit (int): Number of rows to display.
|
|
desc (bool): Sort ascending vs. descending.
|
|
"""
|
|
_check_tabulate()
|
|
|
|
try:
|
|
checkpoints_df = Analysis(experiment_path).dataframe(
|
|
metric="episode_reward_mean", mode="max")
|
|
except TuneError:
|
|
raise click.ClickException("No trial data found!")
|
|
|
|
def key_filter(k):
|
|
return k in DEFAULT_CLI_KEYS or k.startswith(CONFIG_PREFIX)
|
|
|
|
col_keys = [k for k in checkpoints_df.columns if key_filter(k)]
|
|
|
|
if info_keys:
|
|
for k in info_keys:
|
|
if k not in checkpoints_df.columns:
|
|
raise click.ClickException("Provided key invalid: {}. "
|
|
"Available keys: {}.".format(
|
|
k, checkpoints_df.columns))
|
|
col_keys = [k for k in checkpoints_df.columns if k in info_keys]
|
|
|
|
if not col_keys:
|
|
raise click.ClickException("No columns to output.")
|
|
|
|
checkpoints_df = checkpoints_df[col_keys]
|
|
if "last_update_time" in checkpoints_df:
|
|
with pd.option_context("mode.use_inf_as_null", True):
|
|
datetime_series = checkpoints_df["last_update_time"].dropna()
|
|
|
|
datetime_series = datetime_series.apply(
|
|
lambda t: datetime.fromtimestamp(t).strftime(TIMESTAMP_FORMAT))
|
|
checkpoints_df["last_update_time"] = datetime_series
|
|
|
|
if "logdir" in checkpoints_df:
|
|
# logdir often too long to view in table, so drop experiment_path
|
|
checkpoints_df["logdir"] = checkpoints_df["logdir"].str.replace(
|
|
experiment_path, "")
|
|
|
|
if filter_op:
|
|
col, op, val = filter_op.split(" ")
|
|
col_type = checkpoints_df[col].dtype
|
|
if is_numeric_dtype(col_type):
|
|
val = float(val)
|
|
elif is_string_dtype(col_type):
|
|
val = str(val)
|
|
# TODO(Andrew): add support for datetime and boolean
|
|
else:
|
|
raise click.ClickException("Unsupported dtype for {}: {}".format(
|
|
val, col_type))
|
|
op = OPERATORS[op]
|
|
filtered_index = op(checkpoints_df[col], val)
|
|
checkpoints_df = checkpoints_df[filtered_index]
|
|
|
|
if sort:
|
|
for key in sort:
|
|
if key not in checkpoints_df:
|
|
raise click.ClickException("{} not in: {}".format(
|
|
key, list(checkpoints_df)))
|
|
ascending = not desc
|
|
checkpoints_df = checkpoints_df.sort_values(
|
|
by=sort, ascending=ascending)
|
|
|
|
if limit:
|
|
checkpoints_df = checkpoints_df[:limit]
|
|
|
|
print_format_output(checkpoints_df)
|
|
|
|
if output:
|
|
file_extension = os.path.splitext(output)[1].lower()
|
|
if file_extension in (".p", ".pkl", ".pickle"):
|
|
checkpoints_df.to_pickle(output)
|
|
elif file_extension == ".csv":
|
|
checkpoints_df.to_csv(output, index=False)
|
|
else:
|
|
raise click.ClickException(
|
|
"Unsupported filetype: {}".format(output))
|
|
click.secho("Output saved at {}".format(output), fg="green")
|
|
|
|
|
|
def list_experiments(project_path,
|
|
sort=None,
|
|
output=None,
|
|
filter_op=None,
|
|
info_keys=None,
|
|
limit=None,
|
|
desc=False):
|
|
"""Lists experiments in the directory subtree.
|
|
|
|
Args:
|
|
project_path (str): Directory where experiments are located.
|
|
Corresponds to Experiment.local_dir.
|
|
sort (list): Keys to sort by.
|
|
output (str): Name of file where output is saved.
|
|
filter_op (str): Filter operation in the format
|
|
"<column> <operator> <value>".
|
|
info_keys (list): Keys that are displayed.
|
|
limit (int): Number of rows to display.
|
|
desc (bool): Sort ascending vs. descending.
|
|
"""
|
|
_check_tabulate()
|
|
base, experiment_folders, _ = next(os.walk(project_path))
|
|
|
|
experiment_data_collection = []
|
|
|
|
for experiment_dir in experiment_folders:
|
|
num_trials = sum(
|
|
"result.json" in files
|
|
for _, _, files in os.walk(os.path.join(base, experiment_dir)))
|
|
|
|
experiment_data = {"name": experiment_dir, "total_trials": num_trials}
|
|
experiment_data_collection.append(experiment_data)
|
|
|
|
if not experiment_data_collection:
|
|
raise click.ClickException("No experiments found!")
|
|
|
|
info_df = pd.DataFrame(experiment_data_collection)
|
|
if not info_keys:
|
|
info_keys = DEFAULT_PROJECT_INFO_KEYS
|
|
col_keys = [k for k in list(info_keys) if k in info_df]
|
|
if not col_keys:
|
|
raise click.ClickException(
|
|
"None of keys {} in experiment data!".format(info_keys))
|
|
info_df = info_df[col_keys]
|
|
|
|
if filter_op:
|
|
col, op, val = filter_op.split(" ")
|
|
col_type = info_df[col].dtype
|
|
if is_numeric_dtype(col_type):
|
|
val = float(val)
|
|
elif is_string_dtype(col_type):
|
|
val = str(val)
|
|
# TODO(Andrew): add support for datetime and boolean
|
|
else:
|
|
raise click.ClickException("Unsupported dtype for {}: {}".format(
|
|
val, col_type))
|
|
op = OPERATORS[op]
|
|
filtered_index = op(info_df[col], val)
|
|
info_df = info_df[filtered_index]
|
|
|
|
if sort:
|
|
for key in sort:
|
|
if key not in info_df:
|
|
raise click.ClickException("{} not in: {}".format(
|
|
key, list(info_df)))
|
|
ascending = not desc
|
|
info_df = info_df.sort_values(by=sort, ascending=ascending)
|
|
|
|
if limit:
|
|
info_df = info_df[:limit]
|
|
|
|
print_format_output(info_df)
|
|
|
|
if output:
|
|
file_extension = os.path.splitext(output)[1].lower()
|
|
if file_extension in (".p", ".pkl", ".pickle"):
|
|
info_df.to_pickle(output)
|
|
elif file_extension == ".csv":
|
|
info_df.to_csv(output, index=False)
|
|
else:
|
|
raise click.ClickException(
|
|
"Unsupported filetype: {}".format(output))
|
|
click.secho("Output saved at {}".format(output), fg="green")
|
|
|
|
|
|
def add_note(path, filename="note.txt"):
|
|
"""Opens a txt file at the given path where user can add and save notes.
|
|
|
|
Args:
|
|
path (str): Directory where note will be saved.
|
|
filename (str): Name of note. Defaults to "note.txt"
|
|
"""
|
|
path = os.path.expanduser(path)
|
|
assert os.path.isdir(path), "{} is not a valid directory.".format(path)
|
|
|
|
filepath = os.path.join(path, filename)
|
|
exists = os.path.isfile(filepath)
|
|
|
|
try:
|
|
subprocess.call([EDITOR, filepath])
|
|
except Exception as exc:
|
|
click.secho("Editing note failed: {}".format(str(exc)), fg="red")
|
|
if exists:
|
|
print("Note updated at:", filepath)
|
|
else:
|
|
print("Note created at:", filepath)
|