This commit is contained in:
Kashif Rasul
2022-06-06 11:44:19 +02:00
parent 159254348b
commit c99572a6d9
+4 -4
View File
@@ -170,10 +170,10 @@ class SwitchTransformerEncoderLayer(nn.Module):
linear = nn.Linear(d_model, dim_feedforward, **factory_kwargs)
self.linear1 = SwitchFeedForward(
capacity_factor,
drop_tokens,
is_scale_prob,
n_experts,
capacity_factor=capacity_factor,
drop_tokens=drop_tokens,
is_scale_prob=is_scale_prob,
n_experts=n_experts,
expert=linear,
d_model=d_model,
dim_feedforward=dim_feedforward,