mirror of
https://github.com/wassname/ray.git
synced 2026-07-02 02:42:52 +08:00
[tune] Initial Commit for Tune CLI (#3983)
This introduces a light CLI for Tune.
This commit is contained in:
@@ -0,0 +1,232 @@
|
||||
from __future__ import absolute_import
|
||||
from __future__ import division
|
||||
from __future__ import print_function
|
||||
|
||||
import glob
|
||||
import json
|
||||
import logging
|
||||
import os
|
||||
import sys
|
||||
import subprocess
|
||||
from datetime import datetime
|
||||
|
||||
import pandas as pd
|
||||
from ray.tune.util import flatten_dict
|
||||
from ray.tune.result import TRAINING_ITERATION, MEAN_ACCURACY, MEAN_LOSS
|
||||
from ray.tune.trial import Trial
|
||||
try:
|
||||
from tabulate import tabulate
|
||||
except ImportError:
|
||||
tabulate = None
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
TIMESTAMP_FORMAT = "%Y-%m-%d %H:%M:%S (%A)"
|
||||
|
||||
DEFAULT_EXPERIMENT_INFO_KEYS = (
|
||||
"trainable_name",
|
||||
"experiment_tag",
|
||||
"trial_id",
|
||||
"status",
|
||||
"last_update_time",
|
||||
)
|
||||
|
||||
DEFAULT_RESULT_KEYS = (TRAINING_ITERATION, MEAN_ACCURACY, MEAN_LOSS)
|
||||
|
||||
DEFAULT_PROJECT_INFO_KEYS = (
|
||||
"name",
|
||||
"total_trials",
|
||||
"running_trials",
|
||||
"terminated_trials",
|
||||
"error_trials",
|
||||
"last_updated",
|
||||
)
|
||||
|
||||
try:
|
||||
TERM_HEIGHT, TERM_WIDTH = subprocess.check_output(['stty', 'size']).split()
|
||||
TERM_HEIGHT, TERM_WIDTH = int(TERM_HEIGHT), int(TERM_WIDTH)
|
||||
except subprocess.CalledProcessError:
|
||||
TERM_HEIGHT, TERM_WIDTH = 100, 100
|
||||
|
||||
|
||||
def _check_tabulate():
|
||||
"""Checks whether tabulate is installed."""
|
||||
if tabulate is None:
|
||||
raise ImportError(
|
||||
"Tabulate not installed. Please run `pip install tabulate`.")
|
||||
|
||||
|
||||
def print_format_output(dataframe):
|
||||
"""Prints output of given dataframe to fit into terminal.
|
||||
|
||||
Returns:
|
||||
table (pd.DataFrame): Final outputted dataframe.
|
||||
dropped_cols (list): Columns dropped due to terminal size.
|
||||
empty_cols (list): Empty columns (dropped on default).
|
||||
"""
|
||||
print_df = pd.DataFrame()
|
||||
dropped_cols = []
|
||||
empty_cols = []
|
||||
# column display priority is based on the info_keys passed in
|
||||
for i, col in enumerate(dataframe):
|
||||
if dataframe[col].isnull().all():
|
||||
# Don't add col to print_df if is fully empty
|
||||
empty_cols += [col]
|
||||
continue
|
||||
|
||||
print_df[col] = dataframe[col]
|
||||
test_table = tabulate(print_df, headers="keys", tablefmt="psql")
|
||||
if str(test_table).index('\n') > TERM_WIDTH:
|
||||
# Drop all columns beyond terminal width
|
||||
print_df.drop(col, axis=1, inplace=True)
|
||||
dropped_cols += list(dataframe.columns)[i:]
|
||||
break
|
||||
|
||||
table = tabulate(
|
||||
print_df, headers="keys", tablefmt="psql", showindex="never")
|
||||
|
||||
print(table)
|
||||
if dropped_cols:
|
||||
print("Dropped columns:", dropped_cols)
|
||||
print("Please increase your terminal size to view remaining columns.")
|
||||
if empty_cols:
|
||||
print("Empty columns:", empty_cols)
|
||||
|
||||
return table, dropped_cols, empty_cols
|
||||
|
||||
|
||||
def _get_experiment_state(experiment_path, exit_on_fail=False):
|
||||
experiment_path = os.path.expanduser(experiment_path)
|
||||
experiment_state_paths = glob.glob(
|
||||
os.path.join(experiment_path, "experiment_state*.json"))
|
||||
if not experiment_state_paths:
|
||||
if exit_on_fail:
|
||||
print("No experiment state found!")
|
||||
sys.exit(0)
|
||||
else:
|
||||
return
|
||||
experiment_filename = max(list(experiment_state_paths))
|
||||
|
||||
with open(experiment_filename) as f:
|
||||
experiment_state = json.load(f)
|
||||
return experiment_state
|
||||
|
||||
|
||||
def list_trials(experiment_path,
|
||||
sort=None,
|
||||
info_keys=DEFAULT_EXPERIMENT_INFO_KEYS,
|
||||
result_keys=DEFAULT_RESULT_KEYS):
|
||||
"""Lists trials in the directory subtree starting at the given path.
|
||||
|
||||
Args:
|
||||
experiment_path (str): Directory where trials are located.
|
||||
Corresponds to Experiment.local_dir/Experiment.name.
|
||||
sort (str): Key to sort by.
|
||||
info_keys (list): Keys that are displayed.
|
||||
result_keys (list): Keys of last result that are displayed.
|
||||
"""
|
||||
_check_tabulate()
|
||||
experiment_state = _get_experiment_state(
|
||||
experiment_path, exit_on_fail=True)
|
||||
|
||||
checkpoint_dicts = experiment_state["checkpoints"]
|
||||
checkpoint_dicts = [flatten_dict(g) for g in checkpoint_dicts]
|
||||
checkpoints_df = pd.DataFrame(checkpoint_dicts)
|
||||
|
||||
result_keys = ["last_result:{}".format(k) for k in result_keys]
|
||||
col_keys = [
|
||||
k for k in list(info_keys) + result_keys if k in checkpoints_df
|
||||
]
|
||||
checkpoints_df = checkpoints_df[col_keys]
|
||||
|
||||
if "last_update_time" in checkpoints_df:
|
||||
with pd.option_context('mode.use_inf_as_null', True):
|
||||
datetime_series = checkpoints_df["last_update_time"].dropna()
|
||||
|
||||
datetime_series = datetime_series.apply(
|
||||
lambda t: datetime.fromtimestamp(t).strftime(TIMESTAMP_FORMAT))
|
||||
checkpoints_df["last_update_time"] = datetime_series
|
||||
|
||||
if "logdir" in checkpoints_df:
|
||||
# logdir often too verbose to view in table, so drop experiment_path
|
||||
checkpoints_df["logdir"] = checkpoints_df["logdir"].str.replace(
|
||||
experiment_path, '')
|
||||
|
||||
if sort:
|
||||
if sort not in checkpoints_df:
|
||||
raise KeyError("Sort Index '{}' not in: {}".format(
|
||||
sort, list(checkpoints_df)))
|
||||
checkpoints_df = checkpoints_df.sort_values(by=sort)
|
||||
|
||||
print_format_output(checkpoints_df)
|
||||
|
||||
|
||||
def list_experiments(project_path,
|
||||
sort=None,
|
||||
info_keys=DEFAULT_PROJECT_INFO_KEYS):
|
||||
"""Lists experiments in the directory subtree.
|
||||
|
||||
Args:
|
||||
project_path (str): Directory where experiments are located.
|
||||
Corresponds to Experiment.local_dir.
|
||||
sort (str): Key to sort by.
|
||||
info_keys (list): Keys that are displayed.
|
||||
"""
|
||||
_check_tabulate()
|
||||
base, experiment_folders, _ = next(os.walk(project_path))
|
||||
|
||||
experiment_data_collection = []
|
||||
|
||||
for experiment_dir in experiment_folders:
|
||||
experiment_state = _get_experiment_state(
|
||||
os.path.join(base, experiment_dir))
|
||||
if not experiment_state:
|
||||
logger.debug("No experiment state found in %s", experiment_dir)
|
||||
continue
|
||||
|
||||
checkpoints = pd.DataFrame(experiment_state["checkpoints"])
|
||||
runner_data = experiment_state["runner_data"]
|
||||
|
||||
# Format time-based values.
|
||||
time_values = {
|
||||
"start_time": runner_data.get("_start_time"),
|
||||
"last_updated": experiment_state.get("timestamp"),
|
||||
}
|
||||
|
||||
formatted_time_values = {
|
||||
key: datetime.fromtimestamp(val).strftime(TIMESTAMP_FORMAT)
|
||||
if val else None
|
||||
for key, val in time_values.items()
|
||||
}
|
||||
|
||||
experiment_data = {
|
||||
"name": experiment_dir,
|
||||
"total_trials": checkpoints.shape[0],
|
||||
"running_trials": (checkpoints["status"] == Trial.RUNNING).sum(),
|
||||
"terminated_trials": (
|
||||
checkpoints["status"] == Trial.TERMINATED).sum(),
|
||||
"error_trials": (checkpoints["status"] == Trial.ERROR).sum(),
|
||||
}
|
||||
experiment_data.update(formatted_time_values)
|
||||
experiment_data_collection.append(experiment_data)
|
||||
|
||||
if not experiment_data_collection:
|
||||
print("No experiments found!")
|
||||
sys.exit(0)
|
||||
|
||||
info_df = pd.DataFrame(experiment_data_collection)
|
||||
col_keys = [k for k in list(info_keys) if k in info_df]
|
||||
|
||||
if not col_keys:
|
||||
print("None of keys {} in experiment data!".format(info_keys))
|
||||
sys.exit(0)
|
||||
|
||||
info_df = info_df[col_keys]
|
||||
|
||||
if sort:
|
||||
if sort not in info_df:
|
||||
raise KeyError("Sort Index '{}' not in: {}".format(
|
||||
sort, list(info_df)))
|
||||
info_df = info_df.sort_values(by=sort)
|
||||
|
||||
print_format_output(info_df)
|
||||
@@ -26,9 +26,9 @@ parser.add_argument(
|
||||
parser.add_argument(
|
||||
'--epochs',
|
||||
type=int,
|
||||
default=10,
|
||||
default=1,
|
||||
metavar='N',
|
||||
help='number of epochs to train (default: 10)')
|
||||
help='number of epochs to train (default: 1)')
|
||||
parser.add_argument(
|
||||
'--lr',
|
||||
type=float,
|
||||
|
||||
@@ -29,9 +29,9 @@ parser.add_argument(
|
||||
parser.add_argument(
|
||||
'--epochs',
|
||||
type=int,
|
||||
default=10,
|
||||
default=1,
|
||||
metavar='N',
|
||||
help='number of epochs to train (default: 10)')
|
||||
help='number of epochs to train (default: 1)')
|
||||
parser.add_argument(
|
||||
'--lr',
|
||||
type=float,
|
||||
|
||||
@@ -17,15 +17,8 @@ from ray.tune.result import NODE_IP, TRAINING_ITERATION, TIME_TOTAL_S, \
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
try:
|
||||
import tensorflow as tf
|
||||
use_tf150_api = (distutils.version.LooseVersion(tf.VERSION) >=
|
||||
distutils.version.LooseVersion("1.5.0"))
|
||||
except ImportError:
|
||||
tf = None
|
||||
use_tf150_api = True
|
||||
logger.warning("Couldn't import TensorFlow - "
|
||||
"disabling TensorBoard logging.")
|
||||
tf = None
|
||||
use_tf150_api = True
|
||||
|
||||
|
||||
class Logger(object):
|
||||
@@ -121,6 +114,15 @@ def to_tf_values(result, path):
|
||||
|
||||
class TFLogger(Logger):
|
||||
def _init(self):
|
||||
try:
|
||||
global tf, use_tf150_api
|
||||
import tensorflow
|
||||
tf = tensorflow
|
||||
use_tf150_api = (distutils.version.LooseVersion(tf.VERSION) >=
|
||||
distutils.version.LooseVersion("1.5.0"))
|
||||
except ImportError:
|
||||
logger.warning("Couldn't import TensorFlow - "
|
||||
"disabling TensorBoard logging.")
|
||||
self._file_writer = tf.summary.FileWriter(self.logdir)
|
||||
|
||||
def on_result(self, result):
|
||||
|
||||
@@ -18,6 +18,15 @@ NODE_IP = "node_ip"
|
||||
# (Auto-filled) The pid of the training process.
|
||||
PID = "pid"
|
||||
|
||||
# (Optional) Mean reward for current training iteration
|
||||
EPISODE_REWARD_MEAN = "episode_reward_mean"
|
||||
|
||||
# (Optional) Mean loss for training iteration
|
||||
MEAN_LOSS = "mean_loss"
|
||||
|
||||
# (Optional) Mean accuracy for training iteration
|
||||
MEAN_ACCURACY = "mean_accuracy"
|
||||
|
||||
# Number of episodes in this iteration.
|
||||
EPISODES_THIS_ITER = "episodes_this_iter"
|
||||
|
||||
|
||||
@@ -0,0 +1,43 @@
|
||||
from __future__ import absolute_import
|
||||
from __future__ import division
|
||||
from __future__ import print_function
|
||||
|
||||
import click
|
||||
import ray.tune.commands as commands
|
||||
|
||||
|
||||
@click.group()
|
||||
def cli():
|
||||
pass
|
||||
|
||||
|
||||
@cli.command()
|
||||
@click.argument("experiment_path", required=True, type=str)
|
||||
@click.option(
|
||||
'--sort', default=None, type=str, help='Select which column to sort on.')
|
||||
def list_trials(experiment_path, sort):
|
||||
"""Lists trials in the directory subtree starting at the given path."""
|
||||
commands.list_trials(experiment_path, sort)
|
||||
|
||||
|
||||
@cli.command()
|
||||
@click.argument("project_path", required=True, type=str)
|
||||
@click.option(
|
||||
'--sort', default=None, type=str, help='Select which column to sort on.')
|
||||
def list_experiments(project_path, sort):
|
||||
"""Lists experiments in the directory subtree."""
|
||||
commands.list_experiments(project_path, sort)
|
||||
|
||||
|
||||
cli.add_command(list_trials, name="ls")
|
||||
cli.add_command(list_trials, name="list-trials")
|
||||
cli.add_command(list_experiments, name="lsx")
|
||||
cli.add_command(list_experiments, name="list-experiments")
|
||||
|
||||
|
||||
def main():
|
||||
return cli()
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
||||
@@ -4,13 +4,16 @@ from __future__ import print_function
|
||||
|
||||
import copy
|
||||
|
||||
try:
|
||||
import bayes_opt as byo
|
||||
except Exception:
|
||||
byo = None
|
||||
|
||||
from ray.tune.suggest.suggestion import SuggestionAlgorithm
|
||||
|
||||
byo = None
|
||||
|
||||
|
||||
def _import_bayesopt():
|
||||
global byo
|
||||
import bayes_opt
|
||||
byo = bayes_opt
|
||||
|
||||
|
||||
class BayesOptSearch(SuggestionAlgorithm):
|
||||
"""A wrapper around BayesOpt to provide trial suggestions.
|
||||
@@ -56,6 +59,7 @@ class BayesOptSearch(SuggestionAlgorithm):
|
||||
random_state=1,
|
||||
verbose=0,
|
||||
**kwargs):
|
||||
_import_bayesopt()
|
||||
assert byo is not None, (
|
||||
"BayesOpt must be installed!. You can install BayesOpt with"
|
||||
" the command: `pip install bayesian-optimization`.")
|
||||
|
||||
@@ -6,17 +6,19 @@ import numpy as np
|
||||
import copy
|
||||
import logging
|
||||
|
||||
try:
|
||||
hyperopt_logger = logging.getLogger("hyperopt")
|
||||
hyperopt_logger.setLevel(logging.WARNING)
|
||||
import hyperopt as hpo
|
||||
from hyperopt.fmin import generate_trials_to_calculate
|
||||
except Exception:
|
||||
hpo = None
|
||||
|
||||
from ray.tune.error import TuneError
|
||||
from ray.tune.suggest.suggestion import SuggestionAlgorithm
|
||||
|
||||
hpo = None
|
||||
|
||||
|
||||
def _import_hyperopt():
|
||||
global hpo
|
||||
hyperopt_logger = logging.getLogger("hyperopt")
|
||||
hyperopt_logger.setLevel(logging.WARNING)
|
||||
import hyperopt
|
||||
hpo = hyperopt
|
||||
|
||||
|
||||
class HyperOptSearch(SuggestionAlgorithm):
|
||||
"""A wrapper around HyperOpt to provide trial suggestions.
|
||||
@@ -73,7 +75,9 @@ class HyperOptSearch(SuggestionAlgorithm):
|
||||
reward_attr="episode_reward_mean",
|
||||
points_to_evaluate=None,
|
||||
**kwargs):
|
||||
_import_hyperopt()
|
||||
assert hpo is not None, "HyperOpt must be installed!"
|
||||
from hyperopt.fmin import generate_trials_to_calculate
|
||||
assert type(max_concurrent) is int and max_concurrent > 0
|
||||
self._max_concurrent = max_concurrent
|
||||
self._reward_attr = reward_attr
|
||||
|
||||
@@ -0,0 +1,66 @@
|
||||
from __future__ import absolute_import
|
||||
from __future__ import division
|
||||
from __future__ import print_function
|
||||
|
||||
import os
|
||||
import pytest
|
||||
|
||||
import ray
|
||||
from ray import tune
|
||||
from ray.rllib import _register_all
|
||||
from ray.tune import commands
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def start_ray():
|
||||
ray.init()
|
||||
_register_all()
|
||||
yield
|
||||
ray.shutdown()
|
||||
|
||||
|
||||
def test_ls(start_ray, capsys, tmpdir):
|
||||
"""This test captures output of list_trials."""
|
||||
experiment_name = "test_ls"
|
||||
experiment_path = os.path.join(str(tmpdir), experiment_name)
|
||||
num_samples = 2
|
||||
with capsys.disabled():
|
||||
tune.run_experiments({
|
||||
experiment_name: {
|
||||
"run": "__fake",
|
||||
"stop": {
|
||||
"training_iteration": 1
|
||||
},
|
||||
"num_samples": num_samples,
|
||||
"local_dir": str(tmpdir)
|
||||
}
|
||||
})
|
||||
|
||||
commands.list_trials(experiment_path, info_keys=("status", ))
|
||||
captured = capsys.readouterr().out.strip()
|
||||
lines = captured.split("\n")
|
||||
assert sum("TERMINATED" in line for line in lines) == num_samples
|
||||
|
||||
|
||||
def test_lsx(start_ray, capsys, tmpdir):
|
||||
"""This test captures output of list_experiments."""
|
||||
project_path = str(tmpdir)
|
||||
num_experiments = 3
|
||||
for i in range(num_experiments):
|
||||
experiment_name = "test_lsx{}".format(i)
|
||||
with capsys.disabled():
|
||||
tune.run_experiments({
|
||||
experiment_name: {
|
||||
"run": "__fake",
|
||||
"stop": {
|
||||
"training_iteration": 1
|
||||
},
|
||||
"num_samples": 1,
|
||||
"local_dir": project_path
|
||||
}
|
||||
})
|
||||
|
||||
commands.list_experiments(project_path, info_keys=("total_trials", ))
|
||||
captured = capsys.readouterr().out.strip()
|
||||
lines = captured.split("\n")
|
||||
assert sum("1" in line for line in lines) >= 3
|
||||
@@ -18,6 +18,7 @@ class RayTrialExecutorTest(unittest.TestCase):
|
||||
def setUp(self):
|
||||
self.trial_executor = RayTrialExecutor(queue_trials=False)
|
||||
ray.init()
|
||||
_register_all() # Needed for flaky tests
|
||||
|
||||
def tearDown(self):
|
||||
ray.shutdown()
|
||||
|
||||
@@ -583,7 +583,7 @@ class _MockTrial(Trial):
|
||||
self.logger_running = False
|
||||
self.restored_checkpoint = None
|
||||
self.resources = Resources(1, 0)
|
||||
self.trial_name = None
|
||||
self.custom_trial_name = None
|
||||
|
||||
|
||||
class PopulationBasedTestingSuite(unittest.TestCase):
|
||||
|
||||
+15
-14
@@ -25,7 +25,8 @@ from ray.tune.logger import pretty_print, UnifiedLogger
|
||||
# have been defined yet. See https://github.com/ray-project/ray/issues/1716.
|
||||
import ray.tune.registry
|
||||
from ray.tune.result import (DEFAULT_RESULTS_DIR, DONE, HOSTNAME, PID,
|
||||
TIME_TOTAL_S, TRAINING_ITERATION, TIMESTEPS_TOTAL)
|
||||
TIME_TOTAL_S, TRAINING_ITERATION, TIMESTEPS_TOTAL,
|
||||
EPISODE_REWARD_MEAN, MEAN_LOSS, MEAN_ACCURACY)
|
||||
from ray.utils import _random_string, binary_to_hex, hex_to_binary
|
||||
|
||||
DEBUG_PRINT_INTERVAL = 5
|
||||
@@ -299,6 +300,8 @@ class Trial(object):
|
||||
self.error_file = None
|
||||
self.num_failures = 0
|
||||
|
||||
self.custom_trial_name = None
|
||||
|
||||
# AutoML fields
|
||||
self.results = None
|
||||
self.best_result = None
|
||||
@@ -316,10 +319,8 @@ class Trial(object):
|
||||
"param_config",
|
||||
"extra_arg",
|
||||
]
|
||||
|
||||
self.trial_name = None
|
||||
if trial_name_creator:
|
||||
self.trial_name = trial_name_creator(self)
|
||||
self.custom_trial_name = trial_name_creator(self)
|
||||
|
||||
@classmethod
|
||||
def _registration_check(cls, trainable_name):
|
||||
@@ -447,17 +448,17 @@ class Trial(object):
|
||||
if self.last_result.get(TIMESTEPS_TOTAL) is not None:
|
||||
pieces.append('{} ts'.format(self.last_result[TIMESTEPS_TOTAL]))
|
||||
|
||||
if self.last_result.get("episode_reward_mean") is not None:
|
||||
if self.last_result.get(EPISODE_REWARD_MEAN) is not None:
|
||||
pieces.append('{} rew'.format(
|
||||
format(self.last_result["episode_reward_mean"], '.3g')))
|
||||
format(self.last_result[EPISODE_REWARD_MEAN], '.3g')))
|
||||
|
||||
if self.last_result.get("mean_loss") is not None:
|
||||
if self.last_result.get(MEAN_LOSS) is not None:
|
||||
pieces.append('{} loss'.format(
|
||||
format(self.last_result["mean_loss"], '.3g')))
|
||||
format(self.last_result[MEAN_LOSS], '.3g')))
|
||||
|
||||
if self.last_result.get("mean_accuracy") is not None:
|
||||
if self.last_result.get(MEAN_ACCURACY) is not None:
|
||||
pieces.append('{} acc'.format(
|
||||
format(self.last_result["mean_accuracy"], '.3g')))
|
||||
format(self.last_result[MEAN_ACCURACY], '.3g')))
|
||||
|
||||
return ', '.join(pieces)
|
||||
|
||||
@@ -514,8 +515,8 @@ class Trial(object):
|
||||
|
||||
Can be overriden with a custom string creator.
|
||||
"""
|
||||
if self.trial_name:
|
||||
return self.trial_name
|
||||
if self.custom_trial_name:
|
||||
return self.custom_trial_name
|
||||
|
||||
if "env" in self.config:
|
||||
env = self.config["env"]
|
||||
@@ -544,8 +545,6 @@ class Trial(object):
|
||||
|
||||
state["runner"] = None
|
||||
state["result_logger"] = None
|
||||
if self.status == Trial.RUNNING:
|
||||
state["status"] = Trial.PENDING
|
||||
if self.result_logger:
|
||||
self.result_logger.flush()
|
||||
state["__logger_started__"] = True
|
||||
@@ -556,6 +555,8 @@ class Trial(object):
|
||||
def __setstate__(self, state):
|
||||
logger_started = state.pop("__logger_started__")
|
||||
state["resources"] = json_to_resources(state["resources"])
|
||||
if state["status"] == Trial.RUNNING:
|
||||
state["status"] = Trial.PENDING
|
||||
for key in self._nonjson_fields:
|
||||
state[key] = cloudpickle.loads(hex_to_binary(state[key]))
|
||||
|
||||
|
||||
@@ -112,7 +112,9 @@ class TrialRunner(object):
|
||||
self._stop_queue = []
|
||||
self._metadata_checkpoint_dir = metadata_checkpoint_dir
|
||||
|
||||
self._session = datetime.today().strftime("%Y-%m-%d_%H-%M-%S")
|
||||
self._start_time = time.time()
|
||||
self._session_str = datetime.fromtimestamp(
|
||||
self._start_time).strftime("%Y-%m-%d_%H-%M-%S")
|
||||
|
||||
@classmethod
|
||||
def checkpoint_exists(cls, directory):
|
||||
@@ -136,7 +138,8 @@ class TrialRunner(object):
|
||||
runner_state = {
|
||||
"checkpoints": list(
|
||||
self.trial_executor.get_checkpoints().values()),
|
||||
"runner_data": self.__getstate__()
|
||||
"runner_data": self.__getstate__(),
|
||||
"timestamp": time.time()
|
||||
}
|
||||
tmp_file_name = os.path.join(metadata_checkpoint_dir,
|
||||
".tmp_checkpoint")
|
||||
@@ -146,7 +149,7 @@ class TrialRunner(object):
|
||||
os.rename(
|
||||
tmp_file_name,
|
||||
os.path.join(metadata_checkpoint_dir,
|
||||
TrialRunner.CKPT_FILE_TMPL.format(self._session)))
|
||||
TrialRunner.CKPT_FILE_TMPL.format(self._session_str)))
|
||||
return metadata_checkpoint_dir
|
||||
|
||||
@classmethod
|
||||
@@ -558,8 +561,12 @@ class TrialRunner(object):
|
||||
"""
|
||||
state = self.__dict__.copy()
|
||||
for k in [
|
||||
"_trials", "_stop_queue", "_server", "_search_alg",
|
||||
"_scheduler_alg", "trial_executor", "_session"
|
||||
"_trials",
|
||||
"_stop_queue",
|
||||
"_server",
|
||||
"_search_alg",
|
||||
"_scheduler_alg",
|
||||
"trial_executor",
|
||||
]:
|
||||
del state[k]
|
||||
state["launch_web_server"] = bool(self._server)
|
||||
@@ -567,6 +574,14 @@ class TrialRunner(object):
|
||||
|
||||
def __setstate__(self, state):
|
||||
launch_web_server = state.pop("launch_web_server")
|
||||
|
||||
# Use session_str from previous checkpoint if does not exist
|
||||
session_str = state.pop("_session_str")
|
||||
self.__dict__.setdefault("_session_str", session_str)
|
||||
# Use start_time from previous checkpoint if does not exist
|
||||
start_time = state.pop("_start_time")
|
||||
self.__dict__.setdefault("_start_time", start_time)
|
||||
|
||||
self.__dict__.update(state)
|
||||
if launch_web_server:
|
||||
self._server = TuneServer(self, self._server_port)
|
||||
|
||||
+16
-1
@@ -61,7 +61,7 @@ def deep_update(original, new_dict, new_keys_allowed, whitelist):
|
||||
if k not in original:
|
||||
if not new_keys_allowed:
|
||||
raise Exception("Unknown config parameter `{}` ".format(k))
|
||||
if type(original.get(k)) is dict:
|
||||
if isinstance(original.get(k), dict):
|
||||
if k in whitelist:
|
||||
deep_update(original[k], value, True, [])
|
||||
else:
|
||||
@@ -71,6 +71,21 @@ def deep_update(original, new_dict, new_keys_allowed, whitelist):
|
||||
return original
|
||||
|
||||
|
||||
def flatten_dict(dt):
|
||||
while any(isinstance(v, dict) for v in dt.values()):
|
||||
remove = []
|
||||
add = {}
|
||||
for key, value in dt.items():
|
||||
if isinstance(value, dict):
|
||||
for subkey, v in value.items():
|
||||
add[":".join([key, subkey])] = v
|
||||
remove.append(key)
|
||||
dt.update(add)
|
||||
for k in remove:
|
||||
del dt[k]
|
||||
return dt
|
||||
|
||||
|
||||
def _to_pinnable(obj):
|
||||
"""Converts obj to a form that can be pinned in object store memory.
|
||||
|
||||
|
||||
@@ -10,24 +10,11 @@ import os.path as osp
|
||||
import numpy as np
|
||||
import json
|
||||
|
||||
from ray.tune.util import flatten_dict
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
def _flatten_dict(dt):
|
||||
while any(type(v) is dict for v in dt.values()):
|
||||
remove = []
|
||||
add = {}
|
||||
for key, value in dt.items():
|
||||
if type(value) is dict:
|
||||
for subkey, v in value.items():
|
||||
add[":".join([key, subkey])] = v
|
||||
remove.append(key)
|
||||
dt.update(add)
|
||||
for k in remove:
|
||||
del dt[k]
|
||||
return dt
|
||||
|
||||
|
||||
def _parse_results(res_path):
|
||||
res_dict = {}
|
||||
try:
|
||||
@@ -35,7 +22,7 @@ def _parse_results(res_path):
|
||||
# Get last line in file
|
||||
for line in f:
|
||||
pass
|
||||
res_dict = _flatten_dict(json.loads(line.strip()))
|
||||
res_dict = flatten_dict(json.loads(line.strip()))
|
||||
except Exception:
|
||||
logger.exception("Importing %s failed...Perhaps empty?" % res_path)
|
||||
return res_dict
|
||||
@@ -44,7 +31,7 @@ def _parse_results(res_path):
|
||||
def _parse_configs(cfg_path):
|
||||
try:
|
||||
with open(cfg_path) as f:
|
||||
cfg_dict = _flatten_dict(json.load(f))
|
||||
cfg_dict = flatten_dict(json.load(f))
|
||||
except Exception:
|
||||
logger.exception("Config parsing failed.")
|
||||
return cfg_dict
|
||||
|
||||
+1
-1
@@ -175,7 +175,7 @@ setup(
|
||||
entry_points={
|
||||
"console_scripts": [
|
||||
"ray=ray.scripts.scripts:main",
|
||||
"rllib=ray.rllib.scripts:cli [rllib]"
|
||||
"rllib=ray.rllib.scripts:cli [rllib]", "tune=ray.tune.scripts:cli"
|
||||
]
|
||||
},
|
||||
include_package_data=True,
|
||||
|
||||
Reference in New Issue
Block a user