diff --git a/ci/travis/install-dependencies.sh b/ci/travis/install-dependencies.sh
index 29d1b2650..5f92eb43a 100755
--- a/ci/travis/install-dependencies.sh
+++ b/ci/travis/install-dependencies.sh
@@ -25,7 +25,7 @@ if [[ "$PYTHON" == "2.7" ]] && [[ "$platform" == "linux" ]]; then
bash miniconda.sh -b -p $HOME/miniconda
export PATH="$HOME/miniconda/bin:$PATH"
pip install -q scipy tensorflow cython==0.29.0 gym opencv-python-headless pyyaml pandas==0.23.4 requests \
- feather-format lxml openpyxl xlrd py-spy setproctitle faulthandler pytest-timeout mock flaky networkx
+ feather-format lxml openpyxl xlrd py-spy setproctitle faulthandler pytest-timeout mock flaky networkx tabulate
elif [[ "$PYTHON" == "3.5" ]] && [[ "$platform" == "linux" ]]; then
sudo apt-get update
sudo apt-get install -y python-dev python-numpy build-essential curl unzip tmux gdb
@@ -34,7 +34,7 @@ elif [[ "$PYTHON" == "3.5" ]] && [[ "$platform" == "linux" ]]; then
bash miniconda.sh -b -p $HOME/miniconda
export PATH="$HOME/miniconda/bin:$PATH"
pip install -q scipy tensorflow cython==0.29.0 gym opencv-python-headless pyyaml pandas==0.23.4 requests \
- feather-format lxml openpyxl xlrd py-spy setproctitle pytest-timeout flaky networkx
+ feather-format lxml openpyxl xlrd py-spy setproctitle pytest-timeout flaky networkx tabulate
elif [[ "$PYTHON" == "2.7" ]] && [[ "$platform" == "macosx" ]]; then
# check that brew is installed
which -s brew
@@ -50,7 +50,7 @@ elif [[ "$PYTHON" == "2.7" ]] && [[ "$platform" == "macosx" ]]; then
bash miniconda.sh -b -p $HOME/miniconda
export PATH="$HOME/miniconda/bin:$PATH"
pip install -q cython==0.29.0 tensorflow gym opencv-python-headless pyyaml pandas==0.23.4 requests \
- feather-format lxml openpyxl xlrd py-spy setproctitle faulthandler pytest-timeout mock flaky networkx
+ feather-format lxml openpyxl xlrd py-spy setproctitle faulthandler pytest-timeout mock flaky networkx tabulate
elif [[ "$PYTHON" == "3.5" ]] && [[ "$platform" == "macosx" ]]; then
# check that brew is installed
which -s brew
@@ -66,7 +66,7 @@ elif [[ "$PYTHON" == "3.5" ]] && [[ "$platform" == "macosx" ]]; then
bash miniconda.sh -b -p $HOME/miniconda
export PATH="$HOME/miniconda/bin:$PATH"
pip install -q cython==0.29.0 tensorflow gym opencv-python-headless pyyaml pandas==0.23.4 requests \
- feather-format lxml openpyxl xlrd py-spy setproctitle pytest-timeout flaky networkx
+ feather-format lxml openpyxl xlrd py-spy setproctitle pytest-timeout flaky networkx tabulate
elif [[ "$LINT" == "1" ]]; then
sudo apt-get update
sudo apt-get install -y build-essential curl unzip
diff --git a/doc/source/tune-usage.rst b/doc/source/tune-usage.rst
index c69640a64..82ae5a33b 100644
--- a/doc/source/tune-usage.rst
+++ b/doc/source/tune-usage.rst
@@ -499,6 +499,50 @@ And stopping a trial (``PUT /trials/:id``):
curl -X PUT http://
:/trials/
+Tune CLI (Experimental)
+-----------------------
+
+``tune`` has an easy-to-use command line interface (CLI) to manage and monitor your experiments on Ray. To do this, verify that you have the ``tabulate`` library installed:
+
+.. code-block:: bash
+
+ $ pip install tabulate
+
+Here are a few examples of command line calls.
+
+- ``tune list-trials``: List tabular information about trials within an experiment. Add the ``--sort`` flag to sort the output by specific columns.
+
+.. code-block:: bash
+
+ $ tune list-trials [EXPERIMENT_DIR]
+
+ +------------------+-----------------------+------------+
+ | trainable_name | experiment_tag | trial_id |
+ |------------------+-----------------------+------------|
+ | MyTrainableClass | 0_height=40,width=37 | 87b54a1d |
+ | MyTrainableClass | 1_height=21,width=70 | 23b89036 |
+ | MyTrainableClass | 2_height=99,width=90 | 518dbe95 |
+ | MyTrainableClass | 3_height=54,width=21 | 7b99a28a |
+ | MyTrainableClass | 4_height=90,width=69 | ae4e02fb |
+ +------------------+-----------------------+------------+
+ Dropped columns: ['status', 'last_update_time']
+
+- ``tune list-experiments``: List tabular information about experiments within a project. Add the ``--sort`` flag to sort the output by specific columns.
+
+.. code-block:: bash
+
+ $ tune list-experiments [PROJECT_DIR]
+
+ +----------------------+----------------+------------------+---------------------+
+ | name | total_trials | running_trials | terminated_trials |
+ |----------------------+----------------+------------------+---------------------|
+ | pbt_test | 10 | 0 | 0 |
+ | test | 1 | 0 | 0 |
+ | hyperband_test | 1 | 0 | 1 |
+ +----------------------+----------------+------------------+---------------------+
+ Dropped columns: ['error_trials', 'last_updated']
+
+
Further Questions or Issues?
----------------------------
diff --git a/python/ray/tune/commands.py b/python/ray/tune/commands.py
new file mode 100644
index 000000000..c9d458e43
--- /dev/null
+++ b/python/ray/tune/commands.py
@@ -0,0 +1,232 @@
+from __future__ import absolute_import
+from __future__ import division
+from __future__ import print_function
+
+import glob
+import json
+import logging
+import os
+import sys
+import subprocess
+from datetime import datetime
+
+import pandas as pd
+from ray.tune.util import flatten_dict
+from ray.tune.result import TRAINING_ITERATION, MEAN_ACCURACY, MEAN_LOSS
+from ray.tune.trial import Trial
+try:
+ from tabulate import tabulate
+except ImportError:
+ tabulate = None
+
+logger = logging.getLogger(__name__)
+
+TIMESTAMP_FORMAT = "%Y-%m-%d %H:%M:%S (%A)"
+
+DEFAULT_EXPERIMENT_INFO_KEYS = (
+ "trainable_name",
+ "experiment_tag",
+ "trial_id",
+ "status",
+ "last_update_time",
+)
+
+DEFAULT_RESULT_KEYS = (TRAINING_ITERATION, MEAN_ACCURACY, MEAN_LOSS)
+
+DEFAULT_PROJECT_INFO_KEYS = (
+ "name",
+ "total_trials",
+ "running_trials",
+ "terminated_trials",
+ "error_trials",
+ "last_updated",
+)
+
+try:
+ TERM_HEIGHT, TERM_WIDTH = subprocess.check_output(['stty', 'size']).split()
+ TERM_HEIGHT, TERM_WIDTH = int(TERM_HEIGHT), int(TERM_WIDTH)
+except subprocess.CalledProcessError:
+ TERM_HEIGHT, TERM_WIDTH = 100, 100
+
+
+def _check_tabulate():
+ """Checks whether tabulate is installed."""
+ if tabulate is None:
+ raise ImportError(
+ "Tabulate not installed. Please run `pip install tabulate`.")
+
+
+def print_format_output(dataframe):
+ """Prints output of given dataframe to fit into terminal.
+
+ Returns:
+ table (pd.DataFrame): Final outputted dataframe.
+ dropped_cols (list): Columns dropped due to terminal size.
+ empty_cols (list): Empty columns (dropped on default).
+ """
+ print_df = pd.DataFrame()
+ dropped_cols = []
+ empty_cols = []
+ # column display priority is based on the info_keys passed in
+ for i, col in enumerate(dataframe):
+ if dataframe[col].isnull().all():
+ # Don't add col to print_df if is fully empty
+ empty_cols += [col]
+ continue
+
+ print_df[col] = dataframe[col]
+ test_table = tabulate(print_df, headers="keys", tablefmt="psql")
+ if str(test_table).index('\n') > TERM_WIDTH:
+ # Drop all columns beyond terminal width
+ print_df.drop(col, axis=1, inplace=True)
+ dropped_cols += list(dataframe.columns)[i:]
+ break
+
+ table = tabulate(
+ print_df, headers="keys", tablefmt="psql", showindex="never")
+
+ print(table)
+ if dropped_cols:
+ print("Dropped columns:", dropped_cols)
+ print("Please increase your terminal size to view remaining columns.")
+ if empty_cols:
+ print("Empty columns:", empty_cols)
+
+ return table, dropped_cols, empty_cols
+
+
+def _get_experiment_state(experiment_path, exit_on_fail=False):
+ experiment_path = os.path.expanduser(experiment_path)
+ experiment_state_paths = glob.glob(
+ os.path.join(experiment_path, "experiment_state*.json"))
+ if not experiment_state_paths:
+ if exit_on_fail:
+ print("No experiment state found!")
+ sys.exit(0)
+ else:
+ return
+ experiment_filename = max(list(experiment_state_paths))
+
+ with open(experiment_filename) as f:
+ experiment_state = json.load(f)
+ return experiment_state
+
+
+def list_trials(experiment_path,
+ sort=None,
+ info_keys=DEFAULT_EXPERIMENT_INFO_KEYS,
+ result_keys=DEFAULT_RESULT_KEYS):
+ """Lists trials in the directory subtree starting at the given path.
+
+ Args:
+ experiment_path (str): Directory where trials are located.
+ Corresponds to Experiment.local_dir/Experiment.name.
+ sort (str): Key to sort by.
+ info_keys (list): Keys that are displayed.
+ result_keys (list): Keys of last result that are displayed.
+ """
+ _check_tabulate()
+ experiment_state = _get_experiment_state(
+ experiment_path, exit_on_fail=True)
+
+ checkpoint_dicts = experiment_state["checkpoints"]
+ checkpoint_dicts = [flatten_dict(g) for g in checkpoint_dicts]
+ checkpoints_df = pd.DataFrame(checkpoint_dicts)
+
+ result_keys = ["last_result:{}".format(k) for k in result_keys]
+ col_keys = [
+ k for k in list(info_keys) + result_keys if k in checkpoints_df
+ ]
+ checkpoints_df = checkpoints_df[col_keys]
+
+ if "last_update_time" in checkpoints_df:
+ with pd.option_context('mode.use_inf_as_null', True):
+ datetime_series = checkpoints_df["last_update_time"].dropna()
+
+ datetime_series = datetime_series.apply(
+ lambda t: datetime.fromtimestamp(t).strftime(TIMESTAMP_FORMAT))
+ checkpoints_df["last_update_time"] = datetime_series
+
+ if "logdir" in checkpoints_df:
+ # logdir often too verbose to view in table, so drop experiment_path
+ checkpoints_df["logdir"] = checkpoints_df["logdir"].str.replace(
+ experiment_path, '')
+
+ if sort:
+ if sort not in checkpoints_df:
+ raise KeyError("Sort Index '{}' not in: {}".format(
+ sort, list(checkpoints_df)))
+ checkpoints_df = checkpoints_df.sort_values(by=sort)
+
+ print_format_output(checkpoints_df)
+
+
+def list_experiments(project_path,
+ sort=None,
+ info_keys=DEFAULT_PROJECT_INFO_KEYS):
+ """Lists experiments in the directory subtree.
+
+ Args:
+ project_path (str): Directory where experiments are located.
+ Corresponds to Experiment.local_dir.
+ sort (str): Key to sort by.
+ info_keys (list): Keys that are displayed.
+ """
+ _check_tabulate()
+ base, experiment_folders, _ = next(os.walk(project_path))
+
+ experiment_data_collection = []
+
+ for experiment_dir in experiment_folders:
+ experiment_state = _get_experiment_state(
+ os.path.join(base, experiment_dir))
+ if not experiment_state:
+ logger.debug("No experiment state found in %s", experiment_dir)
+ continue
+
+ checkpoints = pd.DataFrame(experiment_state["checkpoints"])
+ runner_data = experiment_state["runner_data"]
+
+ # Format time-based values.
+ time_values = {
+ "start_time": runner_data.get("_start_time"),
+ "last_updated": experiment_state.get("timestamp"),
+ }
+
+ formatted_time_values = {
+ key: datetime.fromtimestamp(val).strftime(TIMESTAMP_FORMAT)
+ if val else None
+ for key, val in time_values.items()
+ }
+
+ experiment_data = {
+ "name": experiment_dir,
+ "total_trials": checkpoints.shape[0],
+ "running_trials": (checkpoints["status"] == Trial.RUNNING).sum(),
+ "terminated_trials": (
+ checkpoints["status"] == Trial.TERMINATED).sum(),
+ "error_trials": (checkpoints["status"] == Trial.ERROR).sum(),
+ }
+ experiment_data.update(formatted_time_values)
+ experiment_data_collection.append(experiment_data)
+
+ if not experiment_data_collection:
+ print("No experiments found!")
+ sys.exit(0)
+
+ info_df = pd.DataFrame(experiment_data_collection)
+ col_keys = [k for k in list(info_keys) if k in info_df]
+
+ if not col_keys:
+ print("None of keys {} in experiment data!".format(info_keys))
+ sys.exit(0)
+
+ info_df = info_df[col_keys]
+
+ if sort:
+ if sort not in info_df:
+ raise KeyError("Sort Index '{}' not in: {}".format(
+ sort, list(info_df)))
+ info_df = info_df.sort_values(by=sort)
+
+ print_format_output(info_df)
diff --git a/python/ray/tune/examples/mnist_pytorch.py b/python/ray/tune/examples/mnist_pytorch.py
index df4072eaf..ee23297d5 100644
--- a/python/ray/tune/examples/mnist_pytorch.py
+++ b/python/ray/tune/examples/mnist_pytorch.py
@@ -26,9 +26,9 @@ parser.add_argument(
parser.add_argument(
'--epochs',
type=int,
- default=10,
+ default=1,
metavar='N',
- help='number of epochs to train (default: 10)')
+ help='number of epochs to train (default: 1)')
parser.add_argument(
'--lr',
type=float,
diff --git a/python/ray/tune/examples/mnist_pytorch_trainable.py b/python/ray/tune/examples/mnist_pytorch_trainable.py
index d22beddee..b4856c462 100644
--- a/python/ray/tune/examples/mnist_pytorch_trainable.py
+++ b/python/ray/tune/examples/mnist_pytorch_trainable.py
@@ -29,9 +29,9 @@ parser.add_argument(
parser.add_argument(
'--epochs',
type=int,
- default=10,
+ default=1,
metavar='N',
- help='number of epochs to train (default: 10)')
+ help='number of epochs to train (default: 1)')
parser.add_argument(
'--lr',
type=float,
diff --git a/python/ray/tune/logger.py b/python/ray/tune/logger.py
index 80da6698c..2aa30ceb5 100644
--- a/python/ray/tune/logger.py
+++ b/python/ray/tune/logger.py
@@ -17,15 +17,8 @@ from ray.tune.result import NODE_IP, TRAINING_ITERATION, TIME_TOTAL_S, \
logger = logging.getLogger(__name__)
-try:
- import tensorflow as tf
- use_tf150_api = (distutils.version.LooseVersion(tf.VERSION) >=
- distutils.version.LooseVersion("1.5.0"))
-except ImportError:
- tf = None
- use_tf150_api = True
- logger.warning("Couldn't import TensorFlow - "
- "disabling TensorBoard logging.")
+tf = None
+use_tf150_api = True
class Logger(object):
@@ -121,6 +114,15 @@ def to_tf_values(result, path):
class TFLogger(Logger):
def _init(self):
+ try:
+ global tf, use_tf150_api
+ import tensorflow
+ tf = tensorflow
+ use_tf150_api = (distutils.version.LooseVersion(tf.VERSION) >=
+ distutils.version.LooseVersion("1.5.0"))
+ except ImportError:
+ logger.warning("Couldn't import TensorFlow - "
+ "disabling TensorBoard logging.")
self._file_writer = tf.summary.FileWriter(self.logdir)
def on_result(self, result):
diff --git a/python/ray/tune/result.py b/python/ray/tune/result.py
index 0d5aeb0d0..47b536186 100644
--- a/python/ray/tune/result.py
+++ b/python/ray/tune/result.py
@@ -18,6 +18,15 @@ NODE_IP = "node_ip"
# (Auto-filled) The pid of the training process.
PID = "pid"
+# (Optional) Mean reward for current training iteration
+EPISODE_REWARD_MEAN = "episode_reward_mean"
+
+# (Optional) Mean loss for training iteration
+MEAN_LOSS = "mean_loss"
+
+# (Optional) Mean accuracy for training iteration
+MEAN_ACCURACY = "mean_accuracy"
+
# Number of episodes in this iteration.
EPISODES_THIS_ITER = "episodes_this_iter"
diff --git a/python/ray/tune/scripts.py b/python/ray/tune/scripts.py
new file mode 100644
index 000000000..f815d9ed8
--- /dev/null
+++ b/python/ray/tune/scripts.py
@@ -0,0 +1,43 @@
+from __future__ import absolute_import
+from __future__ import division
+from __future__ import print_function
+
+import click
+import ray.tune.commands as commands
+
+
+@click.group()
+def cli():
+ pass
+
+
+@cli.command()
+@click.argument("experiment_path", required=True, type=str)
+@click.option(
+ '--sort', default=None, type=str, help='Select which column to sort on.')
+def list_trials(experiment_path, sort):
+ """Lists trials in the directory subtree starting at the given path."""
+ commands.list_trials(experiment_path, sort)
+
+
+@cli.command()
+@click.argument("project_path", required=True, type=str)
+@click.option(
+ '--sort', default=None, type=str, help='Select which column to sort on.')
+def list_experiments(project_path, sort):
+ """Lists experiments in the directory subtree."""
+ commands.list_experiments(project_path, sort)
+
+
+cli.add_command(list_trials, name="ls")
+cli.add_command(list_trials, name="list-trials")
+cli.add_command(list_experiments, name="lsx")
+cli.add_command(list_experiments, name="list-experiments")
+
+
+def main():
+ return cli()
+
+
+if __name__ == "__main__":
+ main()
diff --git a/python/ray/tune/suggest/bayesopt.py b/python/ray/tune/suggest/bayesopt.py
index 48f6406aa..089f1b26d 100644
--- a/python/ray/tune/suggest/bayesopt.py
+++ b/python/ray/tune/suggest/bayesopt.py
@@ -4,13 +4,16 @@ from __future__ import print_function
import copy
-try:
- import bayes_opt as byo
-except Exception:
- byo = None
-
from ray.tune.suggest.suggestion import SuggestionAlgorithm
+byo = None
+
+
+def _import_bayesopt():
+ global byo
+ import bayes_opt
+ byo = bayes_opt
+
class BayesOptSearch(SuggestionAlgorithm):
"""A wrapper around BayesOpt to provide trial suggestions.
@@ -56,6 +59,7 @@ class BayesOptSearch(SuggestionAlgorithm):
random_state=1,
verbose=0,
**kwargs):
+ _import_bayesopt()
assert byo is not None, (
"BayesOpt must be installed!. You can install BayesOpt with"
" the command: `pip install bayesian-optimization`.")
diff --git a/python/ray/tune/suggest/hyperopt.py b/python/ray/tune/suggest/hyperopt.py
index 2c3256250..62795cc6c 100644
--- a/python/ray/tune/suggest/hyperopt.py
+++ b/python/ray/tune/suggest/hyperopt.py
@@ -6,17 +6,19 @@ import numpy as np
import copy
import logging
-try:
- hyperopt_logger = logging.getLogger("hyperopt")
- hyperopt_logger.setLevel(logging.WARNING)
- import hyperopt as hpo
- from hyperopt.fmin import generate_trials_to_calculate
-except Exception:
- hpo = None
-
from ray.tune.error import TuneError
from ray.tune.suggest.suggestion import SuggestionAlgorithm
+hpo = None
+
+
+def _import_hyperopt():
+ global hpo
+ hyperopt_logger = logging.getLogger("hyperopt")
+ hyperopt_logger.setLevel(logging.WARNING)
+ import hyperopt
+ hpo = hyperopt
+
class HyperOptSearch(SuggestionAlgorithm):
"""A wrapper around HyperOpt to provide trial suggestions.
@@ -73,7 +75,9 @@ class HyperOptSearch(SuggestionAlgorithm):
reward_attr="episode_reward_mean",
points_to_evaluate=None,
**kwargs):
+ _import_hyperopt()
assert hpo is not None, "HyperOpt must be installed!"
+ from hyperopt.fmin import generate_trials_to_calculate
assert type(max_concurrent) is int and max_concurrent > 0
self._max_concurrent = max_concurrent
self._reward_attr = reward_attr
diff --git a/python/ray/tune/tests/test_commands.py b/python/ray/tune/tests/test_commands.py
new file mode 100644
index 000000000..174f356d5
--- /dev/null
+++ b/python/ray/tune/tests/test_commands.py
@@ -0,0 +1,66 @@
+from __future__ import absolute_import
+from __future__ import division
+from __future__ import print_function
+
+import os
+import pytest
+
+import ray
+from ray import tune
+from ray.rllib import _register_all
+from ray.tune import commands
+
+
+@pytest.fixture
+def start_ray():
+ ray.init()
+ _register_all()
+ yield
+ ray.shutdown()
+
+
+def test_ls(start_ray, capsys, tmpdir):
+ """This test captures output of list_trials."""
+ experiment_name = "test_ls"
+ experiment_path = os.path.join(str(tmpdir), experiment_name)
+ num_samples = 2
+ with capsys.disabled():
+ tune.run_experiments({
+ experiment_name: {
+ "run": "__fake",
+ "stop": {
+ "training_iteration": 1
+ },
+ "num_samples": num_samples,
+ "local_dir": str(tmpdir)
+ }
+ })
+
+ commands.list_trials(experiment_path, info_keys=("status", ))
+ captured = capsys.readouterr().out.strip()
+ lines = captured.split("\n")
+ assert sum("TERMINATED" in line for line in lines) == num_samples
+
+
+def test_lsx(start_ray, capsys, tmpdir):
+ """This test captures output of list_experiments."""
+ project_path = str(tmpdir)
+ num_experiments = 3
+ for i in range(num_experiments):
+ experiment_name = "test_lsx{}".format(i)
+ with capsys.disabled():
+ tune.run_experiments({
+ experiment_name: {
+ "run": "__fake",
+ "stop": {
+ "training_iteration": 1
+ },
+ "num_samples": 1,
+ "local_dir": project_path
+ }
+ })
+
+ commands.list_experiments(project_path, info_keys=("total_trials", ))
+ captured = capsys.readouterr().out.strip()
+ lines = captured.split("\n")
+ assert sum("1" in line for line in lines) >= 3
diff --git a/python/ray/tune/tests/test_ray_trial_executor.py b/python/ray/tune/tests/test_ray_trial_executor.py
index ee5a98a87..0341dd487 100644
--- a/python/ray/tune/tests/test_ray_trial_executor.py
+++ b/python/ray/tune/tests/test_ray_trial_executor.py
@@ -18,6 +18,7 @@ class RayTrialExecutorTest(unittest.TestCase):
def setUp(self):
self.trial_executor = RayTrialExecutor(queue_trials=False)
ray.init()
+ _register_all() # Needed for flaky tests
def tearDown(self):
ray.shutdown()
diff --git a/python/ray/tune/tests/test_trial_scheduler.py b/python/ray/tune/tests/test_trial_scheduler.py
index b5426bb3d..0f68e39f9 100644
--- a/python/ray/tune/tests/test_trial_scheduler.py
+++ b/python/ray/tune/tests/test_trial_scheduler.py
@@ -583,7 +583,7 @@ class _MockTrial(Trial):
self.logger_running = False
self.restored_checkpoint = None
self.resources = Resources(1, 0)
- self.trial_name = None
+ self.custom_trial_name = None
class PopulationBasedTestingSuite(unittest.TestCase):
diff --git a/python/ray/tune/trial.py b/python/ray/tune/trial.py
index be85eeb77..fb324fcb7 100644
--- a/python/ray/tune/trial.py
+++ b/python/ray/tune/trial.py
@@ -25,7 +25,8 @@ from ray.tune.logger import pretty_print, UnifiedLogger
# have been defined yet. See https://github.com/ray-project/ray/issues/1716.
import ray.tune.registry
from ray.tune.result import (DEFAULT_RESULTS_DIR, DONE, HOSTNAME, PID,
- TIME_TOTAL_S, TRAINING_ITERATION, TIMESTEPS_TOTAL)
+ TIME_TOTAL_S, TRAINING_ITERATION, TIMESTEPS_TOTAL,
+ EPISODE_REWARD_MEAN, MEAN_LOSS, MEAN_ACCURACY)
from ray.utils import _random_string, binary_to_hex, hex_to_binary
DEBUG_PRINT_INTERVAL = 5
@@ -299,6 +300,8 @@ class Trial(object):
self.error_file = None
self.num_failures = 0
+ self.custom_trial_name = None
+
# AutoML fields
self.results = None
self.best_result = None
@@ -316,10 +319,8 @@ class Trial(object):
"param_config",
"extra_arg",
]
-
- self.trial_name = None
if trial_name_creator:
- self.trial_name = trial_name_creator(self)
+ self.custom_trial_name = trial_name_creator(self)
@classmethod
def _registration_check(cls, trainable_name):
@@ -447,17 +448,17 @@ class Trial(object):
if self.last_result.get(TIMESTEPS_TOTAL) is not None:
pieces.append('{} ts'.format(self.last_result[TIMESTEPS_TOTAL]))
- if self.last_result.get("episode_reward_mean") is not None:
+ if self.last_result.get(EPISODE_REWARD_MEAN) is not None:
pieces.append('{} rew'.format(
- format(self.last_result["episode_reward_mean"], '.3g')))
+ format(self.last_result[EPISODE_REWARD_MEAN], '.3g')))
- if self.last_result.get("mean_loss") is not None:
+ if self.last_result.get(MEAN_LOSS) is not None:
pieces.append('{} loss'.format(
- format(self.last_result["mean_loss"], '.3g')))
+ format(self.last_result[MEAN_LOSS], '.3g')))
- if self.last_result.get("mean_accuracy") is not None:
+ if self.last_result.get(MEAN_ACCURACY) is not None:
pieces.append('{} acc'.format(
- format(self.last_result["mean_accuracy"], '.3g')))
+ format(self.last_result[MEAN_ACCURACY], '.3g')))
return ', '.join(pieces)
@@ -514,8 +515,8 @@ class Trial(object):
Can be overriden with a custom string creator.
"""
- if self.trial_name:
- return self.trial_name
+ if self.custom_trial_name:
+ return self.custom_trial_name
if "env" in self.config:
env = self.config["env"]
@@ -544,8 +545,6 @@ class Trial(object):
state["runner"] = None
state["result_logger"] = None
- if self.status == Trial.RUNNING:
- state["status"] = Trial.PENDING
if self.result_logger:
self.result_logger.flush()
state["__logger_started__"] = True
@@ -556,6 +555,8 @@ class Trial(object):
def __setstate__(self, state):
logger_started = state.pop("__logger_started__")
state["resources"] = json_to_resources(state["resources"])
+ if state["status"] == Trial.RUNNING:
+ state["status"] = Trial.PENDING
for key in self._nonjson_fields:
state[key] = cloudpickle.loads(hex_to_binary(state[key]))
diff --git a/python/ray/tune/trial_runner.py b/python/ray/tune/trial_runner.py
index afce6863a..96dfaa5de 100644
--- a/python/ray/tune/trial_runner.py
+++ b/python/ray/tune/trial_runner.py
@@ -112,7 +112,9 @@ class TrialRunner(object):
self._stop_queue = []
self._metadata_checkpoint_dir = metadata_checkpoint_dir
- self._session = datetime.today().strftime("%Y-%m-%d_%H-%M-%S")
+ self._start_time = time.time()
+ self._session_str = datetime.fromtimestamp(
+ self._start_time).strftime("%Y-%m-%d_%H-%M-%S")
@classmethod
def checkpoint_exists(cls, directory):
@@ -136,7 +138,8 @@ class TrialRunner(object):
runner_state = {
"checkpoints": list(
self.trial_executor.get_checkpoints().values()),
- "runner_data": self.__getstate__()
+ "runner_data": self.__getstate__(),
+ "timestamp": time.time()
}
tmp_file_name = os.path.join(metadata_checkpoint_dir,
".tmp_checkpoint")
@@ -146,7 +149,7 @@ class TrialRunner(object):
os.rename(
tmp_file_name,
os.path.join(metadata_checkpoint_dir,
- TrialRunner.CKPT_FILE_TMPL.format(self._session)))
+ TrialRunner.CKPT_FILE_TMPL.format(self._session_str)))
return metadata_checkpoint_dir
@classmethod
@@ -558,8 +561,12 @@ class TrialRunner(object):
"""
state = self.__dict__.copy()
for k in [
- "_trials", "_stop_queue", "_server", "_search_alg",
- "_scheduler_alg", "trial_executor", "_session"
+ "_trials",
+ "_stop_queue",
+ "_server",
+ "_search_alg",
+ "_scheduler_alg",
+ "trial_executor",
]:
del state[k]
state["launch_web_server"] = bool(self._server)
@@ -567,6 +574,14 @@ class TrialRunner(object):
def __setstate__(self, state):
launch_web_server = state.pop("launch_web_server")
+
+ # Use session_str from previous checkpoint if does not exist
+ session_str = state.pop("_session_str")
+ self.__dict__.setdefault("_session_str", session_str)
+ # Use start_time from previous checkpoint if does not exist
+ start_time = state.pop("_start_time")
+ self.__dict__.setdefault("_start_time", start_time)
+
self.__dict__.update(state)
if launch_web_server:
self._server = TuneServer(self, self._server_port)
diff --git a/python/ray/tune/util.py b/python/ray/tune/util.py
index ce4047f2e..75ac57ef1 100644
--- a/python/ray/tune/util.py
+++ b/python/ray/tune/util.py
@@ -61,7 +61,7 @@ def deep_update(original, new_dict, new_keys_allowed, whitelist):
if k not in original:
if not new_keys_allowed:
raise Exception("Unknown config parameter `{}` ".format(k))
- if type(original.get(k)) is dict:
+ if isinstance(original.get(k), dict):
if k in whitelist:
deep_update(original[k], value, True, [])
else:
@@ -71,6 +71,21 @@ def deep_update(original, new_dict, new_keys_allowed, whitelist):
return original
+def flatten_dict(dt):
+ while any(isinstance(v, dict) for v in dt.values()):
+ remove = []
+ add = {}
+ for key, value in dt.items():
+ if isinstance(value, dict):
+ for subkey, v in value.items():
+ add[":".join([key, subkey])] = v
+ remove.append(key)
+ dt.update(add)
+ for k in remove:
+ del dt[k]
+ return dt
+
+
def _to_pinnable(obj):
"""Converts obj to a form that can be pinned in object store memory.
diff --git a/python/ray/tune/visual_utils.py b/python/ray/tune/visual_utils.py
index 9273a9154..4a68bcec9 100644
--- a/python/ray/tune/visual_utils.py
+++ b/python/ray/tune/visual_utils.py
@@ -10,24 +10,11 @@ import os.path as osp
import numpy as np
import json
+from ray.tune.util import flatten_dict
+
logger = logging.getLogger(__name__)
-def _flatten_dict(dt):
- while any(type(v) is dict for v in dt.values()):
- remove = []
- add = {}
- for key, value in dt.items():
- if type(value) is dict:
- for subkey, v in value.items():
- add[":".join([key, subkey])] = v
- remove.append(key)
- dt.update(add)
- for k in remove:
- del dt[k]
- return dt
-
-
def _parse_results(res_path):
res_dict = {}
try:
@@ -35,7 +22,7 @@ def _parse_results(res_path):
# Get last line in file
for line in f:
pass
- res_dict = _flatten_dict(json.loads(line.strip()))
+ res_dict = flatten_dict(json.loads(line.strip()))
except Exception:
logger.exception("Importing %s failed...Perhaps empty?" % res_path)
return res_dict
@@ -44,7 +31,7 @@ def _parse_results(res_path):
def _parse_configs(cfg_path):
try:
with open(cfg_path) as f:
- cfg_dict = _flatten_dict(json.load(f))
+ cfg_dict = flatten_dict(json.load(f))
except Exception:
logger.exception("Config parsing failed.")
return cfg_dict
diff --git a/python/setup.py b/python/setup.py
index 921517207..07b9b3b92 100644
--- a/python/setup.py
+++ b/python/setup.py
@@ -175,7 +175,7 @@ setup(
entry_points={
"console_scripts": [
"ray=ray.scripts.scripts:main",
- "rllib=ray.rllib.scripts:cli [rllib]"
+ "rllib=ray.rllib.scripts:cli [rllib]", "tune=ray.tune.scripts:cli"
]
},
include_package_data=True,