[RLlib] Issue 11591: SAC loss does not use PR-weights in critic loss term. (#12394)

* WIP.

* Fix and LINT.
This commit is contained in:
Sven Mika
2020-11-25 20:28:46 +01:00
committed by GitHub
parent 592c161032
commit b7dbbfbf41
3 changed files with 14 additions and 12 deletions
+7 -7
View File
@@ -12,7 +12,8 @@ import ray
import ray.experimental.tf_utils
from ray.rllib.agents.ddpg.ddpg_tf_policy import ComputeTDErrorMixin, \
TargetNetworkMixin
from ray.rllib.agents.dqn.dqn_tf_policy import postprocess_nstep_and_prio
from ray.rllib.agents.dqn.dqn_tf_policy import postprocess_nstep_and_prio, \
PRIO_WEIGHTS
from ray.rllib.agents.sac.sac_tf_model import SACTFModel
from ray.rllib.agents.sac.sac_torch_model import SACTorchModel
from ray.rllib.evaluation.episode import MultiAgentEpisode
@@ -27,6 +28,7 @@ from ray.rllib.utils.error import UnsupportedSpaceException
from ray.rllib.utils.framework import get_variable, try_import_tf, \
try_import_tfp
from ray.rllib.utils.spaces.simplex import Simplex
from ray.rllib.utils.tf_ops import huber_loss
from ray.rllib.utils.typing import AgentID, LocalOptimizer, ModelGradients, \
TensorType, TrainerConfigDict
@@ -338,13 +340,11 @@ def sac_actor_critic_loss(
td_error = base_td_error
# Calculate one or two critic losses (2 in the twin_q case).
critic_loss = [
0.5 * tf.keras.losses.MSE(
y_true=q_t_selected_target, y_pred=q_t_selected)
]
prio_weights = tf.cast(train_batch[PRIO_WEIGHTS], tf.float32)
critic_loss = [tf.reduce_mean(prio_weights * huber_loss(base_td_error))]
if policy.config["twin_q"]:
critic_loss.append(0.5 * tf.keras.losses.MSE(
y_true=q_t_selected_target, y_pred=twin_q_t_selected))
critic_loss.append(
tf.reduce_mean(prio_weights * huber_loss(twin_td_error)))
# Alpha- and actor losses.
# Note: In the papers, alpha is used directly, here we take the log.
+4 -3
View File
@@ -23,6 +23,7 @@ from ray.rllib.models.torch.torch_action_dist import (
TorchCategorical, TorchSquashedGaussian, TorchDiagGaussian, TorchBeta)
from ray.rllib.utils.framework import try_import_torch
from ray.rllib.utils.spaces.simplex import Simplex
from ray.rllib.utils.torch_ops import huber_loss
from ray.rllib.utils.typing import LocalOptimizer, TensorType, \
TrainerConfigDict
@@ -267,11 +268,11 @@ def actor_critic_loss(
td_error = base_td_error
critic_loss = [
0.5 * torch.mean(torch.pow(q_t_selected_target - q_t_selected, 2.0))
torch.mean(train_batch[PRIO_WEIGHTS] * huber_loss(base_td_error))
]
if policy.config["twin_q"]:
critic_loss.append(0.5 * torch.mean(
torch.pow(q_t_selected_target - twin_q_t_selected, 2.0)))
critic_loss.append(
torch.mean(train_batch[PRIO_WEIGHTS] * huber_loss(twin_td_error)))
# Alpha- and actor losses.
# Note: In the papers, alpha is used directly, here we take the log.
+3 -2
View File
@@ -14,7 +14,7 @@ from ray.rllib.models.torch.torch_action_dist import TorchDirichlet
from ray.rllib.execution.replay_buffer import LocalReplayBuffer
from ray.rllib.policy.sample_batch import SampleBatch
from ray.rllib.utils.framework import try_import_tf, try_import_torch
from ray.rllib.utils.numpy import fc, relu
from ray.rllib.utils.numpy import fc, huber_loss, relu
from ray.rllib.utils.spaces.simplex import Simplex
from ray.rllib.utils.test_utils import check, check_compute_single_action, \
framework_iterator
@@ -524,7 +524,8 @@ class TestSAC(unittest.TestCase):
base_td_error = np.abs(q_t_selected - q_t_selected_target)
td_error = base_td_error
critic_loss = [
0.5 * np.mean(np.power(q_t_selected_target - q_t_selected, 2.0))
np.mean(train_batch["weights"] *
huber_loss(q_t_selected_target - q_t_selected))
]
target_entropy = -np.prod((1, ))
alpha_loss = -np.mean(log_alpha * (log_pis_t + target_entropy))