[rllib] Support Nested Configuration Merging (#1268)

This commit is contained in:
Richard Liaw
2017-12-13 14:39:01 -08:00
committed by GitHub
parent f75b51d178
commit cabbd27c56
4 changed files with 55 additions and 7 deletions
+36 -7
View File
@@ -25,6 +25,37 @@ logger = logging.getLogger(__name__)
logger.setLevel(logging.INFO)
def _deep_update(original, new_dict, new_keys_allowed, whitelist):
"""Updates original dict with values from new_dict recursively.
If new key is introduced in new_dict, then if new_keys_allowed is not
True, an error will be thrown. Further, for sub-dicts, if the key is
in the whitelist, then new subkeys can be introduced.
Args:
original (dict): Dictionary with default values.
new_dict (dict): Dictionary with values to be updated
new_keys_allowed (bool): Whether new keys are allowed.
whitelist (list): List of keys that correspond to dict values
where new subkeys can be introduced. This is only at
the top level.
"""
for k, value in new_dict.items():
if k not in original and k != "env":
if not new_keys_allowed:
raise Exception(
"Unknown config parameter `{}` ".format(k))
else:
logger.warn("`{}` not in default configuration...".format(k))
if type(original.get(k)) is dict:
if k in whitelist:
_deep_update(original[k], value, True, [])
else:
_deep_update(original[k], value, new_keys_allowed, [])
else:
original[k] = value
return original
class Agent(Trainable):
"""All RLlib agents extend this base class.
@@ -40,6 +71,7 @@ class Agent(Trainable):
"""
_allow_unknown_configs = False
_allow_unknown_subkeys = []
_default_logdir = "/tmp/ray"
def __init__(
@@ -67,13 +99,10 @@ class Agent(Trainable):
self.env_creator = lambda: gym.make(env)
self.config = self._default_config.copy()
self.registry = registry
if not self._allow_unknown_configs:
for k in config.keys():
if k not in self.config and k != "env":
raise Exception(
"Unknown agent config `{}`, "
"all agent configs: {}".format(k, self.config.keys()))
self.config.update(config)
self.config = _deep_update(self.config, config,
self._allow_unknown_configs,
self._allow_unknown_subkeys)
if logger_creator:
self._result_logger = logger_creator(self.config)
+1
View File
@@ -111,6 +111,7 @@ DEFAULT_CONFIG = dict(
class DQNAgent(Agent):
_agent_name = "DQN"
_allow_unknown_subkeys = ["model", "optimizer", "tf_session_args"]
_default_config = DEFAULT_CONFIG
def _init(self):
+1
View File
@@ -83,6 +83,7 @@ DEFAULT_CONFIG = {
class PPOAgent(Agent):
_agent_name = "PPO"
_allow_unknown_subkeys = ["model", "tf_session_args"]
_default_config = DEFAULT_CONFIG
def _init(self):
@@ -0,0 +1,17 @@
pong-a3c-pytorch-cnn:
env: PongDeterministic-v4
run: A3C
resources:
cpu: 16
driver_cpu_limit: 1
config:
num_workers: 16
num_batches_per_iteration: 1000
batch_size: 20
use_lstm: false
use_pytorch: true
model:
grayscale: true
zero_mean: false
dim: 80
channel_major: true