fix final_output tensor shape

This commit is contained in:
Kashif Rasul
2022-06-06 14:11:51 +02:00
parent 0b0ecc94f9
commit d11578a235
+1 -1
View File
@@ -75,7 +75,7 @@ class SwitchFeedForward(nn.Module):
]
# Initialize an empty tensor to store outputs
final_output = x.new_zeros((batch_size, seq_len, self.dim_feedforward))
final_output = x.new_zeros((batch_size * seq_len, self.dim_feedforward))
# Capacity of each expert.
# $$\mathrm{expert\;capacity} =