mirror of
https://github.com/wassname/ETSformer.git
synced 2026-06-28 16:10:50 +08:00
69 lines
2.0 KiB
Python
69 lines
2.0 KiB
Python
import math
|
|
|
|
import torch
|
|
import torch.nn as nn
|
|
import torch.fft as fft
|
|
|
|
from einops import rearrange, reduce, repeat
|
|
from scipy.fftpack import next_fast_len
|
|
|
|
|
|
def conv1d_fft(f, g, dim=-1):
|
|
N = f.size(dim)
|
|
M = g.size(dim)
|
|
|
|
fast_len = next_fast_len(N + M - 1)
|
|
|
|
F_f = fft.rfft(f, fast_len, dim=dim)
|
|
F_g = fft.rfft(g, fast_len, dim=dim)
|
|
|
|
F_fg = F_f * F_g.conj()
|
|
out = fft.irfft(F_fg, fast_len, dim=dim)
|
|
out = out.roll((-1,), dims=(dim,))
|
|
idx = torch.as_tensor(range(fast_len - N, fast_len)).to(out.device)
|
|
out = out.index_select(dim, idx)
|
|
|
|
return out
|
|
|
|
|
|
class ExponentialSmoothing(nn.Module):
|
|
|
|
def __init__(self, dim, nhead, dropout=0.1, aux=False):
|
|
super().__init__()
|
|
self._smoothing_weight = nn.Parameter(torch.randn(nhead, 1))
|
|
self.v0 = nn.Parameter(torch.randn(1, 1, nhead, dim))
|
|
self.dropout = nn.Dropout(dropout)
|
|
if aux:
|
|
self.aux_dropout = nn.Dropout(dropout)
|
|
|
|
def forward(self, values, aux_values=None):
|
|
b, t, h, d = values.shape
|
|
|
|
init_weight, weight = self.get_exponential_weight(t)
|
|
output = conv1d_fft(self.dropout(values), weight, dim=1)
|
|
output = init_weight * self.v0 + output
|
|
|
|
if aux_values is not None:
|
|
aux_weight = weight / (1 - self.weight) * self.weight
|
|
aux_output = conv1d_fft(self.aux_dropout(aux_values), aux_weight)
|
|
output = output + aux_output
|
|
|
|
return output
|
|
|
|
def get_exponential_weight(self, T):
|
|
# Generate array [0, 1, ..., T-1]
|
|
powers = torch.arange(T, dtype=torch.float, device=self.weight.device)
|
|
|
|
# (1 - \alpha) * \alpha^t, for all t = T-1, T-2, ..., 0]
|
|
weight = (1 - self.weight) * (self.weight ** torch.flip(powers, dims=(0,)))
|
|
|
|
# \alpha^t for all t = 1, 2, ..., T
|
|
init_weight = self.weight ** (powers + 1)
|
|
|
|
return rearrange(init_weight, 'h t -> 1 t h 1'), \
|
|
rearrange(weight, 'h t -> 1 t h 1')
|
|
|
|
@property
|
|
def weight(self):
|
|
return torch.sigmoid(self._smoothing_weight)
|