mirror of
https://github.com/wassname/ray.git
synced 2026-06-27 21:08:13 +08:00
[RLlib] Issue 11591: SAC loss does not use PR-weights in critic loss term. (#12394)
* WIP. * Fix and LINT.
This commit is contained in:
@@ -12,7 +12,8 @@ import ray
|
||||
import ray.experimental.tf_utils
|
||||
from ray.rllib.agents.ddpg.ddpg_tf_policy import ComputeTDErrorMixin, \
|
||||
TargetNetworkMixin
|
||||
from ray.rllib.agents.dqn.dqn_tf_policy import postprocess_nstep_and_prio
|
||||
from ray.rllib.agents.dqn.dqn_tf_policy import postprocess_nstep_and_prio, \
|
||||
PRIO_WEIGHTS
|
||||
from ray.rllib.agents.sac.sac_tf_model import SACTFModel
|
||||
from ray.rllib.agents.sac.sac_torch_model import SACTorchModel
|
||||
from ray.rllib.evaluation.episode import MultiAgentEpisode
|
||||
@@ -27,6 +28,7 @@ from ray.rllib.utils.error import UnsupportedSpaceException
|
||||
from ray.rllib.utils.framework import get_variable, try_import_tf, \
|
||||
try_import_tfp
|
||||
from ray.rllib.utils.spaces.simplex import Simplex
|
||||
from ray.rllib.utils.tf_ops import huber_loss
|
||||
from ray.rllib.utils.typing import AgentID, LocalOptimizer, ModelGradients, \
|
||||
TensorType, TrainerConfigDict
|
||||
|
||||
@@ -338,13 +340,11 @@ def sac_actor_critic_loss(
|
||||
td_error = base_td_error
|
||||
|
||||
# Calculate one or two critic losses (2 in the twin_q case).
|
||||
critic_loss = [
|
||||
0.5 * tf.keras.losses.MSE(
|
||||
y_true=q_t_selected_target, y_pred=q_t_selected)
|
||||
]
|
||||
prio_weights = tf.cast(train_batch[PRIO_WEIGHTS], tf.float32)
|
||||
critic_loss = [tf.reduce_mean(prio_weights * huber_loss(base_td_error))]
|
||||
if policy.config["twin_q"]:
|
||||
critic_loss.append(0.5 * tf.keras.losses.MSE(
|
||||
y_true=q_t_selected_target, y_pred=twin_q_t_selected))
|
||||
critic_loss.append(
|
||||
tf.reduce_mean(prio_weights * huber_loss(twin_td_error)))
|
||||
|
||||
# Alpha- and actor losses.
|
||||
# Note: In the papers, alpha is used directly, here we take the log.
|
||||
|
||||
@@ -23,6 +23,7 @@ from ray.rllib.models.torch.torch_action_dist import (
|
||||
TorchCategorical, TorchSquashedGaussian, TorchDiagGaussian, TorchBeta)
|
||||
from ray.rllib.utils.framework import try_import_torch
|
||||
from ray.rllib.utils.spaces.simplex import Simplex
|
||||
from ray.rllib.utils.torch_ops import huber_loss
|
||||
from ray.rllib.utils.typing import LocalOptimizer, TensorType, \
|
||||
TrainerConfigDict
|
||||
|
||||
@@ -267,11 +268,11 @@ def actor_critic_loss(
|
||||
td_error = base_td_error
|
||||
|
||||
critic_loss = [
|
||||
0.5 * torch.mean(torch.pow(q_t_selected_target - q_t_selected, 2.0))
|
||||
torch.mean(train_batch[PRIO_WEIGHTS] * huber_loss(base_td_error))
|
||||
]
|
||||
if policy.config["twin_q"]:
|
||||
critic_loss.append(0.5 * torch.mean(
|
||||
torch.pow(q_t_selected_target - twin_q_t_selected, 2.0)))
|
||||
critic_loss.append(
|
||||
torch.mean(train_batch[PRIO_WEIGHTS] * huber_loss(twin_td_error)))
|
||||
|
||||
# Alpha- and actor losses.
|
||||
# Note: In the papers, alpha is used directly, here we take the log.
|
||||
|
||||
@@ -14,7 +14,7 @@ from ray.rllib.models.torch.torch_action_dist import TorchDirichlet
|
||||
from ray.rllib.execution.replay_buffer import LocalReplayBuffer
|
||||
from ray.rllib.policy.sample_batch import SampleBatch
|
||||
from ray.rllib.utils.framework import try_import_tf, try_import_torch
|
||||
from ray.rllib.utils.numpy import fc, relu
|
||||
from ray.rllib.utils.numpy import fc, huber_loss, relu
|
||||
from ray.rllib.utils.spaces.simplex import Simplex
|
||||
from ray.rllib.utils.test_utils import check, check_compute_single_action, \
|
||||
framework_iterator
|
||||
@@ -524,7 +524,8 @@ class TestSAC(unittest.TestCase):
|
||||
base_td_error = np.abs(q_t_selected - q_t_selected_target)
|
||||
td_error = base_td_error
|
||||
critic_loss = [
|
||||
0.5 * np.mean(np.power(q_t_selected_target - q_t_selected, 2.0))
|
||||
np.mean(train_batch["weights"] *
|
||||
huber_loss(q_t_selected_target - q_t_selected))
|
||||
]
|
||||
target_entropy = -np.prod((1, ))
|
||||
alpha_loss = -np.mean(log_alpha * (log_pis_t + target_entropy))
|
||||
|
||||
Reference in New Issue
Block a user