mirror of
https://github.com/wassname/pytorch-transformer-ts.git
synced 2026-06-27 16:31:19 +08:00
added switch encoder layer
This commit is contained in:
+34
-9
@@ -131,7 +131,7 @@ class SwitchFeedForward(nn.Module):
|
||||
# Return
|
||||
#
|
||||
# * the final output
|
||||
# * number of tokens routed to each expert
|
||||
# * counts: number of tokens routed to each expert
|
||||
# * sum of probabilities for each expert
|
||||
# * number of tokens dropped.
|
||||
# * routing probabilities of the selected experts
|
||||
@@ -140,7 +140,7 @@ class SwitchFeedForward(nn.Module):
|
||||
return final_output, counts, route_prob.sum(0), len(dropped), route_prob_max
|
||||
|
||||
|
||||
class TransformerEncoderLayer(nn.Module):
|
||||
class SwitchTransformerEncoderLayer(nn.Module):
|
||||
|
||||
__constants__ = ["batch_first", "norm_first"]
|
||||
|
||||
@@ -148,7 +148,7 @@ class TransformerEncoderLayer(nn.Module):
|
||||
self,
|
||||
d_model: int,
|
||||
nhead: int,
|
||||
capacity_factor: int,
|
||||
capacity_factor: float,
|
||||
drop_tokens: bool,
|
||||
is_scale_prob: bool,
|
||||
n_experts: int = 1,
|
||||
@@ -162,12 +162,13 @@ class TransformerEncoderLayer(nn.Module):
|
||||
dtype=None,
|
||||
) -> None:
|
||||
factory_kwargs = {"device": device, "dtype": dtype}
|
||||
super(TransformerEncoderLayer, self).__init__()
|
||||
super(SwitchTransformerEncoderLayer, self).__init__()
|
||||
self.self_attn = nn.MultiheadAttention(
|
||||
d_model, nhead, dropout=dropout, batch_first=batch_first, **factory_kwargs
|
||||
)
|
||||
# Implementation of Feedforward model
|
||||
linear = nn.Linear(d_model, dim_feedforward, **factory_kwargs)
|
||||
|
||||
self.linear1 = SwitchFeedForward(
|
||||
capacity_factor,
|
||||
drop_tokens,
|
||||
@@ -195,7 +196,7 @@ class TransformerEncoderLayer(nn.Module):
|
||||
def __setstate__(self, state):
|
||||
if "activation" not in state:
|
||||
state["activation"] = F.relu
|
||||
super(TransformerEncoderLayer, self).__setstate__(state)
|
||||
super(SwitchTransformerEncoderLayer, self).__setstate__(state)
|
||||
|
||||
def forward(
|
||||
self,
|
||||
@@ -233,7 +234,8 @@ class TransformerEncoderLayer(nn.Module):
|
||||
|
||||
# feed forward block
|
||||
def _ff_block(self, x: torch.Tensor) -> torch.Tensor:
|
||||
x = self.linear2(self.dropout(self.activation(self.linear1(x))))
|
||||
x, _, _, _, _ = self.linear1(x)
|
||||
x = self.linear2(self.dropout(self.activation(x)))
|
||||
return self.dropout2(x)
|
||||
|
||||
|
||||
@@ -253,8 +255,13 @@ class SwitchTransformerModel(nn.Module):
|
||||
num_encoder_layers: int,
|
||||
num_decoder_layers: int,
|
||||
dim_feedforward: int,
|
||||
capacity_factor: float,
|
||||
activation: str = "gelu",
|
||||
dropout: float = 0.1,
|
||||
layer_norm_eps: float = 1e-5,
|
||||
drop_tokens: bool = False,
|
||||
is_scale_prob: bool = True,
|
||||
n_experts: int = 1,
|
||||
# univariate input
|
||||
input_size: int = 1,
|
||||
embedding_dimension: Optional[List[int]] = None,
|
||||
@@ -296,11 +303,29 @@ class SwitchTransformerModel(nn.Module):
|
||||
self.distr_output = distr_output
|
||||
self.param_proj = distr_output.get_args_proj(d_model)
|
||||
|
||||
# transformer enc-decoder and mask initializer
|
||||
# switch-transformer enc
|
||||
switch_encoder_layer = SwitchTransformerEncoderLayer(
|
||||
d_model=d_model,
|
||||
nhead=nhead,
|
||||
capacity_factor=capacity_factor,
|
||||
drop_tokens=drop_tokens,
|
||||
is_scale_prob=is_scale_prob,
|
||||
n_experts=n_experts,
|
||||
dim_feedforward=dim_feedforward,
|
||||
dropout=dropout,
|
||||
activation=activation,
|
||||
layer_norm_eps=layer_norm_eps,
|
||||
)
|
||||
switch_encoder_norm = nn.LayerNorm(d_model, eps=layer_norm_eps)
|
||||
switch_encoder = nn.TransformerEncoder(
|
||||
switch_encoder_layer, num_encoder_layers, switch_encoder_norm
|
||||
)
|
||||
|
||||
# vanilla decoder and mask initializer
|
||||
self.transformer = nn.Transformer(
|
||||
d_model=d_model,
|
||||
nhead=nhead,
|
||||
num_encoder_layers=num_encoder_layers,
|
||||
custom_encoder=switch_encoder,
|
||||
num_decoder_layers=num_decoder_layers,
|
||||
dim_feedforward=dim_feedforward,
|
||||
dropout=dropout,
|
||||
@@ -311,7 +336,7 @@ class SwitchTransformerModel(nn.Module):
|
||||
# causal decoder tgt mask
|
||||
self.register_buffer(
|
||||
"tgt_mask",
|
||||
self.transformer.generate_square_subsequent_mask(prediction_length),
|
||||
nn.Transformer.generate_square_subsequent_mask(prediction_length),
|
||||
)
|
||||
|
||||
@property
|
||||
|
||||
Reference in New Issue
Block a user