From e7650a3e3ff78845e4a5187da877c7dcbe201ec7 Mon Sep 17 00:00:00 2001 From: gorold Date: Wed, 28 Sep 2022 12:50:29 +0800 Subject: [PATCH] return attention weights --- models/etsformer/decoder.py | 26 ++++++++---- models/etsformer/encoder.py | 82 +++++++++++++++++++++++++++++-------- models/etsformer/model.py | 25 +++++++++-- run.py | 1 + 4 files changed, 107 insertions(+), 27 deletions(-) diff --git a/models/etsformer/decoder.py b/models/etsformer/decoder.py index 8a7ab32..61496da 100644 --- a/models/etsformer/decoder.py +++ b/models/etsformer/decoder.py @@ -5,10 +5,11 @@ from einops import rearrange, reduce, repeat class DampingLayer(nn.Module): - def __init__(self, pred_len, nhead, dropout=0.1): + def __init__(self, pred_len, nhead, dropout=0.1, output_attention=False): super().__init__() self.pred_len = pred_len self.nhead = nhead + self.output_attention = output_attention self._damping_factor = nn.Parameter(torch.randn(1, nhead)) self.dropout = nn.Dropout(dropout) @@ -22,7 +23,10 @@ class DampingLayer(nn.Module): damping_factors = damping_factors.cumsum(dim=0) x = x.view(b, t, self.nhead, -1) x = self.dropout(x) * damping_factors.unsqueeze(-1) - return x.view(b, t, d) + x = x.view(b, t, d) + if self.output_attention: + return x, damping_factors + return x, None @property def damping_factor(self): @@ -31,22 +35,26 @@ class DampingLayer(nn.Module): class DecoderLayer(nn.Module): - def __init__(self, d_model, nhead, c_out, pred_len, dropout=0.1): + def __init__(self, d_model, nhead, c_out, pred_len, dropout=0.1, output_attention=False): super().__init__() self.d_model = d_model self.nhead = nhead self.c_out = c_out self.pred_len = pred_len + self.output_attention = output_attention - self.growth_damping = DampingLayer(pred_len, nhead, dropout=dropout) + self.growth_damping = DampingLayer(pred_len, nhead, dropout=dropout, output_attention=output_attention) self.dropout1 = nn.Dropout(dropout) def forward(self, growth, season): - growth_horizon = self.growth_damping(growth[:, -1:]) + growth_horizon, growth_damping = self.growth_damping(growth[:, -1:]) growth_horizon = self.dropout1(growth_horizon) seasonal_horizon = season[:, -self.pred_len:] - return growth_horizon, seasonal_horizon + + if self.output_attention: + return growth_horizon, seasonal_horizon, growth_damping + return growth_horizon, seasonal_horizon, None class Decoder(nn.Module): @@ -64,11 +72,13 @@ class Decoder(nn.Module): def forward(self, growths, seasons): growth_repr = [] season_repr = [] + growth_dampings = [] for idx, layer in enumerate(self.layers): - growth_horizon, season_horizon = layer(growths[idx], seasons[idx]) + growth_horizon, season_horizon, growth_damping = layer(growths[idx], seasons[idx]) growth_repr.append(growth_horizon) season_repr.append(season_horizon) + growth_dampings.append(growth_damping) growth_repr = sum(growth_repr) season_repr = sum(season_repr) - return self.pred(growth_repr), self.pred(season_repr) + return self.pred(growth_repr), self.pred(season_repr), growth_dampings diff --git a/models/etsformer/encoder.py b/models/etsformer/encoder.py index d3f7173..bcac64e 100644 --- a/models/etsformer/encoder.py +++ b/models/etsformer/encoder.py @@ -3,6 +3,7 @@ import torch.nn as nn import torch.nn.functional as F import torch.fft as fft +import numpy as np from einops import rearrange, reduce, repeat import math, random @@ -12,11 +13,12 @@ from .exponential_smoothing import ExponentialSmoothing class GrowthLayer(nn.Module): - def __init__(self, d_model, nhead, d_head=None, dropout=0.1): + def __init__(self, d_model, nhead, d_head=None, dropout=0.1, output_attention=False): super().__init__() self.d_head = d_head or (d_model // nhead) self.d_model = d_model self.nhead = nhead + self.output_attention = output_attention self.z0 = nn.Parameter(torch.randn(self.nhead, self.d_head)) self.in_proj = nn.Linear(self.d_model, self.d_head * self.nhead) @@ -37,20 +39,29 @@ class GrowthLayer(nn.Module): out = self.es(values) out = torch.cat([repeat(self.es.v0, '1 1 h d -> b 1 h d', b=b), out], dim=1) out = rearrange(out, 'b t h d -> b t (h d)') - return self.out_proj(out) + out = self.out_proj(out) + + if self.output_attention: + return out, self.es.get_exponential_weight(t)[1] + return out, None class FourierLayer(nn.Module): - def __init__(self, d_model, pred_len, k=None, low_freq=1): + def __init__(self, d_model, pred_len, k=None, low_freq=1, output_attention=False): super().__init__() self.d_model = d_model self.pred_len = pred_len self.k = k self.low_freq = low_freq + self.output_attention = output_attention def forward(self, x): """x: (b, t, d)""" + + if self.output_attention: + return self.dft_forward(x) + b, t, d = x.shape x_freq = fft.rfft(x, dim=1) @@ -65,7 +76,7 @@ class FourierLayer(nn.Module): f = repeat(f, 'f -> b f d', b=x_freq.size(0), d=x_freq.size(2)) f = rearrange(f[index_tuple], 'b f d -> b f () d').to(x_freq.device) - return self.extrapolate(x_freq, f, t) + return self.extrapolate(x_freq, f, t), None def extrapolate(self, x_freq, f, t): x_freq = torch.cat([x_freq, x_freq.conj()], dim=1) @@ -88,6 +99,41 @@ class FourierLayer(nn.Module): return x_freq, index_tuple + def dft_forward(self, x): + T = x.size(1) + + dft_mat = fft.fft(torch.eye(T)) + i, j = torch.meshgrid(torch.arange(self.pred_len + T), torch.arange(T)) + omega = np.exp(2 * math.pi * 1j / T) + idft_mat = (np.power(omega, i * j) / T).cfloat() + + x_freq = torch.einsum('ft,btd->bfd', [dft_mat, x.cfloat()]) + + if T % 2 == 0: + x_freq = x_freq[:, self.low_freq:T // 2] + else: + x_freq = x_freq[:, self.low_freq:T // 2 + 1] + + _, indices = torch.topk(x_freq.abs(), self.k, dim=1, largest=True, sorted=True) + indices = indices + self.low_freq + indices = torch.cat([indices, -indices], dim=1) + + dft_mat = repeat(dft_mat, 'f t -> b f t d', b=x.shape[0], d=x.shape[-1]) + idft_mat = repeat(idft_mat, 't f -> b t f d', b=x.shape[0], d=x.shape[-1]) + + mesh_a, mesh_b = torch.meshgrid(torch.arange(x.size(0)), torch.arange(x.size(2))) + + dft_mask = torch.zeros_like(dft_mat) + dft_mask[mesh_a, indices, :, mesh_b] = 1 + dft_mat = dft_mat * dft_mask + + idft_mask = torch.zeros_like(idft_mat) + idft_mask[mesh_a, :, indices, mesh_b] = 1 + idft_mat = idft_mat * idft_mask + + attn = torch.einsum('bofd,bftd->botd', [idft_mat, dft_mat]).real + return torch.einsum('botd,btd->bod', [attn, x]), rearrange(attn, 'b o t d -> b d o t') + class LevelLayer(nn.Module): @@ -114,7 +160,7 @@ class LevelLayer(nn.Module): class EncoderLayer(nn.Module): def __init__(self, d_model, nhead, c_out, seq_len, pred_len, k, dim_feedforward=None, dropout=0.1, - activation='sigmoid', layer_norm_eps=1e-5): + activation='sigmoid', layer_norm_eps=1e-5, output_attention=False): super().__init__() self.d_model = d_model self.nhead = nhead @@ -124,8 +170,8 @@ class EncoderLayer(nn.Module): dim_feedforward = dim_feedforward or 4 * d_model self.dim_feedforward = dim_feedforward - self.growth_layer = GrowthLayer(d_model, nhead, dropout=dropout) - self.seasonal_layer = FourierLayer(d_model, pred_len, k=k) + self.growth_layer = GrowthLayer(d_model, nhead, dropout=dropout, output_attention=output_attention) + self.seasonal_layer = FourierLayer(d_model, pred_len, k=k, output_attention=output_attention) self.level_layer = LevelLayer(d_model, c_out, dropout=dropout) # Implementation of Feedforward model @@ -137,23 +183,23 @@ class EncoderLayer(nn.Module): self.dropout2 = nn.Dropout(dropout) def forward(self, res, level, attn_mask=None): - season = self._season_block(res) + season, season_attn = self._season_block(res) res = res - season[:, :-self.pred_len] - growth = self._growth_block(res) + growth, growth_attn = self._growth_block(res) res = self.norm1(res - growth[:, 1:]) res = self.norm2(res + self.ff(res)) level = self.level_layer(level, growth[:, :-1], season[:, :-self.pred_len]) - return res, level, growth, season + return res, level, growth, season, season_attn, growth_attn def _growth_block(self, x): - x = self.growth_layer(x) - return self.dropout1(x) + x, growth_attn = self.growth_layer(x) + return self.dropout1(x), growth_attn def _season_block(self, x): - x = self.seasonal_layer(x) - return self.dropout2(x) + x, season_attn = self.seasonal_layer(x) + return self.dropout2(x), season_attn class Encoder(nn.Module): @@ -165,9 +211,13 @@ class Encoder(nn.Module): def forward(self, res, level, attn_mask=None): growths = [] seasons = [] + season_attns = [] + growth_attns = [] for layer in self.layers: - res, level, growth, season = layer(res, level, attn_mask=None) + res, level, growth, season, season_attn, growth_attn = layer(res, level, attn_mask=None) growths.append(growth) seasons.append(season) + season_attns.append(season_attn) + growth_attns.append(growth_attn) - return level, growths, seasons + return level, growths, seasons, season_attns, growth_attns diff --git a/models/etsformer/model.py b/models/etsformer/model.py index 89d0b2b..3da1488 100644 --- a/models/etsformer/model.py +++ b/models/etsformer/model.py @@ -1,5 +1,6 @@ import torch import torch.nn as nn +from einops import reduce from .modules import ETSEmbedding from .encoder import EncoderLayer, Encoder @@ -47,6 +48,7 @@ class ETSformer(nn.Module): dim_feedforward=configs.d_ff, dropout=configs.dropout, activation=configs.activation, + output_attention=configs.output_attention, ) for _ in range(configs.e_layers) ] ) @@ -57,6 +59,7 @@ class ETSformer(nn.Module): DecoderLayer( configs.d_model, configs.n_heads, configs.c_out, configs.pred_len, dropout=configs.dropout, + output_attention=configs.output_attention, ) for _ in range(configs.d_layers) ], ) @@ -64,13 +67,29 @@ class ETSformer(nn.Module): self.transform = Transform(sigma=self.configs.std) def forward(self, x_enc, x_mark_enc, x_dec, x_mark_dec, - enc_self_mask=None, dec_self_mask=None, dec_enc_mask=None): + enc_self_mask=None, dec_self_mask=None, dec_enc_mask=None, + decomposed=False, attention=False): with torch.no_grad(): if self.training: x_enc = self.transform.transform(x_enc) res = self.enc_embedding(x_enc) - level, growths, seasons = self.encoder(res, x_enc, attn_mask=enc_self_mask) + level, growths, seasons, season_attns, growth_attns = self.encoder(res, x_enc, attn_mask=enc_self_mask) + + growth, season, growth_dampings = self.decoder(growths, seasons) + + if decomposed: + return level[:, -1:], growth, season - growth, season = self.decoder(growths, seasons) preds = level[:, -1:] + growth + season + + if attention: + decoder_growth_attns = [] + for growth_attn, growth_damping in zip(growth_attns, growth_dampings): + decoder_growth_attns.append(torch.einsum('bth,oh->bhot', [growth_attn.squeeze(-1), growth_damping])) + + season_attns = torch.stack(season_attns, dim=0)[:, :, -self.pred_len:] + season_attns = reduce(season_attns, 'l b d o t -> b o t', reduction='mean') + decoder_growth_attns = torch.stack(decoder_growth_attns, dim=0)[:, :, -self.pred_len:] + decoder_growth_attns = reduce(decoder_growth_attns, 'l b d o t -> b o t', reduction='mean') + return preds, season_attns, decoder_growth_attns return preds diff --git a/run.py b/run.py index 7e8ea5c..a4da50f 100644 --- a/run.py +++ b/run.py @@ -57,6 +57,7 @@ parser.add_argument('--std', type=float, default=0.2) parser.add_argument('--smoothing_learning_rate', type=float, default=0, help='optimizer learning rate') parser.add_argument('--damping_learning_rate', type=float, default=0, help='optimizer learning rate') +parser.add_argument('--output_attention', type=bool, default=False) # optimization parser.add_argument('--optim', type=str, default='adam', help='optimizer')