This commit is contained in:
Kashif Rasul
2022-05-10 10:52:14 +02:00
parent 8e36564b62
commit b2e37ef867
11 changed files with 4304 additions and 0 deletions
View File
+120
View File
@@ -0,0 +1,120 @@
import torch
from cauchy_mult import (
cauchy_mult_bwd,
cauchy_mult_fwd,
cauchy_mult_sym_bwd,
cauchy_mult_sym_fwd,
)
from einops import rearrange
def cauchy_mult_torch(
v: torch.Tensor, z: torch.Tensor, w: torch.Tensor, symmetric=True
) -> torch.Tensor:
"""
v: (B, N)
z: (L)
w: (B, N)
symmetric: whether to assume that v and w contain complex conjugate pairs, of the form
[v_half, v_half.conj()] and [w_half, w_half.conj()]
"""
if not symmetric:
return (
rearrange(v, "b n -> b 1 n")
/ (rearrange(z, "l -> l 1") - rearrange(w, "b n -> b 1 n"))
).sum(dim=-1)
else:
N = v.shape[-1]
assert N % 2 == 0
vv = rearrange(v[:, : N // 2], "b n -> b 1 n")
zz = rearrange(z, "l -> l 1")
ww = rearrange(w[:, : N // 2], "b n -> b 1 n")
return 2 * (
(zz * vv.real - vv.real * ww.real - vv.imag * ww.imag)
/ (zz * zz - 2 * zz * ww.real + ww.abs().square())
).sum(dim=-1)
def cauchy_mult_keops(v, z, w):
from pykeops.torch import LazyTensor
v_l = LazyTensor(rearrange(v, "b N -> b 1 N 1"))
z_l = LazyTensor(rearrange(z, "L -> 1 L 1 1"))
w_l = LazyTensor(rearrange(w, "b N -> b 1 N 1"))
sub = z_l - w_l # (b N L 1), for some reason it doesn't display the last dimension
div = v_l / sub
s = div.sum(dim=2, backend="GPU")
return s.squeeze(-1)
def _cauchy_mult(v, z, w, symmetric=True):
if not symmetric:
return CauchyMultiply.apply(v, z, w)
else:
return CauchyMultiplySymmetric.apply(v, z, w)
def cauchy_mult(v, z, w, symmetric=True):
"""Wrap the cuda method to deal with shapes"""
v, w = torch.broadcast_tensors(v, w)
shape = v.shape
# z_shape = z.shape
z = z.squeeze()
assert len(z.shape) == 1
v = v.contiguous()
w = w.contiguous()
z = z.contiguous()
N = v.size(-1)
assert w.size(-1) == N
y = _cauchy_mult(v.view(-1, N), z, w.view(-1, N), symmetric=symmetric)
y = y.view(*shape[:-1], z.size(-1))
# y = z.new_zeros(*shape[:-1], z.size(-1))
return y
class CauchyMultiply(torch.autograd.Function):
@staticmethod
def forward(ctx, v, z, w):
batch, N = v.shape
# supported_N_values = [1 << log_n for log_n in [1, 2, 3, 4, 5, 6, 7, 8, 9, 10]]
supported_N_values = [1 << log_n for log_n in [6]]
L = z.shape[-1]
if not N in supported_N_values:
raise NotImplementedError(f"Only support N values in {supported_N_values}")
if L % 32 != 0:
raise NotImplementedError(f"Only support L values that are multiples of 32")
if not v.is_cuda and z.is_cuda and w.is_cuda:
raise NotImplementedError(f"Only support CUDA tensors")
ctx.save_for_backward(v, z, w)
return cauchy_mult_fwd(v, z, w)
@staticmethod
def backward(ctx, dout):
v, z, w = ctx.saved_tensors
dv, dw = cauchy_mult_bwd(v, z, w, dout)
return dv, None, dw
class CauchyMultiplySymmetric(torch.autograd.Function):
@staticmethod
def forward(ctx, v, z, w):
batch, N = v.shape
supported_N_values = [1 << log_n for log_n in [1, 2, 3, 4, 5, 6, 7, 8, 9, 10]]
L = z.shape[-1]
if not N in supported_N_values:
raise NotImplementedError(f"Only support N values in {supported_N_values}")
max_L_value = 32 * 1024 * 64 * 1024
if L > max_L_value:
raise NotImplementedError(f"Only support L values <= {max_L_value}")
if not v.is_cuda and z.is_cuda and w.is_cuda:
raise NotImplementedError(f"Only support CUDA tensors")
ctx.save_for_backward(v, z, w)
return cauchy_mult_sym_fwd(v, z, w)
@staticmethod
def backward(ctx, dout):
v, z, w = ctx.saved_tensors
dv, dw = cauchy_mult_sym_bwd(v, z, w, dout)
return dv, None, dw
+216
View File
@@ -0,0 +1,216 @@
""" Utility nn components, in particular handling activations, initializations, and normalization layers """
import math
from functools import partial
import torch
import torch.nn as nn
from opt_einsum import contract
class modrelu(nn.Module):
def __init__(self, features):
# For now we just support square layers
super(modrelu, self).__init__()
self.features = features
self.b = nn.Parameter(torch.Tensor(self.features))
self.reset_parameters()
def reset_parameters(self):
self.b.data.uniform_(-0.01, 0.01)
def forward(self, inputs):
norm = torch.abs(inputs)
biased_norm = norm + self.b
magnitude = nn.functional.relu(biased_norm)
phase = torch.sign(inputs)
return phase * magnitude
class Modrelu(modrelu):
def reset_parameters(self):
self.b.data.uniform_(-0.01, 0.01)
def Activation(activation=None, size=None, dim=-1):
if activation in [None, "id", "identity", "linear"]:
return nn.Identity()
elif activation == "tanh":
return nn.Tanh()
elif activation == "relu":
return nn.ReLU()
elif activation == "gelu":
return nn.GELU()
elif activation in ["swish", "silu"]:
return nn.SiLU()
elif activation == "glu":
return nn.GLU(dim=dim)
elif activation == "sigmoid":
return nn.Sigmoid()
elif activation == "modrelu":
return Modrelu(size)
else:
raise NotImplementedError(
"hidden activation '{}' is not implemented".format(activation)
)
def get_initializer(name, activation=None):
if activation in [None, "id", "identity", "linear", "modrelu"]:
nonlinearity = "linear"
elif activation in ["relu", "tanh", "sigmoid"]:
nonlinearity = activation
elif activation in ["gelu", "swish", "silu"]:
nonlinearity = "relu" # Close to ReLU so approximate with ReLU's gain
else:
raise NotImplementedError(
f"get_initializer: activation {activation} not supported"
)
if name == "uniform":
initializer = partial(torch.nn.init.kaiming_uniform_, nonlinearity=nonlinearity)
elif name == "normal":
initializer = partial(torch.nn.init.kaiming_normal_, nonlinearity=nonlinearity)
elif name == "xavier":
initializer = torch.nn.init.xavier_normal_
elif name == "zero":
initializer = partial(torch.nn.init.constant_, val=0)
elif name == "one":
initializer = partial(torch.nn.init.constant_, val=1)
else:
raise NotImplementedError(
f"get_initializer: initializer type {name} not supported"
)
return initializer
def LinearActivation(
d_input,
d_output,
bias=True,
zero_bias_init=False,
transposed=False,
initializer=None,
activation=None,
activate=False, # Apply activation as part of this module
weight_norm=False,
**kwargs,
):
"""Returns a linear nn.Module with control over axes order, initialization, and activation"""
# Construct core module
linear_cls = TransposedLinear if transposed else nn.Linear
if activation == "glu":
d_output *= 2
linear = linear_cls(d_input, d_output, bias=bias, **kwargs)
# Initialize weight
if initializer is not None:
get_initializer(initializer, activation)(linear.weight)
# Initialize bias
if bias and zero_bias_init:
nn.init.zeros_(linear.bias)
# Weight norm
if weight_norm:
linear = nn.utils.weight_norm(linear)
if activate and activation is not None:
activation = Activation(activation, d_output, dim=-2 if transposed else -1)
linear = nn.Sequential(linear, activation)
return linear
class TransposedLinear(nn.Module):
"""Linear module on the second-to-last dimension"""
def __init__(self, d_input, d_output, bias=True):
super().__init__()
self.weight = nn.Parameter(torch.empty(d_output, d_input))
nn.init.kaiming_uniform_(self.weight, a=math.sqrt(5)) # nn.Linear default init
# nn.init.kaiming_uniform_(self.weight, nonlinearity='linear') # should be equivalent
if bias:
self.bias = nn.Parameter(torch.empty(d_output, 1))
bound = 1 / math.sqrt(d_input)
nn.init.uniform_(self.bias, -bound, bound)
else:
self.bias = 0.0
def forward(self, x):
return contract("... u l, v u -> ... v l", x, self.weight) + self.bias
class TransposedLN(nn.Module):
"""LayerNorm module over second-to-last dimension
This is slow and a dedicated CUDA/Triton implementation shuld provide substantial end-to-end speedup
"""
def __init__(self, d, scalar=True):
super().__init__()
self.scalar = scalar
if self.scalar:
self.m = nn.Parameter(torch.zeros(1))
self.s = nn.Parameter(torch.ones(1))
else:
self.ln = nn.LayerNorm(d)
def forward(self, x):
if self.scalar:
s, m = torch.std_mean(x, dim=-2, unbiased=False, keepdim=True)
y = (self.s / s) * (x - m + self.m)
else:
y = self.ln(x.transpose(-1, -2)).transpose(-1, -2)
return y
class Normalization(nn.Module):
def __init__(
self,
d,
transposed=False, # Length dimension is -1 or -2
_name_="layer",
**kwargs,
):
super().__init__()
self.transposed = transposed
if _name_ == "layer":
self.channel = True # Normalize over channel dimension
if self.transposed:
self.norm = TransposedLN(d, **kwargs)
else:
self.norm = nn.LayerNorm(d, **kwargs)
elif _name_ == "instance":
self.channel = False
norm_args = {"affine": False, "track_running_stats": False}
norm_args.update(kwargs)
self.norm = nn.InstanceNorm1d(
d, **norm_args
) # (True, True) performs very poorly
elif _name_ == "batch":
self.channel = False
norm_args = {"affine": True, "track_running_stats": True}
norm_args.update(kwargs)
self.norm = nn.BatchNorm1d(d, **norm_args)
elif _name_ == "none":
self.channel = True
self.norm = nn.Identity()
else:
raise NotImplementedError
def forward(self, x):
# The cases of LayerNorm / no normalization are automatically handled in all cases
# Instance/Batch Norm work automatically with transposed axes
if self.channel or self.transposed:
return self.norm(x)
else:
x = x.transpose(-1, -2)
x = self.norm(x)
x = x.transpose(-1, -2)
return x
+393
View File
@@ -0,0 +1,393 @@
""" Definitions of A and B matrices for various HiPPO operators. """
import numpy as np
import torch
from einops import rearrange, repeat
from opt_einsum import contract
from scipy import special as ss
def embed_c2r(A):
A = rearrange(A, "... m n -> ... m () n ()")
A = np.pad(A, ((0, 0), (0, 1), (0, 0), (0, 1))) + np.pad(
A, ((0, 0), (1, 0), (0, 0), (1, 0))
)
return rearrange(A, "m x n y -> (m x) (n y)")
# TODO take in 'torch' option to return torch instead of numpy, which converts the shape of B from (N, 1) to (N)
# TODO remove tlagt
def transition(measure, N, **measure_args):
"""A, B transition matrices for different measures
measure: the type of measure
legt - Legendre (translated)
legs - Legendre (scaled)
glagt - generalized Laguerre (translated)
lagt, tlagt - previous versions of (tilted) Laguerre with slightly different normalization
"""
# Laguerre (translated)
if measure == "lagt":
b = measure_args.get("beta", 1.0)
A = np.eye(N) / 2 - np.tril(np.ones((N, N)))
B = b * np.ones((N, 1))
elif measure == "tlagt":
# beta = 1 corresponds to no tilt
b = measure_args.get("beta", 1.0)
A = (1.0 - b) / 2 * np.eye(N) - np.tril(np.ones((N, N)))
B = b * np.ones((N, 1))
# Generalized Laguerre
# alpha 0, beta small is most stable (limits to the 'lagt' measure)
# alpha 0, beta 1 has transition matrix A = [lower triangular 1]
elif measure == "glagt":
alpha = measure_args.get("alpha", 0.0)
beta = measure_args.get("beta", 0.01)
A = -np.eye(N) * (1 + beta) / 2 - np.tril(np.ones((N, N)), -1)
B = ss.binom(alpha + np.arange(N), np.arange(N))[:, None]
L = np.exp(
0.5 * (ss.gammaln(np.arange(N) + alpha + 1) - ss.gammaln(np.arange(N) + 1))
)
A = (1.0 / L[:, None]) * A * L[None, :]
B = (
(1.0 / L[:, None])
* B
* np.exp(-0.5 * ss.gammaln(1 - alpha))
* beta ** ((1 - alpha) / 2)
)
# Legendre (translated)
elif measure == "legt":
Q = np.arange(N, dtype=np.float64)
R = (2 * Q + 1) ** 0.5
j, i = np.meshgrid(Q, Q)
A = R[:, None] * np.where(i < j, (-1.0) ** (i - j), 1) * R[None, :]
B = R[:, None]
A = -A
# Halve again for timescale correctness
# A, B = A/2, B/2
A *= 0.5
B *= 0.5
# LMU: equivalent to LegT up to normalization
elif measure == "lmu":
Q = np.arange(N, dtype=np.float64)
R = (2 * Q + 1)[:, None] # / theta
j, i = np.meshgrid(Q, Q)
A = np.where(i < j, -1, (-1.0) ** (i - j + 1)) * R
B = (-1.0) ** Q[:, None] * R
# Legendre (scaled)
elif measure == "legs":
q = np.arange(N, dtype=np.float64)
col, row = np.meshgrid(q, q)
r = 2 * q + 1
M = -(np.where(row >= col, r, 0) - np.diag(q))
T = np.sqrt(np.diag(2 * q + 1))
A = T @ M @ np.linalg.inv(T)
B = np.diag(T)[:, None]
B = (
B.copy()
) # Otherwise "UserWarning: given NumPY array is not writeable..." after torch.as_tensor(B)
elif measure == "legsd":
q = np.arange(N, dtype=np.float64)
col, row = np.meshgrid(q, q)
r = 2 * q + 1
M = -(np.where(row >= col, r, 0) - np.diag(q))
T = np.sqrt(np.diag(2 * q + 1))
A = T @ M @ np.linalg.inv(T)
B = np.diag(T)[:, None]
B = (
B.copy()
) # Otherwise "UserWarning: given NumPY array is not writeable..." after torch.as_tensor(B)
A += 0.5 * B * B[None, :, 0]
B = B / 2.0
elif measure == "fourier_old":
freqs = np.arange(N // 2)
d = np.stack([freqs, np.zeros(N // 2)], axis=-1).reshape(-1)[:-1]
A = 2 * np.pi * (np.diag(d, 1) - np.diag(d, -1))
A = A - embed_c2r(np.ones((N // 2, N // 2)))
B = embed_c2r(np.ones((N // 2, 1)))[..., :1]
elif measure == "fourier_diag":
freqs = np.arange(N // 2)
d = np.stack([freqs, np.zeros(N // 2)], axis=-1).reshape(-1)[:-1]
A = 2 * np.pi * (-np.diag(d, 1) + np.diag(d, -1))
# A = A - 0.5*embed_c2r(np.ones((N//2, N//2)))
A = A - 0.5 * np.eye(N)
B = embed_c2r(np.ones((N // 2, 1)))[..., :1]
elif measure == "fourier":
freqs = np.arange(N // 2)
d = np.stack([np.zeros(N // 2), freqs], axis=-1).reshape(-1)[1:]
A = np.pi * (-np.diag(d, 1) + np.diag(d, -1))
B = np.zeros(N)
B[0::2] = 2**0.5
B[0] = 1
# Subtract off rank correction - this corresponds to the other endpoint u(t-1) in this case
A = A - B[:, None] * B[None, :]
B = B[:, None]
elif measure == "fourier_decay":
freqs = np.arange(N // 2)
d = np.stack([np.zeros(N // 2), freqs], axis=-1).reshape(-1)[1:]
A = np.pi * (-np.diag(d, 1) + np.diag(d, -1))
B = np.zeros(N)
B[0::2] = 2**0.5
B[0] = 1
# Subtract off rank correction - this corresponds to the other endpoint u(t-1) in this case
A = A - 0.5 * B[:, None] * B[None, :]
B = 0.5 * B[:, None]
elif measure == "fourier2": # Double everything: orthonormal on [0, 1]
freqs = 2 * np.arange(N // 2)
d = np.stack([np.zeros(N // 2), freqs], axis=-1).reshape(-1)[1:]
A = np.pi * (-np.diag(d, 1) + np.diag(d, -1))
B = np.zeros(N)
B[0::2] = 2**0.5
B[0] = 1
# Subtract off rank correction - this corresponds to the other endpoint u(t-1) in this case
A = A - B[:, None] * B[None, :] * 2
B = B[:, None] * 2
elif measure == "random":
A = np.random.randn(N, N) / N
B = np.random.randn(N, 1)
elif measure == "diagonal":
A = -np.diag(np.exp(np.random.randn(N)))
B = np.random.randn(N, 1)
else:
raise NotImplementedError
return A, B
def rank_correction(measure, N, rank=1, dtype=torch.float):
"""Return low-rank matrix L such that A + L is normal"""
if measure == "legs":
assert rank >= 1
P = torch.sqrt(0.5 + torch.arange(N, dtype=dtype)).unsqueeze(0) # (1 N)
elif measure == "legt":
assert rank >= 2
P = torch.sqrt(1 + 2 * torch.arange(N, dtype=dtype)) # (N)
P0 = P.clone()
P0[0::2] = 0.0
P1 = P.clone()
P1[1::2] = 0.0
P = torch.stack([P0, P1], dim=0) # (2 N)
P *= 2 ** (
-0.5
) # Halve the rank correct just like the original matrix was halved
elif measure == "lagt":
assert rank >= 1
P = 0.5**0.5 * torch.ones(1, N, dtype=dtype)
elif measure == "fourier_old":
P = torch.ones(N, dtype=dtype) # (N)
P0 = P.clone()
P0[0::2] = 0.0
P1 = P.clone()
P1[1::2] = 0.0
P = torch.stack([P0, P1], dim=0) # (2 N)
P = torch.zeros(1, N, dtype=dtype)
elif measure == "fourier":
P = torch.zeros(N)
P[0::2] = 2**0.5
P[0] = 1
P = P.unsqueeze(0)
elif measure == "fourier_decay":
P = torch.zeros(N)
P[0::2] = 2**0.5
P[0] = 1
P = P.unsqueeze(0)
P = P / 2**0.5
elif measure == "fourier2":
P = torch.zeros(N)
P[0::2] = 2**0.5
P[0] = 1
P = 2**0.5 * P.unsqueeze(0)
elif measure in ["fourier_diag", "legsd"]:
P = torch.zeros(1, N, dtype=dtype)
else:
raise NotImplementedError
d = P.size(0)
if rank > d:
P = torch.cat([P, torch.zeros(rank - d, N, dtype=dtype)], dim=0) # (rank N)
return P
def initial_C(measure, N, dtype=torch.float):
"""Return C that captures the other endpoint in the HiPPO approximation"""
if measure == "legt":
C = (torch.arange(N, dtype=dtype) * 2 + 1) ** 0.5 * (-1) ** torch.arange(N)
elif measure == "fourier_old":
C = torch.ones(N, dtype=dtype) # (N)
elif measure == "fourier":
C = torch.zeros(N)
C[0::2] = 2**0.5
C[0] = 1
else:
C = torch.zeros(N, dtype=dtype) # (N)
return C
def nplr(measure, N, rank=1, dtype=torch.float):
"""Return w, p, q, V, B such that
(w - p q^*, B) is unitarily equivalent to the original HiPPO A, B by the matrix V
i.e. A = V[w - p q^*]V^*, B = V B
"""
assert dtype == torch.float or torch.cfloat
A, B = transition(measure, N)
A = torch.as_tensor(A, dtype=dtype) # (N, N)
B = torch.as_tensor(B, dtype=dtype)[:, 0] # (N,)
P = rank_correction(measure, N, rank=rank, dtype=dtype) # (r N)
AP = A + torch.sum(P.unsqueeze(-2) * P.unsqueeze(-1), dim=-3)
w, V = torch.linalg.eig(AP) # (..., N) (..., N, N)
# V w V^{-1} = A
# print("check", V @ torch.diag_embed(w) @ V.conj().transpose(-1, -2))
# We require AP to be nearly skew-symmetric
_A = AP + AP.transpose(-1, -2)
if (
err := torch.sum((_A - _A[0, 0] * torch.eye(N)) ** 2) / N
) > 1e-5: # if not torch.allclose(_A - _A[0,0]*torch.eye(N), torch.zeros(N, N), atol=1e-5):
print("WARNING: HiPPO matrix not skew symmetric", err)
# Only keep half of each conjugate pair
# w = w[..., 0::2].contiguous()
# V = V[..., 0::2].contiguous()
_, idx = torch.sort(w.imag)
w_sorted = w[idx]
V_sorted = V[:, idx]
# There is an edge case when eigenvalues can be 0, which requires some machinery to handle
# We use a huge hack here: Assume only one pair is 0, and that it is the first row/column of A (only happens in Fourier case)
V = V_sorted[:, : N // 2]
w = w_sorted[: N // 2]
assert w[-2].abs() > 1e-4, "Only 1 zero eigenvalue allowed in diagonal part of A"
if w[-1].abs() < 1e-4:
V[:, -1] = 0.0
V[0, -1] = 2**-0.5
V[1, -1] = 2**-0.5 * 1j
_AP = V @ torch.diag_embed(w) @ V.conj().transpose(-1, -2)
# assert torch.allclose(2*_AP.real, AP, atol=1e-5)
if (err := torch.sum((2 * _AP.real - AP) ** 2) / N) > 1e-5:
print(
"Warning: Diagonalization of A matrix not numerically precise - error", err
)
# print("check", V @ torch.diag_embed(w) @ V.conj().transpose(-1, -2))
# # Override eigenvectors for 0 eigenvalues, to make them conjugate pairs
# breakpoint()
# rotate = torch.tensor([[1, 1], [1j, -1j]]) / 2**.5
# # rotate = torch.tensor([[1, -1j], [1, 1j]]) / 2**.5
# V_rot = (V.view(N, N//2, 2) @ rotate).view(N, N) # rotate every pair of eigenvectors
# V = torch.where(w.repeat(N, 1) == 0, V_rot, V)
V_inv = V.conj().transpose(-1, -2)
C = initial_C(measure, N, dtype=dtype)
B = contract("ij, j -> i", V_inv, B.to(V)) # V^* B
C = contract("ij, j -> i", V_inv, C.to(V)) # V^* C
P = contract("ij, ...j -> ...i", V_inv, P.to(V)) # V^* P
return w, P, B, C, V
def random_dplr(
N,
rank=1,
H=1,
dtype=torch.float,
real_scale=1.0,
imag_scale=1.0,
scaling="inverse",
random_real=False,
random_imag=False,
normalize=True,
):
assert dtype == torch.float or torch.double
# batch_shape = (H, N//2) if H is not None else (N//2,)
dtype = torch.cfloat if dtype == torch.float else torch.cdouble
# w = -torch.exp(torch.randn(N//2)) + 1j*torch.randn(N//2)
# w = -torch.exp(torch.randn(N//2)) + 1j*2*torch.tensor(np.pi)*N*torch.rand(N//2) # try larger eigenvalue spread
pi = torch.tensor(np.pi)
if random_real:
real_part = torch.rand(H, N // 2)
else:
real_part = 0.5 * torch.ones(H, N // 2)
if random_imag:
imag_part = N // 2 * torch.rand(H, N // 2)
else:
imag_part = repeat(torch.arange(N // 2), "n -> h n", h=H)
real_part = real_scale * real_part
if scaling == "random":
imag_part = torch.randn(H, N // 2)
elif scaling == "linear":
imag_part = pi * imag_part
elif scaling == "inverse": # Based on asymptotics of the default HiPPO matrix
# intercept = torch.log(N//2)/torch.log(2) * 2./3.
# log_imag_part = intercept + 2. * torch.atanh((1+imag_part*2)/N*2-1)
# imag_part = torch.exp(log_imag_part)
# intercept = torch.log(N//2) - .5
# imag_part = torch.exp(2. * torch.atanh((1+imag_part*2)/N*2-1))
imag_part = 1 / pi * N * (N / (1 + 2 * imag_part) - 1)
elif scaling == "inverse2": # Based on asymptotics of the default HiPPO matrix
# intercept = torch.log(N//2)/torch.log(2) * 2./3.
# log_imag_part = intercept + 2. * torch.atanh((1+imag_part*2)/N*2-1)
# imag_part = torch.exp(log_imag_part)
# intercept = torch.log(N//2) - .5
# imag_part = torch.exp(2. * torch.atanh((1+imag_part*2)/N*2-1))
imag_part = 1 / pi * N * (N / (1 + imag_part) - 1)
elif scaling == "quadratic":
imag_part = 1 / pi * (1 + 2 * imag_part) ** 2
else:
raise NotImplementedError
imag_part = imag_scale * imag_part
w = -real_part + 1j * imag_part
# w = -torch.rand(N//2) + 1j*2*torch.tensor(np.pi)*N*torch.rand(N//2) # try larger eigenvalue spread
# w = -1 + torch.arange(N//2) * 1j * 2 * torch.tensor(np.pi)
P = torch.randn(rank, H, N // 2, dtype=dtype)
# p = torch.zeros(rank, N//2, dtype=dtype)
B = torch.randn(H, N // 2, dtype=dtype)
# B = torch.ones(N//2, dtype=dtype)
C = torch.randn(H, N // 2, dtype=dtype)
V = torch.eye(N, dtype=dtype)[..., : N // 2] # Only used in testing
if normalize: # TODO can normalize the full matrix with rank correction too
norm = (
-B / w
) # (H, N) # Result if you integrate the kernel with constant 1 function
zeta = 2 * torch.sum(
torch.abs(norm) ** 2, dim=-1, keepdim=True
) # Variance with a random C vector
B = B / zeta**0.5
return w, P, B, C, V
def test_nplr():
N = 4
measure = "fourier_decay"
w, P, B, C, V = nplr(measure, N, rank=1)
w = torch.cat([w, w.conj()], dim=-1)
V = torch.cat([V, V.conj()], dim=-1)
B = torch.cat([B, B.conj()], dim=-1)
P = torch.cat([P, P.conj()], dim=-1)
Q = P
# q = torch.cat([q, q.conj()], dim=-1)
A = torch.diag_embed(w) - contract("... r p, ... r q -> ... p q", P, Q.conj())
A = contract(
"ij, jk, kl -> ... il", V, A, V.conj().transpose(-1, -2)
) # Ap^{-1} = V @ w^{-1} @ V^T
B = contract("ij, ... j -> ... i", V, B)
print(A.real)
print(B.real)
+1143
View File
File diff suppressed because it is too large Load Diff
+214
View File
@@ -0,0 +1,214 @@
""" Compute a Krylov function efficiently. (S3 renames the Krylov function to a "state space kernel")
A : (N, N)
b : (N,)
c : (N,)
Return: [c^T A^i b for i in [L]]
"""
import torch
import torch.nn.functional as F
from einops import rearrange
from toeplitz import causal_convolution
def krylov_sequential(L, A, b, c=None):
"""Constant matrix A
A : (..., N, N)
b : (..., N)
c : (..., N)
Returns
if c:
x : (..., L)
x[i, l] = c[i] @ A^l @ b[i]
else:
x : (..., N, L)
x[i, l] = A^l @ b[i]
"""
# Check which of dim b and c is smaller to save memory
if c is not None and c.numel() < b.numel():
return krylov_sequential(L, A.transpose(-1, -2), c, b)
b_ = b
x = []
for _ in range(L):
if c is not None:
x_ = torch.sum(
c * b_, dim=-1
) # (...) # could be faster with matmul or einsum?
else:
x_ = b_
x.append(x_)
b_ = (A @ b_.unsqueeze(-1)).squeeze(-1)
x = torch.stack(x, dim=-1)
return x
def krylov(L, A, b, c=None, return_power=False):
"""
Compute the Krylov matrix (b, Ab, A^2b, ...) using the squaring trick.
If return_power=True, return A^{L-1} as well
"""
# TODO There is an edge case if L=1 where output doesn't get broadcasted, which might be an issue if caller is expecting broadcasting semantics... can deal with it if it arises
x = b.unsqueeze(-1) # (..., N, 1)
A_ = A
AL = None
if return_power:
AL = torch.eye(A.shape[-1], dtype=A.dtype, device=A.device)
_L = L - 1
done = L == 1
# loop invariant: _L represents how many indices left to compute
while not done:
if return_power:
if _L % 2 == 1:
AL = A_ @ AL
_L //= 2
# Save memory on last iteration
l = x.shape[-1]
if L - l <= l:
done = True
_x = x[..., : L - l]
else:
_x = x
_x = A_ @ _x
x = torch.cat(
[x, _x], dim=-1
) # there might be a more efficient way of ordering axes
if not done:
A_ = A_ @ A_
assert x.shape[-1] == L
if c is not None:
x = torch.einsum("...nl, ...n -> ...l", x, c)
x = x.contiguous() # WOW!!
if return_power:
return x, AL
else:
return x
def power(L, A, v=None):
"""Compute A^L and the scan sum_i A^i v_i
A: (..., N, N)
v: (..., N, L)
"""
I = torch.eye(A.shape[-1]).to(A) # , dtype=A.dtype, device=A.device)
powers = [A]
l = 1
while True:
if L % 2 == 1:
I = powers[-1] @ I
L //= 2
if L == 0:
break
l *= 2
powers.append(powers[-1] @ powers[-1])
if v is None:
return I
# Invariants:
# powers[-1] := A^l
# l := largest po2 at most L
# Note that an alternative divide and conquer to compute the reduction is possible and can be embedded into the above loop without caching intermediate powers of A
# We do this reverse divide-and-conquer for efficiency reasons:
# 1) it involves fewer padding steps for non-po2 L
# 2) it involves more contiguous arrays
# Take care of edge case for non-po2 arrays
# Note that this initial step is a no-op for the case of power of 2 (l == L)
k = v.size(-1) - l
v_ = powers.pop() @ v[..., l:]
v = v[..., :l]
v[..., :k] = v[..., :k] + v_
# Handle reduction for power of 2
while v.size(-1) > 1:
v = rearrange(v, "... (z l) -> ... z l", z=2)
v = v[..., 0, :] + powers.pop() @ v[..., 1, :]
return I, v.squeeze(-1)
def krylov_toeplitz(L, A, b, c=None):
"""Specializes to lower triangular Toeplitz matrix A represented by its diagonals
A : (..., N)
b : (..., N)
c : (..., N)
Returns
x : (..., N, L)
x[i, l] = A^l @ b[i]
"""
x = b.unsqueeze(0) # (1, ..., N)
A_ = A
while x.shape[0] < L:
xx = causal_convolution(A_, x)
x = torch.cat(
[x, xx], dim=0
) # there might be a more efficient way of ordering axes
A_ = causal_convolution(A_, A_)
x = x[:L, ...] # (L, ..., N)
if c is not None:
x = torch.einsum("l...n, ...n -> ...l", x, c)
else:
x = rearrange(x, "l ... n -> ... n l")
x = x.contiguous()
return x
def krylov_toeplitz_(L, A, b, c=None):
"""Padded version of krylov_toeplitz that saves some fft's
TODO currently not faster than original version, not sure why
"""
N = A.shape[-1]
x = b.unsqueeze(0) # (1, ..., N)
x = F.pad(x, (0, N))
A = F.pad(A, (0, N))
done = L == 1
while not done:
l = x.shape[0]
# Save memory on last iteration
if L - l <= l:
done = True
_x = x[: L - l]
else:
_x = x
Af = torch.fft.rfft(A, n=2 * N, dim=-1)
xf = torch.fft.rfft(_x, n=2 * N, dim=-1)
xf_ = Af * xf
x_ = torch.fft.irfft(xf_, n=2 * N, dim=-1)
x_[..., N:] = 0
x = torch.cat(
[x, x_], dim=0
) # there might be a more efficient way of ordering axes
if not done:
A = torch.fft.irfft(Af * Af, n=2 * N, dim=-1)
A[..., N:] = 0
x = x[:L, ..., :N] # (L, ..., N)
if c is not None:
x = torch.einsum("l...n, ...n -> ...l", x, c)
else:
x = rearrange(x, "l ... n -> ... n l")
x = x.contiguous()
return x
+25
View File
@@ -0,0 +1,25 @@
import logging
from pytorch_lightning.utilities import rank_zero_only
def get_logger(name=__name__, level=logging.INFO) -> logging.Logger:
"""Initializes multi-GPU-friendly python logger."""
logger = logging.getLogger(name)
logger.setLevel(level)
# this ensures all logging levels get marked with the rank zero decorator
# otherwise logs would get multiplied for each GPU process in multi-GPU setup
for level in (
"debug",
"info",
"warning",
"error",
"exception",
"fatal",
"critical",
):
setattr(logger, level, rank_zero_only(getattr(logger, level)))
return logger
+237
View File
@@ -0,0 +1,237 @@
""" pykeops implementations of the core Cauchy kernel used in the S3 algorithm.
"""
import torch
from einops import rearrange
from pykeops.torch import Genred, LazyTensor
_conj = lambda x: torch.cat([x, x.conj()], dim=-1)
def _broadcast_dims(*tensors):
max_dim = max([len(tensor.shape) for tensor in tensors])
tensors = [
tensor.view((1,) * (max_dim - len(tensor.shape)) + tensor.shape)
for tensor in tensors
]
return tensors
def _c2r(x):
return torch.view_as_real(x)
def _r2c(x):
return torch.view_as_complex(x)
def cauchy_slow(v, z, w, conj=True):
"""
v: (..., N)
z: (..., L)
w: (..., N)
returns: (..., L) \sum v/(z-w)
"""
if conj:
v = _conj(v)
w = _conj(w)
cauchy_matrix = v.unsqueeze(-1) / (z.unsqueeze(-2) - w.unsqueeze(-1)) # (... N L)
return torch.sum(cauchy_matrix, dim=-2)
def cauchy_lazy(v, z, w, conj=True):
if conj:
v = _conj(v)
w = _conj(w)
v, z, w = _broadcast_dims(v, z, w)
v_l = LazyTensor(rearrange(v, "... N -> ... N 1 1"))
w_l = LazyTensor(rearrange(w, "... N -> ... N 1 1"))
z_l = LazyTensor(rearrange(z, "... L -> ... 1 L 1"))
sub = z_l - w_l # (b N L 1), for some reason it doesn't display the last dimension
div = v_l / sub
s = div.sum(dim=len(v_l.shape) - 2)
return s.squeeze(-1)
def cauchy(v, z, w, conj=False):
expr = "ComplexDivide(v, z-w)"
cauchy_mult = Genred(
expr,
[
"v = Vj(2)",
"z = Vi(2)",
"w = Vj(2)",
],
reduction_op="Sum",
axis=1,
)
if conj:
v = _conj(v)
w = _conj(w)
v, z, w = _broadcast_dims(v, z, w)
v = _c2r(v)
z = _c2r(z)
w = _c2r(w)
r = cauchy_mult(v, z, w, backend="GPU")
return _r2c(r)
def cauchy_real(v, z, w):
expr = "v / (z - w)"
cauchy_mult = Genred(
expr,
[
"v = Vj(1)",
"z = Vi(1)",
"w = Vj(1)",
],
reduction_op="Sum",
axis=1,
)
v, z, w = _broadcast_dims(v, z, w)
v = v.unsqueeze(-1)
z = z.unsqueeze(-1)
w = w.unsqueeze(-1)
r = cauchy_mult(v, z, w, backend="GPU")
return r
def cauchy_conj(v, z, w, num=2, denom=2):
if num == 1:
expr_num = "z * ComplexReal(v) - Real2Complex(ComplexReal(v)*ComplexReal(w) + ComplexImag(v)*ComplexImag(w))"
elif num == 2:
expr_num = "z * ComplexReal(v) - Real2Complex(Sum(v * w))"
else:
raise NotImplementedError
if denom == 1:
expr_denom = "ComplexMult(z-Real2Complex(ComplexReal(w)), z-Real2Complex(ComplexReal(w))) + Real2Complex(Square(ComplexImag(w)))"
elif denom == 2:
expr_denom = "ComplexMult(z-w, z-Conj(w))"
else:
raise NotImplementedError
cauchy_mult = Genred(
f"ComplexDivide({expr_num}, {expr_denom})",
[
"v = Vj(2)",
"z = Vi(2)",
"w = Vj(2)",
],
reduction_op="Sum",
axis=1,
)
v, z, w = _broadcast_dims(v, z, w)
v = _c2r(v)
z = _c2r(z)
w = _c2r(w)
r = 2 * cauchy_mult(v, z, w, backend="GPU")
return _r2c(r)
def cauchy_conj_components(v, z, w):
"""Assumes z is pure imaginary (as in S4 with bilinear)"""
expr_num = "Imag2Complex(zi*vr) - Real2Complex(vr*wr + vi*wi)"
expr_denom = (
"Real2Complex(Square(wr)+Square(wi)-Square(zi)) - Imag2Complex(IntCst(2)*zi*wr)"
)
cauchy_mult = Genred(
f"ComplexDivide({expr_num}, {expr_denom})",
[
"vr = Vj(1)",
"vi = Vj(1)",
"wr = Vj(1)",
"wi = Vj(1)",
"zi = Vi(1)",
],
reduction_op="Sum",
axis=1,
)
v, z, w = _broadcast_dims(v, z, w)
v = v.unsqueeze(-1)
z = z.unsqueeze(-1)
w = w.unsqueeze(-1)
v_r, v_i = v.real.contiguous(), v.imag.contiguous()
w_r, w_i = w.real.contiguous(), w.imag.contiguous()
z_i = z.imag.contiguous()
r = 2 * cauchy_mult(v_r, v_i, w_r, w_i, z_i, backend="GPU")
return _r2c(r)
def cauchy_conj_components_lazy(v, z, w, type=1):
v, z, w = _broadcast_dims(v, z, w)
v_r, v_i = v.real.contiguous(), v.imag.contiguous()
w_r, w_i = w.real.contiguous(), w.imag.contiguous()
z_i = z.imag.contiguous()
v_r = LazyTensor(rearrange(v_r, "... N -> ... 1 N 1"))
v_i = LazyTensor(rearrange(v_i, "... N -> ... 1 N 1"))
w_r = LazyTensor(rearrange(w_r, "... N -> ... 1 N 1"))
w_i = LazyTensor(rearrange(w_i, "... N -> ... 1 N 1"))
z_i = LazyTensor(rearrange(z_i, "... L -> ... L 1 1"))
if type == 1:
num = -v_r * w_r - v_i * w_i + 1j * z_i * v_r
denom = w_r**2 + w_i**2 - z_i**2 - 2j * w_r * z_i
else:
# z = torch.complex(-w_r, z_i) # Not supported
z = -w_r + 1j * z_i
num = v_r * z - v_i * w_i
denom = z * z + w_i**2 # z**2 is bugged for complex
r = num / denom
r = 2 * r.sum(dim=len(z_i.shape) - 1)
return r.squeeze(-1)
def cauchy_conj2(v, z, w):
expr = "ComplexDivide(v, z-w) + ComplexDivide(Conj(v), z-Conj(w))"
# expr = 'ComplexDivide(v, z-w)'
cauchy_mult = Genred(
expr,
[
"v = Vj(2)",
"z = Vi(2)",
"w = Vj(2)",
],
reduction_op="Sum",
axis=1,
)
v, z, w = _broadcast_dims(v, z, w)
if complex:
v = _c2r(v)
z = _c2r(z)
w = _c2r(w)
r = cauchy_mult(v, z, w, backend="GPU")
return _r2c(r)
def trigger_compilation():
"""Small function to trigger the compilation of a pykeops kernel
Used in scenarios where we must manually control compilation, e.g. the multi-gpu case (https://github.com/getkeops/keops/issues/168)"""
B = 2
N = 4
L = 16
w = torch.randn(B, N // 2, dtype=torch.cfloat, device="cuda")
v = torch.randn(B, N // 2, dtype=torch.cfloat, device="cuda")
z = torch.randn(B, L, dtype=torch.cfloat, device="cuda")
w.requires_grad = True
v.requires_grad = True
cauchy_conj(v, z, w)
+1438
View File
File diff suppressed because one or more lines are too long
+218
View File
@@ -0,0 +1,218 @@
import logging
import opt_einsum as oe
import torch
import torch.nn as nn
import torch.nn.functional as F
from einops import rearrange
from components import Activation, LinearActivation, Normalization
from kernel import HippoSSKernel
optimized = True
if optimized:
contract = oe.contract
else:
contract = torch.einsum
class S4(nn.Module):
requires_length = True
def __init__(
self,
d_model,
d_state=64,
l_max=1, # Maximum length of sequence. Fine if not provided: the kernel will keep doubling in length until longer than sequence. However, this can be marginally slower if the true length is not a power of 2
channels=1, # maps 1-dim to C-dim
bidirectional=False,
# Arguments for FF
activation="gelu", # activation in between SS and FF
ln=False, # Extra normalization
postact=None, # activation after FF
initializer=None, # initializer on FF
weight_norm=False, # weight normalization on FF
hyper_act=None, # Use a "hypernetwork" multiplication
dropout=0.0,
transposed=True, # axis ordering (B, L, D) or (B, D, L)
verbose=False,
shift=False,
linear=False,
# SSM Kernel arguments
**kernel_args,
):
"""
d_state: the dimension of the state, also denoted by N
l_max: the maximum sequence length, also denoted by L
if this is not known at model creation, set l_max=1
channels: can be interpreted as a number of "heads"
bidirectional: bidirectional
dropout: standard dropout argument
transposed: choose backbone axis ordering of (B, L, H) or (B, H, L) [B=batch size, L=sequence length, H=hidden dimension]
Other options are all experimental and should not need to be configured
"""
super().__init__()
if verbose:
from logger import get_logger
log = get_logger(__name__)
log.info(f"Constructing S4 (H, N, L) = ({d_model}, {d_state}, {l_max})")
self.h = d_model
self.n = d_state
self.bidirectional = bidirectional
self.ln = ln
self.channels = channels
self.transposed = transposed
self.shift = shift
self.linear = linear
# optional multiplicative modulation GLU-style
# https://arxiv.org/abs/2002.05202
self.hyper = hyper_act is not None
if self.hyper:
channels *= 2
self.hyper_activation = Activation(hyper_act)
self.D = nn.Parameter(torch.randn(channels, self.h))
if self.bidirectional:
channels *= 2
# SSM Kernel
self.kernel = HippoSSKernel(
self.h, N=self.n, L=l_max, channels=channels, verbose=verbose, **kernel_args
)
# Pointwise
if not self.linear:
self.activation = Activation(activation)
dropout_fn = nn.Dropout2d if self.transposed else nn.Dropout
self.dropout = dropout_fn(dropout) if dropout > 0.0 else nn.Identity()
if self.ln:
self.norm = Normalization(self.h * self.channels, transposed=transposed)
else:
self.norm = nn.Identity()
# position-wise output transform to mix features
if not self.linear:
self.output_linear = LinearActivation(
self.h * self.channels,
self.h,
transposed=self.transposed,
initializer=initializer,
activation=postact,
activate=True,
weight_norm=weight_norm,
)
def forward(
self, u, state=None, **kwargs
): # absorbs return_output and transformer src mask
"""
u: (B H L) if self.transposed else (B L H)
state: (H N) never needed unless you know what you're doing
Returns: same shape as u
"""
if not self.transposed:
u = u.transpose(-1, -2)
L = u.size(-1)
# Compute SS Kernel
k, k_state = self.kernel(L=L, state=state) # (C H L) (B C H L)
# Convolution
if self.bidirectional:
k0, k1 = rearrange(k, "(s c) h l -> s c h l", s=2)
k = F.pad(k0, (0, L)) + F.pad(k1.flip(-1), (L, 0))
if self.shift:
# Try flip and pad to correct for potential off-by-one
k_f = torch.fft.rfft(F.pad(k.flip(-1), (L, 0)), n=2 * L) # (C H L)
u_f = torch.fft.rfft(F.pad(u.flip(-1), (L, 0)), n=2 * L) # (B H L)
y_f = contract(
"bhl,chl->bchl", u_f, k_f
) # k_f.unsqueeze(-4) * u_f.unsqueeze(-3) # (B C H L)
y = torch.fft.irfft(y_f, n=2 * L)[..., L:].flip(-1) # (B C H L)
else:
k_f = torch.fft.rfft(k, n=2 * L) # (C H L)
u_f = torch.fft.rfft(u, n=2 * L) # (B H L)
y_f = contract(
"bhl,chl->bchl", u_f, k_f
) # k_f.unsqueeze(-4) * u_f.unsqueeze(-3) # (B C H L)
y = torch.fft.irfft(y_f, n=2 * L)[..., :L] # (B C H L)
# Compute D term in state space equation - essentially a skip connection
y = y + contract(
"bhl,ch->bchl", u, self.D
) # u.unsqueeze(-3) * self.D.unsqueeze(-1)
# Compute state update
if state is not None:
assert (
not self.bidirectional
), "Bidirectional not supported with state forwarding"
y = y + k_state
next_state = self.kernel.forward_state(u, state)
else:
next_state = None
# Optional hyper-network multiplication
if self.hyper:
y, yh = rearrange(y, "b (s c) h l -> s b c h l", s=2)
y = self.hyper_activation(yh) * y
# Reshape to flatten channels
y = rearrange(y, "... c h l -> ... (c h) l")
if not self.linear:
y = self.dropout(self.activation(y))
if not self.transposed:
y = y.transpose(-1, -2)
if not self.linear:
y = self.norm(y)
y = self.output_linear(y)
return y, next_state
def setup_step(self):
self.kernel.setup_step()
def step(self, u, state):
"""Step one time step as a recurrent model. Intended to be used during validation.
u: (B H)
state: (B H N)
Returns: output (B H), state (B H N)
"""
assert not self.training
y, next_state = self.kernel.step(u, state) # (B C H)
y = y + u.unsqueeze(-2) * self.D
y = rearrange(y, "... c h -> ... (c h)")
y = self.activation(y)
if self.transposed:
y = self.output_linear(y.unsqueeze(-1)).squeeze(-1)
else:
y = self.output_linear(y)
return y, next_state
def default_state(self, *batch_shape, device=None):
return self.kernel.default_state(*batch_shape)
@property
def d_state(self):
return self.h * self.n
@property
def d_output(self):
return self.h
@property
def state_to_tensor(self):
return lambda state: rearrange("... h n -> ... (h n)", state)
+300
View File
@@ -0,0 +1,300 @@
""" Utilities for computing convolutions.
There are 3 equivalent views:
1. causal convolution
2. multiplication of (lower) triangular Toeplitz matrices
3. polynomial multiplication (mod x^N)
"""
import torch
# import torch.nn as nn
import torch.nn.functional as F
# from model.complex import complex_mul
# from pytorch_memlab import profile
def construct_toeplitz(v, f=0.0):
"""Explicit construction of Krylov matrix [v A @ v A^2 @ v ... A^{n-1} @ v]
where A = Z_f. This uses vectorized indexing and cumprod so it's much
faster than using the Krylov function.
Parameters:
v: the starting vector of size n or (rank, n).
f: real number
Returns:
K: Krylov matrix of size (n, n) or (rank, n, n).
"""
n = v.shape[-1]
a = torch.arange(n, device=v.device)
b = -a
indices = a[:, None] + b[None]
K = v[..., indices]
K[..., indices < 0] *= f
return K
def triangular_toeplitz_multiply_(u, v, sum=None):
n = u.shape[-1]
u_expand = F.pad(u, (0, n))
v_expand = F.pad(v, (0, n))
u_f = torch.fft.rfft(u_expand, n=2 * n, dim=-1)
v_f = torch.fft.rfft(v_expand, n=2 * n, dim=-1)
uv_f = u_f * v_f
if sum is not None:
uv_f = uv_f.sum(dim=sum)
output = torch.fft.irfft(uv_f, n=2 * n, dim=-1)[..., :n]
return output
def triangular_toeplitz_multiply_padded_(u, v):
"""Same as triangular_toeplitz_multiply but inputs and output assume to be 0-padded already."""
n = u.shape[-1]
assert n % 2 == 0
u_f = torch.fft.rfft(u, n=n, dim=-1)
v_f = torch.fft.rfft(v, n=n, dim=-1)
uv_f = u_f * v_f
output = torch.fft.irfft(uv_f, n=n, dim=-1)
output[..., n:] = 0
return output
class TriangularToeplitzMult(torch.autograd.Function):
@staticmethod
def forward(ctx, u, v):
ctx.save_for_backward(u, v)
return triangular_toeplitz_multiply_(u, v)
@staticmethod
def backward(ctx, grad):
u, v = ctx.saved_tensors
d_u = triangular_toeplitz_multiply_(grad.flip(-1), v).flip(-1)
d_v = triangular_toeplitz_multiply_(grad.flip(-1), u).flip(-1)
return d_u, d_v
class TriangularToeplitzMultFast(torch.autograd.Function):
@staticmethod
def forward(ctx, u, v):
n = u.shape[-1]
u_expand = F.pad(u, (0, n))
v_expand = F.pad(v, (0, n))
u_f = torch.fft.rfft(u_expand, n=2 * n, dim=-1)
v_f = torch.fft.rfft(v_expand, n=2 * n, dim=-1)
ctx.save_for_backward(u_f, v_f)
uv_f = u_f * v_f
output = torch.fft.irfft(uv_f, n=2 * n, dim=-1)[..., :n]
return output
@staticmethod
def backward(ctx, grad):
u_f, v_f = ctx.saved_tensors
n = grad.shape[-1]
g_expand = F.pad(grad.flip(-1), (0, n))
g_f = torch.fft.rfft(g_expand, n=2 * n, dim=-1)
gu_f = g_f * u_f
gv_f = g_f * v_f
d_u = torch.fft.irfft(gv_f, n=2 * n, dim=-1)[..., :n]
d_v = torch.fft.irfft(gu_f, n=2 * n, dim=-1)[..., :n]
d_u = d_u.flip(-1)
d_v = d_v.flip(-1)
return d_u, d_v
class TriangularToeplitzMultPadded(torch.autograd.Function):
@staticmethod
def forward(ctx, u, v):
ctx.save_for_backward(u, v)
output = triangular_toeplitz_multiply_(u, v)
return output
@staticmethod
def backward(ctx, grad):
u, v = ctx.saved_tensors
d_u = triangular_toeplitz_multiply_padded_(grad.flip(-1), v).flip(-1)
d_v = triangular_toeplitz_multiply_padded_(grad.flip(-1), u).flip(-1)
return d_u, d_v
class TriangularToeplitzMultPaddedFast(torch.autograd.Function):
"""Trade off speed (20-25% faster) for more memory (20-25%)"""
@staticmethod
def forward(ctx, u, v):
n = u.shape[-1]
u_f = torch.fft.rfft(u, n=n, dim=-1)
v_f = torch.fft.rfft(v, n=n, dim=-1)
ctx.save_for_backward(u_f, v_f)
uv_f = u_f * v_f
output = torch.fft.irfft(uv_f, n=n, dim=-1)
output[..., n // 2 :].zero_()
return output
@staticmethod
def backward(ctx, grad):
u_f, v_f = ctx.saved_tensors
n = grad.shape[-1]
g_expand = F.pad(grad[..., : n // 2].flip(-1), (0, n // 2))
g_f = torch.fft.rfft(g_expand, n=n, dim=-1)
gu_f = g_f * u_f
gv_f = g_f * v_f
d_u = torch.fft.irfft(gv_f, n=n, dim=-1)
d_v = torch.fft.irfft(gu_f, n=n, dim=-1)
d_u[..., n // 2 :].zero_()
d_v[..., n // 2 :].zero_()
d_u[..., : n // 2] = d_u[..., : n // 2].flip(-1) # TODO
d_v[..., : n // 2] = d_v[..., : n // 2].flip(-1) # TODO
return d_u, d_v
# triangular_toeplitz_multiply = triangular_toeplitz_multiply_
triangular_toeplitz_multiply = TriangularToeplitzMult.apply
triangular_toeplitz_multiply_fast = TriangularToeplitzMultFast.apply
triangular_toeplitz_multiply_padded = TriangularToeplitzMultPadded.apply
triangular_toeplitz_multiply_padded_fast = TriangularToeplitzMultPaddedFast.apply
def causal_convolution(u, v, fast=True, pad=False):
if not pad and not fast:
return triangular_toeplitz_multiply(u, v)
if not pad and fast:
return triangular_toeplitz_multiply_fast(u, v)
if pad and not fast:
return triangular_toeplitz_multiply_padded(u, v)
if pad and fast:
return triangular_toeplitz_multiply_padded_fast(u, v)
def _fft(x, N):
return torch.fft.rfft(F.pad(x, (0, 2 * N - x.shape[-1])), n=2 * N, dim=-1)
def _ifft(x, N):
return torch.fft.irfft(x, n=2 * N, dim=-1)[..., :N]
def causal_convolution_inverse(u):
"""Invert the causal convolution/polynomial/triangular Toeplitz matrix represented by u.
This is easiest in the polynomial view:
https://www.csa.iisc.ac.in/~chandan/courses/CNT/notes/lec5.pdf
The idea is that
h = g^{-1} (mod x^m) => 2h - gh^2 = g^{-1} (mod x^{2m})
# TODO this can be numerically unstable if input is "poorly conditioned",
# for example if u[0] is magnitudes different from the rest of u
"""
N = u.shape[-1]
v = u[..., :1].reciprocal()
while v.shape[-1] < N:
M = v.shape[-1]
v_f = _fft(v, 2 * M)
u_f = _fft(u[..., : 2 * M], 2 * M)
_v = -_ifft(u_f * v_f**2, 2 * M)
_v[..., :M] = _v[..., :M] + 2 * v
v = _v
# TODO contiguous?
v = v[..., :N]
return v
""" Below are experimental functions for improving the stability of LSSL/S3 algorithm. Currently not used anywhere. """
def causal_convolution_inverse_wrong(u, v):
"""Solve u * x = v. Initial attempt by inverting the multiplication algorithm, which I think doesn't work."""
n = u.shape[-1]
u_expand = F.pad(u, (0, n))
v_expand = F.pad(v, (0, n))
u_f = torch.fft.rfft(u_expand, n=2 * n, dim=-1)
v_f = torch.fft.rfft(v_expand, n=2 * n, dim=-1)
uv_f = v_f / u_f
x = torch.fft.irfft(uv_f, n=2 * n, dim=-1)[..., :n]
return x
def construct_toeplitz_log(v):
n = v.shape[-1]
a = torch.arange(n, device=v.device)
b = -a
indices = a[:, None] + b[None]
K = v[..., indices]
K[..., indices < 0] = -100.0
return K
def _logsumexp(x, dim=-1):
"""logsumexp for complex"""
m = torch.max(torch.real(x), dim=dim, keepdim=True)[0]
x = x - m
x = torch.log(torch.sum(torch.exp(x), dim=dim))
x = x + m.squeeze(dim)
return x
def causal_convolution_inverse_log(u, N=-1):
"""Invert the causal convolution/polynomial/triangular Toeplitz matrix represented by u.
This is easiest in the polynomial view:
https://www.csa.iisc.ac.in/~chandan/courses/CNT/notes/lec5.pdf
The idea is that
h = g^{-1} (mod x^m) => 2h - gh^2 = g^{-1} (mod x^{2m})
# TODO this can be numerically unstable if input is "poorly conditioned",
# for example if u[0] is magnitudes different from the rest of u
"""
if N < 0:
N = u.shape[-1]
v = -u[..., :1]
while v.shape[-1] < N:
M = v.shape[-1]
_v = F.pad(v, (0, M), value=-100.0)
_v_ = construct_toeplitz_log(_v)
u_ = (
u[..., : 2 * M]
if u.shape[-1] >= 2 * M
else F.pad(u, (0, 2 * M - u.shape[-1]), value=-100.0)
)
_u = _logsumexp(_v_ + u_, dim=-1)
_u = _logsumexp(_v_ + _u, dim=-1)
_u = _u + torch.log(-torch.ones_like(_u))
_v = _v + torch.log(2.0 * torch.ones_like(_u))
v = _logsumexp(torch.stack([_v, _u], dim=-1), dim=-1)
# TODO contiguous?
v = v[..., :N]
check = _logsumexp(
construct_toeplitz_log(v) + F.pad(u, (0, N - u.shape[-1]), value=-100.0)
)
print("check", check, torch.exp(check))
return v
if __name__ == "__main__":
a = torch.tensor([1.0, 2, 3, 4], requires_grad=True)
b = torch.tensor([5.0, 6, 7, 8], requires_grad=True)
a.retain_grad()
b.retain_grad()
x = triangular_toeplitz_multiply_padded(F.pad(a, (0, 4)), F.pad(b, (0, 4)))[:4]
print(x) # [5 16 34 60]
x = x.sum()
x.backward()
print(x, a.grad, b.grad) # [26 18 11 5] [10 6 3 1]
if __name__ == "__main__":
N = 4
a = torch.randn(N)
construct_toeplitz(a)
print(a)
b = causal_convolution_inverse(a)
print("inverse", b)
print("check", causal_convolution(a, b))
i = torch.zeros(N)
i[0] = 1.0
b = causal_convolution_inverse_wrong(a, i)
print(b)
print(causal_convolution(a, b))