mirror of
https://github.com/wassname/ray.git
synced 2026-06-29 00:39:32 +08:00
[tune] Support Configuration Merging (#3584)
* merge configs * deep merge * lint * add resolve * test
This commit is contained in:
@@ -1,43 +1,10 @@
|
||||
import copy
|
||||
|
||||
from ray.rllib.utils.filter_manager import FilterManager
|
||||
from ray.rllib.utils.filter import Filter
|
||||
from ray.rllib.utils.policy_client import PolicyClient
|
||||
from ray.rllib.utils.policy_server import PolicyServer
|
||||
from ray.tune.util import merge_dicts, deep_update
|
||||
|
||||
__all__ = ["Filter", "FilterManager", "PolicyClient", "PolicyServer"]
|
||||
|
||||
|
||||
def merge_dicts(d1, d2):
|
||||
"""Returns a new dict that is d1 and d2 deep merged."""
|
||||
merged = copy.deepcopy(d1)
|
||||
deep_update(merged, d2, True, [])
|
||||
return merged
|
||||
|
||||
|
||||
def deep_update(original, new_dict, new_keys_allowed, whitelist):
|
||||
"""Updates original dict with values from new_dict recursively.
|
||||
If new key is introduced in new_dict, then if new_keys_allowed is not
|
||||
True, an error will be thrown. Further, for sub-dicts, if the key is
|
||||
in the whitelist, then new subkeys can be introduced.
|
||||
|
||||
Args:
|
||||
original (dict): Dictionary with default values.
|
||||
new_dict (dict): Dictionary with values to be updated
|
||||
new_keys_allowed (bool): Whether new keys are allowed.
|
||||
whitelist (list): List of keys that correspond to dict values
|
||||
where new subkeys can be introduced. This is only at
|
||||
the top level.
|
||||
"""
|
||||
for k, value in new_dict.items():
|
||||
if k not in original:
|
||||
if not new_keys_allowed:
|
||||
raise Exception("Unknown config parameter `{}` ".format(k))
|
||||
if type(original.get(k)) is dict:
|
||||
if k in whitelist:
|
||||
deep_update(original[k], value, True, [])
|
||||
else:
|
||||
deep_update(original[k], value, new_keys_allowed, [])
|
||||
else:
|
||||
original[k] = value
|
||||
return original
|
||||
__all__ = [
|
||||
"Filter", "FilterManager", "PolicyClient", "PolicyServer", "merge_dicts",
|
||||
"deep_update"
|
||||
]
|
||||
|
||||
@@ -17,7 +17,7 @@ def easy_objective(config, reporter):
|
||||
time.sleep(0.2)
|
||||
assert type(config["activation"]) == str, \
|
||||
"Config is incorrect: {}".format(type(config["activation"]))
|
||||
for i in range(100):
|
||||
for i in range(config["iterations"]):
|
||||
reporter(
|
||||
timesteps_total=i,
|
||||
neg_mean_loss=-(config["height"] - 14)**2 +
|
||||
@@ -47,6 +47,9 @@ if __name__ == '__main__':
|
||||
"my_exp": {
|
||||
"run": "exp",
|
||||
"num_samples": 10 if args.smoke_test else 1000,
|
||||
"config": {
|
||||
"iterations": 100,
|
||||
},
|
||||
"stop": {
|
||||
"timesteps_total": 100
|
||||
},
|
||||
|
||||
@@ -7,10 +7,11 @@ import copy
|
||||
|
||||
from ray.tune.error import TuneError
|
||||
from ray.tune.trial import Trial
|
||||
from ray.tune.util import merge_dicts
|
||||
from ray.tune.experiment import convert_to_experiment_list
|
||||
from ray.tune.config_parser import make_parser, create_trial_from_spec
|
||||
from ray.tune.suggest.search import SearchAlgorithm
|
||||
from ray.tune.suggest.variant_generator import format_vars
|
||||
from ray.tune.suggest.variant_generator import format_vars, resolve_nested_dict
|
||||
|
||||
|
||||
class SuggestionAlgorithm(SearchAlgorithm):
|
||||
@@ -33,9 +34,6 @@ class SuggestionAlgorithm(SearchAlgorithm):
|
||||
|
||||
def __init__(self):
|
||||
"""Constructs a generator given experiment specifications.
|
||||
|
||||
Arguments:
|
||||
experiments (Experiment | list | dict): Experiments to run.
|
||||
"""
|
||||
self._parser = make_parser()
|
||||
self._trial_generator = []
|
||||
@@ -91,10 +89,11 @@ class SuggestionAlgorithm(SearchAlgorithm):
|
||||
else:
|
||||
break
|
||||
spec = copy.deepcopy(experiment_spec)
|
||||
spec["config"] = suggested_config
|
||||
spec["config"] = merge_dicts(spec["config"], suggested_config)
|
||||
flattened_config = resolve_nested_dict(spec["config"])
|
||||
self._counter += 1
|
||||
tag = "{0}_{1}".format(
|
||||
str(self._counter), format_vars(spec["config"]))
|
||||
str(self._counter), format_vars(flattened_config))
|
||||
yield create_trial_from_spec(
|
||||
spec,
|
||||
output_path,
|
||||
|
||||
@@ -93,6 +93,21 @@ _STANDARD_IMPORTS = {
|
||||
_MAX_RESOLUTION_PASSES = 20
|
||||
|
||||
|
||||
def resolve_nested_dict(nested_dict):
|
||||
"""Flattens a nested dict by joining keys into tuple of paths.
|
||||
|
||||
Can then be passed into `format_vars`.
|
||||
"""
|
||||
res = {}
|
||||
for k, v in nested_dict.items():
|
||||
if isinstance(v, dict):
|
||||
for k_, v_ in resolve_nested_dict(v).items():
|
||||
res[(k, ) + k_] = v_
|
||||
else:
|
||||
res[(k, )] = v
|
||||
return res
|
||||
|
||||
|
||||
def format_vars(resolved_vars):
|
||||
out = []
|
||||
for path, value in sorted(resolved_vars.items()):
|
||||
|
||||
@@ -26,7 +26,8 @@ from ray.tune.trial_runner import TrialRunner
|
||||
from ray.tune.suggest import grid_search, BasicVariantGenerator
|
||||
from ray.tune.suggest.suggestion import (_MockSuggestionAlgorithm,
|
||||
SuggestionAlgorithm)
|
||||
from ray.tune.suggest.variant_generator import RecursiveDependencyError
|
||||
from ray.tune.suggest.variant_generator import (RecursiveDependencyError,
|
||||
resolve_nested_dict)
|
||||
|
||||
if sys.version_info >= (3, 3):
|
||||
from unittest.mock import patch
|
||||
@@ -902,6 +903,20 @@ class VariantGeneratorTest(unittest.TestCase):
|
||||
self.assertEqual(trials[0].config, {"x": 100, "y": 1})
|
||||
self.assertEqual(trials[1].config, {"x": 200, "y": 1})
|
||||
|
||||
def test_resolve_dict(self):
|
||||
config = {
|
||||
"a": {
|
||||
"b": 1,
|
||||
"c": 2,
|
||||
},
|
||||
"b": {
|
||||
"a": 3
|
||||
}
|
||||
}
|
||||
resolved = resolve_nested_dict(config)
|
||||
for k, v in [(("a", "b"), 1), (("a", "c"), 2), (("b", "a"), 3)]:
|
||||
self.assertEqual(resolved.get(k), v)
|
||||
|
||||
def testRecursiveDep(self):
|
||||
try:
|
||||
list(
|
||||
@@ -1651,5 +1666,18 @@ class TrialRunnerTest(unittest.TestCase):
|
||||
self.assertRaises(TuneError, runner.step)
|
||||
|
||||
|
||||
class SearchAlgorithmTest(unittest.TestCase):
|
||||
def testNestedSuggestion(self):
|
||||
class TestSuggestion(SuggestionAlgorithm):
|
||||
def _suggest(self, trial_id):
|
||||
return {"a": {"b": {"c": {"d": 4, "e": 5}}}}
|
||||
|
||||
alg = TestSuggestion()
|
||||
alg.add_configurations({"test": {"run": "__fake"}})
|
||||
trial = alg.next_trials()[0]
|
||||
self.assertTrue("e=5" in trial.experiment_tag)
|
||||
self.assertTrue("d=4" in trial.experiment_tag)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
unittest.main(verbosity=2)
|
||||
|
||||
@@ -3,6 +3,7 @@ from __future__ import division
|
||||
from __future__ import print_function
|
||||
|
||||
import base64
|
||||
import copy
|
||||
import numpy as np
|
||||
|
||||
import ray
|
||||
@@ -35,6 +36,41 @@ def get_pinned_object(pinned_id):
|
||||
ObjectID(base64.b64decode(pinned_id[len(PINNED_OBJECT_PREFIX):]))))
|
||||
|
||||
|
||||
def merge_dicts(d1, d2):
|
||||
"""Returns a new dict that is d1 and d2 deep merged."""
|
||||
merged = copy.deepcopy(d1)
|
||||
deep_update(merged, d2, True, [])
|
||||
return merged
|
||||
|
||||
|
||||
def deep_update(original, new_dict, new_keys_allowed, whitelist):
|
||||
"""Updates original dict with values from new_dict recursively.
|
||||
If new key is introduced in new_dict, then if new_keys_allowed is not
|
||||
True, an error will be thrown. Further, for sub-dicts, if the key is
|
||||
in the whitelist, then new subkeys can be introduced.
|
||||
|
||||
Args:
|
||||
original (dict): Dictionary with default values.
|
||||
new_dict (dict): Dictionary with values to be updated
|
||||
new_keys_allowed (bool): Whether new keys are allowed.
|
||||
whitelist (list): List of keys that correspond to dict values
|
||||
where new subkeys can be introduced. This is only at
|
||||
the top level.
|
||||
"""
|
||||
for k, value in new_dict.items():
|
||||
if k not in original:
|
||||
if not new_keys_allowed:
|
||||
raise Exception("Unknown config parameter `{}` ".format(k))
|
||||
if type(original.get(k)) is dict:
|
||||
if k in whitelist:
|
||||
deep_update(original[k], value, True, [])
|
||||
else:
|
||||
deep_update(original[k], value, new_keys_allowed, [])
|
||||
else:
|
||||
original[k] = value
|
||||
return original
|
||||
|
||||
|
||||
def _to_pinnable(obj):
|
||||
"""Converts obj to a form that can be pinned in object store memory.
|
||||
|
||||
|
||||
Reference in New Issue
Block a user