From 0dae50b5ebd09c1a7427a7edcb8ff60301b79fee Mon Sep 17 00:00:00 2001 From: Olli Huotari Date: Wed, 26 Aug 2020 09:10:20 +0300 Subject: [PATCH] Fixed num_atoms>1 in pytorch (#10330) --- rllib/agents/dqn/dqn_torch_model.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/rllib/agents/dqn/dqn_torch_model.py b/rllib/agents/dqn/dqn_torch_model.py index 6fcf6f077..2764f0bb7 100644 --- a/rllib/agents/dqn/dqn_torch_model.py +++ b/rllib/agents/dqn/dqn_torch_model.py @@ -144,7 +144,7 @@ class DQNTorchModel(TorchModelV2, nn.Module): support_logits_per_action = torch.reshape( action_scores, shape=(-1, self.action_space.n, self.num_atoms)) support_prob_per_action = nn.functional.softmax( - support_logits_per_action) + support_logits_per_action, dim=-1) action_scores = torch.sum(z * support_prob_per_action, dim=-1) logits = support_logits_per_action probs = support_prob_per_action