mirror of
https://github.com/wassname/ray.git
synced 2026-06-29 20:17:17 +08:00
fix incorrect critic loss in TD3 (#10775)
Co-authored-by: Manny Vindiola <manuel.m.vindiola.civ@mail.mil>
This commit is contained in:
@@ -175,7 +175,6 @@ def ddpg_actor_critic_loss(policy, model, _, train_batch):
|
||||
if twin_q:
|
||||
td_error = q_t_selected - q_t_selected_target
|
||||
twin_td_error = twin_q_t_selected - q_t_selected_target
|
||||
td_error = td_error + twin_td_error
|
||||
if use_huber:
|
||||
errors = huber_loss(td_error, huber_threshold) + \
|
||||
huber_loss(twin_td_error, huber_threshold)
|
||||
|
||||
@@ -124,7 +124,6 @@ def ddpg_actor_critic_loss(policy, model, _, train_batch):
|
||||
if twin_q:
|
||||
td_error = q_t_selected - q_t_selected_target
|
||||
twin_td_error = twin_q_t_selected - q_t_selected_target
|
||||
td_error = td_error + twin_td_error
|
||||
if use_huber:
|
||||
errors = huber_loss(td_error, huber_threshold) \
|
||||
+ huber_loss(twin_td_error, huber_threshold)
|
||||
|
||||
@@ -495,7 +495,6 @@ class TestDDPG(unittest.TestCase):
|
||||
|
||||
td_error = q_t_selected - q_t_selected_target
|
||||
twin_td_error = twin_q_t_selected - q_t_selected_target
|
||||
td_error = td_error + twin_td_error
|
||||
errors = huber_loss(td_error, huber_threshold) + \
|
||||
huber_loss(twin_td_error, huber_threshold)
|
||||
|
||||
|
||||
Reference in New Issue
Block a user