mirror of
https://github.com/wassname/ray.git
synced 2026-07-05 17:55:15 +08:00
Fixed num_atoms>1 in pytorch (#10330)
This commit is contained in:
@@ -144,7 +144,7 @@ class DQNTorchModel(TorchModelV2, nn.Module):
|
||||
support_logits_per_action = torch.reshape(
|
||||
action_scores, shape=(-1, self.action_space.n, self.num_atoms))
|
||||
support_prob_per_action = nn.functional.softmax(
|
||||
support_logits_per_action)
|
||||
support_logits_per_action, dim=-1)
|
||||
action_scores = torch.sum(z * support_prob_per_action, dim=-1)
|
||||
logits = support_logits_per_action
|
||||
probs = support_prob_per_action
|
||||
|
||||
Reference in New Issue
Block a user