mirror of
https://github.com/wassname/ETSformer.git
synced 2026-06-27 19:13:22 +08:00
return attention weights
This commit is contained in:
@@ -5,10 +5,11 @@ from einops import rearrange, reduce, repeat
|
||||
|
||||
class DampingLayer(nn.Module):
|
||||
|
||||
def __init__(self, pred_len, nhead, dropout=0.1):
|
||||
def __init__(self, pred_len, nhead, dropout=0.1, output_attention=False):
|
||||
super().__init__()
|
||||
self.pred_len = pred_len
|
||||
self.nhead = nhead
|
||||
self.output_attention = output_attention
|
||||
self._damping_factor = nn.Parameter(torch.randn(1, nhead))
|
||||
self.dropout = nn.Dropout(dropout)
|
||||
|
||||
@@ -22,7 +23,10 @@ class DampingLayer(nn.Module):
|
||||
damping_factors = damping_factors.cumsum(dim=0)
|
||||
x = x.view(b, t, self.nhead, -1)
|
||||
x = self.dropout(x) * damping_factors.unsqueeze(-1)
|
||||
return x.view(b, t, d)
|
||||
x = x.view(b, t, d)
|
||||
if self.output_attention:
|
||||
return x, damping_factors
|
||||
return x, None
|
||||
|
||||
@property
|
||||
def damping_factor(self):
|
||||
@@ -31,22 +35,26 @@ class DampingLayer(nn.Module):
|
||||
|
||||
class DecoderLayer(nn.Module):
|
||||
|
||||
def __init__(self, d_model, nhead, c_out, pred_len, dropout=0.1):
|
||||
def __init__(self, d_model, nhead, c_out, pred_len, dropout=0.1, output_attention=False):
|
||||
super().__init__()
|
||||
self.d_model = d_model
|
||||
self.nhead = nhead
|
||||
self.c_out = c_out
|
||||
self.pred_len = pred_len
|
||||
self.output_attention = output_attention
|
||||
|
||||
self.growth_damping = DampingLayer(pred_len, nhead, dropout=dropout)
|
||||
self.growth_damping = DampingLayer(pred_len, nhead, dropout=dropout, output_attention=output_attention)
|
||||
self.dropout1 = nn.Dropout(dropout)
|
||||
|
||||
def forward(self, growth, season):
|
||||
growth_horizon = self.growth_damping(growth[:, -1:])
|
||||
growth_horizon, growth_damping = self.growth_damping(growth[:, -1:])
|
||||
growth_horizon = self.dropout1(growth_horizon)
|
||||
|
||||
seasonal_horizon = season[:, -self.pred_len:]
|
||||
return growth_horizon, seasonal_horizon
|
||||
|
||||
if self.output_attention:
|
||||
return growth_horizon, seasonal_horizon, growth_damping
|
||||
return growth_horizon, seasonal_horizon, None
|
||||
|
||||
|
||||
class Decoder(nn.Module):
|
||||
@@ -64,11 +72,13 @@ class Decoder(nn.Module):
|
||||
def forward(self, growths, seasons):
|
||||
growth_repr = []
|
||||
season_repr = []
|
||||
growth_dampings = []
|
||||
|
||||
for idx, layer in enumerate(self.layers):
|
||||
growth_horizon, season_horizon = layer(growths[idx], seasons[idx])
|
||||
growth_horizon, season_horizon, growth_damping = layer(growths[idx], seasons[idx])
|
||||
growth_repr.append(growth_horizon)
|
||||
season_repr.append(season_horizon)
|
||||
growth_dampings.append(growth_damping)
|
||||
growth_repr = sum(growth_repr)
|
||||
season_repr = sum(season_repr)
|
||||
return self.pred(growth_repr), self.pred(season_repr)
|
||||
return self.pred(growth_repr), self.pred(season_repr), growth_dampings
|
||||
|
||||
+66
-16
@@ -3,6 +3,7 @@ import torch.nn as nn
|
||||
import torch.nn.functional as F
|
||||
import torch.fft as fft
|
||||
|
||||
import numpy as np
|
||||
from einops import rearrange, reduce, repeat
|
||||
import math, random
|
||||
|
||||
@@ -12,11 +13,12 @@ from .exponential_smoothing import ExponentialSmoothing
|
||||
|
||||
class GrowthLayer(nn.Module):
|
||||
|
||||
def __init__(self, d_model, nhead, d_head=None, dropout=0.1):
|
||||
def __init__(self, d_model, nhead, d_head=None, dropout=0.1, output_attention=False):
|
||||
super().__init__()
|
||||
self.d_head = d_head or (d_model // nhead)
|
||||
self.d_model = d_model
|
||||
self.nhead = nhead
|
||||
self.output_attention = output_attention
|
||||
|
||||
self.z0 = nn.Parameter(torch.randn(self.nhead, self.d_head))
|
||||
self.in_proj = nn.Linear(self.d_model, self.d_head * self.nhead)
|
||||
@@ -37,20 +39,29 @@ class GrowthLayer(nn.Module):
|
||||
out = self.es(values)
|
||||
out = torch.cat([repeat(self.es.v0, '1 1 h d -> b 1 h d', b=b), out], dim=1)
|
||||
out = rearrange(out, 'b t h d -> b t (h d)')
|
||||
return self.out_proj(out)
|
||||
out = self.out_proj(out)
|
||||
|
||||
if self.output_attention:
|
||||
return out, self.es.get_exponential_weight(t)[1]
|
||||
return out, None
|
||||
|
||||
|
||||
class FourierLayer(nn.Module):
|
||||
|
||||
def __init__(self, d_model, pred_len, k=None, low_freq=1):
|
||||
def __init__(self, d_model, pred_len, k=None, low_freq=1, output_attention=False):
|
||||
super().__init__()
|
||||
self.d_model = d_model
|
||||
self.pred_len = pred_len
|
||||
self.k = k
|
||||
self.low_freq = low_freq
|
||||
self.output_attention = output_attention
|
||||
|
||||
def forward(self, x):
|
||||
"""x: (b, t, d)"""
|
||||
|
||||
if self.output_attention:
|
||||
return self.dft_forward(x)
|
||||
|
||||
b, t, d = x.shape
|
||||
x_freq = fft.rfft(x, dim=1)
|
||||
|
||||
@@ -65,7 +76,7 @@ class FourierLayer(nn.Module):
|
||||
f = repeat(f, 'f -> b f d', b=x_freq.size(0), d=x_freq.size(2))
|
||||
f = rearrange(f[index_tuple], 'b f d -> b f () d').to(x_freq.device)
|
||||
|
||||
return self.extrapolate(x_freq, f, t)
|
||||
return self.extrapolate(x_freq, f, t), None
|
||||
|
||||
def extrapolate(self, x_freq, f, t):
|
||||
x_freq = torch.cat([x_freq, x_freq.conj()], dim=1)
|
||||
@@ -88,6 +99,41 @@ class FourierLayer(nn.Module):
|
||||
|
||||
return x_freq, index_tuple
|
||||
|
||||
def dft_forward(self, x):
|
||||
T = x.size(1)
|
||||
|
||||
dft_mat = fft.fft(torch.eye(T))
|
||||
i, j = torch.meshgrid(torch.arange(self.pred_len + T), torch.arange(T))
|
||||
omega = np.exp(2 * math.pi * 1j / T)
|
||||
idft_mat = (np.power(omega, i * j) / T).cfloat()
|
||||
|
||||
x_freq = torch.einsum('ft,btd->bfd', [dft_mat, x.cfloat()])
|
||||
|
||||
if T % 2 == 0:
|
||||
x_freq = x_freq[:, self.low_freq:T // 2]
|
||||
else:
|
||||
x_freq = x_freq[:, self.low_freq:T // 2 + 1]
|
||||
|
||||
_, indices = torch.topk(x_freq.abs(), self.k, dim=1, largest=True, sorted=True)
|
||||
indices = indices + self.low_freq
|
||||
indices = torch.cat([indices, -indices], dim=1)
|
||||
|
||||
dft_mat = repeat(dft_mat, 'f t -> b f t d', b=x.shape[0], d=x.shape[-1])
|
||||
idft_mat = repeat(idft_mat, 't f -> b t f d', b=x.shape[0], d=x.shape[-1])
|
||||
|
||||
mesh_a, mesh_b = torch.meshgrid(torch.arange(x.size(0)), torch.arange(x.size(2)))
|
||||
|
||||
dft_mask = torch.zeros_like(dft_mat)
|
||||
dft_mask[mesh_a, indices, :, mesh_b] = 1
|
||||
dft_mat = dft_mat * dft_mask
|
||||
|
||||
idft_mask = torch.zeros_like(idft_mat)
|
||||
idft_mask[mesh_a, :, indices, mesh_b] = 1
|
||||
idft_mat = idft_mat * idft_mask
|
||||
|
||||
attn = torch.einsum('bofd,bftd->botd', [idft_mat, dft_mat]).real
|
||||
return torch.einsum('botd,btd->bod', [attn, x]), rearrange(attn, 'b o t d -> b d o t')
|
||||
|
||||
|
||||
class LevelLayer(nn.Module):
|
||||
|
||||
@@ -114,7 +160,7 @@ class LevelLayer(nn.Module):
|
||||
class EncoderLayer(nn.Module):
|
||||
|
||||
def __init__(self, d_model, nhead, c_out, seq_len, pred_len, k, dim_feedforward=None, dropout=0.1,
|
||||
activation='sigmoid', layer_norm_eps=1e-5):
|
||||
activation='sigmoid', layer_norm_eps=1e-5, output_attention=False):
|
||||
super().__init__()
|
||||
self.d_model = d_model
|
||||
self.nhead = nhead
|
||||
@@ -124,8 +170,8 @@ class EncoderLayer(nn.Module):
|
||||
dim_feedforward = dim_feedforward or 4 * d_model
|
||||
self.dim_feedforward = dim_feedforward
|
||||
|
||||
self.growth_layer = GrowthLayer(d_model, nhead, dropout=dropout)
|
||||
self.seasonal_layer = FourierLayer(d_model, pred_len, k=k)
|
||||
self.growth_layer = GrowthLayer(d_model, nhead, dropout=dropout, output_attention=output_attention)
|
||||
self.seasonal_layer = FourierLayer(d_model, pred_len, k=k, output_attention=output_attention)
|
||||
self.level_layer = LevelLayer(d_model, c_out, dropout=dropout)
|
||||
|
||||
# Implementation of Feedforward model
|
||||
@@ -137,23 +183,23 @@ class EncoderLayer(nn.Module):
|
||||
self.dropout2 = nn.Dropout(dropout)
|
||||
|
||||
def forward(self, res, level, attn_mask=None):
|
||||
season = self._season_block(res)
|
||||
season, season_attn = self._season_block(res)
|
||||
res = res - season[:, :-self.pred_len]
|
||||
growth = self._growth_block(res)
|
||||
growth, growth_attn = self._growth_block(res)
|
||||
res = self.norm1(res - growth[:, 1:])
|
||||
res = self.norm2(res + self.ff(res))
|
||||
|
||||
level = self.level_layer(level, growth[:, :-1], season[:, :-self.pred_len])
|
||||
|
||||
return res, level, growth, season
|
||||
return res, level, growth, season, season_attn, growth_attn
|
||||
|
||||
def _growth_block(self, x):
|
||||
x = self.growth_layer(x)
|
||||
return self.dropout1(x)
|
||||
x, growth_attn = self.growth_layer(x)
|
||||
return self.dropout1(x), growth_attn
|
||||
|
||||
def _season_block(self, x):
|
||||
x = self.seasonal_layer(x)
|
||||
return self.dropout2(x)
|
||||
x, season_attn = self.seasonal_layer(x)
|
||||
return self.dropout2(x), season_attn
|
||||
|
||||
|
||||
class Encoder(nn.Module):
|
||||
@@ -165,9 +211,13 @@ class Encoder(nn.Module):
|
||||
def forward(self, res, level, attn_mask=None):
|
||||
growths = []
|
||||
seasons = []
|
||||
season_attns = []
|
||||
growth_attns = []
|
||||
for layer in self.layers:
|
||||
res, level, growth, season = layer(res, level, attn_mask=None)
|
||||
res, level, growth, season, season_attn, growth_attn = layer(res, level, attn_mask=None)
|
||||
growths.append(growth)
|
||||
seasons.append(season)
|
||||
season_attns.append(season_attn)
|
||||
growth_attns.append(growth_attn)
|
||||
|
||||
return level, growths, seasons
|
||||
return level, growths, seasons, season_attns, growth_attns
|
||||
|
||||
@@ -1,5 +1,6 @@
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
from einops import reduce
|
||||
|
||||
from .modules import ETSEmbedding
|
||||
from .encoder import EncoderLayer, Encoder
|
||||
@@ -47,6 +48,7 @@ class ETSformer(nn.Module):
|
||||
dim_feedforward=configs.d_ff,
|
||||
dropout=configs.dropout,
|
||||
activation=configs.activation,
|
||||
output_attention=configs.output_attention,
|
||||
) for _ in range(configs.e_layers)
|
||||
]
|
||||
)
|
||||
@@ -57,6 +59,7 @@ class ETSformer(nn.Module):
|
||||
DecoderLayer(
|
||||
configs.d_model, configs.n_heads, configs.c_out, configs.pred_len,
|
||||
dropout=configs.dropout,
|
||||
output_attention=configs.output_attention,
|
||||
) for _ in range(configs.d_layers)
|
||||
],
|
||||
)
|
||||
@@ -64,13 +67,29 @@ class ETSformer(nn.Module):
|
||||
self.transform = Transform(sigma=self.configs.std)
|
||||
|
||||
def forward(self, x_enc, x_mark_enc, x_dec, x_mark_dec,
|
||||
enc_self_mask=None, dec_self_mask=None, dec_enc_mask=None):
|
||||
enc_self_mask=None, dec_self_mask=None, dec_enc_mask=None,
|
||||
decomposed=False, attention=False):
|
||||
with torch.no_grad():
|
||||
if self.training:
|
||||
x_enc = self.transform.transform(x_enc)
|
||||
res = self.enc_embedding(x_enc)
|
||||
level, growths, seasons = self.encoder(res, x_enc, attn_mask=enc_self_mask)
|
||||
level, growths, seasons, season_attns, growth_attns = self.encoder(res, x_enc, attn_mask=enc_self_mask)
|
||||
|
||||
growth, season, growth_dampings = self.decoder(growths, seasons)
|
||||
|
||||
if decomposed:
|
||||
return level[:, -1:], growth, season
|
||||
|
||||
growth, season = self.decoder(growths, seasons)
|
||||
preds = level[:, -1:] + growth + season
|
||||
|
||||
if attention:
|
||||
decoder_growth_attns = []
|
||||
for growth_attn, growth_damping in zip(growth_attns, growth_dampings):
|
||||
decoder_growth_attns.append(torch.einsum('bth,oh->bhot', [growth_attn.squeeze(-1), growth_damping]))
|
||||
|
||||
season_attns = torch.stack(season_attns, dim=0)[:, :, -self.pred_len:]
|
||||
season_attns = reduce(season_attns, 'l b d o t -> b o t', reduction='mean')
|
||||
decoder_growth_attns = torch.stack(decoder_growth_attns, dim=0)[:, :, -self.pred_len:]
|
||||
decoder_growth_attns = reduce(decoder_growth_attns, 'l b d o t -> b o t', reduction='mean')
|
||||
return preds, season_attns, decoder_growth_attns
|
||||
return preds
|
||||
|
||||
@@ -57,6 +57,7 @@ parser.add_argument('--std', type=float, default=0.2)
|
||||
|
||||
parser.add_argument('--smoothing_learning_rate', type=float, default=0, help='optimizer learning rate')
|
||||
parser.add_argument('--damping_learning_rate', type=float, default=0, help='optimizer learning rate')
|
||||
parser.add_argument('--output_attention', type=bool, default=False)
|
||||
|
||||
# optimization
|
||||
parser.add_argument('--optim', type=str, default='adam', help='optimizer')
|
||||
|
||||
Reference in New Issue
Block a user