[RLlib] Assert LongTensor in SAC Discrete PyTorch (#11245)

This commit is contained in:
Julius Frost
2020-10-12 16:47:21 -04:00
committed by GitHub
parent 580820a530
commit 7dcfd258cd
+2 -1
View File
@@ -192,7 +192,8 @@ def actor_critic_loss(
# Actually selected Q-values (from the actions batch).
one_hot = F.one_hot(
train_batch[SampleBatch.ACTIONS], num_classes=q_t.size()[-1])
train_batch[SampleBatch.ACTIONS].long(),
num_classes=q_t.size()[-1])
q_t_selected = torch.sum(q_t * one_hot, dim=-1)
if policy.config["twin_q"]:
twin_q_t_selected = torch.sum(twin_q_t * one_hot, dim=-1)