mirror of
https://github.com/wassname/ETSformer.git
synced 2026-06-27 16:43:49 +08:00
97 lines
3.3 KiB
Python
97 lines
3.3 KiB
Python
import torch
|
|
import torch.nn as nn
|
|
from einops import reduce
|
|
|
|
from .modules import ETSEmbedding
|
|
from .encoder import EncoderLayer, Encoder
|
|
from .decoder import DecoderLayer, Decoder
|
|
|
|
|
|
class Transform:
|
|
def __init__(self, sigma):
|
|
self.sigma = sigma
|
|
|
|
@torch.no_grad()
|
|
def transform(self, x):
|
|
return self.jitter(self.shift(self.scale(x)))
|
|
|
|
def jitter(self, x):
|
|
return x + (torch.randn(x.shape).to(x.device) * self.sigma)
|
|
|
|
def scale(self, x):
|
|
return x * (torch.randn(x.size(-1)).to(x.device) * self.sigma + 1)
|
|
|
|
def shift(self, x):
|
|
return x + (torch.randn(x.size(-1)).to(x.device) * self.sigma)
|
|
|
|
|
|
class ETSformer(nn.Module):
|
|
|
|
def __init__(self, configs):
|
|
super().__init__()
|
|
self.seq_len = configs.seq_len
|
|
self.label_len = configs.label_len
|
|
self.pred_len = configs.pred_len
|
|
|
|
self.configs = configs
|
|
|
|
assert configs.e_layers == configs.d_layers, "Encoder and decoder layers must be equal"
|
|
|
|
# Embedding
|
|
self.enc_embedding = ETSEmbedding(configs.enc_in, configs.d_model, dropout=configs.dropout)
|
|
|
|
# Encoder
|
|
self.encoder = Encoder(
|
|
[
|
|
EncoderLayer(
|
|
configs.d_model, configs.n_heads, configs.c_out, configs.seq_len, configs.pred_len, configs.K,
|
|
dim_feedforward=configs.d_ff,
|
|
dropout=configs.dropout,
|
|
activation=configs.activation,
|
|
output_attention=configs.output_attention,
|
|
) for _ in range(configs.e_layers)
|
|
]
|
|
)
|
|
|
|
# Decoder
|
|
self.decoder = Decoder(
|
|
[
|
|
DecoderLayer(
|
|
configs.d_model, configs.n_heads, configs.c_out, configs.pred_len,
|
|
dropout=configs.dropout,
|
|
output_attention=configs.output_attention,
|
|
) for _ in range(configs.d_layers)
|
|
],
|
|
)
|
|
|
|
self.transform = Transform(sigma=self.configs.std)
|
|
|
|
def forward(self, x_enc, x_mark_enc, x_dec, x_mark_dec,
|
|
enc_self_mask=None, dec_self_mask=None, dec_enc_mask=None,
|
|
decomposed=False, attention=False):
|
|
with torch.no_grad():
|
|
if self.training:
|
|
x_enc = self.transform.transform(x_enc)
|
|
res = self.enc_embedding(x_enc)
|
|
|
|
level, growths, seasons, season_attns, growth_attns = self.encoder(res, x_enc, attn_mask=enc_self_mask)
|
|
|
|
growth, season, growth_dampings = self.decoder(growths, seasons)
|
|
|
|
if decomposed:
|
|
return level[:, -1:], growth, season
|
|
|
|
preds = level[:, -1:] + growth + season
|
|
|
|
if attention:
|
|
decoder_growth_attns = []
|
|
for growth_attn, growth_damping in zip(growth_attns, growth_dampings):
|
|
decoder_growth_attns.append(torch.einsum('bth,oh->bhot', [growth_attn.squeeze(-1), growth_damping]))
|
|
|
|
season_attns = torch.stack(season_attns, dim=0)[:, :, -self.pred_len:]
|
|
season_attns = reduce(season_attns, 'l b d o t -> b o t', reduction='mean')
|
|
decoder_growth_attns = torch.stack(decoder_growth_attns, dim=0)[:, :, -self.pred_len:]
|
|
decoder_growth_attns = reduce(decoder_growth_attns, 'l b d o t -> b o t', reduction='mean')
|
|
return preds, season_attns, decoder_growth_attns
|
|
return preds
|