Files
2022-11-28 14:10:49 +08:00

97 lines
3.3 KiB
Python

import torch
import torch.nn as nn
from einops import reduce
from .modules import ETSEmbedding
from .encoder import EncoderLayer, Encoder
from .decoder import DecoderLayer, Decoder
class Transform:
def __init__(self, sigma):
self.sigma = sigma
@torch.no_grad()
def transform(self, x):
return self.jitter(self.shift(self.scale(x)))
def jitter(self, x):
return x + (torch.randn(x.shape).to(x.device) * self.sigma)
def scale(self, x):
return x * (torch.randn(x.size(-1)).to(x.device) * self.sigma + 1)
def shift(self, x):
return x + (torch.randn(x.size(-1)).to(x.device) * self.sigma)
class ETSformer(nn.Module):
def __init__(self, configs):
super().__init__()
self.seq_len = configs.seq_len
self.label_len = configs.label_len
self.pred_len = configs.pred_len
self.configs = configs
assert configs.e_layers == configs.d_layers, "Encoder and decoder layers must be equal"
# Embedding
self.enc_embedding = ETSEmbedding(configs.enc_in, configs.d_model, dropout=configs.dropout)
# Encoder
self.encoder = Encoder(
[
EncoderLayer(
configs.d_model, configs.n_heads, configs.c_out, configs.seq_len, configs.pred_len, configs.K,
dim_feedforward=configs.d_ff,
dropout=configs.dropout,
activation=configs.activation,
output_attention=configs.output_attention,
) for _ in range(configs.e_layers)
]
)
# Decoder
self.decoder = Decoder(
[
DecoderLayer(
configs.d_model, configs.n_heads, configs.c_out, configs.pred_len,
dropout=configs.dropout,
output_attention=configs.output_attention,
) for _ in range(configs.d_layers)
],
)
self.transform = Transform(sigma=self.configs.std)
def forward(self, x_enc, x_mark_enc, x_dec, x_mark_dec,
enc_self_mask=None, dec_self_mask=None, dec_enc_mask=None,
decomposed=False, attention=False):
with torch.no_grad():
if self.training:
x_enc = self.transform.transform(x_enc)
res = self.enc_embedding(x_enc)
level, growths, seasons, season_attns, growth_attns = self.encoder(res, x_enc, attn_mask=enc_self_mask)
growth, season, growth_dampings = self.decoder(growths, seasons)
if decomposed:
return level[:, -1:], growth, season
preds = level[:, -1:] + growth + season
if attention:
decoder_growth_attns = []
for growth_attn, growth_damping in zip(growth_attns, growth_dampings):
decoder_growth_attns.append(torch.einsum('bth,oh->bhot', [growth_attn.squeeze(-1), growth_damping]))
season_attns = torch.stack(season_attns, dim=0)[:, :, -self.pred_len:]
season_attns = reduce(season_attns, 'l b d o t -> b o t', reduction='mean')
decoder_growth_attns = torch.stack(decoder_growth_attns, dim=0)[:, :, -self.pred_len:]
decoder_growth_attns = reduce(decoder_growth_attns, 'l b d o t -> b o t', reduction='mean')
return preds, season_attns, decoder_growth_attns
return preds