diff --git a/switch/module.py b/switch/module.py index 036d66a..ec42d30 100644 --- a/switch/module.py +++ b/switch/module.py @@ -170,10 +170,10 @@ class SwitchTransformerEncoderLayer(nn.Module): linear = nn.Linear(d_model, dim_feedforward, **factory_kwargs) self.linear1 = SwitchFeedForward( - capacity_factor, - drop_tokens, - is_scale_prob, - n_experts, + capacity_factor=capacity_factor, + drop_tokens=drop_tokens, + is_scale_prob=is_scale_prob, + n_experts=n_experts, expert=linear, d_model=d_model, dim_feedforward=dim_feedforward,