[RLlib] Exploration API: merge deterministic flag with exploration classes (SoftQ and StochasticSampling). (#7155)

This commit is contained in:
Sven Mika
2020-02-19 21:18:45 +01:00
committed by GitHub
parent 399424c418
commit d537e9f0d8
56 changed files with 1218 additions and 461 deletions
+31 -10
View File
@@ -146,8 +146,13 @@ def merge_dicts(d1, d2):
return merged
def deep_update(original, new_dict, new_keys_allowed, whitelist):
def deep_update(original,
new_dict,
new_keys_allowed=False,
whitelist=None,
override_all_if_type_changes=None):
"""Updates original dict with values from new_dict recursively.
If new key is introduced in new_dict, then if new_keys_allowed is not
True, an error will be thrown. Further, for sub-dicts, if the key is
in the whitelist, then new subkeys can be introduced.
@@ -156,19 +161,35 @@ def deep_update(original, new_dict, new_keys_allowed, whitelist):
original (dict): Dictionary with default values.
new_dict (dict): Dictionary with values to be updated
new_keys_allowed (bool): Whether new keys are allowed.
whitelist (list): List of keys that correspond to dict values
where new subkeys can be introduced. This is only at
the top level.
whitelist (Optional[List[str]]): List of keys that correspond to dict
values where new subkeys can be introduced. This is only at the top
level.
override_all_if_type_changes(Optional[List[str]]): List of top level
keys with value=dict, for which we always simply override the
entire value (dict), iff the "type" key in that value dict changes.
"""
whitelist = whitelist or []
override_all_if_type_changes = override_all_if_type_changes or []
for k, value in new_dict.items():
if k not in original:
if not new_keys_allowed:
raise Exception("Unknown config parameter `{}` ".format(k))
if k not in original and not new_keys_allowed:
raise Exception("Unknown config parameter `{}` ".format(k))
# Both orginal value and new one are dicts.
if isinstance(original.get(k), dict) and isinstance(value, dict):
if k in whitelist:
deep_update(original[k], value, True, [])
# Check old type vs old one. If different, override entire value.
if k in override_all_if_type_changes and \
"type" in value and "type" in original[k] and \
value["type"] != original[k]["type"]:
original[k] = value
# Whitelisted key -> ok to add new subkeys.
elif k in whitelist:
deep_update(original[k], value, True)
# Non-whitelisted key.
else:
deep_update(original[k], value, new_keys_allowed, [])
deep_update(original[k], value, new_keys_allowed)
# Original value not a dict OR new value not a dict:
# Override entire value.
else:
original[k] = value
return original