diff --git a/python/ray/rllib/utils/__init__.py b/python/ray/rllib/utils/__init__.py index a738e7419..c25cefeb4 100644 --- a/python/ray/rllib/utils/__init__.py +++ b/python/ray/rllib/utils/__init__.py @@ -1,43 +1,10 @@ -import copy - from ray.rllib.utils.filter_manager import FilterManager from ray.rllib.utils.filter import Filter from ray.rllib.utils.policy_client import PolicyClient from ray.rllib.utils.policy_server import PolicyServer +from ray.tune.util import merge_dicts, deep_update -__all__ = ["Filter", "FilterManager", "PolicyClient", "PolicyServer"] - - -def merge_dicts(d1, d2): - """Returns a new dict that is d1 and d2 deep merged.""" - merged = copy.deepcopy(d1) - deep_update(merged, d2, True, []) - return merged - - -def deep_update(original, new_dict, new_keys_allowed, whitelist): - """Updates original dict with values from new_dict recursively. - If new key is introduced in new_dict, then if new_keys_allowed is not - True, an error will be thrown. Further, for sub-dicts, if the key is - in the whitelist, then new subkeys can be introduced. - - Args: - original (dict): Dictionary with default values. - new_dict (dict): Dictionary with values to be updated - new_keys_allowed (bool): Whether new keys are allowed. - whitelist (list): List of keys that correspond to dict values - where new subkeys can be introduced. This is only at - the top level. - """ - for k, value in new_dict.items(): - if k not in original: - if not new_keys_allowed: - raise Exception("Unknown config parameter `{}` ".format(k)) - if type(original.get(k)) is dict: - if k in whitelist: - deep_update(original[k], value, True, []) - else: - deep_update(original[k], value, new_keys_allowed, []) - else: - original[k] = value - return original +__all__ = [ + "Filter", "FilterManager", "PolicyClient", "PolicyServer", "merge_dicts", + "deep_update" +] diff --git a/python/ray/tune/examples/hyperopt_example.py b/python/ray/tune/examples/hyperopt_example.py index 2898bf26d..d70d16b94 100644 --- a/python/ray/tune/examples/hyperopt_example.py +++ b/python/ray/tune/examples/hyperopt_example.py @@ -17,7 +17,7 @@ def easy_objective(config, reporter): time.sleep(0.2) assert type(config["activation"]) == str, \ "Config is incorrect: {}".format(type(config["activation"])) - for i in range(100): + for i in range(config["iterations"]): reporter( timesteps_total=i, neg_mean_loss=-(config["height"] - 14)**2 + @@ -47,6 +47,9 @@ if __name__ == '__main__': "my_exp": { "run": "exp", "num_samples": 10 if args.smoke_test else 1000, + "config": { + "iterations": 100, + }, "stop": { "timesteps_total": 100 }, diff --git a/python/ray/tune/suggest/suggestion.py b/python/ray/tune/suggest/suggestion.py index aa6cbe717..f6e7d532a 100644 --- a/python/ray/tune/suggest/suggestion.py +++ b/python/ray/tune/suggest/suggestion.py @@ -7,10 +7,11 @@ import copy from ray.tune.error import TuneError from ray.tune.trial import Trial +from ray.tune.util import merge_dicts from ray.tune.experiment import convert_to_experiment_list from ray.tune.config_parser import make_parser, create_trial_from_spec from ray.tune.suggest.search import SearchAlgorithm -from ray.tune.suggest.variant_generator import format_vars +from ray.tune.suggest.variant_generator import format_vars, resolve_nested_dict class SuggestionAlgorithm(SearchAlgorithm): @@ -33,9 +34,6 @@ class SuggestionAlgorithm(SearchAlgorithm): def __init__(self): """Constructs a generator given experiment specifications. - - Arguments: - experiments (Experiment | list | dict): Experiments to run. """ self._parser = make_parser() self._trial_generator = [] @@ -91,10 +89,11 @@ class SuggestionAlgorithm(SearchAlgorithm): else: break spec = copy.deepcopy(experiment_spec) - spec["config"] = suggested_config + spec["config"] = merge_dicts(spec["config"], suggested_config) + flattened_config = resolve_nested_dict(spec["config"]) self._counter += 1 tag = "{0}_{1}".format( - str(self._counter), format_vars(spec["config"])) + str(self._counter), format_vars(flattened_config)) yield create_trial_from_spec( spec, output_path, diff --git a/python/ray/tune/suggest/variant_generator.py b/python/ray/tune/suggest/variant_generator.py index c3ca83e0a..09729f988 100644 --- a/python/ray/tune/suggest/variant_generator.py +++ b/python/ray/tune/suggest/variant_generator.py @@ -93,6 +93,21 @@ _STANDARD_IMPORTS = { _MAX_RESOLUTION_PASSES = 20 +def resolve_nested_dict(nested_dict): + """Flattens a nested dict by joining keys into tuple of paths. + + Can then be passed into `format_vars`. + """ + res = {} + for k, v in nested_dict.items(): + if isinstance(v, dict): + for k_, v_ in resolve_nested_dict(v).items(): + res[(k, ) + k_] = v_ + else: + res[(k, )] = v + return res + + def format_vars(resolved_vars): out = [] for path, value in sorted(resolved_vars.items()): diff --git a/python/ray/tune/test/trial_runner_test.py b/python/ray/tune/test/trial_runner_test.py index 40b6575ce..141936b70 100644 --- a/python/ray/tune/test/trial_runner_test.py +++ b/python/ray/tune/test/trial_runner_test.py @@ -26,7 +26,8 @@ from ray.tune.trial_runner import TrialRunner from ray.tune.suggest import grid_search, BasicVariantGenerator from ray.tune.suggest.suggestion import (_MockSuggestionAlgorithm, SuggestionAlgorithm) -from ray.tune.suggest.variant_generator import RecursiveDependencyError +from ray.tune.suggest.variant_generator import (RecursiveDependencyError, + resolve_nested_dict) if sys.version_info >= (3, 3): from unittest.mock import patch @@ -902,6 +903,20 @@ class VariantGeneratorTest(unittest.TestCase): self.assertEqual(trials[0].config, {"x": 100, "y": 1}) self.assertEqual(trials[1].config, {"x": 200, "y": 1}) + def test_resolve_dict(self): + config = { + "a": { + "b": 1, + "c": 2, + }, + "b": { + "a": 3 + } + } + resolved = resolve_nested_dict(config) + for k, v in [(("a", "b"), 1), (("a", "c"), 2), (("b", "a"), 3)]: + self.assertEqual(resolved.get(k), v) + def testRecursiveDep(self): try: list( @@ -1651,5 +1666,18 @@ class TrialRunnerTest(unittest.TestCase): self.assertRaises(TuneError, runner.step) +class SearchAlgorithmTest(unittest.TestCase): + def testNestedSuggestion(self): + class TestSuggestion(SuggestionAlgorithm): + def _suggest(self, trial_id): + return {"a": {"b": {"c": {"d": 4, "e": 5}}}} + + alg = TestSuggestion() + alg.add_configurations({"test": {"run": "__fake"}}) + trial = alg.next_trials()[0] + self.assertTrue("e=5" in trial.experiment_tag) + self.assertTrue("d=4" in trial.experiment_tag) + + if __name__ == "__main__": unittest.main(verbosity=2) diff --git a/python/ray/tune/util.py b/python/ray/tune/util.py index 9c047fd80..5d2db2726 100644 --- a/python/ray/tune/util.py +++ b/python/ray/tune/util.py @@ -3,6 +3,7 @@ from __future__ import division from __future__ import print_function import base64 +import copy import numpy as np import ray @@ -35,6 +36,41 @@ def get_pinned_object(pinned_id): ObjectID(base64.b64decode(pinned_id[len(PINNED_OBJECT_PREFIX):])))) +def merge_dicts(d1, d2): + """Returns a new dict that is d1 and d2 deep merged.""" + merged = copy.deepcopy(d1) + deep_update(merged, d2, True, []) + return merged + + +def deep_update(original, new_dict, new_keys_allowed, whitelist): + """Updates original dict with values from new_dict recursively. + If new key is introduced in new_dict, then if new_keys_allowed is not + True, an error will be thrown. Further, for sub-dicts, if the key is + in the whitelist, then new subkeys can be introduced. + + Args: + original (dict): Dictionary with default values. + new_dict (dict): Dictionary with values to be updated + new_keys_allowed (bool): Whether new keys are allowed. + whitelist (list): List of keys that correspond to dict values + where new subkeys can be introduced. This is only at + the top level. + """ + for k, value in new_dict.items(): + if k not in original: + if not new_keys_allowed: + raise Exception("Unknown config parameter `{}` ".format(k)) + if type(original.get(k)) is dict: + if k in whitelist: + deep_update(original[k], value, True, []) + else: + deep_update(original[k], value, new_keys_allowed, []) + else: + original[k] = value + return original + + def _to_pinnable(obj): """Converts obj to a form that can be pinned in object store memory.