mirror of
https://github.com/wassname/ray.git
synced 2026-06-29 05:34:49 +08:00
[RLlib] Assert LongTensor in SAC Discrete PyTorch (#11245)
This commit is contained in:
@@ -192,7 +192,8 @@ def actor_critic_loss(
|
||||
|
||||
# Actually selected Q-values (from the actions batch).
|
||||
one_hot = F.one_hot(
|
||||
train_batch[SampleBatch.ACTIONS], num_classes=q_t.size()[-1])
|
||||
train_batch[SampleBatch.ACTIONS].long(),
|
||||
num_classes=q_t.size()[-1])
|
||||
q_t_selected = torch.sum(q_t * one_hot, dim=-1)
|
||||
if policy.config["twin_q"]:
|
||||
twin_q_t_selected = torch.sum(twin_q_t * one_hot, dim=-1)
|
||||
|
||||
Reference in New Issue
Block a user