mirror of
https://github.com/wassname/ray.git
synced 2026-06-29 02:47:10 +08:00
[rllib] Add support for LR schedule to DQN/APEX (#4473)
This commit is contained in:
@@ -103,6 +103,8 @@ DEFAULT_CONFIG = with_common_config({
|
||||
# === Optimization ===
|
||||
# Learning rate for adam optimizer
|
||||
"lr": 5e-4,
|
||||
# Learning rate schedule
|
||||
"lr_schedule": None,
|
||||
# Adam epsilon hyper parameter
|
||||
"adam_epsilon": 1e-8,
|
||||
# If not None, clip gradients during optimization at this value
|
||||
|
||||
@@ -15,7 +15,8 @@ from ray.rllib.models import ModelCatalog, Categorical
|
||||
from ray.rllib.utils.annotations import override
|
||||
from ray.rllib.utils.error import UnsupportedSpaceException
|
||||
from ray.rllib.evaluation.policy_graph import PolicyGraph
|
||||
from ray.rllib.evaluation.tf_policy_graph import TFPolicyGraph
|
||||
from ray.rllib.evaluation.tf_policy_graph import TFPolicyGraph, \
|
||||
LearningRateSchedule
|
||||
|
||||
Q_SCOPE = "q_func"
|
||||
Q_TARGET_SCOPE = "target_q_func"
|
||||
@@ -336,7 +337,7 @@ class QValuePolicy(object):
|
||||
self.action_prob = None
|
||||
|
||||
|
||||
class DQNPolicyGraph(DQNPostprocessing, TFPolicyGraph):
|
||||
class DQNPolicyGraph(LearningRateSchedule, DQNPostprocessing, TFPolicyGraph):
|
||||
def __init__(self, observation_space, action_space, config):
|
||||
config = dict(ray.rllib.agents.dqn.dqn.DEFAULT_CONFIG, **config)
|
||||
if not isinstance(action_space, Discrete):
|
||||
@@ -447,6 +448,9 @@ class DQNPolicyGraph(DQNPostprocessing, TFPolicyGraph):
|
||||
(SampleBatch.DONES, self.done_mask),
|
||||
(PRIO_WEIGHTS, self.importance_weights),
|
||||
]
|
||||
|
||||
LearningRateSchedule.__init__(self, self.config["lr"],
|
||||
self.config["lr_schedule"])
|
||||
TFPolicyGraph.__init__(
|
||||
self,
|
||||
observation_space,
|
||||
@@ -461,11 +465,14 @@ class DQNPolicyGraph(DQNPostprocessing, TFPolicyGraph):
|
||||
update_ops=q_batchnorm_update_ops)
|
||||
self.sess.run(tf.global_variables_initializer())
|
||||
|
||||
self.stats_fetches = dict({
|
||||
"cur_lr": tf.cast(self.cur_lr, tf.float64),
|
||||
}, **self.loss.stats)
|
||||
|
||||
@override(TFPolicyGraph)
|
||||
def optimizer(self):
|
||||
return tf.train.AdamOptimizer(
|
||||
learning_rate=self.config["lr"],
|
||||
epsilon=self.config["adam_epsilon"])
|
||||
learning_rate=self.cur_lr, epsilon=self.config["adam_epsilon"])
|
||||
|
||||
@override(TFPolicyGraph)
|
||||
def gradients(self, optimizer, loss):
|
||||
@@ -492,7 +499,7 @@ class DQNPolicyGraph(DQNPostprocessing, TFPolicyGraph):
|
||||
def extra_compute_grad_fetches(self):
|
||||
return {
|
||||
"td_error": self.loss.td_error,
|
||||
LEARNER_STATS_KEY: self.loss.stats,
|
||||
LEARNER_STATS_KEY: self.stats_fetches,
|
||||
}
|
||||
|
||||
@override(PolicyGraph)
|
||||
|
||||
Reference in New Issue
Block a user