mirror of
https://github.com/wassname/ray.git
synced 2026-06-28 09:12:56 +08:00
[rllib] Support Nested Configuration Merging (#1268)
This commit is contained in:
@@ -25,6 +25,37 @@ logger = logging.getLogger(__name__)
|
||||
logger.setLevel(logging.INFO)
|
||||
|
||||
|
||||
def _deep_update(original, new_dict, new_keys_allowed, whitelist):
|
||||
"""Updates original dict with values from new_dict recursively.
|
||||
If new key is introduced in new_dict, then if new_keys_allowed is not
|
||||
True, an error will be thrown. Further, for sub-dicts, if the key is
|
||||
in the whitelist, then new subkeys can be introduced.
|
||||
|
||||
Args:
|
||||
original (dict): Dictionary with default values.
|
||||
new_dict (dict): Dictionary with values to be updated
|
||||
new_keys_allowed (bool): Whether new keys are allowed.
|
||||
whitelist (list): List of keys that correspond to dict values
|
||||
where new subkeys can be introduced. This is only at
|
||||
the top level.
|
||||
"""
|
||||
for k, value in new_dict.items():
|
||||
if k not in original and k != "env":
|
||||
if not new_keys_allowed:
|
||||
raise Exception(
|
||||
"Unknown config parameter `{}` ".format(k))
|
||||
else:
|
||||
logger.warn("`{}` not in default configuration...".format(k))
|
||||
if type(original.get(k)) is dict:
|
||||
if k in whitelist:
|
||||
_deep_update(original[k], value, True, [])
|
||||
else:
|
||||
_deep_update(original[k], value, new_keys_allowed, [])
|
||||
else:
|
||||
original[k] = value
|
||||
return original
|
||||
|
||||
|
||||
class Agent(Trainable):
|
||||
"""All RLlib agents extend this base class.
|
||||
|
||||
@@ -40,6 +71,7 @@ class Agent(Trainable):
|
||||
"""
|
||||
|
||||
_allow_unknown_configs = False
|
||||
_allow_unknown_subkeys = []
|
||||
_default_logdir = "/tmp/ray"
|
||||
|
||||
def __init__(
|
||||
@@ -67,13 +99,10 @@ class Agent(Trainable):
|
||||
self.env_creator = lambda: gym.make(env)
|
||||
self.config = self._default_config.copy()
|
||||
self.registry = registry
|
||||
if not self._allow_unknown_configs:
|
||||
for k in config.keys():
|
||||
if k not in self.config and k != "env":
|
||||
raise Exception(
|
||||
"Unknown agent config `{}`, "
|
||||
"all agent configs: {}".format(k, self.config.keys()))
|
||||
self.config.update(config)
|
||||
|
||||
self.config = _deep_update(self.config, config,
|
||||
self._allow_unknown_configs,
|
||||
self._allow_unknown_subkeys)
|
||||
|
||||
if logger_creator:
|
||||
self._result_logger = logger_creator(self.config)
|
||||
|
||||
@@ -111,6 +111,7 @@ DEFAULT_CONFIG = dict(
|
||||
|
||||
class DQNAgent(Agent):
|
||||
_agent_name = "DQN"
|
||||
_allow_unknown_subkeys = ["model", "optimizer", "tf_session_args"]
|
||||
_default_config = DEFAULT_CONFIG
|
||||
|
||||
def _init(self):
|
||||
|
||||
@@ -83,6 +83,7 @@ DEFAULT_CONFIG = {
|
||||
|
||||
class PPOAgent(Agent):
|
||||
_agent_name = "PPO"
|
||||
_allow_unknown_subkeys = ["model", "tf_session_args"]
|
||||
_default_config = DEFAULT_CONFIG
|
||||
|
||||
def _init(self):
|
||||
|
||||
@@ -0,0 +1,17 @@
|
||||
pong-a3c-pytorch-cnn:
|
||||
env: PongDeterministic-v4
|
||||
run: A3C
|
||||
resources:
|
||||
cpu: 16
|
||||
driver_cpu_limit: 1
|
||||
config:
|
||||
num_workers: 16
|
||||
num_batches_per_iteration: 1000
|
||||
batch_size: 20
|
||||
use_lstm: false
|
||||
use_pytorch: true
|
||||
model:
|
||||
grayscale: true
|
||||
zero_mean: false
|
||||
dim: 80
|
||||
channel_major: true
|
||||
Reference in New Issue
Block a user