[tune] Initial Commit for Tune CLI (#3983)

This introduces a light CLI for Tune.
This commit is contained in:
Richard Liaw
2019-03-08 16:46:05 -08:00
committed by GitHub
parent 3064fad96b
commit 6630a35353
18 changed files with 492 additions and 69 deletions
+232
View File
@@ -0,0 +1,232 @@
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
import glob
import json
import logging
import os
import sys
import subprocess
from datetime import datetime
import pandas as pd
from ray.tune.util import flatten_dict
from ray.tune.result import TRAINING_ITERATION, MEAN_ACCURACY, MEAN_LOSS
from ray.tune.trial import Trial
try:
from tabulate import tabulate
except ImportError:
tabulate = None
logger = logging.getLogger(__name__)
TIMESTAMP_FORMAT = "%Y-%m-%d %H:%M:%S (%A)"
DEFAULT_EXPERIMENT_INFO_KEYS = (
"trainable_name",
"experiment_tag",
"trial_id",
"status",
"last_update_time",
)
DEFAULT_RESULT_KEYS = (TRAINING_ITERATION, MEAN_ACCURACY, MEAN_LOSS)
DEFAULT_PROJECT_INFO_KEYS = (
"name",
"total_trials",
"running_trials",
"terminated_trials",
"error_trials",
"last_updated",
)
try:
TERM_HEIGHT, TERM_WIDTH = subprocess.check_output(['stty', 'size']).split()
TERM_HEIGHT, TERM_WIDTH = int(TERM_HEIGHT), int(TERM_WIDTH)
except subprocess.CalledProcessError:
TERM_HEIGHT, TERM_WIDTH = 100, 100
def _check_tabulate():
"""Checks whether tabulate is installed."""
if tabulate is None:
raise ImportError(
"Tabulate not installed. Please run `pip install tabulate`.")
def print_format_output(dataframe):
"""Prints output of given dataframe to fit into terminal.
Returns:
table (pd.DataFrame): Final outputted dataframe.
dropped_cols (list): Columns dropped due to terminal size.
empty_cols (list): Empty columns (dropped on default).
"""
print_df = pd.DataFrame()
dropped_cols = []
empty_cols = []
# column display priority is based on the info_keys passed in
for i, col in enumerate(dataframe):
if dataframe[col].isnull().all():
# Don't add col to print_df if is fully empty
empty_cols += [col]
continue
print_df[col] = dataframe[col]
test_table = tabulate(print_df, headers="keys", tablefmt="psql")
if str(test_table).index('\n') > TERM_WIDTH:
# Drop all columns beyond terminal width
print_df.drop(col, axis=1, inplace=True)
dropped_cols += list(dataframe.columns)[i:]
break
table = tabulate(
print_df, headers="keys", tablefmt="psql", showindex="never")
print(table)
if dropped_cols:
print("Dropped columns:", dropped_cols)
print("Please increase your terminal size to view remaining columns.")
if empty_cols:
print("Empty columns:", empty_cols)
return table, dropped_cols, empty_cols
def _get_experiment_state(experiment_path, exit_on_fail=False):
experiment_path = os.path.expanduser(experiment_path)
experiment_state_paths = glob.glob(
os.path.join(experiment_path, "experiment_state*.json"))
if not experiment_state_paths:
if exit_on_fail:
print("No experiment state found!")
sys.exit(0)
else:
return
experiment_filename = max(list(experiment_state_paths))
with open(experiment_filename) as f:
experiment_state = json.load(f)
return experiment_state
def list_trials(experiment_path,
sort=None,
info_keys=DEFAULT_EXPERIMENT_INFO_KEYS,
result_keys=DEFAULT_RESULT_KEYS):
"""Lists trials in the directory subtree starting at the given path.
Args:
experiment_path (str): Directory where trials are located.
Corresponds to Experiment.local_dir/Experiment.name.
sort (str): Key to sort by.
info_keys (list): Keys that are displayed.
result_keys (list): Keys of last result that are displayed.
"""
_check_tabulate()
experiment_state = _get_experiment_state(
experiment_path, exit_on_fail=True)
checkpoint_dicts = experiment_state["checkpoints"]
checkpoint_dicts = [flatten_dict(g) for g in checkpoint_dicts]
checkpoints_df = pd.DataFrame(checkpoint_dicts)
result_keys = ["last_result:{}".format(k) for k in result_keys]
col_keys = [
k for k in list(info_keys) + result_keys if k in checkpoints_df
]
checkpoints_df = checkpoints_df[col_keys]
if "last_update_time" in checkpoints_df:
with pd.option_context('mode.use_inf_as_null', True):
datetime_series = checkpoints_df["last_update_time"].dropna()
datetime_series = datetime_series.apply(
lambda t: datetime.fromtimestamp(t).strftime(TIMESTAMP_FORMAT))
checkpoints_df["last_update_time"] = datetime_series
if "logdir" in checkpoints_df:
# logdir often too verbose to view in table, so drop experiment_path
checkpoints_df["logdir"] = checkpoints_df["logdir"].str.replace(
experiment_path, '')
if sort:
if sort not in checkpoints_df:
raise KeyError("Sort Index '{}' not in: {}".format(
sort, list(checkpoints_df)))
checkpoints_df = checkpoints_df.sort_values(by=sort)
print_format_output(checkpoints_df)
def list_experiments(project_path,
sort=None,
info_keys=DEFAULT_PROJECT_INFO_KEYS):
"""Lists experiments in the directory subtree.
Args:
project_path (str): Directory where experiments are located.
Corresponds to Experiment.local_dir.
sort (str): Key to sort by.
info_keys (list): Keys that are displayed.
"""
_check_tabulate()
base, experiment_folders, _ = next(os.walk(project_path))
experiment_data_collection = []
for experiment_dir in experiment_folders:
experiment_state = _get_experiment_state(
os.path.join(base, experiment_dir))
if not experiment_state:
logger.debug("No experiment state found in %s", experiment_dir)
continue
checkpoints = pd.DataFrame(experiment_state["checkpoints"])
runner_data = experiment_state["runner_data"]
# Format time-based values.
time_values = {
"start_time": runner_data.get("_start_time"),
"last_updated": experiment_state.get("timestamp"),
}
formatted_time_values = {
key: datetime.fromtimestamp(val).strftime(TIMESTAMP_FORMAT)
if val else None
for key, val in time_values.items()
}
experiment_data = {
"name": experiment_dir,
"total_trials": checkpoints.shape[0],
"running_trials": (checkpoints["status"] == Trial.RUNNING).sum(),
"terminated_trials": (
checkpoints["status"] == Trial.TERMINATED).sum(),
"error_trials": (checkpoints["status"] == Trial.ERROR).sum(),
}
experiment_data.update(formatted_time_values)
experiment_data_collection.append(experiment_data)
if not experiment_data_collection:
print("No experiments found!")
sys.exit(0)
info_df = pd.DataFrame(experiment_data_collection)
col_keys = [k for k in list(info_keys) if k in info_df]
if not col_keys:
print("None of keys {} in experiment data!".format(info_keys))
sys.exit(0)
info_df = info_df[col_keys]
if sort:
if sort not in info_df:
raise KeyError("Sort Index '{}' not in: {}".format(
sort, list(info_df)))
info_df = info_df.sort_values(by=sort)
print_format_output(info_df)
+2 -2
View File
@@ -26,9 +26,9 @@ parser.add_argument(
parser.add_argument(
'--epochs',
type=int,
default=10,
default=1,
metavar='N',
help='number of epochs to train (default: 10)')
help='number of epochs to train (default: 1)')
parser.add_argument(
'--lr',
type=float,
@@ -29,9 +29,9 @@ parser.add_argument(
parser.add_argument(
'--epochs',
type=int,
default=10,
default=1,
metavar='N',
help='number of epochs to train (default: 10)')
help='number of epochs to train (default: 1)')
parser.add_argument(
'--lr',
type=float,
+11 -9
View File
@@ -17,15 +17,8 @@ from ray.tune.result import NODE_IP, TRAINING_ITERATION, TIME_TOTAL_S, \
logger = logging.getLogger(__name__)
try:
import tensorflow as tf
use_tf150_api = (distutils.version.LooseVersion(tf.VERSION) >=
distutils.version.LooseVersion("1.5.0"))
except ImportError:
tf = None
use_tf150_api = True
logger.warning("Couldn't import TensorFlow - "
"disabling TensorBoard logging.")
tf = None
use_tf150_api = True
class Logger(object):
@@ -121,6 +114,15 @@ def to_tf_values(result, path):
class TFLogger(Logger):
def _init(self):
try:
global tf, use_tf150_api
import tensorflow
tf = tensorflow
use_tf150_api = (distutils.version.LooseVersion(tf.VERSION) >=
distutils.version.LooseVersion("1.5.0"))
except ImportError:
logger.warning("Couldn't import TensorFlow - "
"disabling TensorBoard logging.")
self._file_writer = tf.summary.FileWriter(self.logdir)
def on_result(self, result):
+9
View File
@@ -18,6 +18,15 @@ NODE_IP = "node_ip"
# (Auto-filled) The pid of the training process.
PID = "pid"
# (Optional) Mean reward for current training iteration
EPISODE_REWARD_MEAN = "episode_reward_mean"
# (Optional) Mean loss for training iteration
MEAN_LOSS = "mean_loss"
# (Optional) Mean accuracy for training iteration
MEAN_ACCURACY = "mean_accuracy"
# Number of episodes in this iteration.
EPISODES_THIS_ITER = "episodes_this_iter"
+43
View File
@@ -0,0 +1,43 @@
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
import click
import ray.tune.commands as commands
@click.group()
def cli():
pass
@cli.command()
@click.argument("experiment_path", required=True, type=str)
@click.option(
'--sort', default=None, type=str, help='Select which column to sort on.')
def list_trials(experiment_path, sort):
"""Lists trials in the directory subtree starting at the given path."""
commands.list_trials(experiment_path, sort)
@cli.command()
@click.argument("project_path", required=True, type=str)
@click.option(
'--sort', default=None, type=str, help='Select which column to sort on.')
def list_experiments(project_path, sort):
"""Lists experiments in the directory subtree."""
commands.list_experiments(project_path, sort)
cli.add_command(list_trials, name="ls")
cli.add_command(list_trials, name="list-trials")
cli.add_command(list_experiments, name="lsx")
cli.add_command(list_experiments, name="list-experiments")
def main():
return cli()
if __name__ == "__main__":
main()
+9 -5
View File
@@ -4,13 +4,16 @@ from __future__ import print_function
import copy
try:
import bayes_opt as byo
except Exception:
byo = None
from ray.tune.suggest.suggestion import SuggestionAlgorithm
byo = None
def _import_bayesopt():
global byo
import bayes_opt
byo = bayes_opt
class BayesOptSearch(SuggestionAlgorithm):
"""A wrapper around BayesOpt to provide trial suggestions.
@@ -56,6 +59,7 @@ class BayesOptSearch(SuggestionAlgorithm):
random_state=1,
verbose=0,
**kwargs):
_import_bayesopt()
assert byo is not None, (
"BayesOpt must be installed!. You can install BayesOpt with"
" the command: `pip install bayesian-optimization`.")
+12 -8
View File
@@ -6,17 +6,19 @@ import numpy as np
import copy
import logging
try:
hyperopt_logger = logging.getLogger("hyperopt")
hyperopt_logger.setLevel(logging.WARNING)
import hyperopt as hpo
from hyperopt.fmin import generate_trials_to_calculate
except Exception:
hpo = None
from ray.tune.error import TuneError
from ray.tune.suggest.suggestion import SuggestionAlgorithm
hpo = None
def _import_hyperopt():
global hpo
hyperopt_logger = logging.getLogger("hyperopt")
hyperopt_logger.setLevel(logging.WARNING)
import hyperopt
hpo = hyperopt
class HyperOptSearch(SuggestionAlgorithm):
"""A wrapper around HyperOpt to provide trial suggestions.
@@ -73,7 +75,9 @@ class HyperOptSearch(SuggestionAlgorithm):
reward_attr="episode_reward_mean",
points_to_evaluate=None,
**kwargs):
_import_hyperopt()
assert hpo is not None, "HyperOpt must be installed!"
from hyperopt.fmin import generate_trials_to_calculate
assert type(max_concurrent) is int and max_concurrent > 0
self._max_concurrent = max_concurrent
self._reward_attr = reward_attr
+66
View File
@@ -0,0 +1,66 @@
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
import os
import pytest
import ray
from ray import tune
from ray.rllib import _register_all
from ray.tune import commands
@pytest.fixture
def start_ray():
ray.init()
_register_all()
yield
ray.shutdown()
def test_ls(start_ray, capsys, tmpdir):
"""This test captures output of list_trials."""
experiment_name = "test_ls"
experiment_path = os.path.join(str(tmpdir), experiment_name)
num_samples = 2
with capsys.disabled():
tune.run_experiments({
experiment_name: {
"run": "__fake",
"stop": {
"training_iteration": 1
},
"num_samples": num_samples,
"local_dir": str(tmpdir)
}
})
commands.list_trials(experiment_path, info_keys=("status", ))
captured = capsys.readouterr().out.strip()
lines = captured.split("\n")
assert sum("TERMINATED" in line for line in lines) == num_samples
def test_lsx(start_ray, capsys, tmpdir):
"""This test captures output of list_experiments."""
project_path = str(tmpdir)
num_experiments = 3
for i in range(num_experiments):
experiment_name = "test_lsx{}".format(i)
with capsys.disabled():
tune.run_experiments({
experiment_name: {
"run": "__fake",
"stop": {
"training_iteration": 1
},
"num_samples": 1,
"local_dir": project_path
}
})
commands.list_experiments(project_path, info_keys=("total_trials", ))
captured = capsys.readouterr().out.strip()
lines = captured.split("\n")
assert sum("1" in line for line in lines) >= 3
@@ -18,6 +18,7 @@ class RayTrialExecutorTest(unittest.TestCase):
def setUp(self):
self.trial_executor = RayTrialExecutor(queue_trials=False)
ray.init()
_register_all() # Needed for flaky tests
def tearDown(self):
ray.shutdown()
@@ -583,7 +583,7 @@ class _MockTrial(Trial):
self.logger_running = False
self.restored_checkpoint = None
self.resources = Resources(1, 0)
self.trial_name = None
self.custom_trial_name = None
class PopulationBasedTestingSuite(unittest.TestCase):
+15 -14
View File
@@ -25,7 +25,8 @@ from ray.tune.logger import pretty_print, UnifiedLogger
# have been defined yet. See https://github.com/ray-project/ray/issues/1716.
import ray.tune.registry
from ray.tune.result import (DEFAULT_RESULTS_DIR, DONE, HOSTNAME, PID,
TIME_TOTAL_S, TRAINING_ITERATION, TIMESTEPS_TOTAL)
TIME_TOTAL_S, TRAINING_ITERATION, TIMESTEPS_TOTAL,
EPISODE_REWARD_MEAN, MEAN_LOSS, MEAN_ACCURACY)
from ray.utils import _random_string, binary_to_hex, hex_to_binary
DEBUG_PRINT_INTERVAL = 5
@@ -299,6 +300,8 @@ class Trial(object):
self.error_file = None
self.num_failures = 0
self.custom_trial_name = None
# AutoML fields
self.results = None
self.best_result = None
@@ -316,10 +319,8 @@ class Trial(object):
"param_config",
"extra_arg",
]
self.trial_name = None
if trial_name_creator:
self.trial_name = trial_name_creator(self)
self.custom_trial_name = trial_name_creator(self)
@classmethod
def _registration_check(cls, trainable_name):
@@ -447,17 +448,17 @@ class Trial(object):
if self.last_result.get(TIMESTEPS_TOTAL) is not None:
pieces.append('{} ts'.format(self.last_result[TIMESTEPS_TOTAL]))
if self.last_result.get("episode_reward_mean") is not None:
if self.last_result.get(EPISODE_REWARD_MEAN) is not None:
pieces.append('{} rew'.format(
format(self.last_result["episode_reward_mean"], '.3g')))
format(self.last_result[EPISODE_REWARD_MEAN], '.3g')))
if self.last_result.get("mean_loss") is not None:
if self.last_result.get(MEAN_LOSS) is not None:
pieces.append('{} loss'.format(
format(self.last_result["mean_loss"], '.3g')))
format(self.last_result[MEAN_LOSS], '.3g')))
if self.last_result.get("mean_accuracy") is not None:
if self.last_result.get(MEAN_ACCURACY) is not None:
pieces.append('{} acc'.format(
format(self.last_result["mean_accuracy"], '.3g')))
format(self.last_result[MEAN_ACCURACY], '.3g')))
return ', '.join(pieces)
@@ -514,8 +515,8 @@ class Trial(object):
Can be overriden with a custom string creator.
"""
if self.trial_name:
return self.trial_name
if self.custom_trial_name:
return self.custom_trial_name
if "env" in self.config:
env = self.config["env"]
@@ -544,8 +545,6 @@ class Trial(object):
state["runner"] = None
state["result_logger"] = None
if self.status == Trial.RUNNING:
state["status"] = Trial.PENDING
if self.result_logger:
self.result_logger.flush()
state["__logger_started__"] = True
@@ -556,6 +555,8 @@ class Trial(object):
def __setstate__(self, state):
logger_started = state.pop("__logger_started__")
state["resources"] = json_to_resources(state["resources"])
if state["status"] == Trial.RUNNING:
state["status"] = Trial.PENDING
for key in self._nonjson_fields:
state[key] = cloudpickle.loads(hex_to_binary(state[key]))
+20 -5
View File
@@ -112,7 +112,9 @@ class TrialRunner(object):
self._stop_queue = []
self._metadata_checkpoint_dir = metadata_checkpoint_dir
self._session = datetime.today().strftime("%Y-%m-%d_%H-%M-%S")
self._start_time = time.time()
self._session_str = datetime.fromtimestamp(
self._start_time).strftime("%Y-%m-%d_%H-%M-%S")
@classmethod
def checkpoint_exists(cls, directory):
@@ -136,7 +138,8 @@ class TrialRunner(object):
runner_state = {
"checkpoints": list(
self.trial_executor.get_checkpoints().values()),
"runner_data": self.__getstate__()
"runner_data": self.__getstate__(),
"timestamp": time.time()
}
tmp_file_name = os.path.join(metadata_checkpoint_dir,
".tmp_checkpoint")
@@ -146,7 +149,7 @@ class TrialRunner(object):
os.rename(
tmp_file_name,
os.path.join(metadata_checkpoint_dir,
TrialRunner.CKPT_FILE_TMPL.format(self._session)))
TrialRunner.CKPT_FILE_TMPL.format(self._session_str)))
return metadata_checkpoint_dir
@classmethod
@@ -558,8 +561,12 @@ class TrialRunner(object):
"""
state = self.__dict__.copy()
for k in [
"_trials", "_stop_queue", "_server", "_search_alg",
"_scheduler_alg", "trial_executor", "_session"
"_trials",
"_stop_queue",
"_server",
"_search_alg",
"_scheduler_alg",
"trial_executor",
]:
del state[k]
state["launch_web_server"] = bool(self._server)
@@ -567,6 +574,14 @@ class TrialRunner(object):
def __setstate__(self, state):
launch_web_server = state.pop("launch_web_server")
# Use session_str from previous checkpoint if does not exist
session_str = state.pop("_session_str")
self.__dict__.setdefault("_session_str", session_str)
# Use start_time from previous checkpoint if does not exist
start_time = state.pop("_start_time")
self.__dict__.setdefault("_start_time", start_time)
self.__dict__.update(state)
if launch_web_server:
self._server = TuneServer(self, self._server_port)
+16 -1
View File
@@ -61,7 +61,7 @@ def deep_update(original, new_dict, new_keys_allowed, whitelist):
if k not in original:
if not new_keys_allowed:
raise Exception("Unknown config parameter `{}` ".format(k))
if type(original.get(k)) is dict:
if isinstance(original.get(k), dict):
if k in whitelist:
deep_update(original[k], value, True, [])
else:
@@ -71,6 +71,21 @@ def deep_update(original, new_dict, new_keys_allowed, whitelist):
return original
def flatten_dict(dt):
while any(isinstance(v, dict) for v in dt.values()):
remove = []
add = {}
for key, value in dt.items():
if isinstance(value, dict):
for subkey, v in value.items():
add[":".join([key, subkey])] = v
remove.append(key)
dt.update(add)
for k in remove:
del dt[k]
return dt
def _to_pinnable(obj):
"""Converts obj to a form that can be pinned in object store memory.
+4 -17
View File
@@ -10,24 +10,11 @@ import os.path as osp
import numpy as np
import json
from ray.tune.util import flatten_dict
logger = logging.getLogger(__name__)
def _flatten_dict(dt):
while any(type(v) is dict for v in dt.values()):
remove = []
add = {}
for key, value in dt.items():
if type(value) is dict:
for subkey, v in value.items():
add[":".join([key, subkey])] = v
remove.append(key)
dt.update(add)
for k in remove:
del dt[k]
return dt
def _parse_results(res_path):
res_dict = {}
try:
@@ -35,7 +22,7 @@ def _parse_results(res_path):
# Get last line in file
for line in f:
pass
res_dict = _flatten_dict(json.loads(line.strip()))
res_dict = flatten_dict(json.loads(line.strip()))
except Exception:
logger.exception("Importing %s failed...Perhaps empty?" % res_path)
return res_dict
@@ -44,7 +31,7 @@ def _parse_results(res_path):
def _parse_configs(cfg_path):
try:
with open(cfg_path) as f:
cfg_dict = _flatten_dict(json.load(f))
cfg_dict = flatten_dict(json.load(f))
except Exception:
logger.exception("Config parsing failed.")
return cfg_dict
+1 -1
View File
@@ -175,7 +175,7 @@ setup(
entry_points={
"console_scripts": [
"ray=ray.scripts.scripts:main",
"rllib=ray.rllib.scripts:cli [rllib]"
"rllib=ray.rllib.scripts:cli [rllib]", "tune=ray.tune.scripts:cli"
]
},
include_package_data=True,