diff --git a/rllib/agents/sac/sac_torch_policy.py b/rllib/agents/sac/sac_torch_policy.py index 27199e057..42a796a2e 100644 --- a/rllib/agents/sac/sac_torch_policy.py +++ b/rllib/agents/sac/sac_torch_policy.py @@ -192,7 +192,8 @@ def actor_critic_loss( # Actually selected Q-values (from the actions batch). one_hot = F.one_hot( - train_batch[SampleBatch.ACTIONS], num_classes=q_t.size()[-1]) + train_batch[SampleBatch.ACTIONS].long(), + num_classes=q_t.size()[-1]) q_t_selected = torch.sum(q_t * one_hot, dim=-1) if policy.config["twin_q"]: twin_q_t_selected = torch.sum(twin_q_t * one_hot, dim=-1)