From 7dcfd258cd3e8760d71b76819d8d90a88eb03319 Mon Sep 17 00:00:00 2001 From: Julius Frost <33183774+juliusfrost@users.noreply.github.com> Date: Mon, 12 Oct 2020 16:47:21 -0400 Subject: [PATCH] [RLlib] Assert LongTensor in SAC Discrete PyTorch (#11245) --- rllib/agents/sac/sac_torch_policy.py | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/rllib/agents/sac/sac_torch_policy.py b/rllib/agents/sac/sac_torch_policy.py index 27199e057..42a796a2e 100644 --- a/rllib/agents/sac/sac_torch_policy.py +++ b/rllib/agents/sac/sac_torch_policy.py @@ -192,7 +192,8 @@ def actor_critic_loss( # Actually selected Q-values (from the actions batch). one_hot = F.one_hot( - train_batch[SampleBatch.ACTIONS], num_classes=q_t.size()[-1]) + train_batch[SampleBatch.ACTIONS].long(), + num_classes=q_t.size()[-1]) q_t_selected = torch.sum(q_t * one_hot, dim=-1) if policy.config["twin_q"]: twin_q_t_selected = torch.sum(twin_q_t * one_hot, dim=-1)