mirror of
https://github.com/wassname/ETSformer.git
synced 2026-06-27 17:46:35 +08:00
85 lines
2.8 KiB
Python
85 lines
2.8 KiB
Python
import torch
|
|
import torch.nn as nn
|
|
from einops import rearrange, reduce, repeat
|
|
|
|
|
|
class DampingLayer(nn.Module):
|
|
|
|
def __init__(self, pred_len, nhead, dropout=0.1, output_attention=False):
|
|
super().__init__()
|
|
self.pred_len = pred_len
|
|
self.nhead = nhead
|
|
self.output_attention = output_attention
|
|
self._damping_factor = nn.Parameter(torch.randn(1, nhead))
|
|
self.dropout = nn.Dropout(dropout)
|
|
|
|
def forward(self, x):
|
|
x = repeat(x, 'b 1 d -> b t d', t=self.pred_len)
|
|
b, t, d = x.shape
|
|
|
|
powers = torch.arange(self.pred_len).to(self._damping_factor.device) + 1
|
|
powers = powers.view(self.pred_len, 1)
|
|
damping_factors = self.damping_factor ** powers
|
|
damping_factors = damping_factors.cumsum(dim=0)
|
|
x = x.view(b, t, self.nhead, -1)
|
|
x = self.dropout(x) * damping_factors.unsqueeze(-1)
|
|
x = x.view(b, t, d)
|
|
if self.output_attention:
|
|
return x, damping_factors
|
|
return x, None
|
|
|
|
@property
|
|
def damping_factor(self):
|
|
return torch.sigmoid(self._damping_factor)
|
|
|
|
|
|
class DecoderLayer(nn.Module):
|
|
|
|
def __init__(self, d_model, nhead, c_out, pred_len, dropout=0.1, output_attention=False):
|
|
super().__init__()
|
|
self.d_model = d_model
|
|
self.nhead = nhead
|
|
self.c_out = c_out
|
|
self.pred_len = pred_len
|
|
self.output_attention = output_attention
|
|
|
|
self.growth_damping = DampingLayer(pred_len, nhead, dropout=dropout, output_attention=output_attention)
|
|
self.dropout1 = nn.Dropout(dropout)
|
|
|
|
def forward(self, growth, season):
|
|
growth_horizon, growth_damping = self.growth_damping(growth[:, -1:])
|
|
growth_horizon = self.dropout1(growth_horizon)
|
|
|
|
seasonal_horizon = season[:, -self.pred_len:]
|
|
|
|
if self.output_attention:
|
|
return growth_horizon, seasonal_horizon, growth_damping
|
|
return growth_horizon, seasonal_horizon, None
|
|
|
|
|
|
class Decoder(nn.Module):
|
|
|
|
def __init__(self, layers):
|
|
super().__init__()
|
|
self.d_model = layers[0].d_model
|
|
self.c_out = layers[0].c_out
|
|
self.pred_len = layers[0].pred_len
|
|
self.nhead = layers[0].nhead
|
|
|
|
self.layers = nn.ModuleList(layers)
|
|
self.pred = nn.Linear(self.d_model, self.c_out)
|
|
|
|
def forward(self, growths, seasons):
|
|
growth_repr = []
|
|
season_repr = []
|
|
growth_dampings = []
|
|
|
|
for idx, layer in enumerate(self.layers):
|
|
growth_horizon, season_horizon, growth_damping = layer(growths[idx], seasons[idx])
|
|
growth_repr.append(growth_horizon)
|
|
season_repr.append(season_horizon)
|
|
growth_dampings.append(growth_damping)
|
|
growth_repr = sum(growth_repr)
|
|
season_repr = sum(season_repr)
|
|
return self.pred(growth_repr), self.pred(season_repr), growth_dampings
|