From 2b893d1bb5be8f8db87f732bb73c3f7cab425395 Mon Sep 17 00:00:00 2001 From: mvindiola1 Date: Sun, 20 Sep 2020 23:01:51 -0400 Subject: [PATCH] fix incorrect critic loss in TD3 (#10775) Co-authored-by: Manny Vindiola --- rllib/agents/ddpg/ddpg_tf_policy.py | 1 - rllib/agents/ddpg/ddpg_torch_policy.py | 1 - rllib/agents/ddpg/tests/test_ddpg.py | 1 - 3 files changed, 3 deletions(-) diff --git a/rllib/agents/ddpg/ddpg_tf_policy.py b/rllib/agents/ddpg/ddpg_tf_policy.py index 3a04bed51..6b092ed38 100644 --- a/rllib/agents/ddpg/ddpg_tf_policy.py +++ b/rllib/agents/ddpg/ddpg_tf_policy.py @@ -175,7 +175,6 @@ def ddpg_actor_critic_loss(policy, model, _, train_batch): if twin_q: td_error = q_t_selected - q_t_selected_target twin_td_error = twin_q_t_selected - q_t_selected_target - td_error = td_error + twin_td_error if use_huber: errors = huber_loss(td_error, huber_threshold) + \ huber_loss(twin_td_error, huber_threshold) diff --git a/rllib/agents/ddpg/ddpg_torch_policy.py b/rllib/agents/ddpg/ddpg_torch_policy.py index 68d91a5c9..445564466 100644 --- a/rllib/agents/ddpg/ddpg_torch_policy.py +++ b/rllib/agents/ddpg/ddpg_torch_policy.py @@ -124,7 +124,6 @@ def ddpg_actor_critic_loss(policy, model, _, train_batch): if twin_q: td_error = q_t_selected - q_t_selected_target twin_td_error = twin_q_t_selected - q_t_selected_target - td_error = td_error + twin_td_error if use_huber: errors = huber_loss(td_error, huber_threshold) \ + huber_loss(twin_td_error, huber_threshold) diff --git a/rllib/agents/ddpg/tests/test_ddpg.py b/rllib/agents/ddpg/tests/test_ddpg.py index 91b499207..8ceaf3540 100644 --- a/rllib/agents/ddpg/tests/test_ddpg.py +++ b/rllib/agents/ddpg/tests/test_ddpg.py @@ -495,7 +495,6 @@ class TestDDPG(unittest.TestCase): td_error = q_t_selected - q_t_selected_target twin_td_error = twin_q_t_selected - q_t_selected_target - td_error = td_error + twin_td_error errors = huber_loss(td_error, huber_threshold) + \ huber_loss(twin_td_error, huber_threshold)