import torch import torch.nn as nn from einops import rearrange, reduce, repeat class DampingLayer(nn.Module): def __init__(self, pred_len, nhead, dropout=0.1, output_attention=False): super().__init__() self.pred_len = pred_len self.nhead = nhead self.output_attention = output_attention self._damping_factor = nn.Parameter(torch.randn(1, nhead)) self.dropout = nn.Dropout(dropout) def forward(self, x): x = repeat(x, 'b 1 d -> b t d', t=self.pred_len) b, t, d = x.shape powers = torch.arange(self.pred_len).to(self._damping_factor.device) + 1 powers = powers.view(self.pred_len, 1) damping_factors = self.damping_factor ** powers damping_factors = damping_factors.cumsum(dim=0) x = x.view(b, t, self.nhead, -1) x = self.dropout(x) * damping_factors.unsqueeze(-1) x = x.view(b, t, d) if self.output_attention: return x, damping_factors return x, None @property def damping_factor(self): return torch.sigmoid(self._damping_factor) class DecoderLayer(nn.Module): def __init__(self, d_model, nhead, c_out, pred_len, dropout=0.1, output_attention=False): super().__init__() self.d_model = d_model self.nhead = nhead self.c_out = c_out self.pred_len = pred_len self.output_attention = output_attention self.growth_damping = DampingLayer(pred_len, nhead, dropout=dropout, output_attention=output_attention) self.dropout1 = nn.Dropout(dropout) def forward(self, growth, season): growth_horizon, growth_damping = self.growth_damping(growth[:, -1:]) growth_horizon = self.dropout1(growth_horizon) seasonal_horizon = season[:, -self.pred_len:] if self.output_attention: return growth_horizon, seasonal_horizon, growth_damping return growth_horizon, seasonal_horizon, None class Decoder(nn.Module): def __init__(self, layers): super().__init__() self.d_model = layers[0].d_model self.c_out = layers[0].c_out self.pred_len = layers[0].pred_len self.nhead = layers[0].nhead self.layers = nn.ModuleList(layers) self.pred = nn.Linear(self.d_model, self.c_out) def forward(self, growths, seasons): growth_repr = [] season_repr = [] growth_dampings = [] for idx, layer in enumerate(self.layers): growth_horizon, season_horizon, growth_damping = layer(growths[idx], seasons[idx]) growth_repr.append(growth_horizon) season_repr.append(season_horizon) growth_dampings.append(growth_damping) growth_repr = sum(growth_repr) season_repr = sum(season_repr) return self.pred(growth_repr), self.pred(season_repr), growth_dampings