[tune] Support Configuration Merging (#3584)

* merge configs

* deep merge

* lint

* add resolve

* test
This commit is contained in:
Richard Liaw
2018-12-26 03:07:11 -08:00
committed by Eric Liang
parent 4ce3818be5
commit 6e2d7a9ba1
6 changed files with 94 additions and 46 deletions
+5 -38
View File
@@ -1,43 +1,10 @@
import copy
from ray.rllib.utils.filter_manager import FilterManager
from ray.rllib.utils.filter import Filter
from ray.rllib.utils.policy_client import PolicyClient
from ray.rllib.utils.policy_server import PolicyServer
from ray.tune.util import merge_dicts, deep_update
__all__ = ["Filter", "FilterManager", "PolicyClient", "PolicyServer"]
def merge_dicts(d1, d2):
"""Returns a new dict that is d1 and d2 deep merged."""
merged = copy.deepcopy(d1)
deep_update(merged, d2, True, [])
return merged
def deep_update(original, new_dict, new_keys_allowed, whitelist):
"""Updates original dict with values from new_dict recursively.
If new key is introduced in new_dict, then if new_keys_allowed is not
True, an error will be thrown. Further, for sub-dicts, if the key is
in the whitelist, then new subkeys can be introduced.
Args:
original (dict): Dictionary with default values.
new_dict (dict): Dictionary with values to be updated
new_keys_allowed (bool): Whether new keys are allowed.
whitelist (list): List of keys that correspond to dict values
where new subkeys can be introduced. This is only at
the top level.
"""
for k, value in new_dict.items():
if k not in original:
if not new_keys_allowed:
raise Exception("Unknown config parameter `{}` ".format(k))
if type(original.get(k)) is dict:
if k in whitelist:
deep_update(original[k], value, True, [])
else:
deep_update(original[k], value, new_keys_allowed, [])
else:
original[k] = value
return original
__all__ = [
"Filter", "FilterManager", "PolicyClient", "PolicyServer", "merge_dicts",
"deep_update"
]
+4 -1
View File
@@ -17,7 +17,7 @@ def easy_objective(config, reporter):
time.sleep(0.2)
assert type(config["activation"]) == str, \
"Config is incorrect: {}".format(type(config["activation"]))
for i in range(100):
for i in range(config["iterations"]):
reporter(
timesteps_total=i,
neg_mean_loss=-(config["height"] - 14)**2 +
@@ -47,6 +47,9 @@ if __name__ == '__main__':
"my_exp": {
"run": "exp",
"num_samples": 10 if args.smoke_test else 1000,
"config": {
"iterations": 100,
},
"stop": {
"timesteps_total": 100
},
+5 -6
View File
@@ -7,10 +7,11 @@ import copy
from ray.tune.error import TuneError
from ray.tune.trial import Trial
from ray.tune.util import merge_dicts
from ray.tune.experiment import convert_to_experiment_list
from ray.tune.config_parser import make_parser, create_trial_from_spec
from ray.tune.suggest.search import SearchAlgorithm
from ray.tune.suggest.variant_generator import format_vars
from ray.tune.suggest.variant_generator import format_vars, resolve_nested_dict
class SuggestionAlgorithm(SearchAlgorithm):
@@ -33,9 +34,6 @@ class SuggestionAlgorithm(SearchAlgorithm):
def __init__(self):
"""Constructs a generator given experiment specifications.
Arguments:
experiments (Experiment | list | dict): Experiments to run.
"""
self._parser = make_parser()
self._trial_generator = []
@@ -91,10 +89,11 @@ class SuggestionAlgorithm(SearchAlgorithm):
else:
break
spec = copy.deepcopy(experiment_spec)
spec["config"] = suggested_config
spec["config"] = merge_dicts(spec["config"], suggested_config)
flattened_config = resolve_nested_dict(spec["config"])
self._counter += 1
tag = "{0}_{1}".format(
str(self._counter), format_vars(spec["config"]))
str(self._counter), format_vars(flattened_config))
yield create_trial_from_spec(
spec,
output_path,
@@ -93,6 +93,21 @@ _STANDARD_IMPORTS = {
_MAX_RESOLUTION_PASSES = 20
def resolve_nested_dict(nested_dict):
"""Flattens a nested dict by joining keys into tuple of paths.
Can then be passed into `format_vars`.
"""
res = {}
for k, v in nested_dict.items():
if isinstance(v, dict):
for k_, v_ in resolve_nested_dict(v).items():
res[(k, ) + k_] = v_
else:
res[(k, )] = v
return res
def format_vars(resolved_vars):
out = []
for path, value in sorted(resolved_vars.items()):
+29 -1
View File
@@ -26,7 +26,8 @@ from ray.tune.trial_runner import TrialRunner
from ray.tune.suggest import grid_search, BasicVariantGenerator
from ray.tune.suggest.suggestion import (_MockSuggestionAlgorithm,
SuggestionAlgorithm)
from ray.tune.suggest.variant_generator import RecursiveDependencyError
from ray.tune.suggest.variant_generator import (RecursiveDependencyError,
resolve_nested_dict)
if sys.version_info >= (3, 3):
from unittest.mock import patch
@@ -902,6 +903,20 @@ class VariantGeneratorTest(unittest.TestCase):
self.assertEqual(trials[0].config, {"x": 100, "y": 1})
self.assertEqual(trials[1].config, {"x": 200, "y": 1})
def test_resolve_dict(self):
config = {
"a": {
"b": 1,
"c": 2,
},
"b": {
"a": 3
}
}
resolved = resolve_nested_dict(config)
for k, v in [(("a", "b"), 1), (("a", "c"), 2), (("b", "a"), 3)]:
self.assertEqual(resolved.get(k), v)
def testRecursiveDep(self):
try:
list(
@@ -1651,5 +1666,18 @@ class TrialRunnerTest(unittest.TestCase):
self.assertRaises(TuneError, runner.step)
class SearchAlgorithmTest(unittest.TestCase):
def testNestedSuggestion(self):
class TestSuggestion(SuggestionAlgorithm):
def _suggest(self, trial_id):
return {"a": {"b": {"c": {"d": 4, "e": 5}}}}
alg = TestSuggestion()
alg.add_configurations({"test": {"run": "__fake"}})
trial = alg.next_trials()[0]
self.assertTrue("e=5" in trial.experiment_tag)
self.assertTrue("d=4" in trial.experiment_tag)
if __name__ == "__main__":
unittest.main(verbosity=2)
+36
View File
@@ -3,6 +3,7 @@ from __future__ import division
from __future__ import print_function
import base64
import copy
import numpy as np
import ray
@@ -35,6 +36,41 @@ def get_pinned_object(pinned_id):
ObjectID(base64.b64decode(pinned_id[len(PINNED_OBJECT_PREFIX):]))))
def merge_dicts(d1, d2):
"""Returns a new dict that is d1 and d2 deep merged."""
merged = copy.deepcopy(d1)
deep_update(merged, d2, True, [])
return merged
def deep_update(original, new_dict, new_keys_allowed, whitelist):
"""Updates original dict with values from new_dict recursively.
If new key is introduced in new_dict, then if new_keys_allowed is not
True, an error will be thrown. Further, for sub-dicts, if the key is
in the whitelist, then new subkeys can be introduced.
Args:
original (dict): Dictionary with default values.
new_dict (dict): Dictionary with values to be updated
new_keys_allowed (bool): Whether new keys are allowed.
whitelist (list): List of keys that correspond to dict values
where new subkeys can be introduced. This is only at
the top level.
"""
for k, value in new_dict.items():
if k not in original:
if not new_keys_allowed:
raise Exception("Unknown config parameter `{}` ".format(k))
if type(original.get(k)) is dict:
if k in whitelist:
deep_update(original[k], value, True, [])
else:
deep_update(original[k], value, new_keys_allowed, [])
else:
original[k] = value
return original
def _to_pinnable(obj):
"""Converts obj to a form that can be pinned in object store memory.