diff --git a/ci/travis/install-dependencies.sh b/ci/travis/install-dependencies.sh index 29d1b2650..5f92eb43a 100755 --- a/ci/travis/install-dependencies.sh +++ b/ci/travis/install-dependencies.sh @@ -25,7 +25,7 @@ if [[ "$PYTHON" == "2.7" ]] && [[ "$platform" == "linux" ]]; then bash miniconda.sh -b -p $HOME/miniconda export PATH="$HOME/miniconda/bin:$PATH" pip install -q scipy tensorflow cython==0.29.0 gym opencv-python-headless pyyaml pandas==0.23.4 requests \ - feather-format lxml openpyxl xlrd py-spy setproctitle faulthandler pytest-timeout mock flaky networkx + feather-format lxml openpyxl xlrd py-spy setproctitle faulthandler pytest-timeout mock flaky networkx tabulate elif [[ "$PYTHON" == "3.5" ]] && [[ "$platform" == "linux" ]]; then sudo apt-get update sudo apt-get install -y python-dev python-numpy build-essential curl unzip tmux gdb @@ -34,7 +34,7 @@ elif [[ "$PYTHON" == "3.5" ]] && [[ "$platform" == "linux" ]]; then bash miniconda.sh -b -p $HOME/miniconda export PATH="$HOME/miniconda/bin:$PATH" pip install -q scipy tensorflow cython==0.29.0 gym opencv-python-headless pyyaml pandas==0.23.4 requests \ - feather-format lxml openpyxl xlrd py-spy setproctitle pytest-timeout flaky networkx + feather-format lxml openpyxl xlrd py-spy setproctitle pytest-timeout flaky networkx tabulate elif [[ "$PYTHON" == "2.7" ]] && [[ "$platform" == "macosx" ]]; then # check that brew is installed which -s brew @@ -50,7 +50,7 @@ elif [[ "$PYTHON" == "2.7" ]] && [[ "$platform" == "macosx" ]]; then bash miniconda.sh -b -p $HOME/miniconda export PATH="$HOME/miniconda/bin:$PATH" pip install -q cython==0.29.0 tensorflow gym opencv-python-headless pyyaml pandas==0.23.4 requests \ - feather-format lxml openpyxl xlrd py-spy setproctitle faulthandler pytest-timeout mock flaky networkx + feather-format lxml openpyxl xlrd py-spy setproctitle faulthandler pytest-timeout mock flaky networkx tabulate elif [[ "$PYTHON" == "3.5" ]] && [[ "$platform" == "macosx" ]]; then # check that brew is installed which -s brew @@ -66,7 +66,7 @@ elif [[ "$PYTHON" == "3.5" ]] && [[ "$platform" == "macosx" ]]; then bash miniconda.sh -b -p $HOME/miniconda export PATH="$HOME/miniconda/bin:$PATH" pip install -q cython==0.29.0 tensorflow gym opencv-python-headless pyyaml pandas==0.23.4 requests \ - feather-format lxml openpyxl xlrd py-spy setproctitle pytest-timeout flaky networkx + feather-format lxml openpyxl xlrd py-spy setproctitle pytest-timeout flaky networkx tabulate elif [[ "$LINT" == "1" ]]; then sudo apt-get update sudo apt-get install -y build-essential curl unzip diff --git a/doc/source/tune-usage.rst b/doc/source/tune-usage.rst index c69640a64..82ae5a33b 100644 --- a/doc/source/tune-usage.rst +++ b/doc/source/tune-usage.rst @@ -499,6 +499,50 @@ And stopping a trial (``PUT /trials/:id``): curl -X PUT http://
:/trials/ +Tune CLI (Experimental) +----------------------- + +``tune`` has an easy-to-use command line interface (CLI) to manage and monitor your experiments on Ray. To do this, verify that you have the ``tabulate`` library installed: + +.. code-block:: bash + + $ pip install tabulate + +Here are a few examples of command line calls. + +- ``tune list-trials``: List tabular information about trials within an experiment. Add the ``--sort`` flag to sort the output by specific columns. + +.. code-block:: bash + + $ tune list-trials [EXPERIMENT_DIR] + + +------------------+-----------------------+------------+ + | trainable_name | experiment_tag | trial_id | + |------------------+-----------------------+------------| + | MyTrainableClass | 0_height=40,width=37 | 87b54a1d | + | MyTrainableClass | 1_height=21,width=70 | 23b89036 | + | MyTrainableClass | 2_height=99,width=90 | 518dbe95 | + | MyTrainableClass | 3_height=54,width=21 | 7b99a28a | + | MyTrainableClass | 4_height=90,width=69 | ae4e02fb | + +------------------+-----------------------+------------+ + Dropped columns: ['status', 'last_update_time'] + +- ``tune list-experiments``: List tabular information about experiments within a project. Add the ``--sort`` flag to sort the output by specific columns. + +.. code-block:: bash + + $ tune list-experiments [PROJECT_DIR] + + +----------------------+----------------+------------------+---------------------+ + | name | total_trials | running_trials | terminated_trials | + |----------------------+----------------+------------------+---------------------| + | pbt_test | 10 | 0 | 0 | + | test | 1 | 0 | 0 | + | hyperband_test | 1 | 0 | 1 | + +----------------------+----------------+------------------+---------------------+ + Dropped columns: ['error_trials', 'last_updated'] + + Further Questions or Issues? ---------------------------- diff --git a/python/ray/tune/commands.py b/python/ray/tune/commands.py new file mode 100644 index 000000000..c9d458e43 --- /dev/null +++ b/python/ray/tune/commands.py @@ -0,0 +1,232 @@ +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +import glob +import json +import logging +import os +import sys +import subprocess +from datetime import datetime + +import pandas as pd +from ray.tune.util import flatten_dict +from ray.tune.result import TRAINING_ITERATION, MEAN_ACCURACY, MEAN_LOSS +from ray.tune.trial import Trial +try: + from tabulate import tabulate +except ImportError: + tabulate = None + +logger = logging.getLogger(__name__) + +TIMESTAMP_FORMAT = "%Y-%m-%d %H:%M:%S (%A)" + +DEFAULT_EXPERIMENT_INFO_KEYS = ( + "trainable_name", + "experiment_tag", + "trial_id", + "status", + "last_update_time", +) + +DEFAULT_RESULT_KEYS = (TRAINING_ITERATION, MEAN_ACCURACY, MEAN_LOSS) + +DEFAULT_PROJECT_INFO_KEYS = ( + "name", + "total_trials", + "running_trials", + "terminated_trials", + "error_trials", + "last_updated", +) + +try: + TERM_HEIGHT, TERM_WIDTH = subprocess.check_output(['stty', 'size']).split() + TERM_HEIGHT, TERM_WIDTH = int(TERM_HEIGHT), int(TERM_WIDTH) +except subprocess.CalledProcessError: + TERM_HEIGHT, TERM_WIDTH = 100, 100 + + +def _check_tabulate(): + """Checks whether tabulate is installed.""" + if tabulate is None: + raise ImportError( + "Tabulate not installed. Please run `pip install tabulate`.") + + +def print_format_output(dataframe): + """Prints output of given dataframe to fit into terminal. + + Returns: + table (pd.DataFrame): Final outputted dataframe. + dropped_cols (list): Columns dropped due to terminal size. + empty_cols (list): Empty columns (dropped on default). + """ + print_df = pd.DataFrame() + dropped_cols = [] + empty_cols = [] + # column display priority is based on the info_keys passed in + for i, col in enumerate(dataframe): + if dataframe[col].isnull().all(): + # Don't add col to print_df if is fully empty + empty_cols += [col] + continue + + print_df[col] = dataframe[col] + test_table = tabulate(print_df, headers="keys", tablefmt="psql") + if str(test_table).index('\n') > TERM_WIDTH: + # Drop all columns beyond terminal width + print_df.drop(col, axis=1, inplace=True) + dropped_cols += list(dataframe.columns)[i:] + break + + table = tabulate( + print_df, headers="keys", tablefmt="psql", showindex="never") + + print(table) + if dropped_cols: + print("Dropped columns:", dropped_cols) + print("Please increase your terminal size to view remaining columns.") + if empty_cols: + print("Empty columns:", empty_cols) + + return table, dropped_cols, empty_cols + + +def _get_experiment_state(experiment_path, exit_on_fail=False): + experiment_path = os.path.expanduser(experiment_path) + experiment_state_paths = glob.glob( + os.path.join(experiment_path, "experiment_state*.json")) + if not experiment_state_paths: + if exit_on_fail: + print("No experiment state found!") + sys.exit(0) + else: + return + experiment_filename = max(list(experiment_state_paths)) + + with open(experiment_filename) as f: + experiment_state = json.load(f) + return experiment_state + + +def list_trials(experiment_path, + sort=None, + info_keys=DEFAULT_EXPERIMENT_INFO_KEYS, + result_keys=DEFAULT_RESULT_KEYS): + """Lists trials in the directory subtree starting at the given path. + + Args: + experiment_path (str): Directory where trials are located. + Corresponds to Experiment.local_dir/Experiment.name. + sort (str): Key to sort by. + info_keys (list): Keys that are displayed. + result_keys (list): Keys of last result that are displayed. + """ + _check_tabulate() + experiment_state = _get_experiment_state( + experiment_path, exit_on_fail=True) + + checkpoint_dicts = experiment_state["checkpoints"] + checkpoint_dicts = [flatten_dict(g) for g in checkpoint_dicts] + checkpoints_df = pd.DataFrame(checkpoint_dicts) + + result_keys = ["last_result:{}".format(k) for k in result_keys] + col_keys = [ + k for k in list(info_keys) + result_keys if k in checkpoints_df + ] + checkpoints_df = checkpoints_df[col_keys] + + if "last_update_time" in checkpoints_df: + with pd.option_context('mode.use_inf_as_null', True): + datetime_series = checkpoints_df["last_update_time"].dropna() + + datetime_series = datetime_series.apply( + lambda t: datetime.fromtimestamp(t).strftime(TIMESTAMP_FORMAT)) + checkpoints_df["last_update_time"] = datetime_series + + if "logdir" in checkpoints_df: + # logdir often too verbose to view in table, so drop experiment_path + checkpoints_df["logdir"] = checkpoints_df["logdir"].str.replace( + experiment_path, '') + + if sort: + if sort not in checkpoints_df: + raise KeyError("Sort Index '{}' not in: {}".format( + sort, list(checkpoints_df))) + checkpoints_df = checkpoints_df.sort_values(by=sort) + + print_format_output(checkpoints_df) + + +def list_experiments(project_path, + sort=None, + info_keys=DEFAULT_PROJECT_INFO_KEYS): + """Lists experiments in the directory subtree. + + Args: + project_path (str): Directory where experiments are located. + Corresponds to Experiment.local_dir. + sort (str): Key to sort by. + info_keys (list): Keys that are displayed. + """ + _check_tabulate() + base, experiment_folders, _ = next(os.walk(project_path)) + + experiment_data_collection = [] + + for experiment_dir in experiment_folders: + experiment_state = _get_experiment_state( + os.path.join(base, experiment_dir)) + if not experiment_state: + logger.debug("No experiment state found in %s", experiment_dir) + continue + + checkpoints = pd.DataFrame(experiment_state["checkpoints"]) + runner_data = experiment_state["runner_data"] + + # Format time-based values. + time_values = { + "start_time": runner_data.get("_start_time"), + "last_updated": experiment_state.get("timestamp"), + } + + formatted_time_values = { + key: datetime.fromtimestamp(val).strftime(TIMESTAMP_FORMAT) + if val else None + for key, val in time_values.items() + } + + experiment_data = { + "name": experiment_dir, + "total_trials": checkpoints.shape[0], + "running_trials": (checkpoints["status"] == Trial.RUNNING).sum(), + "terminated_trials": ( + checkpoints["status"] == Trial.TERMINATED).sum(), + "error_trials": (checkpoints["status"] == Trial.ERROR).sum(), + } + experiment_data.update(formatted_time_values) + experiment_data_collection.append(experiment_data) + + if not experiment_data_collection: + print("No experiments found!") + sys.exit(0) + + info_df = pd.DataFrame(experiment_data_collection) + col_keys = [k for k in list(info_keys) if k in info_df] + + if not col_keys: + print("None of keys {} in experiment data!".format(info_keys)) + sys.exit(0) + + info_df = info_df[col_keys] + + if sort: + if sort not in info_df: + raise KeyError("Sort Index '{}' not in: {}".format( + sort, list(info_df))) + info_df = info_df.sort_values(by=sort) + + print_format_output(info_df) diff --git a/python/ray/tune/examples/mnist_pytorch.py b/python/ray/tune/examples/mnist_pytorch.py index df4072eaf..ee23297d5 100644 --- a/python/ray/tune/examples/mnist_pytorch.py +++ b/python/ray/tune/examples/mnist_pytorch.py @@ -26,9 +26,9 @@ parser.add_argument( parser.add_argument( '--epochs', type=int, - default=10, + default=1, metavar='N', - help='number of epochs to train (default: 10)') + help='number of epochs to train (default: 1)') parser.add_argument( '--lr', type=float, diff --git a/python/ray/tune/examples/mnist_pytorch_trainable.py b/python/ray/tune/examples/mnist_pytorch_trainable.py index d22beddee..b4856c462 100644 --- a/python/ray/tune/examples/mnist_pytorch_trainable.py +++ b/python/ray/tune/examples/mnist_pytorch_trainable.py @@ -29,9 +29,9 @@ parser.add_argument( parser.add_argument( '--epochs', type=int, - default=10, + default=1, metavar='N', - help='number of epochs to train (default: 10)') + help='number of epochs to train (default: 1)') parser.add_argument( '--lr', type=float, diff --git a/python/ray/tune/logger.py b/python/ray/tune/logger.py index 80da6698c..2aa30ceb5 100644 --- a/python/ray/tune/logger.py +++ b/python/ray/tune/logger.py @@ -17,15 +17,8 @@ from ray.tune.result import NODE_IP, TRAINING_ITERATION, TIME_TOTAL_S, \ logger = logging.getLogger(__name__) -try: - import tensorflow as tf - use_tf150_api = (distutils.version.LooseVersion(tf.VERSION) >= - distutils.version.LooseVersion("1.5.0")) -except ImportError: - tf = None - use_tf150_api = True - logger.warning("Couldn't import TensorFlow - " - "disabling TensorBoard logging.") +tf = None +use_tf150_api = True class Logger(object): @@ -121,6 +114,15 @@ def to_tf_values(result, path): class TFLogger(Logger): def _init(self): + try: + global tf, use_tf150_api + import tensorflow + tf = tensorflow + use_tf150_api = (distutils.version.LooseVersion(tf.VERSION) >= + distutils.version.LooseVersion("1.5.0")) + except ImportError: + logger.warning("Couldn't import TensorFlow - " + "disabling TensorBoard logging.") self._file_writer = tf.summary.FileWriter(self.logdir) def on_result(self, result): diff --git a/python/ray/tune/result.py b/python/ray/tune/result.py index 0d5aeb0d0..47b536186 100644 --- a/python/ray/tune/result.py +++ b/python/ray/tune/result.py @@ -18,6 +18,15 @@ NODE_IP = "node_ip" # (Auto-filled) The pid of the training process. PID = "pid" +# (Optional) Mean reward for current training iteration +EPISODE_REWARD_MEAN = "episode_reward_mean" + +# (Optional) Mean loss for training iteration +MEAN_LOSS = "mean_loss" + +# (Optional) Mean accuracy for training iteration +MEAN_ACCURACY = "mean_accuracy" + # Number of episodes in this iteration. EPISODES_THIS_ITER = "episodes_this_iter" diff --git a/python/ray/tune/scripts.py b/python/ray/tune/scripts.py new file mode 100644 index 000000000..f815d9ed8 --- /dev/null +++ b/python/ray/tune/scripts.py @@ -0,0 +1,43 @@ +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +import click +import ray.tune.commands as commands + + +@click.group() +def cli(): + pass + + +@cli.command() +@click.argument("experiment_path", required=True, type=str) +@click.option( + '--sort', default=None, type=str, help='Select which column to sort on.') +def list_trials(experiment_path, sort): + """Lists trials in the directory subtree starting at the given path.""" + commands.list_trials(experiment_path, sort) + + +@cli.command() +@click.argument("project_path", required=True, type=str) +@click.option( + '--sort', default=None, type=str, help='Select which column to sort on.') +def list_experiments(project_path, sort): + """Lists experiments in the directory subtree.""" + commands.list_experiments(project_path, sort) + + +cli.add_command(list_trials, name="ls") +cli.add_command(list_trials, name="list-trials") +cli.add_command(list_experiments, name="lsx") +cli.add_command(list_experiments, name="list-experiments") + + +def main(): + return cli() + + +if __name__ == "__main__": + main() diff --git a/python/ray/tune/suggest/bayesopt.py b/python/ray/tune/suggest/bayesopt.py index 48f6406aa..089f1b26d 100644 --- a/python/ray/tune/suggest/bayesopt.py +++ b/python/ray/tune/suggest/bayesopt.py @@ -4,13 +4,16 @@ from __future__ import print_function import copy -try: - import bayes_opt as byo -except Exception: - byo = None - from ray.tune.suggest.suggestion import SuggestionAlgorithm +byo = None + + +def _import_bayesopt(): + global byo + import bayes_opt + byo = bayes_opt + class BayesOptSearch(SuggestionAlgorithm): """A wrapper around BayesOpt to provide trial suggestions. @@ -56,6 +59,7 @@ class BayesOptSearch(SuggestionAlgorithm): random_state=1, verbose=0, **kwargs): + _import_bayesopt() assert byo is not None, ( "BayesOpt must be installed!. You can install BayesOpt with" " the command: `pip install bayesian-optimization`.") diff --git a/python/ray/tune/suggest/hyperopt.py b/python/ray/tune/suggest/hyperopt.py index 2c3256250..62795cc6c 100644 --- a/python/ray/tune/suggest/hyperopt.py +++ b/python/ray/tune/suggest/hyperopt.py @@ -6,17 +6,19 @@ import numpy as np import copy import logging -try: - hyperopt_logger = logging.getLogger("hyperopt") - hyperopt_logger.setLevel(logging.WARNING) - import hyperopt as hpo - from hyperopt.fmin import generate_trials_to_calculate -except Exception: - hpo = None - from ray.tune.error import TuneError from ray.tune.suggest.suggestion import SuggestionAlgorithm +hpo = None + + +def _import_hyperopt(): + global hpo + hyperopt_logger = logging.getLogger("hyperopt") + hyperopt_logger.setLevel(logging.WARNING) + import hyperopt + hpo = hyperopt + class HyperOptSearch(SuggestionAlgorithm): """A wrapper around HyperOpt to provide trial suggestions. @@ -73,7 +75,9 @@ class HyperOptSearch(SuggestionAlgorithm): reward_attr="episode_reward_mean", points_to_evaluate=None, **kwargs): + _import_hyperopt() assert hpo is not None, "HyperOpt must be installed!" + from hyperopt.fmin import generate_trials_to_calculate assert type(max_concurrent) is int and max_concurrent > 0 self._max_concurrent = max_concurrent self._reward_attr = reward_attr diff --git a/python/ray/tune/tests/test_commands.py b/python/ray/tune/tests/test_commands.py new file mode 100644 index 000000000..174f356d5 --- /dev/null +++ b/python/ray/tune/tests/test_commands.py @@ -0,0 +1,66 @@ +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +import os +import pytest + +import ray +from ray import tune +from ray.rllib import _register_all +from ray.tune import commands + + +@pytest.fixture +def start_ray(): + ray.init() + _register_all() + yield + ray.shutdown() + + +def test_ls(start_ray, capsys, tmpdir): + """This test captures output of list_trials.""" + experiment_name = "test_ls" + experiment_path = os.path.join(str(tmpdir), experiment_name) + num_samples = 2 + with capsys.disabled(): + tune.run_experiments({ + experiment_name: { + "run": "__fake", + "stop": { + "training_iteration": 1 + }, + "num_samples": num_samples, + "local_dir": str(tmpdir) + } + }) + + commands.list_trials(experiment_path, info_keys=("status", )) + captured = capsys.readouterr().out.strip() + lines = captured.split("\n") + assert sum("TERMINATED" in line for line in lines) == num_samples + + +def test_lsx(start_ray, capsys, tmpdir): + """This test captures output of list_experiments.""" + project_path = str(tmpdir) + num_experiments = 3 + for i in range(num_experiments): + experiment_name = "test_lsx{}".format(i) + with capsys.disabled(): + tune.run_experiments({ + experiment_name: { + "run": "__fake", + "stop": { + "training_iteration": 1 + }, + "num_samples": 1, + "local_dir": project_path + } + }) + + commands.list_experiments(project_path, info_keys=("total_trials", )) + captured = capsys.readouterr().out.strip() + lines = captured.split("\n") + assert sum("1" in line for line in lines) >= 3 diff --git a/python/ray/tune/tests/test_ray_trial_executor.py b/python/ray/tune/tests/test_ray_trial_executor.py index ee5a98a87..0341dd487 100644 --- a/python/ray/tune/tests/test_ray_trial_executor.py +++ b/python/ray/tune/tests/test_ray_trial_executor.py @@ -18,6 +18,7 @@ class RayTrialExecutorTest(unittest.TestCase): def setUp(self): self.trial_executor = RayTrialExecutor(queue_trials=False) ray.init() + _register_all() # Needed for flaky tests def tearDown(self): ray.shutdown() diff --git a/python/ray/tune/tests/test_trial_scheduler.py b/python/ray/tune/tests/test_trial_scheduler.py index b5426bb3d..0f68e39f9 100644 --- a/python/ray/tune/tests/test_trial_scheduler.py +++ b/python/ray/tune/tests/test_trial_scheduler.py @@ -583,7 +583,7 @@ class _MockTrial(Trial): self.logger_running = False self.restored_checkpoint = None self.resources = Resources(1, 0) - self.trial_name = None + self.custom_trial_name = None class PopulationBasedTestingSuite(unittest.TestCase): diff --git a/python/ray/tune/trial.py b/python/ray/tune/trial.py index be85eeb77..fb324fcb7 100644 --- a/python/ray/tune/trial.py +++ b/python/ray/tune/trial.py @@ -25,7 +25,8 @@ from ray.tune.logger import pretty_print, UnifiedLogger # have been defined yet. See https://github.com/ray-project/ray/issues/1716. import ray.tune.registry from ray.tune.result import (DEFAULT_RESULTS_DIR, DONE, HOSTNAME, PID, - TIME_TOTAL_S, TRAINING_ITERATION, TIMESTEPS_TOTAL) + TIME_TOTAL_S, TRAINING_ITERATION, TIMESTEPS_TOTAL, + EPISODE_REWARD_MEAN, MEAN_LOSS, MEAN_ACCURACY) from ray.utils import _random_string, binary_to_hex, hex_to_binary DEBUG_PRINT_INTERVAL = 5 @@ -299,6 +300,8 @@ class Trial(object): self.error_file = None self.num_failures = 0 + self.custom_trial_name = None + # AutoML fields self.results = None self.best_result = None @@ -316,10 +319,8 @@ class Trial(object): "param_config", "extra_arg", ] - - self.trial_name = None if trial_name_creator: - self.trial_name = trial_name_creator(self) + self.custom_trial_name = trial_name_creator(self) @classmethod def _registration_check(cls, trainable_name): @@ -447,17 +448,17 @@ class Trial(object): if self.last_result.get(TIMESTEPS_TOTAL) is not None: pieces.append('{} ts'.format(self.last_result[TIMESTEPS_TOTAL])) - if self.last_result.get("episode_reward_mean") is not None: + if self.last_result.get(EPISODE_REWARD_MEAN) is not None: pieces.append('{} rew'.format( - format(self.last_result["episode_reward_mean"], '.3g'))) + format(self.last_result[EPISODE_REWARD_MEAN], '.3g'))) - if self.last_result.get("mean_loss") is not None: + if self.last_result.get(MEAN_LOSS) is not None: pieces.append('{} loss'.format( - format(self.last_result["mean_loss"], '.3g'))) + format(self.last_result[MEAN_LOSS], '.3g'))) - if self.last_result.get("mean_accuracy") is not None: + if self.last_result.get(MEAN_ACCURACY) is not None: pieces.append('{} acc'.format( - format(self.last_result["mean_accuracy"], '.3g'))) + format(self.last_result[MEAN_ACCURACY], '.3g'))) return ', '.join(pieces) @@ -514,8 +515,8 @@ class Trial(object): Can be overriden with a custom string creator. """ - if self.trial_name: - return self.trial_name + if self.custom_trial_name: + return self.custom_trial_name if "env" in self.config: env = self.config["env"] @@ -544,8 +545,6 @@ class Trial(object): state["runner"] = None state["result_logger"] = None - if self.status == Trial.RUNNING: - state["status"] = Trial.PENDING if self.result_logger: self.result_logger.flush() state["__logger_started__"] = True @@ -556,6 +555,8 @@ class Trial(object): def __setstate__(self, state): logger_started = state.pop("__logger_started__") state["resources"] = json_to_resources(state["resources"]) + if state["status"] == Trial.RUNNING: + state["status"] = Trial.PENDING for key in self._nonjson_fields: state[key] = cloudpickle.loads(hex_to_binary(state[key])) diff --git a/python/ray/tune/trial_runner.py b/python/ray/tune/trial_runner.py index afce6863a..96dfaa5de 100644 --- a/python/ray/tune/trial_runner.py +++ b/python/ray/tune/trial_runner.py @@ -112,7 +112,9 @@ class TrialRunner(object): self._stop_queue = [] self._metadata_checkpoint_dir = metadata_checkpoint_dir - self._session = datetime.today().strftime("%Y-%m-%d_%H-%M-%S") + self._start_time = time.time() + self._session_str = datetime.fromtimestamp( + self._start_time).strftime("%Y-%m-%d_%H-%M-%S") @classmethod def checkpoint_exists(cls, directory): @@ -136,7 +138,8 @@ class TrialRunner(object): runner_state = { "checkpoints": list( self.trial_executor.get_checkpoints().values()), - "runner_data": self.__getstate__() + "runner_data": self.__getstate__(), + "timestamp": time.time() } tmp_file_name = os.path.join(metadata_checkpoint_dir, ".tmp_checkpoint") @@ -146,7 +149,7 @@ class TrialRunner(object): os.rename( tmp_file_name, os.path.join(metadata_checkpoint_dir, - TrialRunner.CKPT_FILE_TMPL.format(self._session))) + TrialRunner.CKPT_FILE_TMPL.format(self._session_str))) return metadata_checkpoint_dir @classmethod @@ -558,8 +561,12 @@ class TrialRunner(object): """ state = self.__dict__.copy() for k in [ - "_trials", "_stop_queue", "_server", "_search_alg", - "_scheduler_alg", "trial_executor", "_session" + "_trials", + "_stop_queue", + "_server", + "_search_alg", + "_scheduler_alg", + "trial_executor", ]: del state[k] state["launch_web_server"] = bool(self._server) @@ -567,6 +574,14 @@ class TrialRunner(object): def __setstate__(self, state): launch_web_server = state.pop("launch_web_server") + + # Use session_str from previous checkpoint if does not exist + session_str = state.pop("_session_str") + self.__dict__.setdefault("_session_str", session_str) + # Use start_time from previous checkpoint if does not exist + start_time = state.pop("_start_time") + self.__dict__.setdefault("_start_time", start_time) + self.__dict__.update(state) if launch_web_server: self._server = TuneServer(self, self._server_port) diff --git a/python/ray/tune/util.py b/python/ray/tune/util.py index ce4047f2e..75ac57ef1 100644 --- a/python/ray/tune/util.py +++ b/python/ray/tune/util.py @@ -61,7 +61,7 @@ def deep_update(original, new_dict, new_keys_allowed, whitelist): if k not in original: if not new_keys_allowed: raise Exception("Unknown config parameter `{}` ".format(k)) - if type(original.get(k)) is dict: + if isinstance(original.get(k), dict): if k in whitelist: deep_update(original[k], value, True, []) else: @@ -71,6 +71,21 @@ def deep_update(original, new_dict, new_keys_allowed, whitelist): return original +def flatten_dict(dt): + while any(isinstance(v, dict) for v in dt.values()): + remove = [] + add = {} + for key, value in dt.items(): + if isinstance(value, dict): + for subkey, v in value.items(): + add[":".join([key, subkey])] = v + remove.append(key) + dt.update(add) + for k in remove: + del dt[k] + return dt + + def _to_pinnable(obj): """Converts obj to a form that can be pinned in object store memory. diff --git a/python/ray/tune/visual_utils.py b/python/ray/tune/visual_utils.py index 9273a9154..4a68bcec9 100644 --- a/python/ray/tune/visual_utils.py +++ b/python/ray/tune/visual_utils.py @@ -10,24 +10,11 @@ import os.path as osp import numpy as np import json +from ray.tune.util import flatten_dict + logger = logging.getLogger(__name__) -def _flatten_dict(dt): - while any(type(v) is dict for v in dt.values()): - remove = [] - add = {} - for key, value in dt.items(): - if type(value) is dict: - for subkey, v in value.items(): - add[":".join([key, subkey])] = v - remove.append(key) - dt.update(add) - for k in remove: - del dt[k] - return dt - - def _parse_results(res_path): res_dict = {} try: @@ -35,7 +22,7 @@ def _parse_results(res_path): # Get last line in file for line in f: pass - res_dict = _flatten_dict(json.loads(line.strip())) + res_dict = flatten_dict(json.loads(line.strip())) except Exception: logger.exception("Importing %s failed...Perhaps empty?" % res_path) return res_dict @@ -44,7 +31,7 @@ def _parse_results(res_path): def _parse_configs(cfg_path): try: with open(cfg_path) as f: - cfg_dict = _flatten_dict(json.load(f)) + cfg_dict = flatten_dict(json.load(f)) except Exception: logger.exception("Config parsing failed.") return cfg_dict diff --git a/python/setup.py b/python/setup.py index 921517207..07b9b3b92 100644 --- a/python/setup.py +++ b/python/setup.py @@ -175,7 +175,7 @@ setup( entry_points={ "console_scripts": [ "ray=ray.scripts.scripts:main", - "rllib=ray.rllib.scripts:cli [rllib]" + "rllib=ray.rllib.scripts:cli [rllib]", "tune=ray.tune.scripts:cli" ] }, include_package_data=True,