mirror of
https://github.com/wassname/pytorch-transformer-ts.git
synced 2026-06-27 18:44:00 +08:00
fix final_output tensor shape
This commit is contained in:
+1
-1
@@ -75,7 +75,7 @@ class SwitchFeedForward(nn.Module):
|
||||
]
|
||||
|
||||
# Initialize an empty tensor to store outputs
|
||||
final_output = x.new_zeros((batch_size, seq_len, self.dim_feedforward))
|
||||
final_output = x.new_zeros((batch_size * seq_len, self.dim_feedforward))
|
||||
|
||||
# Capacity of each expert.
|
||||
# $$\mathrm{expert\;capacity} =
|
||||
|
||||
Reference in New Issue
Block a user