Files
2022-11-06 01:07:30 -05:00

2066 lines
71 KiB
Python

from math import sqrt
import math
from typing import List, Optional
import numpy as np
import torch
import torch.nn as nn
import torch.nn.functional as F
from gluonts.core.component import validated
from gluonts.time_feature import get_lags_for_frequency
from gluonts.torch.distributions import DistributionOutput, StudentTOutput
from gluonts.torch.modules.feature import FeatureEmbedder
from gluonts.torch.modules.scaler import MeanScaler, NOPScaler
import random
# from matplotlib import pyplot as plt
from torch.nn.functional import interpolate
from scipy.special import eval_legendre
from sympy import Poly, legendre, Symbol, chebyshevt
def legendreDer(k, x):
def _legendre(k, x):
return (2*k+1) * eval_legendre(k, x)
out = 0
for i in np.arange(k-1,-1,-2):
out += _legendre(i, x)
return out
def phi_(phi_c, x, lb = 0, ub = 1):
mask = np.logical_or(x<lb, x>ub) * 1.0
return np.polynomial.polynomial.Polynomial(phi_c)(x) * (1-mask)
def get_phi_psi(k, base):
x = Symbol('x')
phi_coeff = np.zeros((k,k))
phi_2x_coeff = np.zeros((k,k))
if base == 'legendre':
for ki in range(k):
coeff_ = Poly(legendre(ki, 2*x-1), x).all_coeffs()
phi_coeff[ki,:ki+1] = np.flip(np.sqrt(2*ki+1) * np.array(coeff_).astype(np.float64))
coeff_ = Poly(legendre(ki, 4*x-1), x).all_coeffs()
phi_2x_coeff[ki,:ki+1] = np.flip(np.sqrt(2) * np.sqrt(2*ki+1) * np.array(coeff_).astype(np.float64))
psi1_coeff = np.zeros((k, k))
psi2_coeff = np.zeros((k, k))
for ki in range(k):
psi1_coeff[ki,:] = phi_2x_coeff[ki,:]
for i in range(k):
a = phi_2x_coeff[ki,:ki+1]
b = phi_coeff[i, :i+1]
prod_ = np.convolve(a, b)
prod_[np.abs(prod_)<1e-8] = 0
proj_ = (prod_ * 1/(np.arange(len(prod_))+1) * np.power(0.5, 1+np.arange(len(prod_)))).sum()
psi1_coeff[ki,:] -= proj_ * phi_coeff[i,:]
psi2_coeff[ki,:] -= proj_ * phi_coeff[i,:]
for j in range(ki):
a = phi_2x_coeff[ki,:ki+1]
b = psi1_coeff[j, :]
prod_ = np.convolve(a, b)
prod_[np.abs(prod_)<1e-8] = 0
proj_ = (prod_ * 1/(np.arange(len(prod_))+1) * np.power(0.5, 1+np.arange(len(prod_)))).sum()
psi1_coeff[ki,:] -= proj_ * psi1_coeff[j,:]
psi2_coeff[ki,:] -= proj_ * psi2_coeff[j,:]
a = psi1_coeff[ki,:]
prod_ = np.convolve(a, a)
prod_[np.abs(prod_)<1e-8] = 0
norm1 = (prod_ * 1/(np.arange(len(prod_))+1) * np.power(0.5, 1+np.arange(len(prod_)))).sum()
a = psi2_coeff[ki,:]
prod_ = np.convolve(a, a)
prod_[np.abs(prod_)<1e-8] = 0
norm2 = (prod_ * 1/(np.arange(len(prod_))+1) * (1-np.power(0.5, 1+np.arange(len(prod_))))).sum()
norm_ = np.sqrt(norm1 + norm2)
psi1_coeff[ki,:] /= norm_
psi2_coeff[ki,:] /= norm_
psi1_coeff[np.abs(psi1_coeff)<1e-8] = 0
psi2_coeff[np.abs(psi2_coeff)<1e-8] = 0
phi = [np.poly1d(np.flip(phi_coeff[i,:])) for i in range(k)]
psi1 = [np.poly1d(np.flip(psi1_coeff[i,:])) for i in range(k)]
psi2 = [np.poly1d(np.flip(psi2_coeff[i,:])) for i in range(k)]
elif base == 'chebyshev':
for ki in range(k):
if ki == 0:
phi_coeff[ki,:ki+1] = np.sqrt(2/np.pi)
phi_2x_coeff[ki,:ki+1] = np.sqrt(2/np.pi) * np.sqrt(2)
else:
coeff_ = Poly(chebyshevt(ki, 2*x-1), x).all_coeffs()
phi_coeff[ki,:ki+1] = np.flip(2/np.sqrt(np.pi) * np.array(coeff_).astype(np.float64))
coeff_ = Poly(chebyshevt(ki, 4*x-1), x).all_coeffs()
phi_2x_coeff[ki,:ki+1] = np.flip(np.sqrt(2) * 2 / np.sqrt(np.pi) * np.array(coeff_).astype(np.float64))
phi = [partial(phi_, phi_coeff[i,:]) for i in range(k)]
x = Symbol('x')
kUse = 2*k
roots = Poly(chebyshevt(kUse, 2*x-1)).all_roots()
x_m = np.array([rt.evalf(20) for rt in roots]).astype(np.float64)
# x_m[x_m==0.5] = 0.5 + 1e-8 # add small noise to avoid the case of 0.5 belonging to both phi(2x) and phi(2x-1)
# not needed for our purpose here, we use even k always to avoid
wm = np.pi / kUse / 2
psi1_coeff = np.zeros((k, k))
psi2_coeff = np.zeros((k, k))
psi1 = [[] for _ in range(k)]
psi2 = [[] for _ in range(k)]
for ki in range(k):
psi1_coeff[ki,:] = phi_2x_coeff[ki,:]
for i in range(k):
proj_ = (wm * phi[i](x_m) * np.sqrt(2)* phi[ki](2*x_m)).sum()
psi1_coeff[ki,:] -= proj_ * phi_coeff[i,:]
psi2_coeff[ki,:] -= proj_ * phi_coeff[i,:]
for j in range(ki):
proj_ = (wm * psi1[j](x_m) * np.sqrt(2) * phi[ki](2*x_m)).sum()
psi1_coeff[ki,:] -= proj_ * psi1_coeff[j,:]
psi2_coeff[ki,:] -= proj_ * psi2_coeff[j,:]
psi1[ki] = partial(phi_, psi1_coeff[ki,:], lb = 0, ub = 0.5)
psi2[ki] = partial(phi_, psi2_coeff[ki,:], lb = 0.5, ub = 1)
norm1 = (wm * psi1[ki](x_m) * psi1[ki](x_m)).sum()
norm2 = (wm * psi2[ki](x_m) * psi2[ki](x_m)).sum()
norm_ = np.sqrt(norm1 + norm2)
psi1_coeff[ki,:] /= norm_
psi2_coeff[ki,:] /= norm_
psi1_coeff[np.abs(psi1_coeff)<1e-8] = 0
psi2_coeff[np.abs(psi2_coeff)<1e-8] = 0
psi1[ki] = partial(phi_, psi1_coeff[ki,:], lb = 0, ub = 0.5+1e-16)
psi2[ki] = partial(phi_, psi2_coeff[ki,:], lb = 0.5+1e-16, ub = 1)
return phi, psi1, psi2
def get_filter(base, k):
def psi(psi1, psi2, i, inp):
mask = (inp<=0.5) * 1.0
return psi1[i](inp) * mask + psi2[i](inp) * (1-mask)
if base not in ['legendre', 'chebyshev']:
raise Exception('Base not supported')
x = Symbol('x')
H0 = np.zeros((k,k))
H1 = np.zeros((k,k))
G0 = np.zeros((k,k))
G1 = np.zeros((k,k))
PHI0 = np.zeros((k,k))
PHI1 = np.zeros((k,k))
phi, psi1, psi2 = get_phi_psi(k, base)
if base == 'legendre':
roots = Poly(legendre(k, 2*x-1)).all_roots()
x_m = np.array([rt.evalf(20) for rt in roots]).astype(np.float64)
wm = 1/k/legendreDer(k,2*x_m-1)/eval_legendre(k-1,2*x_m-1)
for ki in range(k):
for kpi in range(k):
H0[ki, kpi] = 1/np.sqrt(2) * (wm * phi[ki](x_m/2) * phi[kpi](x_m)).sum()
G0[ki, kpi] = 1/np.sqrt(2) * (wm * psi(psi1, psi2, ki, x_m/2) * phi[kpi](x_m)).sum()
H1[ki, kpi] = 1/np.sqrt(2) * (wm * phi[ki]((x_m+1)/2) * phi[kpi](x_m)).sum()
G1[ki, kpi] = 1/np.sqrt(2) * (wm * psi(psi1, psi2, ki, (x_m+1)/2) * phi[kpi](x_m)).sum()
PHI0 = np.eye(k)
PHI1 = np.eye(k)
elif base == 'chebyshev':
x = Symbol('x')
kUse = 2*k
roots = Poly(chebyshevt(kUse, 2*x-1)).all_roots()
x_m = np.array([rt.evalf(20) for rt in roots]).astype(np.float64)
# x_m[x_m==0.5] = 0.5 + 1e-8 # add small noise to avoid the case of 0.5 belonging to both phi(2x) and phi(2x-1)
# not needed for our purpose here, we use even k always to avoid
wm = np.pi / kUse / 2
for ki in range(k):
for kpi in range(k):
H0[ki, kpi] = 1/np.sqrt(2) * (wm * phi[ki](x_m/2) * phi[kpi](x_m)).sum()
G0[ki, kpi] = 1/np.sqrt(2) * (wm * psi(psi1, psi2, ki, x_m/2) * phi[kpi](x_m)).sum()
H1[ki, kpi] = 1/np.sqrt(2) * (wm * phi[ki]((x_m+1)/2) * phi[kpi](x_m)).sum()
G1[ki, kpi] = 1/np.sqrt(2) * (wm * psi(psi1, psi2, ki, (x_m+1)/2) * phi[kpi](x_m)).sum()
PHI0[ki, kpi] = (wm * phi[ki](2*x_m) * phi[kpi](2*x_m)).sum() * 2
PHI1[ki, kpi] = (wm * phi[ki](2*x_m-1) * phi[kpi](2*x_m-1)).sum() * 2
PHI0[np.abs(PHI0)<1e-8] = 0
PHI1[np.abs(PHI1)<1e-8] = 0
H0[np.abs(H0)<1e-8] = 0
H1[np.abs(H1)<1e-8] = 0
G0[np.abs(G0)<1e-8] = 0
G1[np.abs(G1)<1e-8] = 0
return H0, H1, G0, G1, PHI0, PHI1
class TriangularCausalMask:
def __init__(self, B, L, device="cpu"):
mask_shape = [B, 1, L, L]
with torch.no_grad():
self._mask = torch.triu(
torch.ones(mask_shape, dtype=torch.bool), diagonal=1
).to(device)
@property
def mask(self):
return self._mask
class ProbMask:
def __init__(self, B, H, L, index, scores, device="cpu"):
_mask = torch.ones(L, scores.shape[-1], dtype=torch.bool).to(device).triu(1)
_mask_ex = _mask[None, None, :].expand(B, H, L, scores.shape[-1])
indicator = _mask_ex[
torch.arange(B)[:, None, None], torch.arange(H)[None, :, None], index, :
].to(device)
self._mask = indicator.view(scores.shape).to(device)
@property
def mask(self):
return self._mask
class LocalMask:
def __init__(self, B, L, S, device="cpu"):
mask_shape = [B, 1, L, S]
with torch.no_grad():
self.len = math.ceil(np.log2(L))
self._mask1 = torch.triu(
torch.ones(mask_shape, dtype=torch.bool), diagonal=1
).to(device)
self._mask2 = ~torch.triu(
torch.ones(mask_shape, dtype=torch.bool), diagonal=-self.len
).to(device)
self._mask = self._mask1 + self._mask2
@property
def mask(self):
return self._mask
def adjust_learning_rate(optimizer, epoch, args):
# lr = args.learning_rate * (0.2 ** (epoch // 2))
if args.lradj == "type1":
lr_adjust = {epoch: args.learning_rate * (0.5 ** ((epoch - 1) // 1))}
elif args.lradj == "type2":
lr_adjust = {2: 5e-5, 4: 1e-5, 6: 5e-6, 8: 1e-6, 10: 5e-7, 15: 1e-7, 20: 5e-8}
elif args.lradj == "type3":
lr_adjust = {epoch: args.learning_rate}
elif args.lradj == "type4":
lr_adjust = {epoch: args.learning_rate * (0.9 ** ((epoch - 1) // 1))}
if epoch in lr_adjust.keys():
lr = lr_adjust[epoch]
for param_group in optimizer.param_groups:
param_group["lr"] = lr
print("Updating learning rate to {}".format(lr))
class EarlyStopping:
def __init__(self, patience=7, verbose=False, delta=0):
self.patience = patience
self.verbose = verbose
self.counter = 0
self.best_score = None
self.early_stop = False
self.val_loss_min = np.Inf
self.delta = delta
def __call__(self, val_loss, model, path):
score = -val_loss
if self.best_score is None:
self.best_score = score
self.save_checkpoint(val_loss, model, path)
elif score < self.best_score + self.delta:
self.counter += 1
print(f"EarlyStopping counter: {self.counter} out of {self.patience}")
if self.counter >= self.patience:
self.early_stop = True
else:
self.best_score = score
self.save_checkpoint(val_loss, model, path)
self.counter = 0
def save_checkpoint(self, val_loss, model, path):
if self.verbose:
print(
f"Validation loss decreased ({self.val_loss_min:.6f} --> {val_loss:.6f}). Saving model ..."
)
torch.save(model.state_dict(), path + "/" + "checkpoint.pth")
self.val_loss_min = val_loss
class dotdict(dict):
"""dot.notation access to dictionary attributes"""
__getattr__ = dict.get
__setattr__ = dict.__setitem__
__delattr__ = dict.__delitem__
class StandardScaler:
def __init__(self, mean, std):
self.mean = mean
self.std = std
def transform(self, data):
return (data - self.mean) / self.std
def inverse_transform(self, data):
return (data * self.std) + self.mean
def visual(true, preds=None, name="./pic/test.pdf"):
"""
Results visualization
"""
plt.figure()
plt.plot(true, label="GroundTruth", linewidth=2)
if preds is not None:
plt.plot(preds, label="Prediction", linewidth=2)
plt.legend()
plt.savefig(name, bbox_inches="tight")
def decor_time(func):
def func2(*args, **kw):
now = time.time()
y = func(*args, **kw)
t = time.time() - now
print("call <{}>, time={}".format(func.__name__, t))
return y
return func2
class AutoCorrelation(nn.Module):
"""
AutoCorrelation Mechanism with the following two phases:
(1) period-based dependencies discovery
(2) time delay aggregation
This block can replace the self-attention family mechanism seamlessly.
"""
def __init__(
self,
mask_flag=True,
factor=1,
scale=None,
attention_dropout=0.1,
output_attention=False,
wavelet=False,
):
super(AutoCorrelation, self).__init__()
print("Autocorrelation used !")
self.factor = factor
self.scale = scale
self.mask_flag = mask_flag
self.output_attention = output_attention
self.dropout = nn.Dropout(attention_dropout)
self.agg = None
self.use_wavelet = wavelet
# @decor_time
def time_delay_agg_training(self, values, corr):
"""
SpeedUp version of Autocorrelation (a batch-normalization style design)
This is for the training phase.
"""
head = values.shape[1]
channel = values.shape[2]
length = values.shape[3]
# find top k
top_k = int(self.factor * math.log(length))
mean_value = torch.mean(torch.mean(corr, dim=1), dim=1)
index = torch.topk(torch.mean(mean_value, dim=0), top_k, dim=-1)[1]
weights = torch.stack([mean_value[:, index[i]] for i in range(top_k)], dim=-1)
# update corr
tmp_corr = torch.softmax(weights, dim=-1)
# aggregation
tmp_values = values
delays_agg = torch.zeros_like(values).float()
for i in range(top_k):
pattern = torch.roll(tmp_values, -int(index[i]), -1)
delays_agg = delays_agg + pattern * (
tmp_corr[:, i]
.unsqueeze(1)
.unsqueeze(1)
.unsqueeze(1)
.repeat(1, head, channel, length)
)
return delays_agg # size=[B, H, d, S]
def time_delay_agg_inference(self, values, corr):
"""
SpeedUp version of Autocorrelation (a batch-normalization style design)
This is for the inference phase.
"""
batch = values.shape[0]
head = values.shape[1]
channel = values.shape[2]
length = values.shape[3]
# index init
init_index = (
torch.arange(length)
.unsqueeze(0)
.unsqueeze(0)
.unsqueeze(0)
.repeat(batch, head, channel, 1)
.cuda()
)
# find top k
top_k = int(self.factor * math.log(length))
mean_value = torch.mean(torch.mean(corr, dim=1), dim=1)
weights = torch.topk(mean_value, top_k, dim=-1)[0]
delay = torch.topk(mean_value, top_k, dim=-1)[1]
# update corr
tmp_corr = torch.softmax(weights, dim=-1)
# aggregation
tmp_values = values.repeat(1, 1, 1, 2)
delays_agg = torch.zeros_like(values).float()
for i in range(top_k):
tmp_delay = init_index + delay[:, i].unsqueeze(1).unsqueeze(1).unsqueeze(
1
).repeat(1, head, channel, length)
pattern = torch.gather(tmp_values, dim=-1, index=tmp_delay)
delays_agg = delays_agg + pattern * (
tmp_corr[:, i]
.unsqueeze(1)
.unsqueeze(1)
.unsqueeze(1)
.repeat(1, head, channel, length)
)
return delays_agg
def time_delay_agg_full(self, values, corr):
"""
Standard version of Autocorrelation
"""
batch = values.shape[0]
head = values.shape[1]
channel = values.shape[2]
length = values.shape[3]
# index init
init_index = (
torch.arange(length)
.unsqueeze(0)
.unsqueeze(0)
.unsqueeze(0)
.repeat(batch, head, channel, 1)
.cuda()
)
# find top k
top_k = int(self.factor * math.log(length))
weights = torch.topk(corr, top_k, dim=-1)[0]
delay = torch.topk(corr, top_k, dim=-1)[1]
# update corr
tmp_corr = torch.softmax(weights, dim=-1)
# aggregation
tmp_values = values.repeat(1, 1, 1, 2)
delays_agg = torch.zeros_like(values).float()
for i in range(top_k):
tmp_delay = init_index + delay[..., i].unsqueeze(-1)
pattern = torch.gather(tmp_values, dim=-1, index=tmp_delay)
delays_agg = delays_agg + pattern * (tmp_corr[..., i].unsqueeze(-1))
return delays_agg
def forward(self, queries, keys, values, attn_mask):
B, L, H, E = queries.shape
_, S, _, D = values.shape
if L > S:
zeros = torch.zeros_like(queries[:, : (L - S), :]).float()
values = torch.cat([values, zeros], dim=1)
keys = torch.cat([keys, zeros], dim=1)
else:
values = values[:, :L, :, :]
keys = keys[:, :L, :, :]
# period-based dependencies
if self.use_wavelet != 2:
if self.use_wavelet == 1:
j_list = self.j_list
queries = queries.reshape([B, L, -1])
keys = keys.reshape([B, L, -1])
Ql, Qh_list = self.dwt1d(queries.transpose(1, 2)) # [B, H*D, L]
Kl, Kh_list = self.dwt1d(keys.transpose(1, 2))
qs = [queries.transpose(1, 2)] + Qh_list + [Ql] # [B, H*D, L]
ks = [keys.transpose(1, 2)] + Kh_list + [Kl]
q_list = []
k_list = []
for q, k, j in zip(qs, ks, j_list):
q_list += [interpolate(q, scale_factor=j, mode="linear")[:, :, -L:]]
k_list += [interpolate(k, scale_factor=j, mode="linear")[:, :, -L:]]
queries = (
torch.stack([i.reshape([B, H, E, L]) for i in q_list], dim=3)
.reshape([B, H, -1, L])
.permute(0, 3, 1, 2)
)
keys = (
torch.stack([i.reshape([B, H, E, L]) for i in k_list], dim=3)
.reshape([B, H, -1, L])
.permute(0, 3, 1, 2)
)
else:
pass
q_fft = torch.fft.rfft(
queries.permute(0, 2, 3, 1).contiguous(), dim=-1
) # size=[B, H, E, L]
k_fft = torch.fft.rfft(keys.permute(0, 2, 3, 1).contiguous(), dim=-1)
res = q_fft * torch.conj(k_fft)
corr = torch.fft.irfft(res, dim=-1) # size=[B, H, E, L]
# time delay agg
if self.training:
V = self.time_delay_agg_training(
values.permute(0, 2, 3, 1).contiguous(), corr
).permute(
0, 3, 1, 2
) # [B, L, H, E], [B, H, E, L] -> [B, L, H, E]
else:
V = self.time_delay_agg_inference(
values.permute(0, 2, 3, 1).contiguous(), corr
).permute(0, 3, 1, 2)
else:
V_list = []
queries = queries.reshape([B, L, -1])
keys = keys.reshape([B, L, -1])
values = values.reshape([B, L, -1])
Ql, Qh_list = self.dwt1d(queries.transpose(1, 2)) # [B, H*D, L]
Kl, Kh_list = self.dwt1d(keys.transpose(1, 2))
Vl, Vh_list = self.dwt1d(values.transpose(1, 2))
qs = Qh_list + [Ql] # [B, H*D, L]
ks = Kh_list + [Kl]
vs = Vh_list + [Vl]
for q, k, v in zip(qs, ks, vs):
q = q.reshape([B, H, E, -1])
k = k.reshape([B, H, E, -1])
v = v.reshape([B, H, E, -1]).permute(0, 3, 1, 2)
q_fft = torch.fft.rfft(q.contiguous(), dim=-1)
k_fft = torch.fft.rfft(k.contiguous(), dim=-1)
res = q_fft * torch.conj(k_fft)
corr = torch.fft.irfft(res, dim=-1) # [B, H, E, L]
if self.training:
V = self.time_delay_agg_training(
v.permute(0, 2, 3, 1).contiguous(), corr
).permute(0, 3, 1, 2)
else:
V = self.time_delay_agg_inference(
v.permute(0, 2, 3, 1).contiguous(), corr
).permute(0, 3, 1, 2)
V_list += [V]
Vl = V_list[-1].reshape([B, -1, H * E]).transpose(1, 2)
Vh_list = [i.reshape([B, -1, H * E]).transpose(1, 2) for i in V_list[:-1]]
V = self.dwt1div((Vl, Vh_list)).reshape([B, H, E, -1]).permute(0, 3, 1, 2)
# corr = self.dwt1div((V_list[-1], V_list[:-1]))
if self.output_attention:
return (V.contiguous(), corr.permute(0, 3, 1, 2)) # size = [B, L, H, E]
else:
return (V.contiguous(), None)
class AutoCorrelationLayer(nn.Module):
def __init__(self, correlation, d_model, n_heads, d_keys=None, d_values=None):
super(AutoCorrelationLayer, self).__init__()
d_keys = d_keys or (d_model // n_heads)
d_values = d_values or (d_model // n_heads)
self.inner_correlation = correlation
self.query_projection = nn.Linear(d_model, d_keys * n_heads)
self.key_projection = nn.Linear(d_model, d_keys * n_heads)
self.value_projection = nn.Linear(d_model, d_values * n_heads)
self.out_projection = nn.Linear(d_values * n_heads, d_model)
self.n_heads = n_heads
def forward(self, queries, keys, values, attn_mask):
B, L, _ = queries.shape
_, S, _ = keys.shape
H = self.n_heads
print(queries.size())
print("query proj", self.query_projection(queries).size())
queries = self.query_projection(queries).view(B, L, H, -1)
keys = self.key_projection(keys).view(B, S, H, -1)
values = self.value_projection(values).view(B, S, H, -1)
out, attn = self.inner_correlation(queries, keys, values, attn_mask)
out = out.view(B, L, -1)
return self.out_projection(out), attn
class my_Layernorm(nn.Module):
"""
Special designed layernorm for the seasonal part
"""
def __init__(self, channels):
super(my_Layernorm, self).__init__()
self.layernorm = nn.LayerNorm(channels)
def forward(self, x):
x_hat = self.layernorm(x)
bias = torch.mean(x_hat, dim=1).unsqueeze(1).repeat(1, x.shape[1], 1)
return x_hat - bias
class moving_avg(nn.Module):
"""
Moving average block to highlight the trend of time series
"""
def __init__(self, kernel_size, stride):
super(moving_avg, self).__init__()
self.kernel_size = kernel_size
self.avg = nn.AvgPool1d(kernel_size=kernel_size, stride=stride, padding=0)
def forward(self, x):
# padding on the both ends of time series
front = x[:, 0:1, :].repeat(
1, self.kernel_size - 1 - math.floor((self.kernel_size - 1) // 2), 1
)
end = x[:, -1:, :].repeat(1, math.floor((self.kernel_size - 1) // 2), 1)
x = torch.cat([front, x, end], dim=1)
x = self.avg(x.permute(0, 2, 1))
x = x.permute(0, 2, 1)
return x
class series_decomp(nn.Module):
"""
Series decomposition block
"""
def __init__(self, kernel_size):
super(series_decomp, self).__init__()
self.moving_avg = moving_avg(kernel_size, stride=1)
def forward(self, x):
moving_mean = self.moving_avg(x)
res = x - moving_mean
return res, moving_mean
class series_decomp_multi(nn.Module):
"""
Series decomposition block
"""
def __init__(self, kernel_size):
super(series_decomp_multi, self).__init__()
self.moving_avg = [moving_avg(kernel, stride=1) for kernel in kernel_size]
self.layer = torch.nn.Linear(1, len(kernel_size))
def forward(self, x):
moving_mean = []
for func in self.moving_avg:
moving_avg = func(x)
moving_mean.append(moving_avg.unsqueeze(-1))
moving_mean = torch.cat(moving_mean, dim=-1)
moving_mean = torch.sum(
moving_mean * nn.Softmax(-1)(self.layer(x.unsqueeze(-1))), dim=-1
)
res = x - moving_mean
return res, moving_mean
class FourierDecomp(nn.Module):
def __init__(self):
super(FourierDecomp, self).__init__()
pass
def forward(self, x):
x_ft = torch.fft.rfft(x, dim=-1)
class EncoderLayer(nn.Module):
"""
Autoformer encoder layer with the progressive decomposition architecture
"""
def __init__(
self,
attention,
d_model,
d_ff=None,
moving_avg=25,
dropout=0.1,
activation="relu",
):
super(EncoderLayer, self).__init__()
d_ff = d_ff or 4 * d_model
self.attention = attention
self.conv1 = nn.Conv1d(
in_channels=d_model, out_channels=d_ff, kernel_size=1, bias=False
)
self.conv2 = nn.Conv1d(
in_channels=d_ff, out_channels=d_model, kernel_size=1, bias=False
)
if isinstance(moving_avg, list):
self.decomp1 = series_decomp_multi(moving_avg)
self.decomp2 = series_decomp_multi(moving_avg)
else:
self.decomp1 = series_decomp(moving_avg)
self.decomp2 = series_decomp(moving_avg)
self.dropout = nn.Dropout(dropout)
self.activation = F.relu if activation == "relu" else F.gelu
def forward(self, x, attn_mask=None):
new_x, attn = self.attention(x, x, x, attn_mask=attn_mask)
x = x + self.dropout(new_x)
x, _ = self.decomp1(x)
y = x
y = self.dropout(self.activation(self.conv1(y.transpose(-1, 1))))
y = self.dropout(self.conv2(y).transpose(-1, 1))
res, _ = self.decomp2(x + y)
return res, attn
class Encoder(nn.Module):
"""
Autoformer encoder
"""
def __init__(self, attn_layers, conv_layers=None, norm_layer=None):
super(Encoder, self).__init__()
self.attn_layers = nn.ModuleList(attn_layers)
self.conv_layers = (
nn.ModuleList(conv_layers) if conv_layers is not None else None
)
self.norm = norm_layer
def forward(self, x, attn_mask=None):
attns = []
if self.conv_layers is not None:
for attn_layer, conv_layer in zip(self.attn_layers, self.conv_layers):
x, attn = attn_layer(x, attn_mask=attn_mask)
x = conv_layer(x)
attns.append(attn)
x, attn = self.attn_layers[-1](x)
attns.append(attn)
else:
for attn_layer in self.attn_layers:
x, attn = attn_layer(x, attn_mask=attn_mask)
attns.append(attn)
if self.norm is not None:
x = self.norm(x)
return x, attns
class DecoderLayer(nn.Module):
"""
Autoformer decoder layer with the progressive decomposition architecture
"""
def __init__(
self,
self_attention,
cross_attention,
d_model,
c_out,
d_ff=None,
moving_avg=25,
dropout=0.1,
activation="relu",
):
super(DecoderLayer, self).__init__()
d_ff = d_ff or 4 * d_model
self.self_attention = self_attention
self.cross_attention = cross_attention
self.conv1 = nn.Conv1d(
in_channels=d_model, out_channels=d_ff, kernel_size=1, bias=False
)
self.conv2 = nn.Conv1d(
in_channels=d_ff, out_channels=d_model, kernel_size=1, bias=False
)
if isinstance(moving_avg, list):
self.decomp1 = series_decomp_multi(moving_avg)
self.decomp2 = series_decomp_multi(moving_avg)
self.decomp3 = series_decomp_multi(moving_avg)
else:
self.decomp1 = series_decomp(moving_avg)
self.decomp2 = series_decomp(moving_avg)
self.decomp3 = series_decomp(moving_avg)
self.dropout = nn.Dropout(dropout)
self.projection = nn.Conv1d(
in_channels=d_model,
out_channels=c_out,
kernel_size=3,
stride=1,
padding=1,
padding_mode="circular",
bias=False,
)
self.activation = F.relu if activation == "relu" else F.gelu
def forward(self, x, cross, x_mask=None, cross_mask=None):
x = x + self.dropout(self.self_attention(x, x, x, attn_mask=x_mask)[0])
x, trend1 = self.decomp1(x)
x = x + self.dropout(
self.cross_attention(x, cross, cross, attn_mask=cross_mask)[0]
)
x, trend2 = self.decomp2(x)
y = x
y = self.dropout(self.activation(self.conv1(y.transpose(-1, 1))))
y = self.dropout(self.conv2(y).transpose(-1, 1))
x, trend3 = self.decomp3(x + y)
residual_trend = trend1 + trend2 + trend3
residual_trend = self.projection(residual_trend.permute(0, 2, 1)).transpose(
1, 2
)
return x, residual_trend
class Decoder(nn.Module):
"""
Autoformer encoder
"""
def __init__(self, layers, norm_layer=None, projection=None):
super(Decoder, self).__init__()
self.layers = nn.ModuleList(layers)
self.norm = norm_layer
self.projection = projection
def forward(self, x, cross, x_mask=None, cross_mask=None, trend=None):
for layer in self.layers:
x, residual_trend = layer(x, cross, x_mask=x_mask, cross_mask=cross_mask)
trend = trend + residual_trend
if self.norm is not None:
x = self.norm(x)
if self.projection is not None:
x = self.projection(x)
return x, trend
class PositionalEmbedding(nn.Module):
def __init__(self, d_model, max_len=5000):
super(PositionalEmbedding, self).__init__()
# Compute the positional encodings once in log space.
pe = torch.zeros(max_len, d_model).float()
pe.require_grad = False
position = torch.arange(0, max_len).float().unsqueeze(1)
div_term = (
torch.arange(0, d_model, 2).float() * -(math.log(10000.0) / d_model)
).exp()
pe[:, 0::2] = torch.sin(position * div_term)
pe[:, 1::2] = torch.cos(position * div_term)
pe = pe.unsqueeze(0)
self.register_buffer("pe", pe)
def forward(self, x):
return self.pe[:, : x.size(1)]
class TokenEmbedding(nn.Module):
def __init__(self, c_in, d_model):
super(TokenEmbedding, self).__init__()
padding = 1 if torch.__version__ >= "1.5.0" else 2
self.tokenConv = nn.Conv1d(
in_channels=c_in,
out_channels=d_model,
kernel_size=3,
padding=padding,
padding_mode="circular",
bias=False,
)
for m in self.modules():
if isinstance(m, nn.Conv1d):
nn.init.kaiming_normal_(
m.weight, mode="fan_in", nonlinearity="leaky_relu"
)
def forward(self, x):
x = self.tokenConv(x.permute(0, 2, 1)).transpose(1, 2)
return x
class FixedEmbedding(nn.Module):
def __init__(self, c_in, d_model):
super(FixedEmbedding, self).__init__()
w = torch.zeros(c_in, d_model).float()
w.require_grad = False
position = torch.arange(0, c_in).float().unsqueeze(1)
div_term = (
torch.arange(0, d_model, 2).float() * -(math.log(10000.0) / d_model)
).exp()
w[:, 0::2] = torch.sin(position * div_term)
w[:, 1::2] = torch.cos(position * div_term)
self.emb = nn.Embedding(c_in, d_model)
self.emb.weight = nn.Parameter(w, requires_grad=False)
def forward(self, x):
return self.emb(x).detach()
class TemporalEmbedding(nn.Module):
def __init__(self, d_model, embed_type="fixed", freq="h"):
super(TemporalEmbedding, self).__init__()
minute_size = 4
hour_size = 24
weekday_size = 7
day_size = 32
month_size = 13
Embed = FixedEmbedding if embed_type == "fixed" else nn.Embedding
if freq == "t":
self.minute_embed = Embed(minute_size, d_model)
self.hour_embed = Embed(hour_size, d_model)
self.weekday_embed = Embed(weekday_size, d_model)
self.day_embed = Embed(day_size, d_model)
self.month_embed = Embed(month_size, d_model)
def forward(self, x):
x = x.long()
minute_x = (
self.minute_embed(x[:, :, 4]) if hasattr(self, "minute_embed") else 0.0
)
hour_x = self.hour_embed(x[:, :, 3])
weekday_x = self.weekday_embed(x[:, :, 2])
day_x = self.day_embed(x[:, :, 1])
month_x = self.month_embed(x[:, :, 0])
return hour_x + weekday_x + day_x + month_x + minute_x
class TimeFeatureEmbedding(nn.Module):
def __init__(self, d_model, embed_type="timeF", freq="h"):
super(TimeFeatureEmbedding, self).__init__()
freq_map = {"h": 4, "t": 5, "s": 6, "m": 1, "a": 1, "w": 2, "d": 3, "b": 3}
d_inp = freq_map[freq]
self.embed = nn.Linear(d_inp, d_model, bias=False)
def forward(self, x):
return self.embed(x)
class DataEmbedding(nn.Module):
def __init__(self, c_in, d_model, embed_type="fixed", freq="h", dropout=0.1):
super(DataEmbedding, self).__init__()
self.value_embedding = TokenEmbedding(c_in=c_in, d_model=d_model)
self.position_embedding = PositionalEmbedding(d_model=d_model)
self.temporal_embedding = (
TemporalEmbedding(d_model=d_model, embed_type=embed_type, freq=freq)
if embed_type != "timeF"
else TimeFeatureEmbedding(d_model=d_model, embed_type=embed_type, freq=freq)
)
self.dropout = nn.Dropout(p=dropout)
def forward(self, x, x_mark):
x = (
self.value_embedding(x)
+ self.temporal_embedding(x_mark)
+ self.position_embedding(x)
)
return self.dropout(x)
class DataEmbedding_onlypos(nn.Module):
def __init__(self, c_in, d_model, embed_type="fixed", freq="h", dropout=0.1):
super(DataEmbedding_onlypos, self).__init__()
self.value_embedding = TokenEmbedding(c_in=c_in, d_model=d_model)
self.position_embedding = PositionalEmbedding(d_model=d_model)
self.dropout = nn.Dropout(p=dropout)
def forward(self, x, x_mark):
x = self.value_embedding(x) + self.position_embedding(x)
return self.dropout(x)
class DataEmbedding_wo_pos(nn.Module):
def __init__(self, c_in, d_model, embed_type="fixed", freq="h", dropout=0.1):
super(DataEmbedding_wo_pos, self).__init__()
self.value_embedding = TokenEmbedding(c_in=c_in, d_model=d_model)
self.position_embedding = PositionalEmbedding(d_model=d_model)
self.temporal_embedding = (
TemporalEmbedding(d_model=d_model, embed_type=embed_type, freq=freq)
if embed_type != "timeF"
else TimeFeatureEmbedding(d_model=d_model, embed_type=embed_type, freq=freq)
)
self.dropout = nn.Dropout(p=dropout)
def forward(self, x, x_mark):
# try:
x = self.value_embedding(x) + self.temporal_embedding(x_mark)
# except:
# a = 1
return self.dropout(x)
def get_frequency_modes(seq_len, modes=64, mode_select_method="random"):
"""
get modes on frequency domain:
'random' means sampling randomly;
'else' means sampling the lowest modes;
"""
modes = min(modes, seq_len // 2)
if mode_select_method == "random":
index = list(range(0, seq_len // 2))
np.random.shuffle(index)
index = index[:modes]
else:
index = list(range(0, modes))
index.sort()
return index
# ########## fourier layer #############
class FourierBlock(nn.Module):
def __init__(
self,
n_heads,
in_channels,
out_channels,
seq_len,
modes=0,
mode_select_method="random",
):
super(FourierBlock, self).__init__()
print("fourier enhanced block used!")
"""
1D Fourier block. It performs representation learning on frequency domain,
it does FFT, linear transform, and Inverse FFT.
"""
# get modes on frequency domain
self.index = get_frequency_modes(
seq_len, modes=modes, mode_select_method=mode_select_method
)
print("modes={}, index={}".format(modes, self.index))
self.scale = 1 / (in_channels * out_channels)
self.weights1 = nn.Parameter(
self.scale
* torch.rand(
n_heads,
in_channels // n_heads,
out_channels // n_heads,
len(self.index),
dtype=torch.cfloat,
)
)
# Complex multiplication
def compl_mul1d(self, input, weights):
# (batch, in_channel, x ), (in_channel, out_channel, x) -> (batch, out_channel, x)
return torch.einsum("bhi,hio->bho", input, weights) # hio->bho
def forward(self, q, k, v, mask):
# size = [B, L, H, E]
B, L, H, E = q.shape
x = q.permute(0, 2, 3, 1) # [B, H, E, L]
# Compute Fourier coefficients
x_ft = torch.fft.rfft(x, dim=-1)
print('x_ft size',x_ft.size()) # [B, H, E, L]
print('weight size', self.weights1.size())
print('index', self.index)
# Perform Fourier neural operations
out_ft = torch.zeros(B, H, E, L // 2 + 1, device=x.device, dtype=torch.cfloat)
for wi, i in enumerate(self.index):
out_ft[:, :, :, wi] = self.compl_mul1d(
x_ft[:, :, :, i], self.weights1[:, :, :, wi]
)
# Return to time domain
x = torch.fft.irfft(out_ft, n=x.size(-1))
return (x, None)
# ########## Fourier Cross Former ####################
class FourierCrossAttention(nn.Module):
def __init__(
self,
n_heads,
in_channels,
out_channels,
seq_len_q,
seq_len_kv,
modes=64,
mode_select_method="random",
activation="tanh",
policy=0,
):
super(FourierCrossAttention, self).__init__()
print(" fourier enhanced cross attention used!")
"""
1D Fourier Cross Attention layer. It does FFT, linear transform, attention mechanism and Inverse FFT.
"""
self.activation = activation
self.in_channels = in_channels
self.out_channels = out_channels
# get modes for queries and keys (& values) on frequency domain
self.index_q = get_frequency_modes(
seq_len_q, modes=modes, mode_select_method=mode_select_method
)
self.index_kv = get_frequency_modes(
seq_len_kv, modes=modes, mode_select_method=mode_select_method
)
# print("modes_q={}, index_q={}".format(len(self.index_q), self.index_q))
# print("modes_kv={}, index_kv={}".format(len(self.index_kv), self.index_kv))
self.scale = 1 / (in_channels * out_channels)
self.weights1 = nn.Parameter(
self.scale
* torch.rand(
n_heads,
in_channels // n_heads,
out_channels // n_heads,
len(self.index_q),
dtype=torch.cfloat,
)
)
# Complex multiplication
def compl_mul1d(self, input, weights):
# (batch, in_channel, x ), (in_channel, out_channel, x) -> (batch, out_channel, x)
return torch.einsum("bhi,hoi->boi", input, weights) # bhi,hio->bho"
def forward(self, q, k, v, mask):
# size = [B, L, H, E]
B, L, H, E = q.shape
xq = q.permute(0, 2, 3, 1) # size = [B, H, E, L]
xk = k.permute(0, 2, 3, 1)
xv = v.permute(0, 2, 3, 1)
# Compute Fourier coefficients
xq_ft_ = torch.zeros(
B, H, E, len(self.index_q), device=xq.device, dtype=torch.cfloat
)
xq_ft = torch.fft.rfft(xq, dim=-1)
for i, j in enumerate(self.index_q):
xq_ft_[:, :, :, i] = xq_ft[:, :, :, j]
xk_ft_ = torch.zeros(
B, H, E, len(self.index_kv), device=xq.device, dtype=torch.cfloat
)
xk_ft = torch.fft.rfft(xk, dim=-1)
for i, j in enumerate(self.index_kv):
xk_ft_[:, :, :, i] = xk_ft[:, :, :, j]
# perform attention mechanism on frequency domain
xqk_ft = torch.einsum("bhex,bhey->bhxy", xq_ft_, xk_ft_)
if self.activation == "tanh":
xqk_ft = xqk_ft.tanh()
elif self.activation == "softmax":
xqk_ft = torch.softmax(abs(xqk_ft), dim=-1)
xqk_ft = torch.complex(xqk_ft, torch.zeros_like(xqk_ft))
else:
raise Exception(
"{} actiation function is not implemented".format(self.activation)
)
xqkv_ft = torch.einsum("bhxy,bhey->bhex", xqk_ft, xk_ft_)
xqkvw = torch.einsum("bhex,heox->bhox", xqkv_ft, self.weights1)
out_ft = torch.zeros(B, H, E, L // 2 + 1, device=xq.device, dtype=torch.cfloat)
for i, j in enumerate(self.index_q):
out_ft[:, :, :, j] = xqkvw[:, :, :, i]
# Return to time domain
out = torch.fft.irfft(
out_ft / self.in_channels / self.out_channels, n=xq.size(-1)
)
return (out, None)
class MultiWaveletTransform(nn.Module):
"""
1D multiwavelet block.
"""
def __init__(
self,
ich=1,
k=8,
alpha=16,
c=128,
nCZ=1,
L=0,
base="legendre",
attention_dropout=0.1,
):
super(MultiWaveletTransform, self).__init__()
self.k = k
self.c = c
self.L = L
self.nCZ = nCZ
self.Lk0 = nn.Linear(ich, c * k)
self.Lk1 = nn.Linear(c * k, ich)
self.ich = ich
self.MWT_CZ = nn.ModuleList(MWT_CZ1d(k, alpha, L, c, base) for i in range(nCZ))
def forward(self, queries, keys, values, attn_mask):
B, L, H, E = queries.shape
_, S, _, D = values.shape
if L > S:
zeros = torch.zeros_like(queries[:, : (L - S), :]).float()
values = torch.cat([values, zeros], dim=1)
keys = torch.cat([keys, zeros], dim=1)
else:
values = values[:, :L, :, :]
keys = keys[:, :L, :, :]
values = values.view(B, L, -1)
V = self.Lk0(values).view(B, L, self.c, -1)
for i in range(self.nCZ):
V = self.MWT_CZ[i](V)
if i < self.nCZ - 1:
V = F.relu(V)
V = self.Lk1(V.view(B, L, -1))
V = V.view(B, L, -1, D)
return (V.contiguous(), None)
class MultiWaveletCross(nn.Module):
"""
1D Multiwavelet Cross Attention layer.
"""
def __init__(
self,
in_channels,
out_channels,
seq_len_q,
seq_len_kv,
modes,
c=64,
k=8,
ich=512,
L=0,
base="legendre",
mode_select_method="random",
initializer=None,
activation="tanh",
**kwargs,
):
super(MultiWaveletCross, self).__init__()
self.c = c
self.k = k
self.L = L
H0, H1, G0, G1, PHI0, PHI1 = get_filter(base, k)
H0r = H0 @ PHI0
G0r = G0 @ PHI0
H1r = H1 @ PHI1
G1r = G1 @ PHI1
H0r[np.abs(H0r) < 1e-8] = 0
H1r[np.abs(H1r) < 1e-8] = 0
G0r[np.abs(G0r) < 1e-8] = 0
G1r[np.abs(G1r) < 1e-8] = 0
self.max_item = 3
self.attn1 = FourierCrossAttentionW(
in_channels=in_channels,
out_channels=out_channels,
seq_len_q=seq_len_q,
seq_len_kv=seq_len_kv,
modes=modes,
activation=activation,
mode_select_method=mode_select_method,
)
self.attn2 = FourierCrossAttentionW(
in_channels=in_channels,
out_channels=out_channels,
seq_len_q=seq_len_q,
seq_len_kv=seq_len_kv,
modes=modes,
activation=activation,
mode_select_method=mode_select_method,
)
self.attn3 = FourierCrossAttentionW(
in_channels=in_channels,
out_channels=out_channels,
seq_len_q=seq_len_q,
seq_len_kv=seq_len_kv,
modes=modes,
activation=activation,
mode_select_method=mode_select_method,
)
self.attn4 = FourierCrossAttentionW(
in_channels=in_channels,
out_channels=out_channels,
seq_len_q=seq_len_q,
seq_len_kv=seq_len_kv,
modes=modes,
activation=activation,
mode_select_method=mode_select_method,
)
self.T0 = nn.Linear(k, k)
self.register_buffer("ec_s", torch.Tensor(np.concatenate((H0.T, H1.T), axis=0)))
self.register_buffer("ec_d", torch.Tensor(np.concatenate((G0.T, G1.T), axis=0)))
self.register_buffer("rc_e", torch.Tensor(np.concatenate((H0r, G0r), axis=0)))
self.register_buffer("rc_o", torch.Tensor(np.concatenate((H1r, G1r), axis=0)))
self.Lk = nn.Linear(ich, c * k)
self.Lq = nn.Linear(ich, c * k)
self.Lv = nn.Linear(ich, c * k)
self.out = nn.Linear(c * k, ich)
self.modes1 = modes
def forward(self, q, k, v, mask=None):
B, N, H, E = q.shape # (B, N, H, E) torch.Size([3, 768, 8, 2])
_, S, _, _ = k.shape # (B, S, H, E) torch.Size([3, 96, 8, 2])
q = q.view(q.shape[0], q.shape[1], -1)
k = k.view(k.shape[0], k.shape[1], -1)
v = v.view(v.shape[0], v.shape[1], -1)
q = self.Lq(q)
q = q.view(q.shape[0], q.shape[1], self.c, self.k)
k = self.Lk(k)
k = k.view(k.shape[0], k.shape[1], self.c, self.k)
v = self.Lv(v)
v = v.view(v.shape[0], v.shape[1], self.c, self.k)
if N > S:
zeros = torch.zeros_like(q[:, : (N - S), :]).float()
v = torch.cat([v, zeros], dim=1)
k = torch.cat([k, zeros], dim=1)
else:
v = v[:, :N, :, :]
k = k[:, :N, :, :]
ns = math.floor(np.log2(N))
nl = pow(2, math.ceil(np.log2(N)))
extra_q = q[:, 0 : nl - N, :, :]
extra_k = k[:, 0 : nl - N, :, :]
extra_v = v[:, 0 : nl - N, :, :]
q = torch.cat([q, extra_q], 1)
k = torch.cat([k, extra_k], 1)
v = torch.cat([v, extra_v], 1)
Ud_q = torch.jit.annotate(List[Tuple[Tensor]], [])
Ud_k = torch.jit.annotate(List[Tuple[Tensor]], [])
Ud_v = torch.jit.annotate(List[Tuple[Tensor]], [])
Us_q = torch.jit.annotate(List[Tensor], [])
Us_k = torch.jit.annotate(List[Tensor], [])
Us_v = torch.jit.annotate(List[Tensor], [])
Ud = torch.jit.annotate(List[Tensor], [])
Us = torch.jit.annotate(List[Tensor], [])
# decompose
for i in range(ns - self.L):
# print('q shape',q.shape)
d, q = self.wavelet_transform(q)
Ud_q += [tuple([d, q])]
Us_q += [d]
for i in range(ns - self.L):
d, k = self.wavelet_transform(k)
Ud_k += [tuple([d, k])]
Us_k += [d]
for i in range(ns - self.L):
d, v = self.wavelet_transform(v)
Ud_v += [tuple([d, v])]
Us_v += [d]
for i in range(ns - self.L):
dk, sk = Ud_k[i], Us_k[i]
dq, sq = Ud_q[i], Us_q[i]
dv, sv = Ud_v[i], Us_v[i]
Ud += [
self.attn1(dq[0], dk[0], dv[0], mask)[0]
+ self.attn2(dq[1], dk[1], dv[1], mask)[0]
]
Us += [self.attn3(sq, sk, sv, mask)[0]]
v = self.attn4(q, k, v, mask)[0]
# reconstruct
for i in range(ns - 1 - self.L, -1, -1):
v = v + Us[i]
v = torch.cat((v, Ud[i]), -1)
v = self.evenOdd(v)
v = self.out(v[:, :N, :, :].contiguous().view(B, N, -1))
return (v.contiguous(), None)
def wavelet_transform(self, x):
xa = torch.cat(
[
x[:, ::2, :, :],
x[:, 1::2, :, :],
],
-1,
)
d = torch.matmul(xa, self.ec_d)
s = torch.matmul(xa, self.ec_s)
return d, s
def evenOdd(self, x):
B, N, c, ich = x.shape # (B, N, c, k)
assert ich == 2 * self.k
x_e = torch.matmul(x, self.rc_e)
x_o = torch.matmul(x, self.rc_o)
x = torch.zeros(B, N * 2, c, self.k, device=x.device)
x[..., ::2, :, :] = x_e
x[..., 1::2, :, :] = x_o
return x
class FourierCrossAttentionW(nn.Module):
def __init__(
self,
in_channels,
out_channels,
seq_len_q,
seq_len_kv,
modes=16,
activation="tanh",
mode_select_method="random",
):
super(FourierCrossAttentionW, self).__init__()
print("corss fourier correlation used!")
self.in_channels = in_channels
self.out_channels = out_channels
self.modes1 = modes
self.activation = activation
def forward(self, q, k, v, mask):
B, L, E, H = q.shape
xq = q.permute(0, 3, 2, 1) # size = [B, H, E, L] torch.Size([3, 8, 64, 512])
xk = k.permute(0, 3, 2, 1)
xv = v.permute(0, 3, 2, 1)
self.index_q = list(range(0, min(int(L // 2), self.modes1)))
self.index_k_v = list(range(0, min(int(xv.shape[3] // 2), self.modes1)))
# Compute Fourier coefficients
xq_ft_ = torch.zeros(
B, H, E, len(self.index_q), device=xq.device, dtype=torch.cfloat
)
xq_ft = torch.fft.rfft(xq, dim=-1)
for i, j in enumerate(self.index_q):
xq_ft_[:, :, :, i] = xq_ft[:, :, :, j]
xk_ft_ = torch.zeros(
B, H, E, len(self.index_k_v), device=xq.device, dtype=torch.cfloat
)
xk_ft = torch.fft.rfft(xk, dim=-1)
for i, j in enumerate(self.index_k_v):
xk_ft_[:, :, :, i] = xk_ft[:, :, :, j]
xqk_ft = torch.einsum("bhex,bhey->bhxy", xq_ft_, xk_ft_)
if self.activation == "tanh":
xqk_ft = xqk_ft.tanh()
elif self.activation == "softmax":
xqk_ft = torch.softmax(abs(xqk_ft), dim=-1)
xqk_ft = torch.complex(xqk_ft, torch.zeros_like(xqk_ft))
else:
raise Exception(
"{} actiation function is not implemented".format(self.activation)
)
xqkv_ft = torch.einsum("bhxy,bhey->bhex", xqk_ft, xk_ft_)
xqkvw = xqkv_ft
out_ft = torch.zeros(B, H, E, L // 2 + 1, device=xq.device, dtype=torch.cfloat)
for i, j in enumerate(self.index_q):
out_ft[:, :, :, j] = xqkvw[:, :, :, i]
out = torch.fft.irfft(
out_ft / self.in_channels / self.out_channels, n=xq.size(-1)
).permute(0, 3, 2, 1)
# size = [B, L, H, E]
return (out, None)
class sparseKernelFT1d(nn.Module):
def __init__(self, k, alpha, c=1, nl=1, initializer=None, **kwargs):
super(sparseKernelFT1d, self).__init__()
self.modes1 = alpha
self.scale = 1 / (c * k * c * k)
self.weights1 = nn.Parameter(
self.scale * torch.rand(c * k, c * k, self.modes1, dtype=torch.cfloat)
)
self.weights1.requires_grad = True
self.k = k
def compl_mul1d(self, x, weights):
# (batch, in_channel, x ), (in_channel, out_channel, x) -> (batch, out_channel, x)
return torch.einsum("bix,iox->box", x, weights)
def forward(self, x):
B, N, c, k = x.shape # (B, N, c, k)
x = x.view(B, N, -1)
x = x.permute(0, 2, 1)
x_fft = torch.fft.rfft(x)
# Multiply relevant Fourier modes
l = min(self.modes1, N // 2 + 1)
# l = N//2+1
out_ft = torch.zeros(B, c * k, N // 2 + 1, device=x.device, dtype=torch.cfloat)
out_ft[:, :, :l] = self.compl_mul1d(x_fft[:, :, :l], self.weights1[:, :, :l])
x = torch.fft.irfft(out_ft, n=N)
x = x.permute(0, 2, 1).view(B, N, c, k)
return x
# ##
class MWT_CZ1d(nn.Module):
def __init__(
self, k=3, alpha=64, L=0, c=1, base="legendre", initializer=None, **kwargs
):
super(MWT_CZ1d, self).__init__()
self.k = k
self.L = L
H0, H1, G0, G1, PHI0, PHI1 = get_filter(base, k)
H0r = H0 @ PHI0
G0r = G0 @ PHI0
H1r = H1 @ PHI1
G1r = G1 @ PHI1
H0r[np.abs(H0r) < 1e-8] = 0
H1r[np.abs(H1r) < 1e-8] = 0
G0r[np.abs(G0r) < 1e-8] = 0
G1r[np.abs(G1r) < 1e-8] = 0
self.max_item = 3
self.A = sparseKernelFT1d(k, alpha, c)
self.B = sparseKernelFT1d(k, alpha, c)
self.C = sparseKernelFT1d(k, alpha, c)
self.T0 = nn.Linear(k, k)
self.register_buffer("ec_s", torch.Tensor(np.concatenate((H0.T, H1.T), axis=0)))
self.register_buffer("ec_d", torch.Tensor(np.concatenate((G0.T, G1.T), axis=0)))
self.register_buffer("rc_e", torch.Tensor(np.concatenate((H0r, G0r), axis=0)))
self.register_buffer("rc_o", torch.Tensor(np.concatenate((H1r, G1r), axis=0)))
def forward(self, x):
B, N, c, k = x.shape # (B, N, k)
ns = math.floor(np.log2(N))
nl = pow(2, math.ceil(np.log2(N)))
extra_x = x[:, 0 : nl - N, :, :]
x = torch.cat([x, extra_x], 1)
Ud = torch.jit.annotate(List[Tensor], [])
Us = torch.jit.annotate(List[Tensor], [])
# decompose
for i in range(ns - self.L):
# print('x shape',x.shape)
d, x = self.wavelet_transform(x)
Ud += [self.A(d) + self.B(x)]
Us += [self.C(d)]
x = self.T0(x) # coarsest scale transform
# reconstruct
for i in range(ns - 1 - self.L, -1, -1):
x = x + Us[i]
x = torch.cat((x, Ud[i]), -1)
x = self.evenOdd(x)
x = x[:, :N, :, :]
return x
def wavelet_transform(self, x):
xa = torch.cat(
[
x[:, ::2, :, :],
x[:, 1::2, :, :],
],
-1,
)
d = torch.matmul(xa, self.ec_d)
s = torch.matmul(xa, self.ec_s)
return d, s
def evenOdd(self, x):
B, N, c, ich = x.shape # (B, N, c, k)
assert ich == 2 * self.k
x_e = torch.matmul(x, self.rc_e)
x_o = torch.matmul(x, self.rc_o)
x = torch.zeros(B, N * 2, c, self.k, device=x.device)
x[..., ::2, :, :] = x_e
x[..., 1::2, :, :] = x_o
return x
class FullAttention(nn.Module):
def __init__(
self,
mask_flag=True,
factor=5,
scale=None,
attention_dropout=0.1,
output_attention=False,
):
super(FullAttention, self).__init__()
self.scale = scale
self.mask_flag = mask_flag
self.output_attention = output_attention
self.dropout = nn.Dropout(attention_dropout)
def forward(self, queries, keys, values, attn_mask):
B, L, H, E = queries.shape
_, S, _, D = values.shape
scale = self.scale or 1.0 / sqrt(E)
scores = torch.einsum("blhe,bshe->bhls", queries, keys)
if self.mask_flag:
if attn_mask is None:
attn_mask = TriangularCausalMask(B, L, device=queries.device)
scores.masked_fill_(attn_mask.mask, -np.inf)
A = self.dropout(torch.softmax(scale * scores, dim=-1))
V = torch.einsum("bhls,bshd->blhd", A, values)
if self.output_attention:
return (V.contiguous(), A)
else:
return (V.contiguous(), None)
class FEDformerModel(nn.Module):
@validated()
def __init__(
self,
freq: str,
prediction_length: int,
nhead: int,
num_encoder_layers: int,
num_decoder_layers: int,
dim_feedforward: int = 16, # dimension of fcn
version: str = "Fourier", # Fourier, Wavelets
features: str = "M", # options:[M, S, MS]; M:multivariate predict multivariate, 'S':univariate predict univariate, MS:multivariate predict univariate'
modes: int = 64,
mode_select: str = "random",
base: str = "legendre",
cross_activation: str = "tanh",
L: int = 3,
# forecasting task
context_length: Optional[int] = None, # seq_len : input sequence length
label_length: Optional[int] = 48, # start token length
# model argument
input_size: int = 1, # encoder input size
# dec_in: int = 7, #decoder input size
c_out: int = 7, # output size
moving_avg: Optional[List[int]] = None,
factor: int = 1,
scaling: bool = True,
activation: str = "gelu",
dropout: float = 0.05,
embed: str = "timeF", # options:[timeF, fixed, learned]
output_attention: bool = True, # whether to output attention in encoder
num_feat_dynamic_real: int = 0,
num_feat_static_cat: int = 0,
num_feat_static_real: int = 0,
cardinality: Optional[List[int]] = None,
embedding_dimension: Optional[List[int]] = None,
distr_output: DistributionOutput = StudentTOutput(),
lags_seq: Optional[List[int]] = None,
num_parallel_samples: int = 100,
) -> None:
super().__init__()
self.input_size = input_size
self.target_shape = distr_output.event_shape
self.num_feat_dynamic_real = num_feat_dynamic_real
self.num_feat_static_cat = num_feat_static_cat
self.num_feat_static_real = num_feat_static_real
self.embedding_dimension = (
embedding_dimension
if embedding_dimension is not None or cardinality is None
else [min(50, (cat + 1) // 2) for cat in cardinality]
)
self.lags_seq = lags_seq or get_lags_for_frequency(freq_str=freq)
self.num_parallel_samples = num_parallel_samples
self.history_length = context_length + max(self.lags_seq)
self.embedder = FeatureEmbedder(
cardinalities=cardinality,
embedding_dims=self.embedding_dimension,
)
if scaling:
self.scaler = MeanScaler(dim=1, keepdim=True)
else:
self.scaler = NOPScaler(dim=1, keepdim=True)
# total feature size
d_model = self.input_size * len(self.lags_seq) + self._number_of_features
self.context_length = context_length
self.prediction_length = prediction_length
self.distr_output = distr_output
self.param_proj = distr_output.get_args_proj(d_model)
self.moving_avg = moving_avg
self.version = version
self.mode_select = mode_select
self.modes = modes
self.label_length = label_length
self.output_attention = output_attention
self.dim_feedforward = dim_feedforward
# Decomp
kernel_size = self.moving_avg
if isinstance(kernel_size, list):
self.decomp = series_decomp_multi(kernel_size)
else:
self.decomp = series_decomp(kernel_size)
# Embedding
# The series-wise connection inherently contains the sequential information.
# Thus, we can discard the position embedding of transformers.
# self.enc_embedding = DataEmbedding_wo_pos(enc_in, d_model, embed, freq,
# dropout)
# self.dec_embedding = DataEmbedding_wo_pos(dec_in, d_model, embed, freq,
# dropout)
if self.version == "Wavelets":
encoder_self_att = MultiWaveletTransform(ich=d_model, L=L, base=base)
decoder_self_att = MultiWaveletTransform(ich=d_model, L=L, base=base)
decoder_cross_att = MultiWaveletCross(
in_channels=d_model,
out_channels=d_model,
seq_len_q=self.context_length // 2 + self.prediction_length,
seq_len_kv=self.context_length,
modes=modes,
ich=d_model,
base=base,
activation=cross_activation,
)
else:
encoder_self_att = FourierBlock(
n_heads=nhead,
in_channels=d_model,
out_channels=d_model,
seq_len=self.context_length,
modes=modes,
mode_select_method=mode_select,
)
decoder_self_att = FourierBlock(
n_heads=nhead,
in_channels=d_model,
out_channels=d_model,
seq_len=self.context_length // 2 + self.prediction_length,
modes=modes,
mode_select_method=mode_select,
)
decoder_cross_att = FourierCrossAttention(
n_heads=nhead,
in_channels=d_model,
out_channels=d_model,
seq_len_q=self.context_length // 2 + self.prediction_length,
seq_len_kv=self.context_length,
modes=modes,
mode_select_method=mode_select,
)
# Encoder
print("dim_feedforward", self.dim_feedforward)
enc_modes = int(min(modes, context_length // 2))
dec_modes = int(min(modes, (context_length // 2 + self.prediction_length) // 2))
print("enc_modes: {}, dec_modes: {}".format(enc_modes, dec_modes))
print("encoder_self_att", encoder_self_att)
self.encoder = Encoder(
[
EncoderLayer(
AutoCorrelationLayer(encoder_self_att, d_model, n_heads=nhead),
d_model,
d_ff=self.dim_feedforward,
moving_avg=moving_avg,
dropout=dropout,
activation=activation,
)
for l in range(num_encoder_layers)
],
norm_layer=my_Layernorm(d_model),
)
# Decoder
self.decoder = Decoder(
[
DecoderLayer(
AutoCorrelationLayer(decoder_self_att, d_model, n_heads=nhead),
AutoCorrelationLayer(decoder_cross_att, d_model, n_heads=nhead),
d_model,
c_out=c_out,
d_ff=self.dim_feedforward,
moving_avg=self.moving_avg,
dropout=dropout,
activation=activation,
)
for l in range(num_decoder_layers)
],
norm_layer=my_Layernorm(d_model),
projection=nn.Linear(d_model, c_out, bias=True),
)
@property
def _number_of_features(self) -> int:
return (
sum(self.embedding_dimension)
+ self.num_feat_dynamic_real
+ self.num_feat_static_real
+ 1 # the log(scale)
)
@property
def _past_length(self) -> int:
return self.context_length + max(self.lags_seq)
def get_lagged_subsequences(
self, sequence: torch.Tensor, subsequences_length: int, shift: int = 0
) -> torch.Tensor:
"""
Returns lagged subsequences of a given sequence.
Parameters
----------
sequence : Tensor
the sequence from which lagged subsequences should be extracted.
Shape: (N, T, C).
subsequences_length : int
length of the subsequences to be extracted.
shift: int
shift the lags by this amount back.
Returns
--------
lagged : Tensor
a tensor of shape (N, S, C, I), where S = subsequences_length and
I = len(indices), containing lagged subsequences. Specifically,
lagged[i, j, :, k] = sequence[i, -indices[k]-S+j, :].
"""
sequence_length = sequence.shape[1]
indices = [lag - shift for lag in self.lags_seq]
assert max(indices) + subsequences_length <= sequence_length, (
f"lags cannot go further than history length, found lag {max(indices)} "
f"while history length is only {sequence_length}"
)
lagged_values = []
for lag_index in indices:
begin_index = -lag_index - subsequences_length
end_index = -lag_index if lag_index > 0 else None
lagged_values.append(sequence[:, begin_index:end_index, ...])
return torch.stack(lagged_values, dim=-1)
def _check_shapes(
self,
prior_input: torch.Tensor,
inputs: torch.Tensor,
features: Optional[torch.Tensor],
) -> None:
assert len(prior_input.shape) == len(inputs.shape)
assert (
len(prior_input.shape) == 2 and self.input_size == 1
) or prior_input.shape[2] == self.input_size
assert (len(inputs.shape) == 2 and self.input_size == 1) or inputs.shape[
-1
] == self.input_size
assert (
features is None or features.shape[2] == self._number_of_features
), f"{features.shape[2]}, expected {self._number_of_features}"
def create_network_inputs(
self,
feat_static_cat: torch.Tensor,
feat_static_real: torch.Tensor,
past_time_feat: torch.Tensor,
past_target: torch.Tensor,
past_observed_values: torch.Tensor,
future_time_feat: Optional[torch.Tensor] = None,
future_target: Optional[torch.Tensor] = None,
):
# time feature
time_feat = (
torch.cat(
(
past_time_feat[:, self._past_length - self.context_length :, ...],
future_time_feat,
),
dim=1,
)
if future_target is not None
else past_time_feat[:, self._past_length - self.context_length :, ...]
)
# target
context = past_target[:, -self.context_length :]
observed_context = past_observed_values[:, -self.context_length :]
_, scale = self.scaler(context, observed_context)
inputs = (
torch.cat((past_target, future_target), dim=1) / scale
if future_target is not None
else past_target / scale
)
inputs_length = (
self._past_length + self.prediction_length
if future_target is not None
else self._past_length
)
assert inputs.shape[1] == inputs_length
subsequences_length = (
self.context_length + self.prediction_length
if future_target is not None
else self.context_length
)
# embeddings
embedded_cat = self.embedder(feat_static_cat)
static_feat = torch.cat(
(embedded_cat, feat_static_real, scale.log()),
dim=1,
)
expanded_static_feat = static_feat.unsqueeze(1).expand(
-1, time_feat.shape[1], -1
)
features = torch.cat((expanded_static_feat, time_feat), dim=-1)
# self._check_shapes(prior_input, inputs, features)
# sequence = torch.cat((prior_input, inputs), dim=1)
lagged_sequence = self.get_lagged_subsequences(
sequence=inputs,
subsequences_length=subsequences_length,
)
lags_shape = lagged_sequence.shape
reshaped_lagged_sequence = lagged_sequence.reshape(
lags_shape[0], lags_shape[1], -1
)
transformer_inputs = torch.cat((reshaped_lagged_sequence, features), dim=-1)
return transformer_inputs, scale, static_feat
def output_params(self, transformer_inputs):
enc_input = transformer_inputs[:, : self.context_length, ...]
dec_input = transformer_inputs[:, self.context_length :, ...]
print('enc_input',enc_input.shape)
enc_out, _ = self.encoder(enc_input)
dec_output = self.decoder(dec_input, enc_out)
return self.param_proj(dec_output)
@torch.jit.ignore
def output_distribution(
self, params, scale=None, trailing_n=None
) -> torch.distributions.Distribution:
sliced_params = params
if trailing_n is not None:
sliced_params = [p[:, -trailing_n:] for p in params]
return self.distr_output.distribution(sliced_params, scale=scale)
# for prediction
def forward(
self,
feat_static_cat: torch.Tensor,
feat_static_real: torch.Tensor,
past_time_feat: torch.Tensor,
past_target: torch.Tensor,
past_observed_values: torch.Tensor,
future_time_feat: torch.Tensor,
num_parallel_samples: Optional[int] = None,
) -> torch.Tensor:
if num_parallel_samples is None:
num_parallel_samples = self.num_parallel_samples
encoder_inputs, scale, static_feat = self.create_network_inputs(
feat_static_cat,
feat_static_real,
past_time_feat,
past_target,
past_observed_values,
)
enc_out, _ = self.encoder(encoder_inputs)
repeated_scale = scale.repeat_interleave(
repeats=self.num_parallel_samples, dim=0
)
repeated_past_target = (
past_target.repeat_interleave(repeats=self.num_parallel_samples, dim=0)
/ repeated_scale
)
expanded_static_feat = static_feat.unsqueeze(1).expand(
-1, future_time_feat.shape[1], -1
)
features = torch.cat((expanded_static_feat, future_time_feat), dim=-1)
repeated_features = features.repeat_interleave(
repeats=self.num_parallel_samples, dim=0
)
repeated_enc_out = enc_out.repeat_interleave(
repeats=self.num_parallel_samples, dim=0
)
future_samples = []
# greedy decoding
for k in range(self.prediction_length):
# self._check_shapes(repeated_past_target, next_sample, next_features)
# sequence = torch.cat((repeated_past_target, next_sample), dim=1)
lagged_sequence = self.get_lagged_subsequences(
sequence=repeated_past_target,
subsequences_length=1 + k,
shift=1,
)
lags_shape = lagged_sequence.shape
reshaped_lagged_sequence = lagged_sequence.reshape(
lags_shape[0], lags_shape[1], -1
)
decoder_input = torch.cat(
(reshaped_lagged_sequence, repeated_features[:, : k + 1]), dim=-1
)
output = self.decoder(decoder_input, repeated_enc_out)
params = self.param_proj(output[:, -1:])
distr = self.output_distribution(params, scale=repeated_scale)
next_sample = distr.sample()
repeated_past_target = torch.cat(
(repeated_past_target, next_sample / repeated_scale), dim=1
)
future_samples.append(next_sample)
concat_future_samples = torch.cat(future_samples, dim=1)
return concat_future_samples.reshape(
(-1, self.num_parallel_samples, self.prediction_length) + self.target_shape,
)