mirror of
https://github.com/wassname/pytorch-transformer-ts.git
synced 2026-06-27 18:06:14 +08:00
2066 lines
71 KiB
Python
2066 lines
71 KiB
Python
from math import sqrt
|
|
import math
|
|
from typing import List, Optional
|
|
|
|
import numpy as np
|
|
import torch
|
|
import torch.nn as nn
|
|
import torch.nn.functional as F
|
|
from gluonts.core.component import validated
|
|
from gluonts.time_feature import get_lags_for_frequency
|
|
from gluonts.torch.distributions import DistributionOutput, StudentTOutput
|
|
from gluonts.torch.modules.feature import FeatureEmbedder
|
|
from gluonts.torch.modules.scaler import MeanScaler, NOPScaler
|
|
import random
|
|
# from matplotlib import pyplot as plt
|
|
from torch.nn.functional import interpolate
|
|
|
|
from scipy.special import eval_legendre
|
|
from sympy import Poly, legendre, Symbol, chebyshevt
|
|
|
|
|
|
def legendreDer(k, x):
|
|
def _legendre(k, x):
|
|
return (2*k+1) * eval_legendre(k, x)
|
|
out = 0
|
|
for i in np.arange(k-1,-1,-2):
|
|
out += _legendre(i, x)
|
|
return out
|
|
|
|
|
|
def phi_(phi_c, x, lb = 0, ub = 1):
|
|
mask = np.logical_or(x<lb, x>ub) * 1.0
|
|
return np.polynomial.polynomial.Polynomial(phi_c)(x) * (1-mask)
|
|
|
|
|
|
def get_phi_psi(k, base):
|
|
|
|
x = Symbol('x')
|
|
phi_coeff = np.zeros((k,k))
|
|
phi_2x_coeff = np.zeros((k,k))
|
|
if base == 'legendre':
|
|
for ki in range(k):
|
|
coeff_ = Poly(legendre(ki, 2*x-1), x).all_coeffs()
|
|
phi_coeff[ki,:ki+1] = np.flip(np.sqrt(2*ki+1) * np.array(coeff_).astype(np.float64))
|
|
coeff_ = Poly(legendre(ki, 4*x-1), x).all_coeffs()
|
|
phi_2x_coeff[ki,:ki+1] = np.flip(np.sqrt(2) * np.sqrt(2*ki+1) * np.array(coeff_).astype(np.float64))
|
|
|
|
psi1_coeff = np.zeros((k, k))
|
|
psi2_coeff = np.zeros((k, k))
|
|
for ki in range(k):
|
|
psi1_coeff[ki,:] = phi_2x_coeff[ki,:]
|
|
for i in range(k):
|
|
a = phi_2x_coeff[ki,:ki+1]
|
|
b = phi_coeff[i, :i+1]
|
|
prod_ = np.convolve(a, b)
|
|
prod_[np.abs(prod_)<1e-8] = 0
|
|
proj_ = (prod_ * 1/(np.arange(len(prod_))+1) * np.power(0.5, 1+np.arange(len(prod_)))).sum()
|
|
psi1_coeff[ki,:] -= proj_ * phi_coeff[i,:]
|
|
psi2_coeff[ki,:] -= proj_ * phi_coeff[i,:]
|
|
for j in range(ki):
|
|
a = phi_2x_coeff[ki,:ki+1]
|
|
b = psi1_coeff[j, :]
|
|
prod_ = np.convolve(a, b)
|
|
prod_[np.abs(prod_)<1e-8] = 0
|
|
proj_ = (prod_ * 1/(np.arange(len(prod_))+1) * np.power(0.5, 1+np.arange(len(prod_)))).sum()
|
|
psi1_coeff[ki,:] -= proj_ * psi1_coeff[j,:]
|
|
psi2_coeff[ki,:] -= proj_ * psi2_coeff[j,:]
|
|
|
|
a = psi1_coeff[ki,:]
|
|
prod_ = np.convolve(a, a)
|
|
prod_[np.abs(prod_)<1e-8] = 0
|
|
norm1 = (prod_ * 1/(np.arange(len(prod_))+1) * np.power(0.5, 1+np.arange(len(prod_)))).sum()
|
|
|
|
a = psi2_coeff[ki,:]
|
|
prod_ = np.convolve(a, a)
|
|
prod_[np.abs(prod_)<1e-8] = 0
|
|
norm2 = (prod_ * 1/(np.arange(len(prod_))+1) * (1-np.power(0.5, 1+np.arange(len(prod_))))).sum()
|
|
norm_ = np.sqrt(norm1 + norm2)
|
|
psi1_coeff[ki,:] /= norm_
|
|
psi2_coeff[ki,:] /= norm_
|
|
psi1_coeff[np.abs(psi1_coeff)<1e-8] = 0
|
|
psi2_coeff[np.abs(psi2_coeff)<1e-8] = 0
|
|
|
|
phi = [np.poly1d(np.flip(phi_coeff[i,:])) for i in range(k)]
|
|
psi1 = [np.poly1d(np.flip(psi1_coeff[i,:])) for i in range(k)]
|
|
psi2 = [np.poly1d(np.flip(psi2_coeff[i,:])) for i in range(k)]
|
|
|
|
elif base == 'chebyshev':
|
|
for ki in range(k):
|
|
if ki == 0:
|
|
phi_coeff[ki,:ki+1] = np.sqrt(2/np.pi)
|
|
phi_2x_coeff[ki,:ki+1] = np.sqrt(2/np.pi) * np.sqrt(2)
|
|
else:
|
|
coeff_ = Poly(chebyshevt(ki, 2*x-1), x).all_coeffs()
|
|
phi_coeff[ki,:ki+1] = np.flip(2/np.sqrt(np.pi) * np.array(coeff_).astype(np.float64))
|
|
coeff_ = Poly(chebyshevt(ki, 4*x-1), x).all_coeffs()
|
|
phi_2x_coeff[ki,:ki+1] = np.flip(np.sqrt(2) * 2 / np.sqrt(np.pi) * np.array(coeff_).astype(np.float64))
|
|
|
|
phi = [partial(phi_, phi_coeff[i,:]) for i in range(k)]
|
|
|
|
x = Symbol('x')
|
|
kUse = 2*k
|
|
roots = Poly(chebyshevt(kUse, 2*x-1)).all_roots()
|
|
x_m = np.array([rt.evalf(20) for rt in roots]).astype(np.float64)
|
|
# x_m[x_m==0.5] = 0.5 + 1e-8 # add small noise to avoid the case of 0.5 belonging to both phi(2x) and phi(2x-1)
|
|
# not needed for our purpose here, we use even k always to avoid
|
|
wm = np.pi / kUse / 2
|
|
|
|
psi1_coeff = np.zeros((k, k))
|
|
psi2_coeff = np.zeros((k, k))
|
|
|
|
psi1 = [[] for _ in range(k)]
|
|
psi2 = [[] for _ in range(k)]
|
|
|
|
for ki in range(k):
|
|
psi1_coeff[ki,:] = phi_2x_coeff[ki,:]
|
|
for i in range(k):
|
|
proj_ = (wm * phi[i](x_m) * np.sqrt(2)* phi[ki](2*x_m)).sum()
|
|
psi1_coeff[ki,:] -= proj_ * phi_coeff[i,:]
|
|
psi2_coeff[ki,:] -= proj_ * phi_coeff[i,:]
|
|
|
|
for j in range(ki):
|
|
proj_ = (wm * psi1[j](x_m) * np.sqrt(2) * phi[ki](2*x_m)).sum()
|
|
psi1_coeff[ki,:] -= proj_ * psi1_coeff[j,:]
|
|
psi2_coeff[ki,:] -= proj_ * psi2_coeff[j,:]
|
|
|
|
psi1[ki] = partial(phi_, psi1_coeff[ki,:], lb = 0, ub = 0.5)
|
|
psi2[ki] = partial(phi_, psi2_coeff[ki,:], lb = 0.5, ub = 1)
|
|
|
|
norm1 = (wm * psi1[ki](x_m) * psi1[ki](x_m)).sum()
|
|
norm2 = (wm * psi2[ki](x_m) * psi2[ki](x_m)).sum()
|
|
|
|
norm_ = np.sqrt(norm1 + norm2)
|
|
psi1_coeff[ki,:] /= norm_
|
|
psi2_coeff[ki,:] /= norm_
|
|
psi1_coeff[np.abs(psi1_coeff)<1e-8] = 0
|
|
psi2_coeff[np.abs(psi2_coeff)<1e-8] = 0
|
|
|
|
psi1[ki] = partial(phi_, psi1_coeff[ki,:], lb = 0, ub = 0.5+1e-16)
|
|
psi2[ki] = partial(phi_, psi2_coeff[ki,:], lb = 0.5+1e-16, ub = 1)
|
|
|
|
return phi, psi1, psi2
|
|
|
|
|
|
def get_filter(base, k):
|
|
|
|
def psi(psi1, psi2, i, inp):
|
|
mask = (inp<=0.5) * 1.0
|
|
return psi1[i](inp) * mask + psi2[i](inp) * (1-mask)
|
|
|
|
if base not in ['legendre', 'chebyshev']:
|
|
raise Exception('Base not supported')
|
|
|
|
x = Symbol('x')
|
|
H0 = np.zeros((k,k))
|
|
H1 = np.zeros((k,k))
|
|
G0 = np.zeros((k,k))
|
|
G1 = np.zeros((k,k))
|
|
PHI0 = np.zeros((k,k))
|
|
PHI1 = np.zeros((k,k))
|
|
phi, psi1, psi2 = get_phi_psi(k, base)
|
|
if base == 'legendre':
|
|
roots = Poly(legendre(k, 2*x-1)).all_roots()
|
|
x_m = np.array([rt.evalf(20) for rt in roots]).astype(np.float64)
|
|
wm = 1/k/legendreDer(k,2*x_m-1)/eval_legendre(k-1,2*x_m-1)
|
|
|
|
for ki in range(k):
|
|
for kpi in range(k):
|
|
H0[ki, kpi] = 1/np.sqrt(2) * (wm * phi[ki](x_m/2) * phi[kpi](x_m)).sum()
|
|
G0[ki, kpi] = 1/np.sqrt(2) * (wm * psi(psi1, psi2, ki, x_m/2) * phi[kpi](x_m)).sum()
|
|
H1[ki, kpi] = 1/np.sqrt(2) * (wm * phi[ki]((x_m+1)/2) * phi[kpi](x_m)).sum()
|
|
G1[ki, kpi] = 1/np.sqrt(2) * (wm * psi(psi1, psi2, ki, (x_m+1)/2) * phi[kpi](x_m)).sum()
|
|
|
|
PHI0 = np.eye(k)
|
|
PHI1 = np.eye(k)
|
|
|
|
elif base == 'chebyshev':
|
|
x = Symbol('x')
|
|
kUse = 2*k
|
|
roots = Poly(chebyshevt(kUse, 2*x-1)).all_roots()
|
|
x_m = np.array([rt.evalf(20) for rt in roots]).astype(np.float64)
|
|
# x_m[x_m==0.5] = 0.5 + 1e-8 # add small noise to avoid the case of 0.5 belonging to both phi(2x) and phi(2x-1)
|
|
# not needed for our purpose here, we use even k always to avoid
|
|
wm = np.pi / kUse / 2
|
|
|
|
for ki in range(k):
|
|
for kpi in range(k):
|
|
H0[ki, kpi] = 1/np.sqrt(2) * (wm * phi[ki](x_m/2) * phi[kpi](x_m)).sum()
|
|
G0[ki, kpi] = 1/np.sqrt(2) * (wm * psi(psi1, psi2, ki, x_m/2) * phi[kpi](x_m)).sum()
|
|
H1[ki, kpi] = 1/np.sqrt(2) * (wm * phi[ki]((x_m+1)/2) * phi[kpi](x_m)).sum()
|
|
G1[ki, kpi] = 1/np.sqrt(2) * (wm * psi(psi1, psi2, ki, (x_m+1)/2) * phi[kpi](x_m)).sum()
|
|
|
|
PHI0[ki, kpi] = (wm * phi[ki](2*x_m) * phi[kpi](2*x_m)).sum() * 2
|
|
PHI1[ki, kpi] = (wm * phi[ki](2*x_m-1) * phi[kpi](2*x_m-1)).sum() * 2
|
|
|
|
PHI0[np.abs(PHI0)<1e-8] = 0
|
|
PHI1[np.abs(PHI1)<1e-8] = 0
|
|
|
|
H0[np.abs(H0)<1e-8] = 0
|
|
H1[np.abs(H1)<1e-8] = 0
|
|
G0[np.abs(G0)<1e-8] = 0
|
|
G1[np.abs(G1)<1e-8] = 0
|
|
|
|
return H0, H1, G0, G1, PHI0, PHI1
|
|
|
|
|
|
|
|
|
|
class TriangularCausalMask:
|
|
def __init__(self, B, L, device="cpu"):
|
|
mask_shape = [B, 1, L, L]
|
|
with torch.no_grad():
|
|
self._mask = torch.triu(
|
|
torch.ones(mask_shape, dtype=torch.bool), diagonal=1
|
|
).to(device)
|
|
|
|
@property
|
|
def mask(self):
|
|
return self._mask
|
|
|
|
|
|
class ProbMask:
|
|
def __init__(self, B, H, L, index, scores, device="cpu"):
|
|
_mask = torch.ones(L, scores.shape[-1], dtype=torch.bool).to(device).triu(1)
|
|
_mask_ex = _mask[None, None, :].expand(B, H, L, scores.shape[-1])
|
|
indicator = _mask_ex[
|
|
torch.arange(B)[:, None, None], torch.arange(H)[None, :, None], index, :
|
|
].to(device)
|
|
self._mask = indicator.view(scores.shape).to(device)
|
|
|
|
@property
|
|
def mask(self):
|
|
return self._mask
|
|
|
|
|
|
class LocalMask:
|
|
def __init__(self, B, L, S, device="cpu"):
|
|
mask_shape = [B, 1, L, S]
|
|
with torch.no_grad():
|
|
self.len = math.ceil(np.log2(L))
|
|
self._mask1 = torch.triu(
|
|
torch.ones(mask_shape, dtype=torch.bool), diagonal=1
|
|
).to(device)
|
|
self._mask2 = ~torch.triu(
|
|
torch.ones(mask_shape, dtype=torch.bool), diagonal=-self.len
|
|
).to(device)
|
|
self._mask = self._mask1 + self._mask2
|
|
|
|
@property
|
|
def mask(self):
|
|
return self._mask
|
|
|
|
|
|
def adjust_learning_rate(optimizer, epoch, args):
|
|
# lr = args.learning_rate * (0.2 ** (epoch // 2))
|
|
if args.lradj == "type1":
|
|
lr_adjust = {epoch: args.learning_rate * (0.5 ** ((epoch - 1) // 1))}
|
|
elif args.lradj == "type2":
|
|
lr_adjust = {2: 5e-5, 4: 1e-5, 6: 5e-6, 8: 1e-6, 10: 5e-7, 15: 1e-7, 20: 5e-8}
|
|
elif args.lradj == "type3":
|
|
lr_adjust = {epoch: args.learning_rate}
|
|
elif args.lradj == "type4":
|
|
lr_adjust = {epoch: args.learning_rate * (0.9 ** ((epoch - 1) // 1))}
|
|
if epoch in lr_adjust.keys():
|
|
lr = lr_adjust[epoch]
|
|
for param_group in optimizer.param_groups:
|
|
param_group["lr"] = lr
|
|
print("Updating learning rate to {}".format(lr))
|
|
|
|
|
|
class EarlyStopping:
|
|
def __init__(self, patience=7, verbose=False, delta=0):
|
|
self.patience = patience
|
|
self.verbose = verbose
|
|
self.counter = 0
|
|
self.best_score = None
|
|
self.early_stop = False
|
|
self.val_loss_min = np.Inf
|
|
self.delta = delta
|
|
|
|
def __call__(self, val_loss, model, path):
|
|
score = -val_loss
|
|
if self.best_score is None:
|
|
self.best_score = score
|
|
self.save_checkpoint(val_loss, model, path)
|
|
elif score < self.best_score + self.delta:
|
|
self.counter += 1
|
|
print(f"EarlyStopping counter: {self.counter} out of {self.patience}")
|
|
if self.counter >= self.patience:
|
|
self.early_stop = True
|
|
else:
|
|
self.best_score = score
|
|
self.save_checkpoint(val_loss, model, path)
|
|
self.counter = 0
|
|
|
|
def save_checkpoint(self, val_loss, model, path):
|
|
if self.verbose:
|
|
print(
|
|
f"Validation loss decreased ({self.val_loss_min:.6f} --> {val_loss:.6f}). Saving model ..."
|
|
)
|
|
torch.save(model.state_dict(), path + "/" + "checkpoint.pth")
|
|
self.val_loss_min = val_loss
|
|
|
|
|
|
class dotdict(dict):
|
|
"""dot.notation access to dictionary attributes"""
|
|
|
|
__getattr__ = dict.get
|
|
__setattr__ = dict.__setitem__
|
|
__delattr__ = dict.__delitem__
|
|
|
|
|
|
class StandardScaler:
|
|
def __init__(self, mean, std):
|
|
self.mean = mean
|
|
self.std = std
|
|
|
|
def transform(self, data):
|
|
return (data - self.mean) / self.std
|
|
|
|
def inverse_transform(self, data):
|
|
return (data * self.std) + self.mean
|
|
|
|
|
|
def visual(true, preds=None, name="./pic/test.pdf"):
|
|
"""
|
|
Results visualization
|
|
"""
|
|
plt.figure()
|
|
plt.plot(true, label="GroundTruth", linewidth=2)
|
|
if preds is not None:
|
|
plt.plot(preds, label="Prediction", linewidth=2)
|
|
plt.legend()
|
|
plt.savefig(name, bbox_inches="tight")
|
|
|
|
|
|
def decor_time(func):
|
|
def func2(*args, **kw):
|
|
now = time.time()
|
|
y = func(*args, **kw)
|
|
t = time.time() - now
|
|
print("call <{}>, time={}".format(func.__name__, t))
|
|
return y
|
|
|
|
return func2
|
|
|
|
|
|
class AutoCorrelation(nn.Module):
|
|
"""
|
|
AutoCorrelation Mechanism with the following two phases:
|
|
(1) period-based dependencies discovery
|
|
(2) time delay aggregation
|
|
This block can replace the self-attention family mechanism seamlessly.
|
|
"""
|
|
|
|
def __init__(
|
|
self,
|
|
mask_flag=True,
|
|
factor=1,
|
|
scale=None,
|
|
attention_dropout=0.1,
|
|
output_attention=False,
|
|
wavelet=False,
|
|
):
|
|
super(AutoCorrelation, self).__init__()
|
|
print("Autocorrelation used !")
|
|
self.factor = factor
|
|
self.scale = scale
|
|
self.mask_flag = mask_flag
|
|
self.output_attention = output_attention
|
|
self.dropout = nn.Dropout(attention_dropout)
|
|
self.agg = None
|
|
self.use_wavelet = wavelet
|
|
|
|
# @decor_time
|
|
def time_delay_agg_training(self, values, corr):
|
|
"""
|
|
SpeedUp version of Autocorrelation (a batch-normalization style design)
|
|
This is for the training phase.
|
|
"""
|
|
head = values.shape[1]
|
|
channel = values.shape[2]
|
|
length = values.shape[3]
|
|
# find top k
|
|
top_k = int(self.factor * math.log(length))
|
|
mean_value = torch.mean(torch.mean(corr, dim=1), dim=1)
|
|
index = torch.topk(torch.mean(mean_value, dim=0), top_k, dim=-1)[1]
|
|
weights = torch.stack([mean_value[:, index[i]] for i in range(top_k)], dim=-1)
|
|
# update corr
|
|
tmp_corr = torch.softmax(weights, dim=-1)
|
|
# aggregation
|
|
tmp_values = values
|
|
delays_agg = torch.zeros_like(values).float()
|
|
for i in range(top_k):
|
|
pattern = torch.roll(tmp_values, -int(index[i]), -1)
|
|
delays_agg = delays_agg + pattern * (
|
|
tmp_corr[:, i]
|
|
.unsqueeze(1)
|
|
.unsqueeze(1)
|
|
.unsqueeze(1)
|
|
.repeat(1, head, channel, length)
|
|
)
|
|
return delays_agg # size=[B, H, d, S]
|
|
|
|
def time_delay_agg_inference(self, values, corr):
|
|
"""
|
|
SpeedUp version of Autocorrelation (a batch-normalization style design)
|
|
This is for the inference phase.
|
|
"""
|
|
batch = values.shape[0]
|
|
head = values.shape[1]
|
|
channel = values.shape[2]
|
|
length = values.shape[3]
|
|
# index init
|
|
init_index = (
|
|
torch.arange(length)
|
|
.unsqueeze(0)
|
|
.unsqueeze(0)
|
|
.unsqueeze(0)
|
|
.repeat(batch, head, channel, 1)
|
|
.cuda()
|
|
)
|
|
# find top k
|
|
top_k = int(self.factor * math.log(length))
|
|
mean_value = torch.mean(torch.mean(corr, dim=1), dim=1)
|
|
weights = torch.topk(mean_value, top_k, dim=-1)[0]
|
|
delay = torch.topk(mean_value, top_k, dim=-1)[1]
|
|
# update corr
|
|
tmp_corr = torch.softmax(weights, dim=-1)
|
|
# aggregation
|
|
tmp_values = values.repeat(1, 1, 1, 2)
|
|
delays_agg = torch.zeros_like(values).float()
|
|
for i in range(top_k):
|
|
tmp_delay = init_index + delay[:, i].unsqueeze(1).unsqueeze(1).unsqueeze(
|
|
1
|
|
).repeat(1, head, channel, length)
|
|
pattern = torch.gather(tmp_values, dim=-1, index=tmp_delay)
|
|
delays_agg = delays_agg + pattern * (
|
|
tmp_corr[:, i]
|
|
.unsqueeze(1)
|
|
.unsqueeze(1)
|
|
.unsqueeze(1)
|
|
.repeat(1, head, channel, length)
|
|
)
|
|
return delays_agg
|
|
|
|
def time_delay_agg_full(self, values, corr):
|
|
"""
|
|
Standard version of Autocorrelation
|
|
"""
|
|
batch = values.shape[0]
|
|
head = values.shape[1]
|
|
channel = values.shape[2]
|
|
length = values.shape[3]
|
|
# index init
|
|
init_index = (
|
|
torch.arange(length)
|
|
.unsqueeze(0)
|
|
.unsqueeze(0)
|
|
.unsqueeze(0)
|
|
.repeat(batch, head, channel, 1)
|
|
.cuda()
|
|
)
|
|
# find top k
|
|
top_k = int(self.factor * math.log(length))
|
|
weights = torch.topk(corr, top_k, dim=-1)[0]
|
|
delay = torch.topk(corr, top_k, dim=-1)[1]
|
|
# update corr
|
|
tmp_corr = torch.softmax(weights, dim=-1)
|
|
# aggregation
|
|
tmp_values = values.repeat(1, 1, 1, 2)
|
|
delays_agg = torch.zeros_like(values).float()
|
|
for i in range(top_k):
|
|
tmp_delay = init_index + delay[..., i].unsqueeze(-1)
|
|
pattern = torch.gather(tmp_values, dim=-1, index=tmp_delay)
|
|
delays_agg = delays_agg + pattern * (tmp_corr[..., i].unsqueeze(-1))
|
|
return delays_agg
|
|
|
|
def forward(self, queries, keys, values, attn_mask):
|
|
B, L, H, E = queries.shape
|
|
_, S, _, D = values.shape
|
|
if L > S:
|
|
zeros = torch.zeros_like(queries[:, : (L - S), :]).float()
|
|
values = torch.cat([values, zeros], dim=1)
|
|
keys = torch.cat([keys, zeros], dim=1)
|
|
else:
|
|
values = values[:, :L, :, :]
|
|
keys = keys[:, :L, :, :]
|
|
|
|
# period-based dependencies
|
|
if self.use_wavelet != 2:
|
|
if self.use_wavelet == 1:
|
|
j_list = self.j_list
|
|
queries = queries.reshape([B, L, -1])
|
|
keys = keys.reshape([B, L, -1])
|
|
Ql, Qh_list = self.dwt1d(queries.transpose(1, 2)) # [B, H*D, L]
|
|
Kl, Kh_list = self.dwt1d(keys.transpose(1, 2))
|
|
qs = [queries.transpose(1, 2)] + Qh_list + [Ql] # [B, H*D, L]
|
|
ks = [keys.transpose(1, 2)] + Kh_list + [Kl]
|
|
q_list = []
|
|
k_list = []
|
|
for q, k, j in zip(qs, ks, j_list):
|
|
q_list += [interpolate(q, scale_factor=j, mode="linear")[:, :, -L:]]
|
|
k_list += [interpolate(k, scale_factor=j, mode="linear")[:, :, -L:]]
|
|
queries = (
|
|
torch.stack([i.reshape([B, H, E, L]) for i in q_list], dim=3)
|
|
.reshape([B, H, -1, L])
|
|
.permute(0, 3, 1, 2)
|
|
)
|
|
keys = (
|
|
torch.stack([i.reshape([B, H, E, L]) for i in k_list], dim=3)
|
|
.reshape([B, H, -1, L])
|
|
.permute(0, 3, 1, 2)
|
|
)
|
|
else:
|
|
pass
|
|
q_fft = torch.fft.rfft(
|
|
queries.permute(0, 2, 3, 1).contiguous(), dim=-1
|
|
) # size=[B, H, E, L]
|
|
k_fft = torch.fft.rfft(keys.permute(0, 2, 3, 1).contiguous(), dim=-1)
|
|
res = q_fft * torch.conj(k_fft)
|
|
corr = torch.fft.irfft(res, dim=-1) # size=[B, H, E, L]
|
|
|
|
# time delay agg
|
|
if self.training:
|
|
V = self.time_delay_agg_training(
|
|
values.permute(0, 2, 3, 1).contiguous(), corr
|
|
).permute(
|
|
0, 3, 1, 2
|
|
) # [B, L, H, E], [B, H, E, L] -> [B, L, H, E]
|
|
else:
|
|
V = self.time_delay_agg_inference(
|
|
values.permute(0, 2, 3, 1).contiguous(), corr
|
|
).permute(0, 3, 1, 2)
|
|
else:
|
|
V_list = []
|
|
queries = queries.reshape([B, L, -1])
|
|
keys = keys.reshape([B, L, -1])
|
|
values = values.reshape([B, L, -1])
|
|
Ql, Qh_list = self.dwt1d(queries.transpose(1, 2)) # [B, H*D, L]
|
|
Kl, Kh_list = self.dwt1d(keys.transpose(1, 2))
|
|
Vl, Vh_list = self.dwt1d(values.transpose(1, 2))
|
|
qs = Qh_list + [Ql] # [B, H*D, L]
|
|
ks = Kh_list + [Kl]
|
|
vs = Vh_list + [Vl]
|
|
for q, k, v in zip(qs, ks, vs):
|
|
q = q.reshape([B, H, E, -1])
|
|
k = k.reshape([B, H, E, -1])
|
|
v = v.reshape([B, H, E, -1]).permute(0, 3, 1, 2)
|
|
q_fft = torch.fft.rfft(q.contiguous(), dim=-1)
|
|
k_fft = torch.fft.rfft(k.contiguous(), dim=-1)
|
|
res = q_fft * torch.conj(k_fft)
|
|
corr = torch.fft.irfft(res, dim=-1) # [B, H, E, L]
|
|
if self.training:
|
|
V = self.time_delay_agg_training(
|
|
v.permute(0, 2, 3, 1).contiguous(), corr
|
|
).permute(0, 3, 1, 2)
|
|
else:
|
|
V = self.time_delay_agg_inference(
|
|
v.permute(0, 2, 3, 1).contiguous(), corr
|
|
).permute(0, 3, 1, 2)
|
|
V_list += [V]
|
|
Vl = V_list[-1].reshape([B, -1, H * E]).transpose(1, 2)
|
|
Vh_list = [i.reshape([B, -1, H * E]).transpose(1, 2) for i in V_list[:-1]]
|
|
V = self.dwt1div((Vl, Vh_list)).reshape([B, H, E, -1]).permute(0, 3, 1, 2)
|
|
# corr = self.dwt1div((V_list[-1], V_list[:-1]))
|
|
|
|
if self.output_attention:
|
|
return (V.contiguous(), corr.permute(0, 3, 1, 2)) # size = [B, L, H, E]
|
|
else:
|
|
return (V.contiguous(), None)
|
|
|
|
|
|
class AutoCorrelationLayer(nn.Module):
|
|
def __init__(self, correlation, d_model, n_heads, d_keys=None, d_values=None):
|
|
super(AutoCorrelationLayer, self).__init__()
|
|
|
|
d_keys = d_keys or (d_model // n_heads)
|
|
d_values = d_values or (d_model // n_heads)
|
|
|
|
self.inner_correlation = correlation
|
|
self.query_projection = nn.Linear(d_model, d_keys * n_heads)
|
|
self.key_projection = nn.Linear(d_model, d_keys * n_heads)
|
|
self.value_projection = nn.Linear(d_model, d_values * n_heads)
|
|
self.out_projection = nn.Linear(d_values * n_heads, d_model)
|
|
self.n_heads = n_heads
|
|
|
|
def forward(self, queries, keys, values, attn_mask):
|
|
B, L, _ = queries.shape
|
|
_, S, _ = keys.shape
|
|
H = self.n_heads
|
|
print(queries.size())
|
|
print("query proj", self.query_projection(queries).size())
|
|
queries = self.query_projection(queries).view(B, L, H, -1)
|
|
keys = self.key_projection(keys).view(B, S, H, -1)
|
|
values = self.value_projection(values).view(B, S, H, -1)
|
|
|
|
out, attn = self.inner_correlation(queries, keys, values, attn_mask)
|
|
|
|
out = out.view(B, L, -1)
|
|
return self.out_projection(out), attn
|
|
|
|
|
|
class my_Layernorm(nn.Module):
|
|
"""
|
|
Special designed layernorm for the seasonal part
|
|
"""
|
|
|
|
def __init__(self, channels):
|
|
super(my_Layernorm, self).__init__()
|
|
self.layernorm = nn.LayerNorm(channels)
|
|
|
|
def forward(self, x):
|
|
x_hat = self.layernorm(x)
|
|
bias = torch.mean(x_hat, dim=1).unsqueeze(1).repeat(1, x.shape[1], 1)
|
|
return x_hat - bias
|
|
|
|
|
|
class moving_avg(nn.Module):
|
|
"""
|
|
Moving average block to highlight the trend of time series
|
|
"""
|
|
|
|
def __init__(self, kernel_size, stride):
|
|
super(moving_avg, self).__init__()
|
|
self.kernel_size = kernel_size
|
|
self.avg = nn.AvgPool1d(kernel_size=kernel_size, stride=stride, padding=0)
|
|
|
|
def forward(self, x):
|
|
# padding on the both ends of time series
|
|
front = x[:, 0:1, :].repeat(
|
|
1, self.kernel_size - 1 - math.floor((self.kernel_size - 1) // 2), 1
|
|
)
|
|
end = x[:, -1:, :].repeat(1, math.floor((self.kernel_size - 1) // 2), 1)
|
|
x = torch.cat([front, x, end], dim=1)
|
|
x = self.avg(x.permute(0, 2, 1))
|
|
x = x.permute(0, 2, 1)
|
|
return x
|
|
|
|
|
|
class series_decomp(nn.Module):
|
|
"""
|
|
Series decomposition block
|
|
"""
|
|
|
|
def __init__(self, kernel_size):
|
|
super(series_decomp, self).__init__()
|
|
self.moving_avg = moving_avg(kernel_size, stride=1)
|
|
|
|
def forward(self, x):
|
|
moving_mean = self.moving_avg(x)
|
|
res = x - moving_mean
|
|
return res, moving_mean
|
|
|
|
|
|
class series_decomp_multi(nn.Module):
|
|
"""
|
|
Series decomposition block
|
|
"""
|
|
|
|
def __init__(self, kernel_size):
|
|
super(series_decomp_multi, self).__init__()
|
|
self.moving_avg = [moving_avg(kernel, stride=1) for kernel in kernel_size]
|
|
self.layer = torch.nn.Linear(1, len(kernel_size))
|
|
|
|
def forward(self, x):
|
|
moving_mean = []
|
|
for func in self.moving_avg:
|
|
moving_avg = func(x)
|
|
moving_mean.append(moving_avg.unsqueeze(-1))
|
|
moving_mean = torch.cat(moving_mean, dim=-1)
|
|
moving_mean = torch.sum(
|
|
moving_mean * nn.Softmax(-1)(self.layer(x.unsqueeze(-1))), dim=-1
|
|
)
|
|
res = x - moving_mean
|
|
return res, moving_mean
|
|
|
|
|
|
class FourierDecomp(nn.Module):
|
|
def __init__(self):
|
|
super(FourierDecomp, self).__init__()
|
|
pass
|
|
|
|
def forward(self, x):
|
|
x_ft = torch.fft.rfft(x, dim=-1)
|
|
|
|
|
|
class EncoderLayer(nn.Module):
|
|
"""
|
|
Autoformer encoder layer with the progressive decomposition architecture
|
|
"""
|
|
|
|
def __init__(
|
|
self,
|
|
attention,
|
|
d_model,
|
|
d_ff=None,
|
|
moving_avg=25,
|
|
dropout=0.1,
|
|
activation="relu",
|
|
):
|
|
super(EncoderLayer, self).__init__()
|
|
d_ff = d_ff or 4 * d_model
|
|
self.attention = attention
|
|
self.conv1 = nn.Conv1d(
|
|
in_channels=d_model, out_channels=d_ff, kernel_size=1, bias=False
|
|
)
|
|
self.conv2 = nn.Conv1d(
|
|
in_channels=d_ff, out_channels=d_model, kernel_size=1, bias=False
|
|
)
|
|
|
|
if isinstance(moving_avg, list):
|
|
self.decomp1 = series_decomp_multi(moving_avg)
|
|
self.decomp2 = series_decomp_multi(moving_avg)
|
|
else:
|
|
self.decomp1 = series_decomp(moving_avg)
|
|
self.decomp2 = series_decomp(moving_avg)
|
|
|
|
self.dropout = nn.Dropout(dropout)
|
|
self.activation = F.relu if activation == "relu" else F.gelu
|
|
|
|
def forward(self, x, attn_mask=None):
|
|
new_x, attn = self.attention(x, x, x, attn_mask=attn_mask)
|
|
x = x + self.dropout(new_x)
|
|
x, _ = self.decomp1(x)
|
|
y = x
|
|
y = self.dropout(self.activation(self.conv1(y.transpose(-1, 1))))
|
|
y = self.dropout(self.conv2(y).transpose(-1, 1))
|
|
res, _ = self.decomp2(x + y)
|
|
return res, attn
|
|
|
|
|
|
class Encoder(nn.Module):
|
|
"""
|
|
Autoformer encoder
|
|
"""
|
|
|
|
def __init__(self, attn_layers, conv_layers=None, norm_layer=None):
|
|
super(Encoder, self).__init__()
|
|
self.attn_layers = nn.ModuleList(attn_layers)
|
|
self.conv_layers = (
|
|
nn.ModuleList(conv_layers) if conv_layers is not None else None
|
|
)
|
|
self.norm = norm_layer
|
|
|
|
def forward(self, x, attn_mask=None):
|
|
attns = []
|
|
if self.conv_layers is not None:
|
|
for attn_layer, conv_layer in zip(self.attn_layers, self.conv_layers):
|
|
x, attn = attn_layer(x, attn_mask=attn_mask)
|
|
x = conv_layer(x)
|
|
attns.append(attn)
|
|
x, attn = self.attn_layers[-1](x)
|
|
attns.append(attn)
|
|
else:
|
|
for attn_layer in self.attn_layers:
|
|
x, attn = attn_layer(x, attn_mask=attn_mask)
|
|
attns.append(attn)
|
|
|
|
if self.norm is not None:
|
|
x = self.norm(x)
|
|
|
|
return x, attns
|
|
|
|
|
|
class DecoderLayer(nn.Module):
|
|
"""
|
|
Autoformer decoder layer with the progressive decomposition architecture
|
|
"""
|
|
|
|
def __init__(
|
|
self,
|
|
self_attention,
|
|
cross_attention,
|
|
d_model,
|
|
c_out,
|
|
d_ff=None,
|
|
moving_avg=25,
|
|
dropout=0.1,
|
|
activation="relu",
|
|
):
|
|
super(DecoderLayer, self).__init__()
|
|
d_ff = d_ff or 4 * d_model
|
|
self.self_attention = self_attention
|
|
self.cross_attention = cross_attention
|
|
self.conv1 = nn.Conv1d(
|
|
in_channels=d_model, out_channels=d_ff, kernel_size=1, bias=False
|
|
)
|
|
self.conv2 = nn.Conv1d(
|
|
in_channels=d_ff, out_channels=d_model, kernel_size=1, bias=False
|
|
)
|
|
|
|
if isinstance(moving_avg, list):
|
|
self.decomp1 = series_decomp_multi(moving_avg)
|
|
self.decomp2 = series_decomp_multi(moving_avg)
|
|
self.decomp3 = series_decomp_multi(moving_avg)
|
|
else:
|
|
self.decomp1 = series_decomp(moving_avg)
|
|
self.decomp2 = series_decomp(moving_avg)
|
|
self.decomp3 = series_decomp(moving_avg)
|
|
|
|
self.dropout = nn.Dropout(dropout)
|
|
self.projection = nn.Conv1d(
|
|
in_channels=d_model,
|
|
out_channels=c_out,
|
|
kernel_size=3,
|
|
stride=1,
|
|
padding=1,
|
|
padding_mode="circular",
|
|
bias=False,
|
|
)
|
|
self.activation = F.relu if activation == "relu" else F.gelu
|
|
|
|
def forward(self, x, cross, x_mask=None, cross_mask=None):
|
|
x = x + self.dropout(self.self_attention(x, x, x, attn_mask=x_mask)[0])
|
|
|
|
x, trend1 = self.decomp1(x)
|
|
x = x + self.dropout(
|
|
self.cross_attention(x, cross, cross, attn_mask=cross_mask)[0]
|
|
)
|
|
|
|
x, trend2 = self.decomp2(x)
|
|
y = x
|
|
y = self.dropout(self.activation(self.conv1(y.transpose(-1, 1))))
|
|
y = self.dropout(self.conv2(y).transpose(-1, 1))
|
|
x, trend3 = self.decomp3(x + y)
|
|
|
|
residual_trend = trend1 + trend2 + trend3
|
|
residual_trend = self.projection(residual_trend.permute(0, 2, 1)).transpose(
|
|
1, 2
|
|
)
|
|
return x, residual_trend
|
|
|
|
|
|
class Decoder(nn.Module):
|
|
"""
|
|
Autoformer encoder
|
|
"""
|
|
|
|
def __init__(self, layers, norm_layer=None, projection=None):
|
|
super(Decoder, self).__init__()
|
|
self.layers = nn.ModuleList(layers)
|
|
self.norm = norm_layer
|
|
self.projection = projection
|
|
|
|
def forward(self, x, cross, x_mask=None, cross_mask=None, trend=None):
|
|
for layer in self.layers:
|
|
x, residual_trend = layer(x, cross, x_mask=x_mask, cross_mask=cross_mask)
|
|
trend = trend + residual_trend
|
|
|
|
if self.norm is not None:
|
|
x = self.norm(x)
|
|
|
|
if self.projection is not None:
|
|
x = self.projection(x)
|
|
return x, trend
|
|
|
|
|
|
class PositionalEmbedding(nn.Module):
|
|
def __init__(self, d_model, max_len=5000):
|
|
super(PositionalEmbedding, self).__init__()
|
|
# Compute the positional encodings once in log space.
|
|
pe = torch.zeros(max_len, d_model).float()
|
|
pe.require_grad = False
|
|
|
|
position = torch.arange(0, max_len).float().unsqueeze(1)
|
|
div_term = (
|
|
torch.arange(0, d_model, 2).float() * -(math.log(10000.0) / d_model)
|
|
).exp()
|
|
|
|
pe[:, 0::2] = torch.sin(position * div_term)
|
|
pe[:, 1::2] = torch.cos(position * div_term)
|
|
|
|
pe = pe.unsqueeze(0)
|
|
self.register_buffer("pe", pe)
|
|
|
|
def forward(self, x):
|
|
return self.pe[:, : x.size(1)]
|
|
|
|
|
|
class TokenEmbedding(nn.Module):
|
|
def __init__(self, c_in, d_model):
|
|
super(TokenEmbedding, self).__init__()
|
|
padding = 1 if torch.__version__ >= "1.5.0" else 2
|
|
self.tokenConv = nn.Conv1d(
|
|
in_channels=c_in,
|
|
out_channels=d_model,
|
|
kernel_size=3,
|
|
padding=padding,
|
|
padding_mode="circular",
|
|
bias=False,
|
|
)
|
|
for m in self.modules():
|
|
if isinstance(m, nn.Conv1d):
|
|
nn.init.kaiming_normal_(
|
|
m.weight, mode="fan_in", nonlinearity="leaky_relu"
|
|
)
|
|
|
|
def forward(self, x):
|
|
x = self.tokenConv(x.permute(0, 2, 1)).transpose(1, 2)
|
|
return x
|
|
|
|
|
|
class FixedEmbedding(nn.Module):
|
|
def __init__(self, c_in, d_model):
|
|
super(FixedEmbedding, self).__init__()
|
|
|
|
w = torch.zeros(c_in, d_model).float()
|
|
w.require_grad = False
|
|
|
|
position = torch.arange(0, c_in).float().unsqueeze(1)
|
|
div_term = (
|
|
torch.arange(0, d_model, 2).float() * -(math.log(10000.0) / d_model)
|
|
).exp()
|
|
|
|
w[:, 0::2] = torch.sin(position * div_term)
|
|
w[:, 1::2] = torch.cos(position * div_term)
|
|
|
|
self.emb = nn.Embedding(c_in, d_model)
|
|
self.emb.weight = nn.Parameter(w, requires_grad=False)
|
|
|
|
def forward(self, x):
|
|
return self.emb(x).detach()
|
|
|
|
|
|
class TemporalEmbedding(nn.Module):
|
|
def __init__(self, d_model, embed_type="fixed", freq="h"):
|
|
super(TemporalEmbedding, self).__init__()
|
|
|
|
minute_size = 4
|
|
hour_size = 24
|
|
weekday_size = 7
|
|
day_size = 32
|
|
month_size = 13
|
|
|
|
Embed = FixedEmbedding if embed_type == "fixed" else nn.Embedding
|
|
if freq == "t":
|
|
self.minute_embed = Embed(minute_size, d_model)
|
|
self.hour_embed = Embed(hour_size, d_model)
|
|
self.weekday_embed = Embed(weekday_size, d_model)
|
|
self.day_embed = Embed(day_size, d_model)
|
|
self.month_embed = Embed(month_size, d_model)
|
|
|
|
def forward(self, x):
|
|
x = x.long()
|
|
|
|
minute_x = (
|
|
self.minute_embed(x[:, :, 4]) if hasattr(self, "minute_embed") else 0.0
|
|
)
|
|
hour_x = self.hour_embed(x[:, :, 3])
|
|
weekday_x = self.weekday_embed(x[:, :, 2])
|
|
day_x = self.day_embed(x[:, :, 1])
|
|
month_x = self.month_embed(x[:, :, 0])
|
|
|
|
return hour_x + weekday_x + day_x + month_x + minute_x
|
|
|
|
|
|
class TimeFeatureEmbedding(nn.Module):
|
|
def __init__(self, d_model, embed_type="timeF", freq="h"):
|
|
super(TimeFeatureEmbedding, self).__init__()
|
|
|
|
freq_map = {"h": 4, "t": 5, "s": 6, "m": 1, "a": 1, "w": 2, "d": 3, "b": 3}
|
|
d_inp = freq_map[freq]
|
|
self.embed = nn.Linear(d_inp, d_model, bias=False)
|
|
|
|
def forward(self, x):
|
|
return self.embed(x)
|
|
|
|
|
|
class DataEmbedding(nn.Module):
|
|
def __init__(self, c_in, d_model, embed_type="fixed", freq="h", dropout=0.1):
|
|
super(DataEmbedding, self).__init__()
|
|
|
|
self.value_embedding = TokenEmbedding(c_in=c_in, d_model=d_model)
|
|
self.position_embedding = PositionalEmbedding(d_model=d_model)
|
|
self.temporal_embedding = (
|
|
TemporalEmbedding(d_model=d_model, embed_type=embed_type, freq=freq)
|
|
if embed_type != "timeF"
|
|
else TimeFeatureEmbedding(d_model=d_model, embed_type=embed_type, freq=freq)
|
|
)
|
|
self.dropout = nn.Dropout(p=dropout)
|
|
|
|
def forward(self, x, x_mark):
|
|
x = (
|
|
self.value_embedding(x)
|
|
+ self.temporal_embedding(x_mark)
|
|
+ self.position_embedding(x)
|
|
)
|
|
return self.dropout(x)
|
|
|
|
|
|
class DataEmbedding_onlypos(nn.Module):
|
|
def __init__(self, c_in, d_model, embed_type="fixed", freq="h", dropout=0.1):
|
|
super(DataEmbedding_onlypos, self).__init__()
|
|
|
|
self.value_embedding = TokenEmbedding(c_in=c_in, d_model=d_model)
|
|
self.position_embedding = PositionalEmbedding(d_model=d_model)
|
|
self.dropout = nn.Dropout(p=dropout)
|
|
|
|
def forward(self, x, x_mark):
|
|
x = self.value_embedding(x) + self.position_embedding(x)
|
|
return self.dropout(x)
|
|
|
|
|
|
class DataEmbedding_wo_pos(nn.Module):
|
|
def __init__(self, c_in, d_model, embed_type="fixed", freq="h", dropout=0.1):
|
|
super(DataEmbedding_wo_pos, self).__init__()
|
|
|
|
self.value_embedding = TokenEmbedding(c_in=c_in, d_model=d_model)
|
|
self.position_embedding = PositionalEmbedding(d_model=d_model)
|
|
self.temporal_embedding = (
|
|
TemporalEmbedding(d_model=d_model, embed_type=embed_type, freq=freq)
|
|
if embed_type != "timeF"
|
|
else TimeFeatureEmbedding(d_model=d_model, embed_type=embed_type, freq=freq)
|
|
)
|
|
self.dropout = nn.Dropout(p=dropout)
|
|
|
|
def forward(self, x, x_mark):
|
|
# try:
|
|
x = self.value_embedding(x) + self.temporal_embedding(x_mark)
|
|
# except:
|
|
# a = 1
|
|
return self.dropout(x)
|
|
|
|
|
|
def get_frequency_modes(seq_len, modes=64, mode_select_method="random"):
|
|
"""
|
|
get modes on frequency domain:
|
|
'random' means sampling randomly;
|
|
'else' means sampling the lowest modes;
|
|
"""
|
|
modes = min(modes, seq_len // 2)
|
|
if mode_select_method == "random":
|
|
index = list(range(0, seq_len // 2))
|
|
np.random.shuffle(index)
|
|
index = index[:modes]
|
|
else:
|
|
index = list(range(0, modes))
|
|
index.sort()
|
|
return index
|
|
|
|
|
|
# ########## fourier layer #############
|
|
class FourierBlock(nn.Module):
|
|
def __init__(
|
|
self,
|
|
n_heads,
|
|
in_channels,
|
|
out_channels,
|
|
seq_len,
|
|
modes=0,
|
|
mode_select_method="random",
|
|
):
|
|
super(FourierBlock, self).__init__()
|
|
print("fourier enhanced block used!")
|
|
"""
|
|
1D Fourier block. It performs representation learning on frequency domain,
|
|
it does FFT, linear transform, and Inverse FFT.
|
|
"""
|
|
# get modes on frequency domain
|
|
self.index = get_frequency_modes(
|
|
seq_len, modes=modes, mode_select_method=mode_select_method
|
|
)
|
|
print("modes={}, index={}".format(modes, self.index))
|
|
|
|
self.scale = 1 / (in_channels * out_channels)
|
|
self.weights1 = nn.Parameter(
|
|
self.scale
|
|
* torch.rand(
|
|
n_heads,
|
|
in_channels // n_heads,
|
|
out_channels // n_heads,
|
|
len(self.index),
|
|
dtype=torch.cfloat,
|
|
)
|
|
)
|
|
|
|
# Complex multiplication
|
|
def compl_mul1d(self, input, weights):
|
|
# (batch, in_channel, x ), (in_channel, out_channel, x) -> (batch, out_channel, x)
|
|
return torch.einsum("bhi,hio->bho", input, weights) # hio->bho
|
|
|
|
def forward(self, q, k, v, mask):
|
|
# size = [B, L, H, E]
|
|
B, L, H, E = q.shape
|
|
x = q.permute(0, 2, 3, 1) # [B, H, E, L]
|
|
# Compute Fourier coefficients
|
|
x_ft = torch.fft.rfft(x, dim=-1)
|
|
|
|
print('x_ft size',x_ft.size()) # [B, H, E, L]
|
|
print('weight size', self.weights1.size())
|
|
print('index', self.index)
|
|
# Perform Fourier neural operations
|
|
out_ft = torch.zeros(B, H, E, L // 2 + 1, device=x.device, dtype=torch.cfloat)
|
|
for wi, i in enumerate(self.index):
|
|
out_ft[:, :, :, wi] = self.compl_mul1d(
|
|
x_ft[:, :, :, i], self.weights1[:, :, :, wi]
|
|
)
|
|
# Return to time domain
|
|
x = torch.fft.irfft(out_ft, n=x.size(-1))
|
|
return (x, None)
|
|
|
|
|
|
# ########## Fourier Cross Former ####################
|
|
class FourierCrossAttention(nn.Module):
|
|
def __init__(
|
|
self,
|
|
n_heads,
|
|
in_channels,
|
|
out_channels,
|
|
seq_len_q,
|
|
seq_len_kv,
|
|
modes=64,
|
|
mode_select_method="random",
|
|
activation="tanh",
|
|
policy=0,
|
|
):
|
|
super(FourierCrossAttention, self).__init__()
|
|
print(" fourier enhanced cross attention used!")
|
|
"""
|
|
1D Fourier Cross Attention layer. It does FFT, linear transform, attention mechanism and Inverse FFT.
|
|
"""
|
|
self.activation = activation
|
|
self.in_channels = in_channels
|
|
self.out_channels = out_channels
|
|
# get modes for queries and keys (& values) on frequency domain
|
|
self.index_q = get_frequency_modes(
|
|
seq_len_q, modes=modes, mode_select_method=mode_select_method
|
|
)
|
|
self.index_kv = get_frequency_modes(
|
|
seq_len_kv, modes=modes, mode_select_method=mode_select_method
|
|
)
|
|
|
|
# print("modes_q={}, index_q={}".format(len(self.index_q), self.index_q))
|
|
# print("modes_kv={}, index_kv={}".format(len(self.index_kv), self.index_kv))
|
|
|
|
self.scale = 1 / (in_channels * out_channels)
|
|
self.weights1 = nn.Parameter(
|
|
self.scale
|
|
* torch.rand(
|
|
n_heads,
|
|
in_channels // n_heads,
|
|
out_channels // n_heads,
|
|
len(self.index_q),
|
|
dtype=torch.cfloat,
|
|
)
|
|
)
|
|
|
|
# Complex multiplication
|
|
def compl_mul1d(self, input, weights):
|
|
# (batch, in_channel, x ), (in_channel, out_channel, x) -> (batch, out_channel, x)
|
|
return torch.einsum("bhi,hoi->boi", input, weights) # bhi,hio->bho"
|
|
|
|
def forward(self, q, k, v, mask):
|
|
# size = [B, L, H, E]
|
|
B, L, H, E = q.shape
|
|
xq = q.permute(0, 2, 3, 1) # size = [B, H, E, L]
|
|
xk = k.permute(0, 2, 3, 1)
|
|
xv = v.permute(0, 2, 3, 1)
|
|
|
|
# Compute Fourier coefficients
|
|
xq_ft_ = torch.zeros(
|
|
B, H, E, len(self.index_q), device=xq.device, dtype=torch.cfloat
|
|
)
|
|
xq_ft = torch.fft.rfft(xq, dim=-1)
|
|
for i, j in enumerate(self.index_q):
|
|
xq_ft_[:, :, :, i] = xq_ft[:, :, :, j]
|
|
xk_ft_ = torch.zeros(
|
|
B, H, E, len(self.index_kv), device=xq.device, dtype=torch.cfloat
|
|
)
|
|
xk_ft = torch.fft.rfft(xk, dim=-1)
|
|
for i, j in enumerate(self.index_kv):
|
|
xk_ft_[:, :, :, i] = xk_ft[:, :, :, j]
|
|
|
|
# perform attention mechanism on frequency domain
|
|
xqk_ft = torch.einsum("bhex,bhey->bhxy", xq_ft_, xk_ft_)
|
|
if self.activation == "tanh":
|
|
xqk_ft = xqk_ft.tanh()
|
|
elif self.activation == "softmax":
|
|
xqk_ft = torch.softmax(abs(xqk_ft), dim=-1)
|
|
xqk_ft = torch.complex(xqk_ft, torch.zeros_like(xqk_ft))
|
|
else:
|
|
raise Exception(
|
|
"{} actiation function is not implemented".format(self.activation)
|
|
)
|
|
xqkv_ft = torch.einsum("bhxy,bhey->bhex", xqk_ft, xk_ft_)
|
|
xqkvw = torch.einsum("bhex,heox->bhox", xqkv_ft, self.weights1)
|
|
out_ft = torch.zeros(B, H, E, L // 2 + 1, device=xq.device, dtype=torch.cfloat)
|
|
for i, j in enumerate(self.index_q):
|
|
out_ft[:, :, :, j] = xqkvw[:, :, :, i]
|
|
# Return to time domain
|
|
out = torch.fft.irfft(
|
|
out_ft / self.in_channels / self.out_channels, n=xq.size(-1)
|
|
)
|
|
return (out, None)
|
|
|
|
|
|
class MultiWaveletTransform(nn.Module):
|
|
"""
|
|
1D multiwavelet block.
|
|
"""
|
|
|
|
def __init__(
|
|
self,
|
|
ich=1,
|
|
k=8,
|
|
alpha=16,
|
|
c=128,
|
|
nCZ=1,
|
|
L=0,
|
|
base="legendre",
|
|
attention_dropout=0.1,
|
|
):
|
|
super(MultiWaveletTransform, self).__init__()
|
|
self.k = k
|
|
self.c = c
|
|
self.L = L
|
|
self.nCZ = nCZ
|
|
self.Lk0 = nn.Linear(ich, c * k)
|
|
self.Lk1 = nn.Linear(c * k, ich)
|
|
self.ich = ich
|
|
self.MWT_CZ = nn.ModuleList(MWT_CZ1d(k, alpha, L, c, base) for i in range(nCZ))
|
|
|
|
def forward(self, queries, keys, values, attn_mask):
|
|
B, L, H, E = queries.shape
|
|
_, S, _, D = values.shape
|
|
if L > S:
|
|
zeros = torch.zeros_like(queries[:, : (L - S), :]).float()
|
|
values = torch.cat([values, zeros], dim=1)
|
|
keys = torch.cat([keys, zeros], dim=1)
|
|
else:
|
|
values = values[:, :L, :, :]
|
|
keys = keys[:, :L, :, :]
|
|
values = values.view(B, L, -1)
|
|
|
|
V = self.Lk0(values).view(B, L, self.c, -1)
|
|
for i in range(self.nCZ):
|
|
V = self.MWT_CZ[i](V)
|
|
if i < self.nCZ - 1:
|
|
V = F.relu(V)
|
|
|
|
V = self.Lk1(V.view(B, L, -1))
|
|
V = V.view(B, L, -1, D)
|
|
return (V.contiguous(), None)
|
|
|
|
|
|
class MultiWaveletCross(nn.Module):
|
|
"""
|
|
1D Multiwavelet Cross Attention layer.
|
|
"""
|
|
|
|
def __init__(
|
|
self,
|
|
in_channels,
|
|
out_channels,
|
|
seq_len_q,
|
|
seq_len_kv,
|
|
modes,
|
|
c=64,
|
|
k=8,
|
|
ich=512,
|
|
L=0,
|
|
base="legendre",
|
|
mode_select_method="random",
|
|
initializer=None,
|
|
activation="tanh",
|
|
**kwargs,
|
|
):
|
|
super(MultiWaveletCross, self).__init__()
|
|
|
|
self.c = c
|
|
self.k = k
|
|
self.L = L
|
|
H0, H1, G0, G1, PHI0, PHI1 = get_filter(base, k)
|
|
H0r = H0 @ PHI0
|
|
G0r = G0 @ PHI0
|
|
H1r = H1 @ PHI1
|
|
G1r = G1 @ PHI1
|
|
|
|
H0r[np.abs(H0r) < 1e-8] = 0
|
|
H1r[np.abs(H1r) < 1e-8] = 0
|
|
G0r[np.abs(G0r) < 1e-8] = 0
|
|
G1r[np.abs(G1r) < 1e-8] = 0
|
|
self.max_item = 3
|
|
|
|
self.attn1 = FourierCrossAttentionW(
|
|
in_channels=in_channels,
|
|
out_channels=out_channels,
|
|
seq_len_q=seq_len_q,
|
|
seq_len_kv=seq_len_kv,
|
|
modes=modes,
|
|
activation=activation,
|
|
mode_select_method=mode_select_method,
|
|
)
|
|
self.attn2 = FourierCrossAttentionW(
|
|
in_channels=in_channels,
|
|
out_channels=out_channels,
|
|
seq_len_q=seq_len_q,
|
|
seq_len_kv=seq_len_kv,
|
|
modes=modes,
|
|
activation=activation,
|
|
mode_select_method=mode_select_method,
|
|
)
|
|
self.attn3 = FourierCrossAttentionW(
|
|
in_channels=in_channels,
|
|
out_channels=out_channels,
|
|
seq_len_q=seq_len_q,
|
|
seq_len_kv=seq_len_kv,
|
|
modes=modes,
|
|
activation=activation,
|
|
mode_select_method=mode_select_method,
|
|
)
|
|
self.attn4 = FourierCrossAttentionW(
|
|
in_channels=in_channels,
|
|
out_channels=out_channels,
|
|
seq_len_q=seq_len_q,
|
|
seq_len_kv=seq_len_kv,
|
|
modes=modes,
|
|
activation=activation,
|
|
mode_select_method=mode_select_method,
|
|
)
|
|
self.T0 = nn.Linear(k, k)
|
|
self.register_buffer("ec_s", torch.Tensor(np.concatenate((H0.T, H1.T), axis=0)))
|
|
self.register_buffer("ec_d", torch.Tensor(np.concatenate((G0.T, G1.T), axis=0)))
|
|
|
|
self.register_buffer("rc_e", torch.Tensor(np.concatenate((H0r, G0r), axis=0)))
|
|
self.register_buffer("rc_o", torch.Tensor(np.concatenate((H1r, G1r), axis=0)))
|
|
|
|
self.Lk = nn.Linear(ich, c * k)
|
|
self.Lq = nn.Linear(ich, c * k)
|
|
self.Lv = nn.Linear(ich, c * k)
|
|
self.out = nn.Linear(c * k, ich)
|
|
self.modes1 = modes
|
|
|
|
def forward(self, q, k, v, mask=None):
|
|
B, N, H, E = q.shape # (B, N, H, E) torch.Size([3, 768, 8, 2])
|
|
_, S, _, _ = k.shape # (B, S, H, E) torch.Size([3, 96, 8, 2])
|
|
|
|
q = q.view(q.shape[0], q.shape[1], -1)
|
|
k = k.view(k.shape[0], k.shape[1], -1)
|
|
v = v.view(v.shape[0], v.shape[1], -1)
|
|
q = self.Lq(q)
|
|
q = q.view(q.shape[0], q.shape[1], self.c, self.k)
|
|
k = self.Lk(k)
|
|
k = k.view(k.shape[0], k.shape[1], self.c, self.k)
|
|
v = self.Lv(v)
|
|
v = v.view(v.shape[0], v.shape[1], self.c, self.k)
|
|
|
|
if N > S:
|
|
zeros = torch.zeros_like(q[:, : (N - S), :]).float()
|
|
v = torch.cat([v, zeros], dim=1)
|
|
k = torch.cat([k, zeros], dim=1)
|
|
else:
|
|
v = v[:, :N, :, :]
|
|
k = k[:, :N, :, :]
|
|
|
|
ns = math.floor(np.log2(N))
|
|
nl = pow(2, math.ceil(np.log2(N)))
|
|
extra_q = q[:, 0 : nl - N, :, :]
|
|
extra_k = k[:, 0 : nl - N, :, :]
|
|
extra_v = v[:, 0 : nl - N, :, :]
|
|
q = torch.cat([q, extra_q], 1)
|
|
k = torch.cat([k, extra_k], 1)
|
|
v = torch.cat([v, extra_v], 1)
|
|
|
|
Ud_q = torch.jit.annotate(List[Tuple[Tensor]], [])
|
|
Ud_k = torch.jit.annotate(List[Tuple[Tensor]], [])
|
|
Ud_v = torch.jit.annotate(List[Tuple[Tensor]], [])
|
|
|
|
Us_q = torch.jit.annotate(List[Tensor], [])
|
|
Us_k = torch.jit.annotate(List[Tensor], [])
|
|
Us_v = torch.jit.annotate(List[Tensor], [])
|
|
|
|
Ud = torch.jit.annotate(List[Tensor], [])
|
|
Us = torch.jit.annotate(List[Tensor], [])
|
|
|
|
# decompose
|
|
for i in range(ns - self.L):
|
|
# print('q shape',q.shape)
|
|
d, q = self.wavelet_transform(q)
|
|
Ud_q += [tuple([d, q])]
|
|
Us_q += [d]
|
|
for i in range(ns - self.L):
|
|
d, k = self.wavelet_transform(k)
|
|
Ud_k += [tuple([d, k])]
|
|
Us_k += [d]
|
|
for i in range(ns - self.L):
|
|
d, v = self.wavelet_transform(v)
|
|
Ud_v += [tuple([d, v])]
|
|
Us_v += [d]
|
|
for i in range(ns - self.L):
|
|
dk, sk = Ud_k[i], Us_k[i]
|
|
dq, sq = Ud_q[i], Us_q[i]
|
|
dv, sv = Ud_v[i], Us_v[i]
|
|
Ud += [
|
|
self.attn1(dq[0], dk[0], dv[0], mask)[0]
|
|
+ self.attn2(dq[1], dk[1], dv[1], mask)[0]
|
|
]
|
|
Us += [self.attn3(sq, sk, sv, mask)[0]]
|
|
v = self.attn4(q, k, v, mask)[0]
|
|
|
|
# reconstruct
|
|
for i in range(ns - 1 - self.L, -1, -1):
|
|
v = v + Us[i]
|
|
v = torch.cat((v, Ud[i]), -1)
|
|
v = self.evenOdd(v)
|
|
v = self.out(v[:, :N, :, :].contiguous().view(B, N, -1))
|
|
return (v.contiguous(), None)
|
|
|
|
def wavelet_transform(self, x):
|
|
xa = torch.cat(
|
|
[
|
|
x[:, ::2, :, :],
|
|
x[:, 1::2, :, :],
|
|
],
|
|
-1,
|
|
)
|
|
d = torch.matmul(xa, self.ec_d)
|
|
s = torch.matmul(xa, self.ec_s)
|
|
return d, s
|
|
|
|
def evenOdd(self, x):
|
|
B, N, c, ich = x.shape # (B, N, c, k)
|
|
assert ich == 2 * self.k
|
|
x_e = torch.matmul(x, self.rc_e)
|
|
x_o = torch.matmul(x, self.rc_o)
|
|
|
|
x = torch.zeros(B, N * 2, c, self.k, device=x.device)
|
|
x[..., ::2, :, :] = x_e
|
|
x[..., 1::2, :, :] = x_o
|
|
return x
|
|
|
|
|
|
class FourierCrossAttentionW(nn.Module):
|
|
def __init__(
|
|
self,
|
|
in_channels,
|
|
out_channels,
|
|
seq_len_q,
|
|
seq_len_kv,
|
|
modes=16,
|
|
activation="tanh",
|
|
mode_select_method="random",
|
|
):
|
|
super(FourierCrossAttentionW, self).__init__()
|
|
print("corss fourier correlation used!")
|
|
self.in_channels = in_channels
|
|
self.out_channels = out_channels
|
|
self.modes1 = modes
|
|
self.activation = activation
|
|
|
|
def forward(self, q, k, v, mask):
|
|
B, L, E, H = q.shape
|
|
|
|
xq = q.permute(0, 3, 2, 1) # size = [B, H, E, L] torch.Size([3, 8, 64, 512])
|
|
xk = k.permute(0, 3, 2, 1)
|
|
xv = v.permute(0, 3, 2, 1)
|
|
self.index_q = list(range(0, min(int(L // 2), self.modes1)))
|
|
self.index_k_v = list(range(0, min(int(xv.shape[3] // 2), self.modes1)))
|
|
|
|
# Compute Fourier coefficients
|
|
xq_ft_ = torch.zeros(
|
|
B, H, E, len(self.index_q), device=xq.device, dtype=torch.cfloat
|
|
)
|
|
xq_ft = torch.fft.rfft(xq, dim=-1)
|
|
for i, j in enumerate(self.index_q):
|
|
xq_ft_[:, :, :, i] = xq_ft[:, :, :, j]
|
|
|
|
xk_ft_ = torch.zeros(
|
|
B, H, E, len(self.index_k_v), device=xq.device, dtype=torch.cfloat
|
|
)
|
|
xk_ft = torch.fft.rfft(xk, dim=-1)
|
|
for i, j in enumerate(self.index_k_v):
|
|
xk_ft_[:, :, :, i] = xk_ft[:, :, :, j]
|
|
xqk_ft = torch.einsum("bhex,bhey->bhxy", xq_ft_, xk_ft_)
|
|
if self.activation == "tanh":
|
|
xqk_ft = xqk_ft.tanh()
|
|
elif self.activation == "softmax":
|
|
xqk_ft = torch.softmax(abs(xqk_ft), dim=-1)
|
|
xqk_ft = torch.complex(xqk_ft, torch.zeros_like(xqk_ft))
|
|
else:
|
|
raise Exception(
|
|
"{} actiation function is not implemented".format(self.activation)
|
|
)
|
|
xqkv_ft = torch.einsum("bhxy,bhey->bhex", xqk_ft, xk_ft_)
|
|
|
|
xqkvw = xqkv_ft
|
|
out_ft = torch.zeros(B, H, E, L // 2 + 1, device=xq.device, dtype=torch.cfloat)
|
|
for i, j in enumerate(self.index_q):
|
|
out_ft[:, :, :, j] = xqkvw[:, :, :, i]
|
|
|
|
out = torch.fft.irfft(
|
|
out_ft / self.in_channels / self.out_channels, n=xq.size(-1)
|
|
).permute(0, 3, 2, 1)
|
|
# size = [B, L, H, E]
|
|
return (out, None)
|
|
|
|
|
|
class sparseKernelFT1d(nn.Module):
|
|
def __init__(self, k, alpha, c=1, nl=1, initializer=None, **kwargs):
|
|
super(sparseKernelFT1d, self).__init__()
|
|
|
|
self.modes1 = alpha
|
|
self.scale = 1 / (c * k * c * k)
|
|
self.weights1 = nn.Parameter(
|
|
self.scale * torch.rand(c * k, c * k, self.modes1, dtype=torch.cfloat)
|
|
)
|
|
self.weights1.requires_grad = True
|
|
self.k = k
|
|
|
|
def compl_mul1d(self, x, weights):
|
|
# (batch, in_channel, x ), (in_channel, out_channel, x) -> (batch, out_channel, x)
|
|
return torch.einsum("bix,iox->box", x, weights)
|
|
|
|
def forward(self, x):
|
|
B, N, c, k = x.shape # (B, N, c, k)
|
|
|
|
x = x.view(B, N, -1)
|
|
x = x.permute(0, 2, 1)
|
|
x_fft = torch.fft.rfft(x)
|
|
# Multiply relevant Fourier modes
|
|
l = min(self.modes1, N // 2 + 1)
|
|
# l = N//2+1
|
|
out_ft = torch.zeros(B, c * k, N // 2 + 1, device=x.device, dtype=torch.cfloat)
|
|
out_ft[:, :, :l] = self.compl_mul1d(x_fft[:, :, :l], self.weights1[:, :, :l])
|
|
x = torch.fft.irfft(out_ft, n=N)
|
|
x = x.permute(0, 2, 1).view(B, N, c, k)
|
|
return x
|
|
|
|
|
|
# ##
|
|
class MWT_CZ1d(nn.Module):
|
|
def __init__(
|
|
self, k=3, alpha=64, L=0, c=1, base="legendre", initializer=None, **kwargs
|
|
):
|
|
super(MWT_CZ1d, self).__init__()
|
|
|
|
self.k = k
|
|
self.L = L
|
|
H0, H1, G0, G1, PHI0, PHI1 = get_filter(base, k)
|
|
H0r = H0 @ PHI0
|
|
G0r = G0 @ PHI0
|
|
H1r = H1 @ PHI1
|
|
G1r = G1 @ PHI1
|
|
|
|
H0r[np.abs(H0r) < 1e-8] = 0
|
|
H1r[np.abs(H1r) < 1e-8] = 0
|
|
G0r[np.abs(G0r) < 1e-8] = 0
|
|
G1r[np.abs(G1r) < 1e-8] = 0
|
|
self.max_item = 3
|
|
|
|
self.A = sparseKernelFT1d(k, alpha, c)
|
|
self.B = sparseKernelFT1d(k, alpha, c)
|
|
self.C = sparseKernelFT1d(k, alpha, c)
|
|
|
|
self.T0 = nn.Linear(k, k)
|
|
|
|
self.register_buffer("ec_s", torch.Tensor(np.concatenate((H0.T, H1.T), axis=0)))
|
|
self.register_buffer("ec_d", torch.Tensor(np.concatenate((G0.T, G1.T), axis=0)))
|
|
|
|
self.register_buffer("rc_e", torch.Tensor(np.concatenate((H0r, G0r), axis=0)))
|
|
self.register_buffer("rc_o", torch.Tensor(np.concatenate((H1r, G1r), axis=0)))
|
|
|
|
def forward(self, x):
|
|
B, N, c, k = x.shape # (B, N, k)
|
|
ns = math.floor(np.log2(N))
|
|
nl = pow(2, math.ceil(np.log2(N)))
|
|
extra_x = x[:, 0 : nl - N, :, :]
|
|
x = torch.cat([x, extra_x], 1)
|
|
Ud = torch.jit.annotate(List[Tensor], [])
|
|
Us = torch.jit.annotate(List[Tensor], [])
|
|
# decompose
|
|
for i in range(ns - self.L):
|
|
# print('x shape',x.shape)
|
|
d, x = self.wavelet_transform(x)
|
|
Ud += [self.A(d) + self.B(x)]
|
|
Us += [self.C(d)]
|
|
x = self.T0(x) # coarsest scale transform
|
|
|
|
# reconstruct
|
|
for i in range(ns - 1 - self.L, -1, -1):
|
|
x = x + Us[i]
|
|
x = torch.cat((x, Ud[i]), -1)
|
|
x = self.evenOdd(x)
|
|
x = x[:, :N, :, :]
|
|
|
|
return x
|
|
|
|
def wavelet_transform(self, x):
|
|
xa = torch.cat(
|
|
[
|
|
x[:, ::2, :, :],
|
|
x[:, 1::2, :, :],
|
|
],
|
|
-1,
|
|
)
|
|
d = torch.matmul(xa, self.ec_d)
|
|
s = torch.matmul(xa, self.ec_s)
|
|
return d, s
|
|
|
|
def evenOdd(self, x):
|
|
|
|
B, N, c, ich = x.shape # (B, N, c, k)
|
|
assert ich == 2 * self.k
|
|
x_e = torch.matmul(x, self.rc_e)
|
|
x_o = torch.matmul(x, self.rc_o)
|
|
|
|
x = torch.zeros(B, N * 2, c, self.k, device=x.device)
|
|
x[..., ::2, :, :] = x_e
|
|
x[..., 1::2, :, :] = x_o
|
|
return x
|
|
|
|
|
|
class FullAttention(nn.Module):
|
|
def __init__(
|
|
self,
|
|
mask_flag=True,
|
|
factor=5,
|
|
scale=None,
|
|
attention_dropout=0.1,
|
|
output_attention=False,
|
|
):
|
|
super(FullAttention, self).__init__()
|
|
self.scale = scale
|
|
self.mask_flag = mask_flag
|
|
self.output_attention = output_attention
|
|
self.dropout = nn.Dropout(attention_dropout)
|
|
|
|
def forward(self, queries, keys, values, attn_mask):
|
|
B, L, H, E = queries.shape
|
|
_, S, _, D = values.shape
|
|
scale = self.scale or 1.0 / sqrt(E)
|
|
|
|
scores = torch.einsum("blhe,bshe->bhls", queries, keys)
|
|
|
|
if self.mask_flag:
|
|
if attn_mask is None:
|
|
attn_mask = TriangularCausalMask(B, L, device=queries.device)
|
|
|
|
scores.masked_fill_(attn_mask.mask, -np.inf)
|
|
|
|
A = self.dropout(torch.softmax(scale * scores, dim=-1))
|
|
V = torch.einsum("bhls,bshd->blhd", A, values)
|
|
|
|
if self.output_attention:
|
|
return (V.contiguous(), A)
|
|
else:
|
|
return (V.contiguous(), None)
|
|
|
|
|
|
class FEDformerModel(nn.Module):
|
|
@validated()
|
|
def __init__(
|
|
self,
|
|
freq: str,
|
|
prediction_length: int,
|
|
nhead: int,
|
|
num_encoder_layers: int,
|
|
num_decoder_layers: int,
|
|
dim_feedforward: int = 16, # dimension of fcn
|
|
version: str = "Fourier", # Fourier, Wavelets
|
|
features: str = "M", # options:[M, S, MS]; M:multivariate predict multivariate, 'S':univariate predict univariate, MS:multivariate predict univariate'
|
|
modes: int = 64,
|
|
mode_select: str = "random",
|
|
base: str = "legendre",
|
|
cross_activation: str = "tanh",
|
|
L: int = 3,
|
|
# forecasting task
|
|
context_length: Optional[int] = None, # seq_len : input sequence length
|
|
label_length: Optional[int] = 48, # start token length
|
|
# model argument
|
|
input_size: int = 1, # encoder input size
|
|
# dec_in: int = 7, #decoder input size
|
|
c_out: int = 7, # output size
|
|
moving_avg: Optional[List[int]] = None,
|
|
factor: int = 1,
|
|
scaling: bool = True,
|
|
activation: str = "gelu",
|
|
dropout: float = 0.05,
|
|
embed: str = "timeF", # options:[timeF, fixed, learned]
|
|
output_attention: bool = True, # whether to output attention in encoder
|
|
num_feat_dynamic_real: int = 0,
|
|
num_feat_static_cat: int = 0,
|
|
num_feat_static_real: int = 0,
|
|
cardinality: Optional[List[int]] = None,
|
|
embedding_dimension: Optional[List[int]] = None,
|
|
distr_output: DistributionOutput = StudentTOutput(),
|
|
lags_seq: Optional[List[int]] = None,
|
|
num_parallel_samples: int = 100,
|
|
) -> None:
|
|
super().__init__()
|
|
|
|
self.input_size = input_size
|
|
|
|
self.target_shape = distr_output.event_shape
|
|
self.num_feat_dynamic_real = num_feat_dynamic_real
|
|
self.num_feat_static_cat = num_feat_static_cat
|
|
self.num_feat_static_real = num_feat_static_real
|
|
self.embedding_dimension = (
|
|
embedding_dimension
|
|
if embedding_dimension is not None or cardinality is None
|
|
else [min(50, (cat + 1) // 2) for cat in cardinality]
|
|
)
|
|
self.lags_seq = lags_seq or get_lags_for_frequency(freq_str=freq)
|
|
self.num_parallel_samples = num_parallel_samples
|
|
self.history_length = context_length + max(self.lags_seq)
|
|
self.embedder = FeatureEmbedder(
|
|
cardinalities=cardinality,
|
|
embedding_dims=self.embedding_dimension,
|
|
)
|
|
if scaling:
|
|
self.scaler = MeanScaler(dim=1, keepdim=True)
|
|
else:
|
|
self.scaler = NOPScaler(dim=1, keepdim=True)
|
|
|
|
# total feature size
|
|
d_model = self.input_size * len(self.lags_seq) + self._number_of_features
|
|
|
|
self.context_length = context_length
|
|
self.prediction_length = prediction_length
|
|
self.distr_output = distr_output
|
|
self.param_proj = distr_output.get_args_proj(d_model)
|
|
self.moving_avg = moving_avg
|
|
self.version = version
|
|
self.mode_select = mode_select
|
|
self.modes = modes
|
|
self.label_length = label_length
|
|
self.output_attention = output_attention
|
|
self.dim_feedforward = dim_feedforward
|
|
# Decomp
|
|
kernel_size = self.moving_avg
|
|
if isinstance(kernel_size, list):
|
|
self.decomp = series_decomp_multi(kernel_size)
|
|
else:
|
|
self.decomp = series_decomp(kernel_size)
|
|
|
|
# Embedding
|
|
# The series-wise connection inherently contains the sequential information.
|
|
# Thus, we can discard the position embedding of transformers.
|
|
# self.enc_embedding = DataEmbedding_wo_pos(enc_in, d_model, embed, freq,
|
|
# dropout)
|
|
# self.dec_embedding = DataEmbedding_wo_pos(dec_in, d_model, embed, freq,
|
|
# dropout)
|
|
|
|
if self.version == "Wavelets":
|
|
encoder_self_att = MultiWaveletTransform(ich=d_model, L=L, base=base)
|
|
decoder_self_att = MultiWaveletTransform(ich=d_model, L=L, base=base)
|
|
decoder_cross_att = MultiWaveletCross(
|
|
in_channels=d_model,
|
|
out_channels=d_model,
|
|
seq_len_q=self.context_length // 2 + self.prediction_length,
|
|
seq_len_kv=self.context_length,
|
|
modes=modes,
|
|
ich=d_model,
|
|
base=base,
|
|
activation=cross_activation,
|
|
)
|
|
else:
|
|
encoder_self_att = FourierBlock(
|
|
n_heads=nhead,
|
|
in_channels=d_model,
|
|
out_channels=d_model,
|
|
seq_len=self.context_length,
|
|
modes=modes,
|
|
mode_select_method=mode_select,
|
|
)
|
|
decoder_self_att = FourierBlock(
|
|
n_heads=nhead,
|
|
in_channels=d_model,
|
|
out_channels=d_model,
|
|
seq_len=self.context_length // 2 + self.prediction_length,
|
|
modes=modes,
|
|
mode_select_method=mode_select,
|
|
)
|
|
decoder_cross_att = FourierCrossAttention(
|
|
n_heads=nhead,
|
|
in_channels=d_model,
|
|
out_channels=d_model,
|
|
seq_len_q=self.context_length // 2 + self.prediction_length,
|
|
seq_len_kv=self.context_length,
|
|
modes=modes,
|
|
mode_select_method=mode_select,
|
|
)
|
|
# Encoder
|
|
print("dim_feedforward", self.dim_feedforward)
|
|
enc_modes = int(min(modes, context_length // 2))
|
|
dec_modes = int(min(modes, (context_length // 2 + self.prediction_length) // 2))
|
|
print("enc_modes: {}, dec_modes: {}".format(enc_modes, dec_modes))
|
|
print("encoder_self_att", encoder_self_att)
|
|
self.encoder = Encoder(
|
|
[
|
|
EncoderLayer(
|
|
AutoCorrelationLayer(encoder_self_att, d_model, n_heads=nhead),
|
|
d_model,
|
|
d_ff=self.dim_feedforward,
|
|
moving_avg=moving_avg,
|
|
dropout=dropout,
|
|
activation=activation,
|
|
)
|
|
for l in range(num_encoder_layers)
|
|
],
|
|
norm_layer=my_Layernorm(d_model),
|
|
)
|
|
|
|
# Decoder
|
|
self.decoder = Decoder(
|
|
[
|
|
DecoderLayer(
|
|
AutoCorrelationLayer(decoder_self_att, d_model, n_heads=nhead),
|
|
AutoCorrelationLayer(decoder_cross_att, d_model, n_heads=nhead),
|
|
d_model,
|
|
c_out=c_out,
|
|
d_ff=self.dim_feedforward,
|
|
moving_avg=self.moving_avg,
|
|
dropout=dropout,
|
|
activation=activation,
|
|
)
|
|
for l in range(num_decoder_layers)
|
|
],
|
|
norm_layer=my_Layernorm(d_model),
|
|
projection=nn.Linear(d_model, c_out, bias=True),
|
|
)
|
|
|
|
@property
|
|
def _number_of_features(self) -> int:
|
|
return (
|
|
sum(self.embedding_dimension)
|
|
+ self.num_feat_dynamic_real
|
|
+ self.num_feat_static_real
|
|
+ 1 # the log(scale)
|
|
)
|
|
|
|
@property
|
|
def _past_length(self) -> int:
|
|
return self.context_length + max(self.lags_seq)
|
|
|
|
def get_lagged_subsequences(
|
|
self, sequence: torch.Tensor, subsequences_length: int, shift: int = 0
|
|
) -> torch.Tensor:
|
|
"""
|
|
Returns lagged subsequences of a given sequence.
|
|
Parameters
|
|
----------
|
|
sequence : Tensor
|
|
the sequence from which lagged subsequences should be extracted.
|
|
Shape: (N, T, C).
|
|
subsequences_length : int
|
|
length of the subsequences to be extracted.
|
|
shift: int
|
|
shift the lags by this amount back.
|
|
Returns
|
|
--------
|
|
lagged : Tensor
|
|
a tensor of shape (N, S, C, I), where S = subsequences_length and
|
|
I = len(indices), containing lagged subsequences. Specifically,
|
|
lagged[i, j, :, k] = sequence[i, -indices[k]-S+j, :].
|
|
"""
|
|
sequence_length = sequence.shape[1]
|
|
indices = [lag - shift for lag in self.lags_seq]
|
|
|
|
assert max(indices) + subsequences_length <= sequence_length, (
|
|
f"lags cannot go further than history length, found lag {max(indices)} "
|
|
f"while history length is only {sequence_length}"
|
|
)
|
|
|
|
lagged_values = []
|
|
for lag_index in indices:
|
|
begin_index = -lag_index - subsequences_length
|
|
end_index = -lag_index if lag_index > 0 else None
|
|
lagged_values.append(sequence[:, begin_index:end_index, ...])
|
|
return torch.stack(lagged_values, dim=-1)
|
|
|
|
def _check_shapes(
|
|
self,
|
|
prior_input: torch.Tensor,
|
|
inputs: torch.Tensor,
|
|
features: Optional[torch.Tensor],
|
|
) -> None:
|
|
assert len(prior_input.shape) == len(inputs.shape)
|
|
assert (
|
|
len(prior_input.shape) == 2 and self.input_size == 1
|
|
) or prior_input.shape[2] == self.input_size
|
|
assert (len(inputs.shape) == 2 and self.input_size == 1) or inputs.shape[
|
|
-1
|
|
] == self.input_size
|
|
assert (
|
|
features is None or features.shape[2] == self._number_of_features
|
|
), f"{features.shape[2]}, expected {self._number_of_features}"
|
|
|
|
def create_network_inputs(
|
|
self,
|
|
feat_static_cat: torch.Tensor,
|
|
feat_static_real: torch.Tensor,
|
|
past_time_feat: torch.Tensor,
|
|
past_target: torch.Tensor,
|
|
past_observed_values: torch.Tensor,
|
|
future_time_feat: Optional[torch.Tensor] = None,
|
|
future_target: Optional[torch.Tensor] = None,
|
|
):
|
|
# time feature
|
|
time_feat = (
|
|
torch.cat(
|
|
(
|
|
past_time_feat[:, self._past_length - self.context_length :, ...],
|
|
future_time_feat,
|
|
),
|
|
dim=1,
|
|
)
|
|
if future_target is not None
|
|
else past_time_feat[:, self._past_length - self.context_length :, ...]
|
|
)
|
|
|
|
# target
|
|
context = past_target[:, -self.context_length :]
|
|
observed_context = past_observed_values[:, -self.context_length :]
|
|
_, scale = self.scaler(context, observed_context)
|
|
|
|
inputs = (
|
|
torch.cat((past_target, future_target), dim=1) / scale
|
|
if future_target is not None
|
|
else past_target / scale
|
|
)
|
|
|
|
inputs_length = (
|
|
self._past_length + self.prediction_length
|
|
if future_target is not None
|
|
else self._past_length
|
|
)
|
|
assert inputs.shape[1] == inputs_length
|
|
|
|
subsequences_length = (
|
|
self.context_length + self.prediction_length
|
|
if future_target is not None
|
|
else self.context_length
|
|
)
|
|
|
|
# embeddings
|
|
embedded_cat = self.embedder(feat_static_cat)
|
|
static_feat = torch.cat(
|
|
(embedded_cat, feat_static_real, scale.log()),
|
|
dim=1,
|
|
)
|
|
expanded_static_feat = static_feat.unsqueeze(1).expand(
|
|
-1, time_feat.shape[1], -1
|
|
)
|
|
|
|
features = torch.cat((expanded_static_feat, time_feat), dim=-1)
|
|
|
|
# self._check_shapes(prior_input, inputs, features)
|
|
|
|
# sequence = torch.cat((prior_input, inputs), dim=1)
|
|
lagged_sequence = self.get_lagged_subsequences(
|
|
sequence=inputs,
|
|
subsequences_length=subsequences_length,
|
|
)
|
|
|
|
lags_shape = lagged_sequence.shape
|
|
reshaped_lagged_sequence = lagged_sequence.reshape(
|
|
lags_shape[0], lags_shape[1], -1
|
|
)
|
|
|
|
transformer_inputs = torch.cat((reshaped_lagged_sequence, features), dim=-1)
|
|
|
|
return transformer_inputs, scale, static_feat
|
|
|
|
def output_params(self, transformer_inputs):
|
|
enc_input = transformer_inputs[:, : self.context_length, ...]
|
|
dec_input = transformer_inputs[:, self.context_length :, ...]
|
|
print('enc_input',enc_input.shape)
|
|
enc_out, _ = self.encoder(enc_input)
|
|
dec_output = self.decoder(dec_input, enc_out)
|
|
|
|
return self.param_proj(dec_output)
|
|
|
|
@torch.jit.ignore
|
|
def output_distribution(
|
|
self, params, scale=None, trailing_n=None
|
|
) -> torch.distributions.Distribution:
|
|
sliced_params = params
|
|
if trailing_n is not None:
|
|
sliced_params = [p[:, -trailing_n:] for p in params]
|
|
return self.distr_output.distribution(sliced_params, scale=scale)
|
|
|
|
# for prediction
|
|
def forward(
|
|
self,
|
|
feat_static_cat: torch.Tensor,
|
|
feat_static_real: torch.Tensor,
|
|
past_time_feat: torch.Tensor,
|
|
past_target: torch.Tensor,
|
|
past_observed_values: torch.Tensor,
|
|
future_time_feat: torch.Tensor,
|
|
num_parallel_samples: Optional[int] = None,
|
|
) -> torch.Tensor:
|
|
|
|
if num_parallel_samples is None:
|
|
num_parallel_samples = self.num_parallel_samples
|
|
|
|
encoder_inputs, scale, static_feat = self.create_network_inputs(
|
|
feat_static_cat,
|
|
feat_static_real,
|
|
past_time_feat,
|
|
past_target,
|
|
past_observed_values,
|
|
)
|
|
|
|
enc_out, _ = self.encoder(encoder_inputs)
|
|
|
|
repeated_scale = scale.repeat_interleave(
|
|
repeats=self.num_parallel_samples, dim=0
|
|
)
|
|
|
|
repeated_past_target = (
|
|
past_target.repeat_interleave(repeats=self.num_parallel_samples, dim=0)
|
|
/ repeated_scale
|
|
)
|
|
|
|
expanded_static_feat = static_feat.unsqueeze(1).expand(
|
|
-1, future_time_feat.shape[1], -1
|
|
)
|
|
features = torch.cat((expanded_static_feat, future_time_feat), dim=-1)
|
|
repeated_features = features.repeat_interleave(
|
|
repeats=self.num_parallel_samples, dim=0
|
|
)
|
|
|
|
repeated_enc_out = enc_out.repeat_interleave(
|
|
repeats=self.num_parallel_samples, dim=0
|
|
)
|
|
|
|
future_samples = []
|
|
|
|
# greedy decoding
|
|
for k in range(self.prediction_length):
|
|
# self._check_shapes(repeated_past_target, next_sample, next_features)
|
|
# sequence = torch.cat((repeated_past_target, next_sample), dim=1)
|
|
|
|
lagged_sequence = self.get_lagged_subsequences(
|
|
sequence=repeated_past_target,
|
|
subsequences_length=1 + k,
|
|
shift=1,
|
|
)
|
|
|
|
lags_shape = lagged_sequence.shape
|
|
reshaped_lagged_sequence = lagged_sequence.reshape(
|
|
lags_shape[0], lags_shape[1], -1
|
|
)
|
|
|
|
decoder_input = torch.cat(
|
|
(reshaped_lagged_sequence, repeated_features[:, : k + 1]), dim=-1
|
|
)
|
|
|
|
output = self.decoder(decoder_input, repeated_enc_out)
|
|
|
|
params = self.param_proj(output[:, -1:])
|
|
distr = self.output_distribution(params, scale=repeated_scale)
|
|
next_sample = distr.sample()
|
|
|
|
repeated_past_target = torch.cat(
|
|
(repeated_past_target, next_sample / repeated_scale), dim=1
|
|
)
|
|
future_samples.append(next_sample)
|
|
|
|
concat_future_samples = torch.cat(future_samples, dim=1)
|
|
return concat_future_samples.reshape(
|
|
(-1, self.num_parallel_samples, self.prediction_length) + self.target_shape,
|
|
) |