diff --git a/python/ray/rllib/agents/dqn/dqn.py b/python/ray/rllib/agents/dqn/dqn.py index f7859b517..62c56cbc5 100644 --- a/python/ray/rllib/agents/dqn/dqn.py +++ b/python/ray/rllib/agents/dqn/dqn.py @@ -103,6 +103,8 @@ DEFAULT_CONFIG = with_common_config({ # === Optimization === # Learning rate for adam optimizer "lr": 5e-4, + # Learning rate schedule + "lr_schedule": None, # Adam epsilon hyper parameter "adam_epsilon": 1e-8, # If not None, clip gradients during optimization at this value diff --git a/python/ray/rllib/agents/dqn/dqn_policy_graph.py b/python/ray/rllib/agents/dqn/dqn_policy_graph.py index 56af1e04b..1b5bb4624 100644 --- a/python/ray/rllib/agents/dqn/dqn_policy_graph.py +++ b/python/ray/rllib/agents/dqn/dqn_policy_graph.py @@ -15,7 +15,8 @@ from ray.rllib.models import ModelCatalog, Categorical from ray.rllib.utils.annotations import override from ray.rllib.utils.error import UnsupportedSpaceException from ray.rllib.evaluation.policy_graph import PolicyGraph -from ray.rllib.evaluation.tf_policy_graph import TFPolicyGraph +from ray.rllib.evaluation.tf_policy_graph import TFPolicyGraph, \ + LearningRateSchedule Q_SCOPE = "q_func" Q_TARGET_SCOPE = "target_q_func" @@ -336,7 +337,7 @@ class QValuePolicy(object): self.action_prob = None -class DQNPolicyGraph(DQNPostprocessing, TFPolicyGraph): +class DQNPolicyGraph(LearningRateSchedule, DQNPostprocessing, TFPolicyGraph): def __init__(self, observation_space, action_space, config): config = dict(ray.rllib.agents.dqn.dqn.DEFAULT_CONFIG, **config) if not isinstance(action_space, Discrete): @@ -447,6 +448,9 @@ class DQNPolicyGraph(DQNPostprocessing, TFPolicyGraph): (SampleBatch.DONES, self.done_mask), (PRIO_WEIGHTS, self.importance_weights), ] + + LearningRateSchedule.__init__(self, self.config["lr"], + self.config["lr_schedule"]) TFPolicyGraph.__init__( self, observation_space, @@ -461,11 +465,14 @@ class DQNPolicyGraph(DQNPostprocessing, TFPolicyGraph): update_ops=q_batchnorm_update_ops) self.sess.run(tf.global_variables_initializer()) + self.stats_fetches = dict({ + "cur_lr": tf.cast(self.cur_lr, tf.float64), + }, **self.loss.stats) + @override(TFPolicyGraph) def optimizer(self): return tf.train.AdamOptimizer( - learning_rate=self.config["lr"], - epsilon=self.config["adam_epsilon"]) + learning_rate=self.cur_lr, epsilon=self.config["adam_epsilon"]) @override(TFPolicyGraph) def gradients(self, optimizer, loss): @@ -492,7 +499,7 @@ class DQNPolicyGraph(DQNPostprocessing, TFPolicyGraph): def extra_compute_grad_fetches(self): return { "td_error": self.loss.td_error, - LEARNER_STATS_KEY: self.loss.stats, + LEARNER_STATS_KEY: self.stats_fetches, } @override(PolicyGraph)