mirror of
https://github.com/wassname/pytorch-transformer-ts.git
synced 2026-06-27 16:46:28 +08:00
fix args
This commit is contained in:
+4
-4
@@ -170,10 +170,10 @@ class SwitchTransformerEncoderLayer(nn.Module):
|
||||
linear = nn.Linear(d_model, dim_feedforward, **factory_kwargs)
|
||||
|
||||
self.linear1 = SwitchFeedForward(
|
||||
capacity_factor,
|
||||
drop_tokens,
|
||||
is_scale_prob,
|
||||
n_experts,
|
||||
capacity_factor=capacity_factor,
|
||||
drop_tokens=drop_tokens,
|
||||
is_scale_prob=is_scale_prob,
|
||||
n_experts=n_experts,
|
||||
expert=linear,
|
||||
d_model=d_model,
|
||||
dim_feedforward=dim_feedforward,
|
||||
|
||||
Reference in New Issue
Block a user