added switch encoder layer

This commit is contained in:
Kashif Rasul
2022-06-06 11:04:43 +02:00
parent 40981abdb4
commit a429cef13d
+34 -9
View File
@@ -131,7 +131,7 @@ class SwitchFeedForward(nn.Module):
# Return
#
# * the final output
# * number of tokens routed to each expert
# * counts: number of tokens routed to each expert
# * sum of probabilities for each expert
# * number of tokens dropped.
# * routing probabilities of the selected experts
@@ -140,7 +140,7 @@ class SwitchFeedForward(nn.Module):
return final_output, counts, route_prob.sum(0), len(dropped), route_prob_max
class TransformerEncoderLayer(nn.Module):
class SwitchTransformerEncoderLayer(nn.Module):
__constants__ = ["batch_first", "norm_first"]
@@ -148,7 +148,7 @@ class TransformerEncoderLayer(nn.Module):
self,
d_model: int,
nhead: int,
capacity_factor: int,
capacity_factor: float,
drop_tokens: bool,
is_scale_prob: bool,
n_experts: int = 1,
@@ -162,12 +162,13 @@ class TransformerEncoderLayer(nn.Module):
dtype=None,
) -> None:
factory_kwargs = {"device": device, "dtype": dtype}
super(TransformerEncoderLayer, self).__init__()
super(SwitchTransformerEncoderLayer, self).__init__()
self.self_attn = nn.MultiheadAttention(
d_model, nhead, dropout=dropout, batch_first=batch_first, **factory_kwargs
)
# Implementation of Feedforward model
linear = nn.Linear(d_model, dim_feedforward, **factory_kwargs)
self.linear1 = SwitchFeedForward(
capacity_factor,
drop_tokens,
@@ -195,7 +196,7 @@ class TransformerEncoderLayer(nn.Module):
def __setstate__(self, state):
if "activation" not in state:
state["activation"] = F.relu
super(TransformerEncoderLayer, self).__setstate__(state)
super(SwitchTransformerEncoderLayer, self).__setstate__(state)
def forward(
self,
@@ -233,7 +234,8 @@ class TransformerEncoderLayer(nn.Module):
# feed forward block
def _ff_block(self, x: torch.Tensor) -> torch.Tensor:
x = self.linear2(self.dropout(self.activation(self.linear1(x))))
x, _, _, _, _ = self.linear1(x)
x = self.linear2(self.dropout(self.activation(x)))
return self.dropout2(x)
@@ -253,8 +255,13 @@ class SwitchTransformerModel(nn.Module):
num_encoder_layers: int,
num_decoder_layers: int,
dim_feedforward: int,
capacity_factor: float,
activation: str = "gelu",
dropout: float = 0.1,
layer_norm_eps: float = 1e-5,
drop_tokens: bool = False,
is_scale_prob: bool = True,
n_experts: int = 1,
# univariate input
input_size: int = 1,
embedding_dimension: Optional[List[int]] = None,
@@ -296,11 +303,29 @@ class SwitchTransformerModel(nn.Module):
self.distr_output = distr_output
self.param_proj = distr_output.get_args_proj(d_model)
# transformer enc-decoder and mask initializer
# switch-transformer enc
switch_encoder_layer = SwitchTransformerEncoderLayer(
d_model=d_model,
nhead=nhead,
capacity_factor=capacity_factor,
drop_tokens=drop_tokens,
is_scale_prob=is_scale_prob,
n_experts=n_experts,
dim_feedforward=dim_feedforward,
dropout=dropout,
activation=activation,
layer_norm_eps=layer_norm_eps,
)
switch_encoder_norm = nn.LayerNorm(d_model, eps=layer_norm_eps)
switch_encoder = nn.TransformerEncoder(
switch_encoder_layer, num_encoder_layers, switch_encoder_norm
)
# vanilla decoder and mask initializer
self.transformer = nn.Transformer(
d_model=d_model,
nhead=nhead,
num_encoder_layers=num_encoder_layers,
custom_encoder=switch_encoder,
num_decoder_layers=num_decoder_layers,
dim_feedforward=dim_feedforward,
dropout=dropout,
@@ -311,7 +336,7 @@ class SwitchTransformerModel(nn.Module):
# causal decoder tgt mask
self.register_buffer(
"tgt_mask",
self.transformer.generate_square_subsequent_mask(prediction_length),
nn.Transformer.generate_square_subsequent_mask(prediction_length),
)
@property