Fixed num_atoms>1 in pytorch (#10330)

This commit is contained in:
Olli Huotari
2020-08-26 09:10:20 +03:00
committed by GitHub
parent 8c0503ddd3
commit 0dae50b5eb
+1 -1
View File
@@ -144,7 +144,7 @@ class DQNTorchModel(TorchModelV2, nn.Module):
support_logits_per_action = torch.reshape(
action_scores, shape=(-1, self.action_space.n, self.num_atoms))
support_prob_per_action = nn.functional.softmax(
support_logits_per_action)
support_logits_per_action, dim=-1)
action_scores = torch.sum(z * support_prob_per_action, dim=-1)
logits = support_logits_per_action
probs = support_prob_per_action