[rllib] Add support for LR schedule to DQN/APEX (#4473)

This commit is contained in:
opherlieber
2019-04-01 21:35:34 +03:00
committed by Eric Liang
parent 0d94f3eeef
commit 60b230b8ad
2 changed files with 14 additions and 5 deletions
+2
View File
@@ -103,6 +103,8 @@ DEFAULT_CONFIG = with_common_config({
# === Optimization ===
# Learning rate for adam optimizer
"lr": 5e-4,
# Learning rate schedule
"lr_schedule": None,
# Adam epsilon hyper parameter
"adam_epsilon": 1e-8,
# If not None, clip gradients during optimization at this value
@@ -15,7 +15,8 @@ from ray.rllib.models import ModelCatalog, Categorical
from ray.rllib.utils.annotations import override
from ray.rllib.utils.error import UnsupportedSpaceException
from ray.rllib.evaluation.policy_graph import PolicyGraph
from ray.rllib.evaluation.tf_policy_graph import TFPolicyGraph
from ray.rllib.evaluation.tf_policy_graph import TFPolicyGraph, \
LearningRateSchedule
Q_SCOPE = "q_func"
Q_TARGET_SCOPE = "target_q_func"
@@ -336,7 +337,7 @@ class QValuePolicy(object):
self.action_prob = None
class DQNPolicyGraph(DQNPostprocessing, TFPolicyGraph):
class DQNPolicyGraph(LearningRateSchedule, DQNPostprocessing, TFPolicyGraph):
def __init__(self, observation_space, action_space, config):
config = dict(ray.rllib.agents.dqn.dqn.DEFAULT_CONFIG, **config)
if not isinstance(action_space, Discrete):
@@ -447,6 +448,9 @@ class DQNPolicyGraph(DQNPostprocessing, TFPolicyGraph):
(SampleBatch.DONES, self.done_mask),
(PRIO_WEIGHTS, self.importance_weights),
]
LearningRateSchedule.__init__(self, self.config["lr"],
self.config["lr_schedule"])
TFPolicyGraph.__init__(
self,
observation_space,
@@ -461,11 +465,14 @@ class DQNPolicyGraph(DQNPostprocessing, TFPolicyGraph):
update_ops=q_batchnorm_update_ops)
self.sess.run(tf.global_variables_initializer())
self.stats_fetches = dict({
"cur_lr": tf.cast(self.cur_lr, tf.float64),
}, **self.loss.stats)
@override(TFPolicyGraph)
def optimizer(self):
return tf.train.AdamOptimizer(
learning_rate=self.config["lr"],
epsilon=self.config["adam_epsilon"])
learning_rate=self.cur_lr, epsilon=self.config["adam_epsilon"])
@override(TFPolicyGraph)
def gradients(self, optimizer, loss):
@@ -492,7 +499,7 @@ class DQNPolicyGraph(DQNPostprocessing, TFPolicyGraph):
def extra_compute_grad_fetches(self):
return {
"td_error": self.loss.td_error,
LEARNER_STATS_KEY: self.loss.stats,
LEARNER_STATS_KEY: self.stats_fetches,
}
@override(PolicyGraph)