mirror of
https://github.com/wassname/pytorch-transformer-ts.git
synced 2026-06-27 15:16:27 +08:00
Adding utils functions
This commit is contained in:
+903
-1
File diff suppressed because one or more lines are too long
+199
-10
@@ -12,9 +12,199 @@ from gluonts.torch.distributions import DistributionOutput, StudentTOutput
|
||||
from gluonts.torch.modules.feature import FeatureEmbedder
|
||||
from gluonts.torch.modules.scaler import MeanScaler, NOPScaler
|
||||
import random
|
||||
from matplotlib import pyplot as plt
|
||||
# from matplotlib import pyplot as plt
|
||||
from torch.nn.functional import interpolate
|
||||
|
||||
from scipy.special import eval_legendre
|
||||
from sympy import Poly, legendre, Symbol, chebyshevt
|
||||
|
||||
|
||||
def legendreDer(k, x):
|
||||
def _legendre(k, x):
|
||||
return (2*k+1) * eval_legendre(k, x)
|
||||
out = 0
|
||||
for i in np.arange(k-1,-1,-2):
|
||||
out += _legendre(i, x)
|
||||
return out
|
||||
|
||||
|
||||
def phi_(phi_c, x, lb = 0, ub = 1):
|
||||
mask = np.logical_or(x<lb, x>ub) * 1.0
|
||||
return np.polynomial.polynomial.Polynomial(phi_c)(x) * (1-mask)
|
||||
|
||||
|
||||
def get_phi_psi(k, base):
|
||||
|
||||
x = Symbol('x')
|
||||
phi_coeff = np.zeros((k,k))
|
||||
phi_2x_coeff = np.zeros((k,k))
|
||||
if base == 'legendre':
|
||||
for ki in range(k):
|
||||
coeff_ = Poly(legendre(ki, 2*x-1), x).all_coeffs()
|
||||
phi_coeff[ki,:ki+1] = np.flip(np.sqrt(2*ki+1) * np.array(coeff_).astype(np.float64))
|
||||
coeff_ = Poly(legendre(ki, 4*x-1), x).all_coeffs()
|
||||
phi_2x_coeff[ki,:ki+1] = np.flip(np.sqrt(2) * np.sqrt(2*ki+1) * np.array(coeff_).astype(np.float64))
|
||||
|
||||
psi1_coeff = np.zeros((k, k))
|
||||
psi2_coeff = np.zeros((k, k))
|
||||
for ki in range(k):
|
||||
psi1_coeff[ki,:] = phi_2x_coeff[ki,:]
|
||||
for i in range(k):
|
||||
a = phi_2x_coeff[ki,:ki+1]
|
||||
b = phi_coeff[i, :i+1]
|
||||
prod_ = np.convolve(a, b)
|
||||
prod_[np.abs(prod_)<1e-8] = 0
|
||||
proj_ = (prod_ * 1/(np.arange(len(prod_))+1) * np.power(0.5, 1+np.arange(len(prod_)))).sum()
|
||||
psi1_coeff[ki,:] -= proj_ * phi_coeff[i,:]
|
||||
psi2_coeff[ki,:] -= proj_ * phi_coeff[i,:]
|
||||
for j in range(ki):
|
||||
a = phi_2x_coeff[ki,:ki+1]
|
||||
b = psi1_coeff[j, :]
|
||||
prod_ = np.convolve(a, b)
|
||||
prod_[np.abs(prod_)<1e-8] = 0
|
||||
proj_ = (prod_ * 1/(np.arange(len(prod_))+1) * np.power(0.5, 1+np.arange(len(prod_)))).sum()
|
||||
psi1_coeff[ki,:] -= proj_ * psi1_coeff[j,:]
|
||||
psi2_coeff[ki,:] -= proj_ * psi2_coeff[j,:]
|
||||
|
||||
a = psi1_coeff[ki,:]
|
||||
prod_ = np.convolve(a, a)
|
||||
prod_[np.abs(prod_)<1e-8] = 0
|
||||
norm1 = (prod_ * 1/(np.arange(len(prod_))+1) * np.power(0.5, 1+np.arange(len(prod_)))).sum()
|
||||
|
||||
a = psi2_coeff[ki,:]
|
||||
prod_ = np.convolve(a, a)
|
||||
prod_[np.abs(prod_)<1e-8] = 0
|
||||
norm2 = (prod_ * 1/(np.arange(len(prod_))+1) * (1-np.power(0.5, 1+np.arange(len(prod_))))).sum()
|
||||
norm_ = np.sqrt(norm1 + norm2)
|
||||
psi1_coeff[ki,:] /= norm_
|
||||
psi2_coeff[ki,:] /= norm_
|
||||
psi1_coeff[np.abs(psi1_coeff)<1e-8] = 0
|
||||
psi2_coeff[np.abs(psi2_coeff)<1e-8] = 0
|
||||
|
||||
phi = [np.poly1d(np.flip(phi_coeff[i,:])) for i in range(k)]
|
||||
psi1 = [np.poly1d(np.flip(psi1_coeff[i,:])) for i in range(k)]
|
||||
psi2 = [np.poly1d(np.flip(psi2_coeff[i,:])) for i in range(k)]
|
||||
|
||||
elif base == 'chebyshev':
|
||||
for ki in range(k):
|
||||
if ki == 0:
|
||||
phi_coeff[ki,:ki+1] = np.sqrt(2/np.pi)
|
||||
phi_2x_coeff[ki,:ki+1] = np.sqrt(2/np.pi) * np.sqrt(2)
|
||||
else:
|
||||
coeff_ = Poly(chebyshevt(ki, 2*x-1), x).all_coeffs()
|
||||
phi_coeff[ki,:ki+1] = np.flip(2/np.sqrt(np.pi) * np.array(coeff_).astype(np.float64))
|
||||
coeff_ = Poly(chebyshevt(ki, 4*x-1), x).all_coeffs()
|
||||
phi_2x_coeff[ki,:ki+1] = np.flip(np.sqrt(2) * 2 / np.sqrt(np.pi) * np.array(coeff_).astype(np.float64))
|
||||
|
||||
phi = [partial(phi_, phi_coeff[i,:]) for i in range(k)]
|
||||
|
||||
x = Symbol('x')
|
||||
kUse = 2*k
|
||||
roots = Poly(chebyshevt(kUse, 2*x-1)).all_roots()
|
||||
x_m = np.array([rt.evalf(20) for rt in roots]).astype(np.float64)
|
||||
# x_m[x_m==0.5] = 0.5 + 1e-8 # add small noise to avoid the case of 0.5 belonging to both phi(2x) and phi(2x-1)
|
||||
# not needed for our purpose here, we use even k always to avoid
|
||||
wm = np.pi / kUse / 2
|
||||
|
||||
psi1_coeff = np.zeros((k, k))
|
||||
psi2_coeff = np.zeros((k, k))
|
||||
|
||||
psi1 = [[] for _ in range(k)]
|
||||
psi2 = [[] for _ in range(k)]
|
||||
|
||||
for ki in range(k):
|
||||
psi1_coeff[ki,:] = phi_2x_coeff[ki,:]
|
||||
for i in range(k):
|
||||
proj_ = (wm * phi[i](x_m) * np.sqrt(2)* phi[ki](2*x_m)).sum()
|
||||
psi1_coeff[ki,:] -= proj_ * phi_coeff[i,:]
|
||||
psi2_coeff[ki,:] -= proj_ * phi_coeff[i,:]
|
||||
|
||||
for j in range(ki):
|
||||
proj_ = (wm * psi1[j](x_m) * np.sqrt(2) * phi[ki](2*x_m)).sum()
|
||||
psi1_coeff[ki,:] -= proj_ * psi1_coeff[j,:]
|
||||
psi2_coeff[ki,:] -= proj_ * psi2_coeff[j,:]
|
||||
|
||||
psi1[ki] = partial(phi_, psi1_coeff[ki,:], lb = 0, ub = 0.5)
|
||||
psi2[ki] = partial(phi_, psi2_coeff[ki,:], lb = 0.5, ub = 1)
|
||||
|
||||
norm1 = (wm * psi1[ki](x_m) * psi1[ki](x_m)).sum()
|
||||
norm2 = (wm * psi2[ki](x_m) * psi2[ki](x_m)).sum()
|
||||
|
||||
norm_ = np.sqrt(norm1 + norm2)
|
||||
psi1_coeff[ki,:] /= norm_
|
||||
psi2_coeff[ki,:] /= norm_
|
||||
psi1_coeff[np.abs(psi1_coeff)<1e-8] = 0
|
||||
psi2_coeff[np.abs(psi2_coeff)<1e-8] = 0
|
||||
|
||||
psi1[ki] = partial(phi_, psi1_coeff[ki,:], lb = 0, ub = 0.5+1e-16)
|
||||
psi2[ki] = partial(phi_, psi2_coeff[ki,:], lb = 0.5+1e-16, ub = 1)
|
||||
|
||||
return phi, psi1, psi2
|
||||
|
||||
|
||||
def get_filter(base, k):
|
||||
|
||||
def psi(psi1, psi2, i, inp):
|
||||
mask = (inp<=0.5) * 1.0
|
||||
return psi1[i](inp) * mask + psi2[i](inp) * (1-mask)
|
||||
|
||||
if base not in ['legendre', 'chebyshev']:
|
||||
raise Exception('Base not supported')
|
||||
|
||||
x = Symbol('x')
|
||||
H0 = np.zeros((k,k))
|
||||
H1 = np.zeros((k,k))
|
||||
G0 = np.zeros((k,k))
|
||||
G1 = np.zeros((k,k))
|
||||
PHI0 = np.zeros((k,k))
|
||||
PHI1 = np.zeros((k,k))
|
||||
phi, psi1, psi2 = get_phi_psi(k, base)
|
||||
if base == 'legendre':
|
||||
roots = Poly(legendre(k, 2*x-1)).all_roots()
|
||||
x_m = np.array([rt.evalf(20) for rt in roots]).astype(np.float64)
|
||||
wm = 1/k/legendreDer(k,2*x_m-1)/eval_legendre(k-1,2*x_m-1)
|
||||
|
||||
for ki in range(k):
|
||||
for kpi in range(k):
|
||||
H0[ki, kpi] = 1/np.sqrt(2) * (wm * phi[ki](x_m/2) * phi[kpi](x_m)).sum()
|
||||
G0[ki, kpi] = 1/np.sqrt(2) * (wm * psi(psi1, psi2, ki, x_m/2) * phi[kpi](x_m)).sum()
|
||||
H1[ki, kpi] = 1/np.sqrt(2) * (wm * phi[ki]((x_m+1)/2) * phi[kpi](x_m)).sum()
|
||||
G1[ki, kpi] = 1/np.sqrt(2) * (wm * psi(psi1, psi2, ki, (x_m+1)/2) * phi[kpi](x_m)).sum()
|
||||
|
||||
PHI0 = np.eye(k)
|
||||
PHI1 = np.eye(k)
|
||||
|
||||
elif base == 'chebyshev':
|
||||
x = Symbol('x')
|
||||
kUse = 2*k
|
||||
roots = Poly(chebyshevt(kUse, 2*x-1)).all_roots()
|
||||
x_m = np.array([rt.evalf(20) for rt in roots]).astype(np.float64)
|
||||
# x_m[x_m==0.5] = 0.5 + 1e-8 # add small noise to avoid the case of 0.5 belonging to both phi(2x) and phi(2x-1)
|
||||
# not needed for our purpose here, we use even k always to avoid
|
||||
wm = np.pi / kUse / 2
|
||||
|
||||
for ki in range(k):
|
||||
for kpi in range(k):
|
||||
H0[ki, kpi] = 1/np.sqrt(2) * (wm * phi[ki](x_m/2) * phi[kpi](x_m)).sum()
|
||||
G0[ki, kpi] = 1/np.sqrt(2) * (wm * psi(psi1, psi2, ki, x_m/2) * phi[kpi](x_m)).sum()
|
||||
H1[ki, kpi] = 1/np.sqrt(2) * (wm * phi[ki]((x_m+1)/2) * phi[kpi](x_m)).sum()
|
||||
G1[ki, kpi] = 1/np.sqrt(2) * (wm * psi(psi1, psi2, ki, (x_m+1)/2) * phi[kpi](x_m)).sum()
|
||||
|
||||
PHI0[ki, kpi] = (wm * phi[ki](2*x_m) * phi[kpi](2*x_m)).sum() * 2
|
||||
PHI1[ki, kpi] = (wm * phi[ki](2*x_m-1) * phi[kpi](2*x_m-1)).sum() * 2
|
||||
|
||||
PHI0[np.abs(PHI0)<1e-8] = 0
|
||||
PHI1[np.abs(PHI1)<1e-8] = 0
|
||||
|
||||
H0[np.abs(H0)<1e-8] = 0
|
||||
H1[np.abs(H1)<1e-8] = 0
|
||||
G0[np.abs(G0)<1e-8] = 0
|
||||
G1[np.abs(G1)<1e-8] = 0
|
||||
|
||||
return H0, H1, G0, G1, PHI0, PHI1
|
||||
|
||||
|
||||
|
||||
|
||||
class TriangularCausalMask:
|
||||
def __init__(self, B, L, device="cpu"):
|
||||
@@ -888,8 +1078,6 @@ class FourierBlock(nn.Module):
|
||||
# Complex multiplication
|
||||
def compl_mul1d(self, input, weights):
|
||||
# (batch, in_channel, x ), (in_channel, out_channel, x) -> (batch, out_channel, x)
|
||||
print("input fft", input.size())
|
||||
print("weight fft", weights.size())
|
||||
return torch.einsum("bhi,hio->bho", input, weights) # hio->bho
|
||||
|
||||
def forward(self, q, k, v, mask):
|
||||
@@ -898,7 +1086,10 @@ class FourierBlock(nn.Module):
|
||||
x = q.permute(0, 2, 3, 1) # [B, H, E, L]
|
||||
# Compute Fourier coefficients
|
||||
x_ft = torch.fft.rfft(x, dim=-1)
|
||||
print(x_ft.size()) # [B, H, E, L]
|
||||
|
||||
print('x_ft size',x_ft.size()) # [B, H, E, L]
|
||||
print('weight size', self.weights1.size())
|
||||
print('index', self.index)
|
||||
# Perform Fourier neural operations
|
||||
out_ft = torch.zeros(B, H, E, L // 2 + 1, device=x.device, dtype=torch.cfloat)
|
||||
for wi, i in enumerate(self.index):
|
||||
@@ -940,8 +1131,8 @@ class FourierCrossAttention(nn.Module):
|
||||
seq_len_kv, modes=modes, mode_select_method=mode_select_method
|
||||
)
|
||||
|
||||
print("modes_q={}, index_q={}".format(len(self.index_q), self.index_q))
|
||||
print("modes_kv={}, index_kv={}".format(len(self.index_kv), self.index_kv))
|
||||
# print("modes_q={}, index_q={}".format(len(self.index_q), self.index_q))
|
||||
# print("modes_kv={}, index_kv={}".format(len(self.index_kv), self.index_kv))
|
||||
|
||||
self.scale = 1 / (in_channels * out_channels)
|
||||
self.weights1 = nn.Parameter(
|
||||
@@ -1021,7 +1212,6 @@ class MultiWaveletTransform(nn.Module):
|
||||
attention_dropout=0.1,
|
||||
):
|
||||
super(MultiWaveletTransform, self).__init__()
|
||||
print("base", base)
|
||||
self.k = k
|
||||
self.c = c
|
||||
self.L = L
|
||||
@@ -1077,7 +1267,6 @@ class MultiWaveletCross(nn.Module):
|
||||
**kwargs,
|
||||
):
|
||||
super(MultiWaveletCross, self).__init__()
|
||||
print("base", base)
|
||||
|
||||
self.c = c
|
||||
self.k = k
|
||||
@@ -1777,7 +1966,7 @@ class FEDformerModel(nn.Module):
|
||||
def output_params(self, transformer_inputs):
|
||||
enc_input = transformer_inputs[:, : self.context_length, ...]
|
||||
dec_input = transformer_inputs[:, self.context_length :, ...]
|
||||
|
||||
print('enc_input',enc_input.shape)
|
||||
enc_out, _ = self.encoder(enc_input)
|
||||
dec_output = self.decoder(dec_input, enc_out)
|
||||
|
||||
@@ -1874,4 +2063,4 @@ class FEDformerModel(nn.Module):
|
||||
concat_future_samples = torch.cat(future_samples, dim=1)
|
||||
return concat_future_samples.reshape(
|
||||
(-1, self.num_parallel_samples, self.prediction_length) + self.target_shape,
|
||||
)
|
||||
)
|
||||
Reference in New Issue
Block a user