diff --git a/rllib/agents/dqn/dqn_torch_model.py b/rllib/agents/dqn/dqn_torch_model.py index 6fcf6f077..2764f0bb7 100644 --- a/rllib/agents/dqn/dqn_torch_model.py +++ b/rllib/agents/dqn/dqn_torch_model.py @@ -144,7 +144,7 @@ class DQNTorchModel(TorchModelV2, nn.Module): support_logits_per_action = torch.reshape( action_scores, shape=(-1, self.action_space.n, self.num_atoms)) support_prob_per_action = nn.functional.softmax( - support_logits_per_action) + support_logits_per_action, dim=-1) action_scores = torch.sum(z * support_prob_per_action, dim=-1) logits = support_logits_per_action probs = support_prob_per_action