[tune] Allow nested values in trial runner (#5346)

This commit is contained in:
Richard Liaw
2019-08-06 14:36:17 -07:00
committed by GitHub
parent e8d9cfc1f1
commit 094ec7adbc
10 changed files with 163 additions and 106 deletions
@@ -2,7 +2,6 @@ from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
import copy
import json
import logging
import os
@@ -13,30 +12,10 @@ except ImportError:
pd = None
from ray.tune.error import TuneError
from ray.tune.result import EXPR_PROGRESS_FILE, EXPR_PARAM_FILE
from ray.tune.util import flatten_dict
from ray.tune.result import EXPR_PROGRESS_FILE, EXPR_PARAM_FILE, CONFIG_PREFIX
logger = logging.getLogger(__name__)
UNNEST_KEYS = ("config", "last_result")
def unnest_checkpoints(checkpoints):
checkpoint_dicts = []
for g in checkpoints:
checkpoint = copy.deepcopy(g)
for key in UNNEST_KEYS:
if key not in checkpoint:
continue
try:
unnest_dict = flatten_dict(checkpoint.pop(key))
checkpoint.update(unnest_dict)
except Exception:
logger.debug("Failed to flatten dict.")
checkpoint = flatten_dict(checkpoint)
checkpoint_dicts.append(checkpoint)
return checkpoint_dicts
class Analysis(object):
"""Analyze all results from a directory of experiments."""
@@ -130,7 +109,7 @@ class Analysis(object):
config = json.load(f)
if prefix:
for k in list(config):
config["config/" + k] = config.pop(k)
config[CONFIG_PREFIX + k] = config.pop(k)
self._configs[path] = config
except Exception:
fail_count += 1
+39 -47
View File
@@ -2,10 +2,9 @@ from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
import click
import logging
import glob
import os
import sys
import subprocess
import operator
from datetime import datetime
@@ -13,7 +12,7 @@ from datetime import datetime
import pandas as pd
from pandas.api.types import is_string_dtype, is_numeric_dtype
from ray.tune.result import (TRAINING_ITERATION, MEAN_ACCURACY, MEAN_LOSS,
TIME_TOTAL_S, TRIAL_ID)
TIME_TOTAL_S, TRIAL_ID, CONFIG_PREFIX)
from ray.tune.analysis import Analysis
from ray.tune import TuneError
try:
@@ -34,9 +33,6 @@ DEFAULT_EXPERIMENT_INFO_KEYS = ("trainable_name", "experiment_tag",
DEFAULT_PROJECT_INFO_KEYS = (
"name",
"total_trials",
"running_trials",
"terminated_trials",
"error_trials",
"last_updated",
)
@@ -63,20 +59,6 @@ def _check_tabulate():
"Tabulate not installed. Please run `pip install tabulate`.")
def get_most_recent_state(experiment_path):
experiment_path = os.path.expanduser(experiment_path)
if not os.path.isdir(experiment_path):
raise TuneError("{} is not a valid directory.".format(experiment_path))
experiment_state_paths = glob.glob(
os.path.join(experiment_path, "experiment_state*.json"))
if not experiment_state_paths:
raise TuneError(
"No experiment state found in {}!".format(experiment_path))
experiment_filename = max(
list(experiment_state_paths)) # if more than one, pick latest
return experiment_filename
def print_format_output(dataframe):
"""Prints output of given dataframe to fit into terminal.
@@ -108,10 +90,11 @@ def print_format_output(dataframe):
print(table)
if dropped_cols:
print("Dropped columns:", dropped_cols)
print("Please increase your terminal size to view remaining columns.")
click.secho("Dropped columns: {}".format(dropped_cols), fg="yellow")
click.secho("Please increase your terminal size "
"to view remaining columns.")
if empty_cols:
print("Empty columns:", empty_cols)
click.secho("Empty columns: {}".format(empty_cols), fg="yellow")
return table, dropped_cols, empty_cols
@@ -141,17 +124,25 @@ def list_trials(experiment_path,
try:
checkpoints_df = Analysis(experiment_path).dataframe()
except TuneError:
print("No experiment state found!")
sys.exit(1)
raise click.ClickException("No trial data found!")
def key_filter(k):
return k in DEFAULT_EXPERIMENT_INFO_KEYS or k.startswith(CONFIG_PREFIX)
col_keys = [k for k in checkpoints_df.columns if key_filter(k)]
if info_keys:
for k in info_keys:
if k not in checkpoints_df.columns:
raise click.ClickException("Provided key invalid: {}. "
"Available keys: {}.".format(
k, checkpoints_df.columns))
col_keys = [k for k in checkpoints_df.columns if k in info_keys]
if not col_keys:
raise click.ClickException("No columns to output.")
if not info_keys:
info_keys = DEFAULT_EXPERIMENT_INFO_KEYS
col_keys = [
k for k in checkpoints_df.columns
if k in info_keys or k.startswith("config/")
]
checkpoints_df = checkpoints_df[col_keys]
if "last_update_time" in checkpoints_df:
with pd.option_context("mode.use_inf_as_null", True):
datetime_series = checkpoints_df["last_update_time"].dropna()
@@ -174,7 +165,7 @@ def list_trials(experiment_path,
val = str(val)
# TODO(Andrew): add support for datetime and boolean
else:
raise ValueError("Unsupported dtype for {}: {}".format(
raise click.ClickException("Unsupported dtype for {}: {}".format(
val, col_type))
op = OPERATORS[op]
filtered_index = op(checkpoints_df[col], val)
@@ -183,8 +174,8 @@ def list_trials(experiment_path,
if sort:
for key in sort:
if key not in checkpoints_df:
raise KeyError("{} not in: {}".format(key,
list(checkpoints_df)))
raise click.ClickException("{} not in: {}".format(
key, list(checkpoints_df)))
ascending = not desc
checkpoints_df = checkpoints_df.sort_values(
by=sort, ascending=ascending)
@@ -201,8 +192,9 @@ def list_trials(experiment_path,
elif file_extension == ".csv":
checkpoints_df.to_csv(output, index=False)
else:
raise ValueError("Unsupported filetype: {}".format(output))
print("Output saved at:", output)
raise click.ClickException(
"Unsupported filetype: {}".format(output))
click.secho("Output saved at {}".format(output), fg="green")
def list_experiments(project_path,
@@ -239,16 +231,15 @@ def list_experiments(project_path,
experiment_data_collection.append(experiment_data)
if not experiment_data_collection:
print("No experiments found!")
sys.exit(0)
raise click.ClickException("No experiments found!")
info_df = pd.DataFrame(experiment_data_collection)
if not info_keys:
info_keys = DEFAULT_PROJECT_INFO_KEYS
col_keys = [k for k in list(info_keys) if k in info_df]
if not col_keys:
print("None of keys {} in experiment data!".format(info_keys))
sys.exit(0)
raise click.ClickException(
"None of keys {} in experiment data!".format(info_keys))
info_df = info_df[col_keys]
if filter_op:
@@ -260,7 +251,7 @@ def list_experiments(project_path,
val = str(val)
# TODO(Andrew): add support for datetime and boolean
else:
raise ValueError("Unsupported dtype for {}: {}".format(
raise click.ClickException("Unsupported dtype for {}: {}".format(
val, col_type))
op = OPERATORS[op]
filtered_index = op(info_df[col], val)
@@ -269,7 +260,8 @@ def list_experiments(project_path,
if sort:
for key in sort:
if key not in info_df:
raise KeyError("{} not in: {}".format(key, list(info_df)))
raise click.ClickException("{} not in: {}".format(
key, list(info_df)))
ascending = not desc
info_df = info_df.sort_values(by=sort, ascending=ascending)
@@ -285,8 +277,9 @@ def list_experiments(project_path,
elif file_extension == ".csv":
info_df.to_csv(output, index=False)
else:
raise ValueError("Unsupported filetype: {}".format(output))
print("Output saved at:", output)
raise click.ClickException(
"Unsupported filetype: {}".format(output))
click.secho("Output saved at {}".format(output), fg="green")
def add_note(path, filename="note.txt"):
@@ -305,8 +298,7 @@ def add_note(path, filename="note.txt"):
try:
subprocess.call([EDITOR, filepath])
except Exception as exc:
logger.error("Editing note failed!")
raise exc
click.secho("Editing note failed: {}".format(str(exc)), fg="red")
if exists:
print("Note updated at:", filepath)
else:
+1 -1
View File
@@ -61,7 +61,7 @@ class RayTrialExecutor(TrialExecutor):
logger.info("Initializing Ray automatically."
"For cluster usage or custom Ray initialization, "
"call `ray.init(...)` before `tune.run`.")
ray.init()
ray.init(object_store_memory=int(1e8))
if ray.is_initialized():
self._update_avail_resources()
+3
View File
@@ -84,3 +84,6 @@ EXPR_PROGRESS_FILE = "progress.csv"
# File that stores results of the trial.
EXPR_RESULT_FILE = "result.json"
# Config prefix when using Analysis.
CONFIG_PREFIX = "config/"
+2
View File
@@ -134,6 +134,7 @@ class _MockSuggestionAlgorithm(SuggestionAlgorithm):
self.live_trials = {}
self.counter = {"result": 0, "complete": 0}
self.stall = False
self.results = []
super(_MockSuggestionAlgorithm, self).__init__(**kwargs)
def _suggest(self, trial_id):
@@ -144,6 +145,7 @@ class _MockSuggestionAlgorithm(SuggestionAlgorithm):
def on_trial_result(self, trial_id, result):
self.counter["result"] += 1
self.results += [result]
def on_trial_complete(self, trial_id, **kwargs):
self.counter["complete"] += 1
+34 -1
View File
@@ -2,6 +2,7 @@ from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
import click
import os
import pytest
import subprocess
@@ -16,6 +17,7 @@ import ray
from ray import tune
from ray.rllib import _register_all
from ray.tune import commands
from ray.tune.result import CONFIG_PREFIX
class Capturing():
@@ -33,7 +35,7 @@ class Capturing():
@pytest.fixture
def start_ray():
ray.init(log_to_driver=False)
ray.init(log_to_driver=False, local_mode=True)
_register_all()
yield
ray.shutdown()
@@ -80,6 +82,7 @@ def test_ls(start_ray, tmpdir):
with Capturing() as output:
commands.list_trials(experiment_path, info_keys=columns, limit=limit)
lines = output.captured
assert all(col in lines[1] for col in columns)
assert lines[1].count("|") == len(columns) + 1
assert len(lines) == 3 + limit + 1
@@ -93,6 +96,36 @@ def test_ls(start_ray, tmpdir):
lines = output.captured
assert len(lines) == 3 + num_samples + 1
with pytest.raises(click.ClickException):
commands.list_trials(
experiment_path,
sort=["trial_id"],
info_keys=("training_iteration", ))
with pytest.raises(click.ClickException):
commands.list_trials(experiment_path, info_keys=("asdf", ))
def test_ls_with_cfg(start_ray, tmpdir):
experiment_name = "test_ls_with_cfg"
experiment_path = os.path.join(str(tmpdir), experiment_name)
tune.run(
"__fake",
name=experiment_name,
stop={"training_iteration": 1},
config={"test_variable": tune.grid_search(list(range(5)))},
local_dir=str(tmpdir),
global_checkpoint_period=0)
columns = [CONFIG_PREFIX + "test_variable", "trial_id"]
limit = 4
with Capturing() as output:
commands.list_trials(experiment_path, info_keys=columns, limit=limit)
lines = output.captured
assert all(col in lines[1] for col in columns)
assert lines[1].count("|") == len(columns) + 1
assert len(lines) == 3 + limit + 1
def test_lsx(start_ray, tmpdir):
"""This test captures output of list_experiments."""
@@ -48,6 +48,10 @@ class ExperimentAnalysisSuite(unittest.TestCase):
self.assertTrue(isinstance(df, pd.DataFrame))
self.assertEquals(df.shape[0], self.num_samples)
def testStats(self):
assert self.ea.stats()
assert self.ea.runner_data()
def testTrialDataframe(self):
checkpoints = self.ea._checkpoints
idx = random.randint(0, len(checkpoints) - 1)
+58 -12
View File
@@ -25,7 +25,7 @@ from ray.tune.result import (DEFAULT_RESULTS_DIR, TIMESTEPS_TOTAL, DONE,
TRAINING_ITERATION, TIMESTEPS_THIS_ITER,
TIME_THIS_ITER_S, TIME_TOTAL_S, TRIAL_ID)
from ray.tune.logger import Logger
from ray.tune.util import pin_in_object_store, get_pinned_object
from ray.tune.util import pin_in_object_store, get_pinned_object, flatten_dict
from ray.tune.experiment import Experiment
from ray.tune.trial import Trial, ExportFormat
from ray.tune.trial_runner import TrialRunner
@@ -44,7 +44,7 @@ else:
class TrainableFunctionApiTest(unittest.TestCase):
def setUp(self):
ray.init(num_cpus=4, num_gpus=0)
ray.init(num_cpus=4, num_gpus=0, object_store_memory=int(1e8))
def tearDown(self):
ray.shutdown()
@@ -433,14 +433,16 @@ class TrainableFunctionApiTest(unittest.TestCase):
for i in range(10):
reporter(test={"test1": {"test2": i}})
[trial] = tune.run(
train, stop={
"test": {
"test1": {
"test2": 6
with self.assertRaises(TuneError):
[trial] = tune.run(
train, stop={
"test": {
"test1": {
"test2": 6
}
}
}
}).trials
}).trials
[trial] = tune.run(train, stop={"test/test1/test2": 6}).trials
self.assertEqual(trial.last_result["training_iteration"], 7)
def testEarlyReturn(self):
@@ -514,6 +516,53 @@ class TrainableFunctionApiTest(unittest.TestCase):
self.assertEqual(trial.status, Trial.TERMINATED)
self.assertEqual(trial.last_result["mean_accuracy"], float("inf"))
def testNestedResults(self):
def create_result(i):
return {"test": {"1": {"2": {"3": i, "4": False}}}}
flattened_keys = list(flatten_dict(create_result(0)))
class _MockScheduler(FIFOScheduler):
results = []
def on_trial_result(self, trial_runner, trial, result):
self.results += [result]
return TrialScheduler.CONTINUE
def on_trial_complete(self, trial_runner, trial, result):
self.complete_result = result
def train(config, reporter):
for i in range(100):
reporter(**create_result(i))
algo = _MockSuggestionAlgorithm()
scheduler = _MockScheduler()
[trial] = tune.run(
train,
scheduler=scheduler,
search_alg=algo,
stop={
"test/1/2/3": 20
}).trials
self.assertEqual(trial.status, Trial.TERMINATED)
self.assertEqual(trial.last_result["test"]["1"]["2"]["3"], 20)
self.assertEqual(trial.last_result["test"]["1"]["2"]["4"], False)
self.assertEqual(trial.last_result[TRAINING_ITERATION], 21)
self.assertEqual(len(scheduler.results), 20)
self.assertTrue(
all(
set(result) >= set(flattened_keys)
for result in scheduler.results))
self.assertTrue(set(scheduler.complete_result) >= set(flattened_keys))
self.assertEqual(len(algo.results), 20)
self.assertTrue(
all(set(result) >= set(flattened_keys) for result in algo.results))
with self.assertRaises(TuneError):
[trial] = tune.run(train, stop={"1/2/3": 20})
with self.assertRaises(TuneError):
[trial] = tune.run(train, stop={"test": 1}).trials
def testReportTimeStep(self):
# Test that no timestep count are logged if never the Trainable never
# returns any.
@@ -704,9 +753,6 @@ class TrainableFunctionApiTest(unittest.TestCase):
class RunExperimentTest(unittest.TestCase):
def setUp(self):
ray.init()
def tearDown(self):
ray.shutdown()
_register_all() # re-register the evicted objects
+12 -16
View File
@@ -37,21 +37,6 @@ def has_trainable(trainable_name):
ray.tune.registry.TRAINABLE_CLASS, trainable_name)
def recursive_criteria_check(result, criteria):
for criteria, stop_value in criteria.items():
if criteria not in result:
raise TuneError(
"Stopping criteria {} not provided in result {}.".format(
criteria, result))
elif isinstance(result[criteria], dict) and isinstance(
stop_value, dict):
if recursive_criteria_check(result[criteria], stop_value):
return True
elif result[criteria] >= stop_value:
return True
return False
class Checkpoint(object):
"""Describes a checkpoint of trial state.
@@ -292,7 +277,18 @@ class Trial(object):
if result.get(DONE):
return True
return recursive_criteria_check(result, self.stopping_criterion)
for criteria, stop_value in self.stopping_criterion.items():
if criteria not in result:
raise TuneError(
"Stopping criteria {} not provided in result {}.".format(
criteria, result))
elif isinstance(criteria, dict):
raise ValueError(
"Stopping criteria is now flattened by default. "
"Use forward slashes to nest values `key1/key2/key3`.")
elif result[criteria] >= stop_value:
return True
return False
def should_checkpoint(self):
"""Whether this trial is due for checkpointing."""
+8 -6
View File
@@ -22,7 +22,7 @@ from ray.tune.trial import Trial, Checkpoint
from ray.tune.sample import function
from ray.tune.schedulers import FIFOScheduler, TrialScheduler
from ray.tune.suggest import BasicVariantGenerator
from ray.tune.util import warn_if_slow
from ray.tune.util import warn_if_slow, flatten_dict
from ray.utils import binary_to_hex, hex_to_binary
from ray.tune.web_server import TuneServer
@@ -508,18 +508,20 @@ class TrialRunner(object):
self._total_time += result[TIME_THIS_ITER_S]
if trial.should_stop(result):
flat_result = flatten_dict(result)
if trial.should_stop(flat_result):
# Hook into scheduler
self._scheduler_alg.on_trial_complete(self, trial, result)
self._scheduler_alg.on_trial_complete(self, trial, flat_result)
self._search_alg.on_trial_complete(
trial.trial_id, result=result)
trial.trial_id, result=flat_result)
decision = TrialScheduler.STOP
else:
with warn_if_slow("scheduler.on_trial_result"):
decision = self._scheduler_alg.on_trial_result(
self, trial, result)
self, trial, flat_result)
with warn_if_slow("search_alg.on_trial_result"):
self._search_alg.on_trial_result(trial.trial_id, result)
self._search_alg.on_trial_result(trial.trial_id,
flat_result)
if decision == TrialScheduler.STOP:
with warn_if_slow("search_alg.on_trial_complete"):
self._search_alg.on_trial_complete(