From cabbd27c563cb03e998a8575f626ca82f948c74f Mon Sep 17 00:00:00 2001 From: Richard Liaw Date: Wed, 13 Dec 2017 14:39:01 -0800 Subject: [PATCH] [rllib] Support Nested Configuration Merging (#1268) --- python/ray/rllib/agent.py | 43 ++++++++++++++++--- python/ray/rllib/dqn/dqn.py | 1 + python/ray/rllib/ppo/ppo.py | 1 + .../tuned_examples/pong-a3c-pytorch.yaml | 17 ++++++++ 4 files changed, 55 insertions(+), 7 deletions(-) create mode 100644 python/ray/rllib/tuned_examples/pong-a3c-pytorch.yaml diff --git a/python/ray/rllib/agent.py b/python/ray/rllib/agent.py index 1770a1a29..a889e0677 100644 --- a/python/ray/rllib/agent.py +++ b/python/ray/rllib/agent.py @@ -25,6 +25,37 @@ logger = logging.getLogger(__name__) logger.setLevel(logging.INFO) +def _deep_update(original, new_dict, new_keys_allowed, whitelist): + """Updates original dict with values from new_dict recursively. + If new key is introduced in new_dict, then if new_keys_allowed is not + True, an error will be thrown. Further, for sub-dicts, if the key is + in the whitelist, then new subkeys can be introduced. + + Args: + original (dict): Dictionary with default values. + new_dict (dict): Dictionary with values to be updated + new_keys_allowed (bool): Whether new keys are allowed. + whitelist (list): List of keys that correspond to dict values + where new subkeys can be introduced. This is only at + the top level. + """ + for k, value in new_dict.items(): + if k not in original and k != "env": + if not new_keys_allowed: + raise Exception( + "Unknown config parameter `{}` ".format(k)) + else: + logger.warn("`{}` not in default configuration...".format(k)) + if type(original.get(k)) is dict: + if k in whitelist: + _deep_update(original[k], value, True, []) + else: + _deep_update(original[k], value, new_keys_allowed, []) + else: + original[k] = value + return original + + class Agent(Trainable): """All RLlib agents extend this base class. @@ -40,6 +71,7 @@ class Agent(Trainable): """ _allow_unknown_configs = False + _allow_unknown_subkeys = [] _default_logdir = "/tmp/ray" def __init__( @@ -67,13 +99,10 @@ class Agent(Trainable): self.env_creator = lambda: gym.make(env) self.config = self._default_config.copy() self.registry = registry - if not self._allow_unknown_configs: - for k in config.keys(): - if k not in self.config and k != "env": - raise Exception( - "Unknown agent config `{}`, " - "all agent configs: {}".format(k, self.config.keys())) - self.config.update(config) + + self.config = _deep_update(self.config, config, + self._allow_unknown_configs, + self._allow_unknown_subkeys) if logger_creator: self._result_logger = logger_creator(self.config) diff --git a/python/ray/rllib/dqn/dqn.py b/python/ray/rllib/dqn/dqn.py index 22901541a..7f1c165cd 100644 --- a/python/ray/rllib/dqn/dqn.py +++ b/python/ray/rllib/dqn/dqn.py @@ -111,6 +111,7 @@ DEFAULT_CONFIG = dict( class DQNAgent(Agent): _agent_name = "DQN" + _allow_unknown_subkeys = ["model", "optimizer", "tf_session_args"] _default_config = DEFAULT_CONFIG def _init(self): diff --git a/python/ray/rllib/ppo/ppo.py b/python/ray/rllib/ppo/ppo.py index 0dc4cc5b3..3a526272c 100644 --- a/python/ray/rllib/ppo/ppo.py +++ b/python/ray/rllib/ppo/ppo.py @@ -83,6 +83,7 @@ DEFAULT_CONFIG = { class PPOAgent(Agent): _agent_name = "PPO" + _allow_unknown_subkeys = ["model", "tf_session_args"] _default_config = DEFAULT_CONFIG def _init(self): diff --git a/python/ray/rllib/tuned_examples/pong-a3c-pytorch.yaml b/python/ray/rllib/tuned_examples/pong-a3c-pytorch.yaml new file mode 100644 index 000000000..05c6537cd --- /dev/null +++ b/python/ray/rllib/tuned_examples/pong-a3c-pytorch.yaml @@ -0,0 +1,17 @@ +pong-a3c-pytorch-cnn: + env: PongDeterministic-v4 + run: A3C + resources: + cpu: 16 + driver_cpu_limit: 1 + config: + num_workers: 16 + num_batches_per_iteration: 1000 + batch_size: 20 + use_lstm: false + use_pytorch: true + model: + grayscale: true + zero_mean: false + dim: 80 + channel_major: true