mirror of
https://github.com/wassname/ray.git
synced 2026-06-28 08:07:54 +08:00
Remove redundant scaler of l2 reg (#5172)
* remove redundant scaler of l2 reg * lint formatted * Update ddpg_policy.py
This commit is contained in:
@@ -231,17 +231,15 @@ class DDPGTFPolicy(DDPGPostprocessing, TFPolicy):
|
||||
if config["l2_reg"] is not None:
|
||||
for var in self.policy_vars:
|
||||
if "bias" not in var.name:
|
||||
self.actor_loss += (
|
||||
config["l2_reg"] * 0.5 * tf.nn.l2_loss(var))
|
||||
self.actor_loss += (config["l2_reg"] * tf.nn.l2_loss(var))
|
||||
for var in self.q_func_vars:
|
||||
if "bias" not in var.name:
|
||||
self.critic_loss += (
|
||||
config["l2_reg"] * 0.5 * tf.nn.l2_loss(var))
|
||||
self.critic_loss += (config["l2_reg"] * tf.nn.l2_loss(var))
|
||||
if self.config["twin_q"]:
|
||||
for var in self.twin_q_func_vars:
|
||||
if "bias" not in var.name:
|
||||
self.critic_loss += (
|
||||
config["l2_reg"] * 0.5 * tf.nn.l2_loss(var))
|
||||
config["l2_reg"] * tf.nn.l2_loss(var))
|
||||
|
||||
# update_target_fn will be called periodically to copy Q network to
|
||||
# target Q network
|
||||
|
||||
Reference in New Issue
Block a user