diff --git a/rllib/agents/sac/sac_tf_policy.py b/rllib/agents/sac/sac_tf_policy.py index 808b7d6e4..44ddbff1f 100644 --- a/rllib/agents/sac/sac_tf_policy.py +++ b/rllib/agents/sac/sac_tf_policy.py @@ -12,7 +12,8 @@ import ray import ray.experimental.tf_utils from ray.rllib.agents.ddpg.ddpg_tf_policy import ComputeTDErrorMixin, \ TargetNetworkMixin -from ray.rllib.agents.dqn.dqn_tf_policy import postprocess_nstep_and_prio +from ray.rllib.agents.dqn.dqn_tf_policy import postprocess_nstep_and_prio, \ + PRIO_WEIGHTS from ray.rllib.agents.sac.sac_tf_model import SACTFModel from ray.rllib.agents.sac.sac_torch_model import SACTorchModel from ray.rllib.evaluation.episode import MultiAgentEpisode @@ -27,6 +28,7 @@ from ray.rllib.utils.error import UnsupportedSpaceException from ray.rllib.utils.framework import get_variable, try_import_tf, \ try_import_tfp from ray.rllib.utils.spaces.simplex import Simplex +from ray.rllib.utils.tf_ops import huber_loss from ray.rllib.utils.typing import AgentID, LocalOptimizer, ModelGradients, \ TensorType, TrainerConfigDict @@ -338,13 +340,11 @@ def sac_actor_critic_loss( td_error = base_td_error # Calculate one or two critic losses (2 in the twin_q case). - critic_loss = [ - 0.5 * tf.keras.losses.MSE( - y_true=q_t_selected_target, y_pred=q_t_selected) - ] + prio_weights = tf.cast(train_batch[PRIO_WEIGHTS], tf.float32) + critic_loss = [tf.reduce_mean(prio_weights * huber_loss(base_td_error))] if policy.config["twin_q"]: - critic_loss.append(0.5 * tf.keras.losses.MSE( - y_true=q_t_selected_target, y_pred=twin_q_t_selected)) + critic_loss.append( + tf.reduce_mean(prio_weights * huber_loss(twin_td_error))) # Alpha- and actor losses. # Note: In the papers, alpha is used directly, here we take the log. diff --git a/rllib/agents/sac/sac_torch_policy.py b/rllib/agents/sac/sac_torch_policy.py index 5b39137b7..b4b225865 100644 --- a/rllib/agents/sac/sac_torch_policy.py +++ b/rllib/agents/sac/sac_torch_policy.py @@ -23,6 +23,7 @@ from ray.rllib.models.torch.torch_action_dist import ( TorchCategorical, TorchSquashedGaussian, TorchDiagGaussian, TorchBeta) from ray.rllib.utils.framework import try_import_torch from ray.rllib.utils.spaces.simplex import Simplex +from ray.rllib.utils.torch_ops import huber_loss from ray.rllib.utils.typing import LocalOptimizer, TensorType, \ TrainerConfigDict @@ -267,11 +268,11 @@ def actor_critic_loss( td_error = base_td_error critic_loss = [ - 0.5 * torch.mean(torch.pow(q_t_selected_target - q_t_selected, 2.0)) + torch.mean(train_batch[PRIO_WEIGHTS] * huber_loss(base_td_error)) ] if policy.config["twin_q"]: - critic_loss.append(0.5 * torch.mean( - torch.pow(q_t_selected_target - twin_q_t_selected, 2.0))) + critic_loss.append( + torch.mean(train_batch[PRIO_WEIGHTS] * huber_loss(twin_td_error))) # Alpha- and actor losses. # Note: In the papers, alpha is used directly, here we take the log. diff --git a/rllib/agents/sac/tests/test_sac.py b/rllib/agents/sac/tests/test_sac.py index 11f4c047d..6a84b19c7 100644 --- a/rllib/agents/sac/tests/test_sac.py +++ b/rllib/agents/sac/tests/test_sac.py @@ -14,7 +14,7 @@ from ray.rllib.models.torch.torch_action_dist import TorchDirichlet from ray.rllib.execution.replay_buffer import LocalReplayBuffer from ray.rllib.policy.sample_batch import SampleBatch from ray.rllib.utils.framework import try_import_tf, try_import_torch -from ray.rllib.utils.numpy import fc, relu +from ray.rllib.utils.numpy import fc, huber_loss, relu from ray.rllib.utils.spaces.simplex import Simplex from ray.rllib.utils.test_utils import check, check_compute_single_action, \ framework_iterator @@ -524,7 +524,8 @@ class TestSAC(unittest.TestCase): base_td_error = np.abs(q_t_selected - q_t_selected_target) td_error = base_td_error critic_loss = [ - 0.5 * np.mean(np.power(q_t_selected_target - q_t_selected, 2.0)) + np.mean(train_batch["weights"] * + huber_loss(q_t_selected_target - q_t_selected)) ] target_entropy = -np.prod((1, )) alpha_loss = -np.mean(log_alpha * (log_pis_t + target_entropy))