mirror of
https://github.com/wassname/pytorch-transformer-ts.git
synced 2026-06-27 18:06:14 +08:00
added s4
This commit is contained in:
+120
@@ -0,0 +1,120 @@
|
||||
import torch
|
||||
from cauchy_mult import (
|
||||
cauchy_mult_bwd,
|
||||
cauchy_mult_fwd,
|
||||
cauchy_mult_sym_bwd,
|
||||
cauchy_mult_sym_fwd,
|
||||
)
|
||||
from einops import rearrange
|
||||
|
||||
|
||||
def cauchy_mult_torch(
|
||||
v: torch.Tensor, z: torch.Tensor, w: torch.Tensor, symmetric=True
|
||||
) -> torch.Tensor:
|
||||
"""
|
||||
v: (B, N)
|
||||
z: (L)
|
||||
w: (B, N)
|
||||
symmetric: whether to assume that v and w contain complex conjugate pairs, of the form
|
||||
[v_half, v_half.conj()] and [w_half, w_half.conj()]
|
||||
"""
|
||||
if not symmetric:
|
||||
return (
|
||||
rearrange(v, "b n -> b 1 n")
|
||||
/ (rearrange(z, "l -> l 1") - rearrange(w, "b n -> b 1 n"))
|
||||
).sum(dim=-1)
|
||||
else:
|
||||
N = v.shape[-1]
|
||||
assert N % 2 == 0
|
||||
vv = rearrange(v[:, : N // 2], "b n -> b 1 n")
|
||||
zz = rearrange(z, "l -> l 1")
|
||||
ww = rearrange(w[:, : N // 2], "b n -> b 1 n")
|
||||
return 2 * (
|
||||
(zz * vv.real - vv.real * ww.real - vv.imag * ww.imag)
|
||||
/ (zz * zz - 2 * zz * ww.real + ww.abs().square())
|
||||
).sum(dim=-1)
|
||||
|
||||
|
||||
def cauchy_mult_keops(v, z, w):
|
||||
from pykeops.torch import LazyTensor
|
||||
|
||||
v_l = LazyTensor(rearrange(v, "b N -> b 1 N 1"))
|
||||
z_l = LazyTensor(rearrange(z, "L -> 1 L 1 1"))
|
||||
w_l = LazyTensor(rearrange(w, "b N -> b 1 N 1"))
|
||||
sub = z_l - w_l # (b N L 1), for some reason it doesn't display the last dimension
|
||||
div = v_l / sub
|
||||
s = div.sum(dim=2, backend="GPU")
|
||||
return s.squeeze(-1)
|
||||
|
||||
|
||||
def _cauchy_mult(v, z, w, symmetric=True):
|
||||
if not symmetric:
|
||||
return CauchyMultiply.apply(v, z, w)
|
||||
else:
|
||||
return CauchyMultiplySymmetric.apply(v, z, w)
|
||||
|
||||
|
||||
def cauchy_mult(v, z, w, symmetric=True):
|
||||
"""Wrap the cuda method to deal with shapes"""
|
||||
v, w = torch.broadcast_tensors(v, w)
|
||||
shape = v.shape
|
||||
# z_shape = z.shape
|
||||
z = z.squeeze()
|
||||
assert len(z.shape) == 1
|
||||
|
||||
v = v.contiguous()
|
||||
w = w.contiguous()
|
||||
z = z.contiguous()
|
||||
|
||||
N = v.size(-1)
|
||||
assert w.size(-1) == N
|
||||
y = _cauchy_mult(v.view(-1, N), z, w.view(-1, N), symmetric=symmetric)
|
||||
y = y.view(*shape[:-1], z.size(-1))
|
||||
# y = z.new_zeros(*shape[:-1], z.size(-1))
|
||||
return y
|
||||
|
||||
|
||||
class CauchyMultiply(torch.autograd.Function):
|
||||
@staticmethod
|
||||
def forward(ctx, v, z, w):
|
||||
batch, N = v.shape
|
||||
# supported_N_values = [1 << log_n for log_n in [1, 2, 3, 4, 5, 6, 7, 8, 9, 10]]
|
||||
supported_N_values = [1 << log_n for log_n in [6]]
|
||||
L = z.shape[-1]
|
||||
if not N in supported_N_values:
|
||||
raise NotImplementedError(f"Only support N values in {supported_N_values}")
|
||||
if L % 32 != 0:
|
||||
raise NotImplementedError(f"Only support L values that are multiples of 32")
|
||||
if not v.is_cuda and z.is_cuda and w.is_cuda:
|
||||
raise NotImplementedError(f"Only support CUDA tensors")
|
||||
ctx.save_for_backward(v, z, w)
|
||||
return cauchy_mult_fwd(v, z, w)
|
||||
|
||||
@staticmethod
|
||||
def backward(ctx, dout):
|
||||
v, z, w = ctx.saved_tensors
|
||||
dv, dw = cauchy_mult_bwd(v, z, w, dout)
|
||||
return dv, None, dw
|
||||
|
||||
|
||||
class CauchyMultiplySymmetric(torch.autograd.Function):
|
||||
@staticmethod
|
||||
def forward(ctx, v, z, w):
|
||||
batch, N = v.shape
|
||||
supported_N_values = [1 << log_n for log_n in [1, 2, 3, 4, 5, 6, 7, 8, 9, 10]]
|
||||
L = z.shape[-1]
|
||||
if not N in supported_N_values:
|
||||
raise NotImplementedError(f"Only support N values in {supported_N_values}")
|
||||
max_L_value = 32 * 1024 * 64 * 1024
|
||||
if L > max_L_value:
|
||||
raise NotImplementedError(f"Only support L values <= {max_L_value}")
|
||||
if not v.is_cuda and z.is_cuda and w.is_cuda:
|
||||
raise NotImplementedError(f"Only support CUDA tensors")
|
||||
ctx.save_for_backward(v, z, w)
|
||||
return cauchy_mult_sym_fwd(v, z, w)
|
||||
|
||||
@staticmethod
|
||||
def backward(ctx, dout):
|
||||
v, z, w = ctx.saved_tensors
|
||||
dv, dw = cauchy_mult_sym_bwd(v, z, w, dout)
|
||||
return dv, None, dw
|
||||
@@ -0,0 +1,216 @@
|
||||
""" Utility nn components, in particular handling activations, initializations, and normalization layers """
|
||||
|
||||
import math
|
||||
from functools import partial
|
||||
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
from opt_einsum import contract
|
||||
|
||||
|
||||
class modrelu(nn.Module):
|
||||
def __init__(self, features):
|
||||
# For now we just support square layers
|
||||
super(modrelu, self).__init__()
|
||||
self.features = features
|
||||
self.b = nn.Parameter(torch.Tensor(self.features))
|
||||
self.reset_parameters()
|
||||
|
||||
def reset_parameters(self):
|
||||
self.b.data.uniform_(-0.01, 0.01)
|
||||
|
||||
def forward(self, inputs):
|
||||
norm = torch.abs(inputs)
|
||||
biased_norm = norm + self.b
|
||||
magnitude = nn.functional.relu(biased_norm)
|
||||
phase = torch.sign(inputs)
|
||||
|
||||
return phase * magnitude
|
||||
|
||||
|
||||
class Modrelu(modrelu):
|
||||
def reset_parameters(self):
|
||||
self.b.data.uniform_(-0.01, 0.01)
|
||||
|
||||
|
||||
def Activation(activation=None, size=None, dim=-1):
|
||||
if activation in [None, "id", "identity", "linear"]:
|
||||
return nn.Identity()
|
||||
elif activation == "tanh":
|
||||
return nn.Tanh()
|
||||
elif activation == "relu":
|
||||
return nn.ReLU()
|
||||
elif activation == "gelu":
|
||||
return nn.GELU()
|
||||
elif activation in ["swish", "silu"]:
|
||||
return nn.SiLU()
|
||||
elif activation == "glu":
|
||||
return nn.GLU(dim=dim)
|
||||
elif activation == "sigmoid":
|
||||
return nn.Sigmoid()
|
||||
elif activation == "modrelu":
|
||||
return Modrelu(size)
|
||||
else:
|
||||
raise NotImplementedError(
|
||||
"hidden activation '{}' is not implemented".format(activation)
|
||||
)
|
||||
|
||||
|
||||
def get_initializer(name, activation=None):
|
||||
if activation in [None, "id", "identity", "linear", "modrelu"]:
|
||||
nonlinearity = "linear"
|
||||
elif activation in ["relu", "tanh", "sigmoid"]:
|
||||
nonlinearity = activation
|
||||
elif activation in ["gelu", "swish", "silu"]:
|
||||
nonlinearity = "relu" # Close to ReLU so approximate with ReLU's gain
|
||||
else:
|
||||
raise NotImplementedError(
|
||||
f"get_initializer: activation {activation} not supported"
|
||||
)
|
||||
|
||||
if name == "uniform":
|
||||
initializer = partial(torch.nn.init.kaiming_uniform_, nonlinearity=nonlinearity)
|
||||
elif name == "normal":
|
||||
initializer = partial(torch.nn.init.kaiming_normal_, nonlinearity=nonlinearity)
|
||||
elif name == "xavier":
|
||||
initializer = torch.nn.init.xavier_normal_
|
||||
elif name == "zero":
|
||||
initializer = partial(torch.nn.init.constant_, val=0)
|
||||
elif name == "one":
|
||||
initializer = partial(torch.nn.init.constant_, val=1)
|
||||
else:
|
||||
raise NotImplementedError(
|
||||
f"get_initializer: initializer type {name} not supported"
|
||||
)
|
||||
|
||||
return initializer
|
||||
|
||||
|
||||
def LinearActivation(
|
||||
d_input,
|
||||
d_output,
|
||||
bias=True,
|
||||
zero_bias_init=False,
|
||||
transposed=False,
|
||||
initializer=None,
|
||||
activation=None,
|
||||
activate=False, # Apply activation as part of this module
|
||||
weight_norm=False,
|
||||
**kwargs,
|
||||
):
|
||||
"""Returns a linear nn.Module with control over axes order, initialization, and activation"""
|
||||
|
||||
# Construct core module
|
||||
linear_cls = TransposedLinear if transposed else nn.Linear
|
||||
if activation == "glu":
|
||||
d_output *= 2
|
||||
linear = linear_cls(d_input, d_output, bias=bias, **kwargs)
|
||||
|
||||
# Initialize weight
|
||||
if initializer is not None:
|
||||
get_initializer(initializer, activation)(linear.weight)
|
||||
|
||||
# Initialize bias
|
||||
if bias and zero_bias_init:
|
||||
nn.init.zeros_(linear.bias)
|
||||
|
||||
# Weight norm
|
||||
if weight_norm:
|
||||
linear = nn.utils.weight_norm(linear)
|
||||
|
||||
if activate and activation is not None:
|
||||
activation = Activation(activation, d_output, dim=-2 if transposed else -1)
|
||||
linear = nn.Sequential(linear, activation)
|
||||
return linear
|
||||
|
||||
|
||||
class TransposedLinear(nn.Module):
|
||||
"""Linear module on the second-to-last dimension"""
|
||||
|
||||
def __init__(self, d_input, d_output, bias=True):
|
||||
super().__init__()
|
||||
|
||||
self.weight = nn.Parameter(torch.empty(d_output, d_input))
|
||||
nn.init.kaiming_uniform_(self.weight, a=math.sqrt(5)) # nn.Linear default init
|
||||
# nn.init.kaiming_uniform_(self.weight, nonlinearity='linear') # should be equivalent
|
||||
|
||||
if bias:
|
||||
self.bias = nn.Parameter(torch.empty(d_output, 1))
|
||||
bound = 1 / math.sqrt(d_input)
|
||||
nn.init.uniform_(self.bias, -bound, bound)
|
||||
else:
|
||||
self.bias = 0.0
|
||||
|
||||
def forward(self, x):
|
||||
return contract("... u l, v u -> ... v l", x, self.weight) + self.bias
|
||||
|
||||
|
||||
class TransposedLN(nn.Module):
|
||||
"""LayerNorm module over second-to-last dimension
|
||||
|
||||
This is slow and a dedicated CUDA/Triton implementation shuld provide substantial end-to-end speedup
|
||||
"""
|
||||
|
||||
def __init__(self, d, scalar=True):
|
||||
super().__init__()
|
||||
self.scalar = scalar
|
||||
if self.scalar:
|
||||
self.m = nn.Parameter(torch.zeros(1))
|
||||
self.s = nn.Parameter(torch.ones(1))
|
||||
else:
|
||||
self.ln = nn.LayerNorm(d)
|
||||
|
||||
def forward(self, x):
|
||||
if self.scalar:
|
||||
s, m = torch.std_mean(x, dim=-2, unbiased=False, keepdim=True)
|
||||
y = (self.s / s) * (x - m + self.m)
|
||||
else:
|
||||
y = self.ln(x.transpose(-1, -2)).transpose(-1, -2)
|
||||
return y
|
||||
|
||||
|
||||
class Normalization(nn.Module):
|
||||
def __init__(
|
||||
self,
|
||||
d,
|
||||
transposed=False, # Length dimension is -1 or -2
|
||||
_name_="layer",
|
||||
**kwargs,
|
||||
):
|
||||
super().__init__()
|
||||
self.transposed = transposed
|
||||
|
||||
if _name_ == "layer":
|
||||
self.channel = True # Normalize over channel dimension
|
||||
if self.transposed:
|
||||
self.norm = TransposedLN(d, **kwargs)
|
||||
else:
|
||||
self.norm = nn.LayerNorm(d, **kwargs)
|
||||
elif _name_ == "instance":
|
||||
self.channel = False
|
||||
norm_args = {"affine": False, "track_running_stats": False}
|
||||
norm_args.update(kwargs)
|
||||
self.norm = nn.InstanceNorm1d(
|
||||
d, **norm_args
|
||||
) # (True, True) performs very poorly
|
||||
elif _name_ == "batch":
|
||||
self.channel = False
|
||||
norm_args = {"affine": True, "track_running_stats": True}
|
||||
norm_args.update(kwargs)
|
||||
self.norm = nn.BatchNorm1d(d, **norm_args)
|
||||
elif _name_ == "none":
|
||||
self.channel = True
|
||||
self.norm = nn.Identity()
|
||||
else:
|
||||
raise NotImplementedError
|
||||
|
||||
def forward(self, x):
|
||||
# The cases of LayerNorm / no normalization are automatically handled in all cases
|
||||
# Instance/Batch Norm work automatically with transposed axes
|
||||
if self.channel or self.transposed:
|
||||
return self.norm(x)
|
||||
else:
|
||||
x = x.transpose(-1, -2)
|
||||
x = self.norm(x)
|
||||
x = x.transpose(-1, -2)
|
||||
return x
|
||||
+393
@@ -0,0 +1,393 @@
|
||||
""" Definitions of A and B matrices for various HiPPO operators. """
|
||||
|
||||
import numpy as np
|
||||
import torch
|
||||
from einops import rearrange, repeat
|
||||
from opt_einsum import contract
|
||||
from scipy import special as ss
|
||||
|
||||
|
||||
def embed_c2r(A):
|
||||
A = rearrange(A, "... m n -> ... m () n ()")
|
||||
A = np.pad(A, ((0, 0), (0, 1), (0, 0), (0, 1))) + np.pad(
|
||||
A, ((0, 0), (1, 0), (0, 0), (1, 0))
|
||||
)
|
||||
return rearrange(A, "m x n y -> (m x) (n y)")
|
||||
|
||||
|
||||
# TODO take in 'torch' option to return torch instead of numpy, which converts the shape of B from (N, 1) to (N)
|
||||
# TODO remove tlagt
|
||||
def transition(measure, N, **measure_args):
|
||||
"""A, B transition matrices for different measures
|
||||
|
||||
measure: the type of measure
|
||||
legt - Legendre (translated)
|
||||
legs - Legendre (scaled)
|
||||
glagt - generalized Laguerre (translated)
|
||||
lagt, tlagt - previous versions of (tilted) Laguerre with slightly different normalization
|
||||
"""
|
||||
# Laguerre (translated)
|
||||
if measure == "lagt":
|
||||
b = measure_args.get("beta", 1.0)
|
||||
A = np.eye(N) / 2 - np.tril(np.ones((N, N)))
|
||||
B = b * np.ones((N, 1))
|
||||
elif measure == "tlagt":
|
||||
# beta = 1 corresponds to no tilt
|
||||
b = measure_args.get("beta", 1.0)
|
||||
A = (1.0 - b) / 2 * np.eye(N) - np.tril(np.ones((N, N)))
|
||||
B = b * np.ones((N, 1))
|
||||
# Generalized Laguerre
|
||||
# alpha 0, beta small is most stable (limits to the 'lagt' measure)
|
||||
# alpha 0, beta 1 has transition matrix A = [lower triangular 1]
|
||||
elif measure == "glagt":
|
||||
alpha = measure_args.get("alpha", 0.0)
|
||||
beta = measure_args.get("beta", 0.01)
|
||||
A = -np.eye(N) * (1 + beta) / 2 - np.tril(np.ones((N, N)), -1)
|
||||
B = ss.binom(alpha + np.arange(N), np.arange(N))[:, None]
|
||||
|
||||
L = np.exp(
|
||||
0.5 * (ss.gammaln(np.arange(N) + alpha + 1) - ss.gammaln(np.arange(N) + 1))
|
||||
)
|
||||
A = (1.0 / L[:, None]) * A * L[None, :]
|
||||
B = (
|
||||
(1.0 / L[:, None])
|
||||
* B
|
||||
* np.exp(-0.5 * ss.gammaln(1 - alpha))
|
||||
* beta ** ((1 - alpha) / 2)
|
||||
)
|
||||
# Legendre (translated)
|
||||
elif measure == "legt":
|
||||
Q = np.arange(N, dtype=np.float64)
|
||||
R = (2 * Q + 1) ** 0.5
|
||||
j, i = np.meshgrid(Q, Q)
|
||||
A = R[:, None] * np.where(i < j, (-1.0) ** (i - j), 1) * R[None, :]
|
||||
B = R[:, None]
|
||||
A = -A
|
||||
|
||||
# Halve again for timescale correctness
|
||||
# A, B = A/2, B/2
|
||||
A *= 0.5
|
||||
B *= 0.5
|
||||
# LMU: equivalent to LegT up to normalization
|
||||
elif measure == "lmu":
|
||||
Q = np.arange(N, dtype=np.float64)
|
||||
R = (2 * Q + 1)[:, None] # / theta
|
||||
j, i = np.meshgrid(Q, Q)
|
||||
A = np.where(i < j, -1, (-1.0) ** (i - j + 1)) * R
|
||||
B = (-1.0) ** Q[:, None] * R
|
||||
# Legendre (scaled)
|
||||
elif measure == "legs":
|
||||
q = np.arange(N, dtype=np.float64)
|
||||
col, row = np.meshgrid(q, q)
|
||||
r = 2 * q + 1
|
||||
M = -(np.where(row >= col, r, 0) - np.diag(q))
|
||||
T = np.sqrt(np.diag(2 * q + 1))
|
||||
A = T @ M @ np.linalg.inv(T)
|
||||
B = np.diag(T)[:, None]
|
||||
B = (
|
||||
B.copy()
|
||||
) # Otherwise "UserWarning: given NumPY array is not writeable..." after torch.as_tensor(B)
|
||||
elif measure == "legsd":
|
||||
q = np.arange(N, dtype=np.float64)
|
||||
col, row = np.meshgrid(q, q)
|
||||
r = 2 * q + 1
|
||||
M = -(np.where(row >= col, r, 0) - np.diag(q))
|
||||
T = np.sqrt(np.diag(2 * q + 1))
|
||||
A = T @ M @ np.linalg.inv(T)
|
||||
B = np.diag(T)[:, None]
|
||||
B = (
|
||||
B.copy()
|
||||
) # Otherwise "UserWarning: given NumPY array is not writeable..." after torch.as_tensor(B)
|
||||
A += 0.5 * B * B[None, :, 0]
|
||||
B = B / 2.0
|
||||
elif measure == "fourier_old":
|
||||
freqs = np.arange(N // 2)
|
||||
d = np.stack([freqs, np.zeros(N // 2)], axis=-1).reshape(-1)[:-1]
|
||||
A = 2 * np.pi * (np.diag(d, 1) - np.diag(d, -1))
|
||||
A = A - embed_c2r(np.ones((N // 2, N // 2)))
|
||||
B = embed_c2r(np.ones((N // 2, 1)))[..., :1]
|
||||
elif measure == "fourier_diag":
|
||||
freqs = np.arange(N // 2)
|
||||
d = np.stack([freqs, np.zeros(N // 2)], axis=-1).reshape(-1)[:-1]
|
||||
A = 2 * np.pi * (-np.diag(d, 1) + np.diag(d, -1))
|
||||
# A = A - 0.5*embed_c2r(np.ones((N//2, N//2)))
|
||||
A = A - 0.5 * np.eye(N)
|
||||
B = embed_c2r(np.ones((N // 2, 1)))[..., :1]
|
||||
elif measure == "fourier":
|
||||
freqs = np.arange(N // 2)
|
||||
d = np.stack([np.zeros(N // 2), freqs], axis=-1).reshape(-1)[1:]
|
||||
A = np.pi * (-np.diag(d, 1) + np.diag(d, -1))
|
||||
B = np.zeros(N)
|
||||
B[0::2] = 2**0.5
|
||||
B[0] = 1
|
||||
|
||||
# Subtract off rank correction - this corresponds to the other endpoint u(t-1) in this case
|
||||
A = A - B[:, None] * B[None, :]
|
||||
B = B[:, None]
|
||||
elif measure == "fourier_decay":
|
||||
freqs = np.arange(N // 2)
|
||||
d = np.stack([np.zeros(N // 2), freqs], axis=-1).reshape(-1)[1:]
|
||||
A = np.pi * (-np.diag(d, 1) + np.diag(d, -1))
|
||||
B = np.zeros(N)
|
||||
B[0::2] = 2**0.5
|
||||
B[0] = 1
|
||||
|
||||
# Subtract off rank correction - this corresponds to the other endpoint u(t-1) in this case
|
||||
A = A - 0.5 * B[:, None] * B[None, :]
|
||||
B = 0.5 * B[:, None]
|
||||
elif measure == "fourier2": # Double everything: orthonormal on [0, 1]
|
||||
freqs = 2 * np.arange(N // 2)
|
||||
d = np.stack([np.zeros(N // 2), freqs], axis=-1).reshape(-1)[1:]
|
||||
A = np.pi * (-np.diag(d, 1) + np.diag(d, -1))
|
||||
B = np.zeros(N)
|
||||
B[0::2] = 2**0.5
|
||||
B[0] = 1
|
||||
|
||||
# Subtract off rank correction - this corresponds to the other endpoint u(t-1) in this case
|
||||
A = A - B[:, None] * B[None, :] * 2
|
||||
B = B[:, None] * 2
|
||||
|
||||
elif measure == "random":
|
||||
A = np.random.randn(N, N) / N
|
||||
B = np.random.randn(N, 1)
|
||||
elif measure == "diagonal":
|
||||
A = -np.diag(np.exp(np.random.randn(N)))
|
||||
B = np.random.randn(N, 1)
|
||||
else:
|
||||
raise NotImplementedError
|
||||
|
||||
return A, B
|
||||
|
||||
|
||||
def rank_correction(measure, N, rank=1, dtype=torch.float):
|
||||
"""Return low-rank matrix L such that A + L is normal"""
|
||||
|
||||
if measure == "legs":
|
||||
assert rank >= 1
|
||||
P = torch.sqrt(0.5 + torch.arange(N, dtype=dtype)).unsqueeze(0) # (1 N)
|
||||
elif measure == "legt":
|
||||
assert rank >= 2
|
||||
P = torch.sqrt(1 + 2 * torch.arange(N, dtype=dtype)) # (N)
|
||||
P0 = P.clone()
|
||||
P0[0::2] = 0.0
|
||||
P1 = P.clone()
|
||||
P1[1::2] = 0.0
|
||||
P = torch.stack([P0, P1], dim=0) # (2 N)
|
||||
P *= 2 ** (
|
||||
-0.5
|
||||
) # Halve the rank correct just like the original matrix was halved
|
||||
elif measure == "lagt":
|
||||
assert rank >= 1
|
||||
P = 0.5**0.5 * torch.ones(1, N, dtype=dtype)
|
||||
elif measure == "fourier_old":
|
||||
P = torch.ones(N, dtype=dtype) # (N)
|
||||
P0 = P.clone()
|
||||
P0[0::2] = 0.0
|
||||
P1 = P.clone()
|
||||
P1[1::2] = 0.0
|
||||
P = torch.stack([P0, P1], dim=0) # (2 N)
|
||||
P = torch.zeros(1, N, dtype=dtype)
|
||||
elif measure == "fourier":
|
||||
P = torch.zeros(N)
|
||||
P[0::2] = 2**0.5
|
||||
P[0] = 1
|
||||
P = P.unsqueeze(0)
|
||||
elif measure == "fourier_decay":
|
||||
P = torch.zeros(N)
|
||||
P[0::2] = 2**0.5
|
||||
P[0] = 1
|
||||
P = P.unsqueeze(0)
|
||||
P = P / 2**0.5
|
||||
elif measure == "fourier2":
|
||||
P = torch.zeros(N)
|
||||
P[0::2] = 2**0.5
|
||||
P[0] = 1
|
||||
P = 2**0.5 * P.unsqueeze(0)
|
||||
elif measure in ["fourier_diag", "legsd"]:
|
||||
P = torch.zeros(1, N, dtype=dtype)
|
||||
else:
|
||||
raise NotImplementedError
|
||||
|
||||
d = P.size(0)
|
||||
if rank > d:
|
||||
P = torch.cat([P, torch.zeros(rank - d, N, dtype=dtype)], dim=0) # (rank N)
|
||||
return P
|
||||
|
||||
|
||||
def initial_C(measure, N, dtype=torch.float):
|
||||
"""Return C that captures the other endpoint in the HiPPO approximation"""
|
||||
|
||||
if measure == "legt":
|
||||
C = (torch.arange(N, dtype=dtype) * 2 + 1) ** 0.5 * (-1) ** torch.arange(N)
|
||||
elif measure == "fourier_old":
|
||||
C = torch.ones(N, dtype=dtype) # (N)
|
||||
elif measure == "fourier":
|
||||
C = torch.zeros(N)
|
||||
C[0::2] = 2**0.5
|
||||
C[0] = 1
|
||||
else:
|
||||
C = torch.zeros(N, dtype=dtype) # (N)
|
||||
|
||||
return C
|
||||
|
||||
|
||||
def nplr(measure, N, rank=1, dtype=torch.float):
|
||||
"""Return w, p, q, V, B such that
|
||||
(w - p q^*, B) is unitarily equivalent to the original HiPPO A, B by the matrix V
|
||||
i.e. A = V[w - p q^*]V^*, B = V B
|
||||
"""
|
||||
assert dtype == torch.float or torch.cfloat
|
||||
|
||||
A, B = transition(measure, N)
|
||||
A = torch.as_tensor(A, dtype=dtype) # (N, N)
|
||||
B = torch.as_tensor(B, dtype=dtype)[:, 0] # (N,)
|
||||
|
||||
P = rank_correction(measure, N, rank=rank, dtype=dtype) # (r N)
|
||||
AP = A + torch.sum(P.unsqueeze(-2) * P.unsqueeze(-1), dim=-3)
|
||||
w, V = torch.linalg.eig(AP) # (..., N) (..., N, N)
|
||||
# V w V^{-1} = A
|
||||
|
||||
# print("check", V @ torch.diag_embed(w) @ V.conj().transpose(-1, -2))
|
||||
|
||||
# We require AP to be nearly skew-symmetric
|
||||
_A = AP + AP.transpose(-1, -2)
|
||||
if (
|
||||
err := torch.sum((_A - _A[0, 0] * torch.eye(N)) ** 2) / N
|
||||
) > 1e-5: # if not torch.allclose(_A - _A[0,0]*torch.eye(N), torch.zeros(N, N), atol=1e-5):
|
||||
print("WARNING: HiPPO matrix not skew symmetric", err)
|
||||
|
||||
# Only keep half of each conjugate pair
|
||||
# w = w[..., 0::2].contiguous()
|
||||
# V = V[..., 0::2].contiguous()
|
||||
_, idx = torch.sort(w.imag)
|
||||
w_sorted = w[idx]
|
||||
V_sorted = V[:, idx]
|
||||
|
||||
# There is an edge case when eigenvalues can be 0, which requires some machinery to handle
|
||||
# We use a huge hack here: Assume only one pair is 0, and that it is the first row/column of A (only happens in Fourier case)
|
||||
V = V_sorted[:, : N // 2]
|
||||
w = w_sorted[: N // 2]
|
||||
assert w[-2].abs() > 1e-4, "Only 1 zero eigenvalue allowed in diagonal part of A"
|
||||
if w[-1].abs() < 1e-4:
|
||||
V[:, -1] = 0.0
|
||||
V[0, -1] = 2**-0.5
|
||||
V[1, -1] = 2**-0.5 * 1j
|
||||
|
||||
_AP = V @ torch.diag_embed(w) @ V.conj().transpose(-1, -2)
|
||||
# assert torch.allclose(2*_AP.real, AP, atol=1e-5)
|
||||
if (err := torch.sum((2 * _AP.real - AP) ** 2) / N) > 1e-5:
|
||||
print(
|
||||
"Warning: Diagonalization of A matrix not numerically precise - error", err
|
||||
)
|
||||
# print("check", V @ torch.diag_embed(w) @ V.conj().transpose(-1, -2))
|
||||
|
||||
# # Override eigenvectors for 0 eigenvalues, to make them conjugate pairs
|
||||
# breakpoint()
|
||||
# rotate = torch.tensor([[1, 1], [1j, -1j]]) / 2**.5
|
||||
# # rotate = torch.tensor([[1, -1j], [1, 1j]]) / 2**.5
|
||||
# V_rot = (V.view(N, N//2, 2) @ rotate).view(N, N) # rotate every pair of eigenvectors
|
||||
# V = torch.where(w.repeat(N, 1) == 0, V_rot, V)
|
||||
|
||||
V_inv = V.conj().transpose(-1, -2)
|
||||
|
||||
C = initial_C(measure, N, dtype=dtype)
|
||||
B = contract("ij, j -> i", V_inv, B.to(V)) # V^* B
|
||||
C = contract("ij, j -> i", V_inv, C.to(V)) # V^* C
|
||||
P = contract("ij, ...j -> ...i", V_inv, P.to(V)) # V^* P
|
||||
|
||||
return w, P, B, C, V
|
||||
|
||||
|
||||
def random_dplr(
|
||||
N,
|
||||
rank=1,
|
||||
H=1,
|
||||
dtype=torch.float,
|
||||
real_scale=1.0,
|
||||
imag_scale=1.0,
|
||||
scaling="inverse",
|
||||
random_real=False,
|
||||
random_imag=False,
|
||||
normalize=True,
|
||||
):
|
||||
assert dtype == torch.float or torch.double
|
||||
# batch_shape = (H, N//2) if H is not None else (N//2,)
|
||||
dtype = torch.cfloat if dtype == torch.float else torch.cdouble
|
||||
# w = -torch.exp(torch.randn(N//2)) + 1j*torch.randn(N//2)
|
||||
# w = -torch.exp(torch.randn(N//2)) + 1j*2*torch.tensor(np.pi)*N*torch.rand(N//2) # try larger eigenvalue spread
|
||||
|
||||
pi = torch.tensor(np.pi)
|
||||
if random_real:
|
||||
real_part = torch.rand(H, N // 2)
|
||||
else:
|
||||
real_part = 0.5 * torch.ones(H, N // 2)
|
||||
if random_imag:
|
||||
imag_part = N // 2 * torch.rand(H, N // 2)
|
||||
else:
|
||||
imag_part = repeat(torch.arange(N // 2), "n -> h n", h=H)
|
||||
|
||||
real_part = real_scale * real_part
|
||||
if scaling == "random":
|
||||
imag_part = torch.randn(H, N // 2)
|
||||
elif scaling == "linear":
|
||||
imag_part = pi * imag_part
|
||||
elif scaling == "inverse": # Based on asymptotics of the default HiPPO matrix
|
||||
# intercept = torch.log(N//2)/torch.log(2) * 2./3.
|
||||
# log_imag_part = intercept + 2. * torch.atanh((1+imag_part*2)/N*2-1)
|
||||
# imag_part = torch.exp(log_imag_part)
|
||||
# intercept = torch.log(N//2) - .5
|
||||
# imag_part = torch.exp(2. * torch.atanh((1+imag_part*2)/N*2-1))
|
||||
imag_part = 1 / pi * N * (N / (1 + 2 * imag_part) - 1)
|
||||
elif scaling == "inverse2": # Based on asymptotics of the default HiPPO matrix
|
||||
# intercept = torch.log(N//2)/torch.log(2) * 2./3.
|
||||
# log_imag_part = intercept + 2. * torch.atanh((1+imag_part*2)/N*2-1)
|
||||
# imag_part = torch.exp(log_imag_part)
|
||||
# intercept = torch.log(N//2) - .5
|
||||
# imag_part = torch.exp(2. * torch.atanh((1+imag_part*2)/N*2-1))
|
||||
imag_part = 1 / pi * N * (N / (1 + imag_part) - 1)
|
||||
elif scaling == "quadratic":
|
||||
imag_part = 1 / pi * (1 + 2 * imag_part) ** 2
|
||||
else:
|
||||
raise NotImplementedError
|
||||
imag_part = imag_scale * imag_part
|
||||
w = -real_part + 1j * imag_part
|
||||
|
||||
# w = -torch.rand(N//2) + 1j*2*torch.tensor(np.pi)*N*torch.rand(N//2) # try larger eigenvalue spread
|
||||
# w = -1 + torch.arange(N//2) * 1j * 2 * torch.tensor(np.pi)
|
||||
P = torch.randn(rank, H, N // 2, dtype=dtype)
|
||||
# p = torch.zeros(rank, N//2, dtype=dtype)
|
||||
B = torch.randn(H, N // 2, dtype=dtype)
|
||||
# B = torch.ones(N//2, dtype=dtype)
|
||||
C = torch.randn(H, N // 2, dtype=dtype)
|
||||
V = torch.eye(N, dtype=dtype)[..., : N // 2] # Only used in testing
|
||||
|
||||
if normalize: # TODO can normalize the full matrix with rank correction too
|
||||
norm = (
|
||||
-B / w
|
||||
) # (H, N) # Result if you integrate the kernel with constant 1 function
|
||||
zeta = 2 * torch.sum(
|
||||
torch.abs(norm) ** 2, dim=-1, keepdim=True
|
||||
) # Variance with a random C vector
|
||||
B = B / zeta**0.5
|
||||
|
||||
return w, P, B, C, V
|
||||
|
||||
|
||||
def test_nplr():
|
||||
N = 4
|
||||
measure = "fourier_decay"
|
||||
w, P, B, C, V = nplr(measure, N, rank=1)
|
||||
w = torch.cat([w, w.conj()], dim=-1)
|
||||
V = torch.cat([V, V.conj()], dim=-1)
|
||||
B = torch.cat([B, B.conj()], dim=-1)
|
||||
P = torch.cat([P, P.conj()], dim=-1)
|
||||
Q = P
|
||||
# q = torch.cat([q, q.conj()], dim=-1)
|
||||
A = torch.diag_embed(w) - contract("... r p, ... r q -> ... p q", P, Q.conj())
|
||||
|
||||
A = contract(
|
||||
"ij, jk, kl -> ... il", V, A, V.conj().transpose(-1, -2)
|
||||
) # Ap^{-1} = V @ w^{-1} @ V^T
|
||||
B = contract("ij, ... j -> ... i", V, B)
|
||||
print(A.real)
|
||||
print(B.real)
|
||||
+1143
File diff suppressed because it is too large
Load Diff
+214
@@ -0,0 +1,214 @@
|
||||
""" Compute a Krylov function efficiently. (S3 renames the Krylov function to a "state space kernel")
|
||||
|
||||
A : (N, N)
|
||||
b : (N,)
|
||||
c : (N,)
|
||||
Return: [c^T A^i b for i in [L]]
|
||||
"""
|
||||
|
||||
|
||||
import torch
|
||||
import torch.nn.functional as F
|
||||
from einops import rearrange
|
||||
|
||||
from toeplitz import causal_convolution
|
||||
|
||||
|
||||
def krylov_sequential(L, A, b, c=None):
|
||||
"""Constant matrix A
|
||||
|
||||
A : (..., N, N)
|
||||
b : (..., N)
|
||||
c : (..., N)
|
||||
|
||||
Returns
|
||||
if c:
|
||||
x : (..., L)
|
||||
x[i, l] = c[i] @ A^l @ b[i]
|
||||
|
||||
else:
|
||||
x : (..., N, L)
|
||||
x[i, l] = A^l @ b[i]
|
||||
"""
|
||||
|
||||
# Check which of dim b and c is smaller to save memory
|
||||
if c is not None and c.numel() < b.numel():
|
||||
return krylov_sequential(L, A.transpose(-1, -2), c, b)
|
||||
|
||||
b_ = b
|
||||
x = []
|
||||
for _ in range(L):
|
||||
if c is not None:
|
||||
x_ = torch.sum(
|
||||
c * b_, dim=-1
|
||||
) # (...) # could be faster with matmul or einsum?
|
||||
else:
|
||||
x_ = b_
|
||||
x.append(x_)
|
||||
b_ = (A @ b_.unsqueeze(-1)).squeeze(-1)
|
||||
|
||||
x = torch.stack(x, dim=-1)
|
||||
return x
|
||||
|
||||
|
||||
def krylov(L, A, b, c=None, return_power=False):
|
||||
"""
|
||||
Compute the Krylov matrix (b, Ab, A^2b, ...) using the squaring trick.
|
||||
|
||||
If return_power=True, return A^{L-1} as well
|
||||
"""
|
||||
# TODO There is an edge case if L=1 where output doesn't get broadcasted, which might be an issue if caller is expecting broadcasting semantics... can deal with it if it arises
|
||||
|
||||
x = b.unsqueeze(-1) # (..., N, 1)
|
||||
A_ = A
|
||||
|
||||
AL = None
|
||||
if return_power:
|
||||
AL = torch.eye(A.shape[-1], dtype=A.dtype, device=A.device)
|
||||
_L = L - 1
|
||||
|
||||
done = L == 1
|
||||
# loop invariant: _L represents how many indices left to compute
|
||||
while not done:
|
||||
if return_power:
|
||||
if _L % 2 == 1:
|
||||
AL = A_ @ AL
|
||||
_L //= 2
|
||||
|
||||
# Save memory on last iteration
|
||||
l = x.shape[-1]
|
||||
if L - l <= l:
|
||||
done = True
|
||||
_x = x[..., : L - l]
|
||||
else:
|
||||
_x = x
|
||||
|
||||
_x = A_ @ _x
|
||||
x = torch.cat(
|
||||
[x, _x], dim=-1
|
||||
) # there might be a more efficient way of ordering axes
|
||||
if not done:
|
||||
A_ = A_ @ A_
|
||||
|
||||
assert x.shape[-1] == L
|
||||
|
||||
if c is not None:
|
||||
x = torch.einsum("...nl, ...n -> ...l", x, c)
|
||||
x = x.contiguous() # WOW!!
|
||||
if return_power:
|
||||
return x, AL
|
||||
else:
|
||||
return x
|
||||
|
||||
|
||||
def power(L, A, v=None):
|
||||
"""Compute A^L and the scan sum_i A^i v_i
|
||||
|
||||
A: (..., N, N)
|
||||
v: (..., N, L)
|
||||
"""
|
||||
|
||||
I = torch.eye(A.shape[-1]).to(A) # , dtype=A.dtype, device=A.device)
|
||||
|
||||
powers = [A]
|
||||
l = 1
|
||||
while True:
|
||||
if L % 2 == 1:
|
||||
I = powers[-1] @ I
|
||||
L //= 2
|
||||
if L == 0:
|
||||
break
|
||||
l *= 2
|
||||
powers.append(powers[-1] @ powers[-1])
|
||||
|
||||
if v is None:
|
||||
return I
|
||||
|
||||
# Invariants:
|
||||
# powers[-1] := A^l
|
||||
# l := largest po2 at most L
|
||||
|
||||
# Note that an alternative divide and conquer to compute the reduction is possible and can be embedded into the above loop without caching intermediate powers of A
|
||||
# We do this reverse divide-and-conquer for efficiency reasons:
|
||||
# 1) it involves fewer padding steps for non-po2 L
|
||||
# 2) it involves more contiguous arrays
|
||||
|
||||
# Take care of edge case for non-po2 arrays
|
||||
# Note that this initial step is a no-op for the case of power of 2 (l == L)
|
||||
k = v.size(-1) - l
|
||||
v_ = powers.pop() @ v[..., l:]
|
||||
v = v[..., :l]
|
||||
v[..., :k] = v[..., :k] + v_
|
||||
|
||||
# Handle reduction for power of 2
|
||||
while v.size(-1) > 1:
|
||||
v = rearrange(v, "... (z l) -> ... z l", z=2)
|
||||
v = v[..., 0, :] + powers.pop() @ v[..., 1, :]
|
||||
return I, v.squeeze(-1)
|
||||
|
||||
|
||||
def krylov_toeplitz(L, A, b, c=None):
|
||||
"""Specializes to lower triangular Toeplitz matrix A represented by its diagonals
|
||||
|
||||
A : (..., N)
|
||||
b : (..., N)
|
||||
c : (..., N)
|
||||
|
||||
Returns
|
||||
x : (..., N, L)
|
||||
x[i, l] = A^l @ b[i]
|
||||
"""
|
||||
x = b.unsqueeze(0) # (1, ..., N)
|
||||
A_ = A
|
||||
while x.shape[0] < L:
|
||||
xx = causal_convolution(A_, x)
|
||||
x = torch.cat(
|
||||
[x, xx], dim=0
|
||||
) # there might be a more efficient way of ordering axes
|
||||
A_ = causal_convolution(A_, A_)
|
||||
x = x[:L, ...] # (L, ..., N)
|
||||
if c is not None:
|
||||
x = torch.einsum("l...n, ...n -> ...l", x, c)
|
||||
else:
|
||||
x = rearrange(x, "l ... n -> ... n l")
|
||||
x = x.contiguous()
|
||||
return x
|
||||
|
||||
|
||||
def krylov_toeplitz_(L, A, b, c=None):
|
||||
"""Padded version of krylov_toeplitz that saves some fft's
|
||||
|
||||
TODO currently not faster than original version, not sure why
|
||||
"""
|
||||
N = A.shape[-1]
|
||||
|
||||
x = b.unsqueeze(0) # (1, ..., N)
|
||||
x = F.pad(x, (0, N))
|
||||
A = F.pad(A, (0, N))
|
||||
done = L == 1
|
||||
while not done:
|
||||
l = x.shape[0]
|
||||
# Save memory on last iteration
|
||||
if L - l <= l:
|
||||
done = True
|
||||
_x = x[: L - l]
|
||||
else:
|
||||
_x = x
|
||||
Af = torch.fft.rfft(A, n=2 * N, dim=-1)
|
||||
xf = torch.fft.rfft(_x, n=2 * N, dim=-1)
|
||||
xf_ = Af * xf
|
||||
x_ = torch.fft.irfft(xf_, n=2 * N, dim=-1)
|
||||
x_[..., N:] = 0
|
||||
x = torch.cat(
|
||||
[x, x_], dim=0
|
||||
) # there might be a more efficient way of ordering axes
|
||||
if not done:
|
||||
A = torch.fft.irfft(Af * Af, n=2 * N, dim=-1)
|
||||
A[..., N:] = 0
|
||||
x = x[:L, ..., :N] # (L, ..., N)
|
||||
if c is not None:
|
||||
x = torch.einsum("l...n, ...n -> ...l", x, c)
|
||||
else:
|
||||
x = rearrange(x, "l ... n -> ... n l")
|
||||
x = x.contiguous()
|
||||
return x
|
||||
@@ -0,0 +1,25 @@
|
||||
import logging
|
||||
|
||||
from pytorch_lightning.utilities import rank_zero_only
|
||||
|
||||
|
||||
def get_logger(name=__name__, level=logging.INFO) -> logging.Logger:
|
||||
"""Initializes multi-GPU-friendly python logger."""
|
||||
|
||||
logger = logging.getLogger(name)
|
||||
logger.setLevel(level)
|
||||
|
||||
# this ensures all logging levels get marked with the rank zero decorator
|
||||
# otherwise logs would get multiplied for each GPU process in multi-GPU setup
|
||||
for level in (
|
||||
"debug",
|
||||
"info",
|
||||
"warning",
|
||||
"error",
|
||||
"exception",
|
||||
"fatal",
|
||||
"critical",
|
||||
):
|
||||
setattr(logger, level, rank_zero_only(getattr(logger, level)))
|
||||
|
||||
return logger
|
||||
@@ -0,0 +1,237 @@
|
||||
""" pykeops implementations of the core Cauchy kernel used in the S3 algorithm.
|
||||
"""
|
||||
|
||||
|
||||
import torch
|
||||
from einops import rearrange
|
||||
from pykeops.torch import Genred, LazyTensor
|
||||
|
||||
|
||||
_conj = lambda x: torch.cat([x, x.conj()], dim=-1)
|
||||
|
||||
|
||||
def _broadcast_dims(*tensors):
|
||||
max_dim = max([len(tensor.shape) for tensor in tensors])
|
||||
tensors = [
|
||||
tensor.view((1,) * (max_dim - len(tensor.shape)) + tensor.shape)
|
||||
for tensor in tensors
|
||||
]
|
||||
return tensors
|
||||
|
||||
|
||||
def _c2r(x):
|
||||
return torch.view_as_real(x)
|
||||
|
||||
|
||||
def _r2c(x):
|
||||
return torch.view_as_complex(x)
|
||||
|
||||
|
||||
def cauchy_slow(v, z, w, conj=True):
|
||||
"""
|
||||
v: (..., N)
|
||||
z: (..., L)
|
||||
w: (..., N)
|
||||
returns: (..., L) \sum v/(z-w)
|
||||
"""
|
||||
if conj:
|
||||
v = _conj(v)
|
||||
w = _conj(w)
|
||||
cauchy_matrix = v.unsqueeze(-1) / (z.unsqueeze(-2) - w.unsqueeze(-1)) # (... N L)
|
||||
return torch.sum(cauchy_matrix, dim=-2)
|
||||
|
||||
|
||||
def cauchy_lazy(v, z, w, conj=True):
|
||||
if conj:
|
||||
v = _conj(v)
|
||||
w = _conj(w)
|
||||
v, z, w = _broadcast_dims(v, z, w)
|
||||
v_l = LazyTensor(rearrange(v, "... N -> ... N 1 1"))
|
||||
w_l = LazyTensor(rearrange(w, "... N -> ... N 1 1"))
|
||||
z_l = LazyTensor(rearrange(z, "... L -> ... 1 L 1"))
|
||||
sub = z_l - w_l # (b N L 1), for some reason it doesn't display the last dimension
|
||||
div = v_l / sub
|
||||
s = div.sum(dim=len(v_l.shape) - 2)
|
||||
return s.squeeze(-1)
|
||||
|
||||
|
||||
def cauchy(v, z, w, conj=False):
|
||||
expr = "ComplexDivide(v, z-w)"
|
||||
cauchy_mult = Genred(
|
||||
expr,
|
||||
[
|
||||
"v = Vj(2)",
|
||||
"z = Vi(2)",
|
||||
"w = Vj(2)",
|
||||
],
|
||||
reduction_op="Sum",
|
||||
axis=1,
|
||||
)
|
||||
|
||||
if conj:
|
||||
v = _conj(v)
|
||||
w = _conj(w)
|
||||
v, z, w = _broadcast_dims(v, z, w)
|
||||
v = _c2r(v)
|
||||
z = _c2r(z)
|
||||
w = _c2r(w)
|
||||
|
||||
r = cauchy_mult(v, z, w, backend="GPU")
|
||||
return _r2c(r)
|
||||
|
||||
|
||||
def cauchy_real(v, z, w):
|
||||
expr = "v / (z - w)"
|
||||
cauchy_mult = Genred(
|
||||
expr,
|
||||
[
|
||||
"v = Vj(1)",
|
||||
"z = Vi(1)",
|
||||
"w = Vj(1)",
|
||||
],
|
||||
reduction_op="Sum",
|
||||
axis=1,
|
||||
)
|
||||
v, z, w = _broadcast_dims(v, z, w)
|
||||
v = v.unsqueeze(-1)
|
||||
z = z.unsqueeze(-1)
|
||||
w = w.unsqueeze(-1)
|
||||
|
||||
r = cauchy_mult(v, z, w, backend="GPU")
|
||||
return r
|
||||
|
||||
|
||||
def cauchy_conj(v, z, w, num=2, denom=2):
|
||||
if num == 1:
|
||||
expr_num = "z * ComplexReal(v) - Real2Complex(ComplexReal(v)*ComplexReal(w) + ComplexImag(v)*ComplexImag(w))"
|
||||
elif num == 2:
|
||||
expr_num = "z * ComplexReal(v) - Real2Complex(Sum(v * w))"
|
||||
else:
|
||||
raise NotImplementedError
|
||||
|
||||
if denom == 1:
|
||||
expr_denom = "ComplexMult(z-Real2Complex(ComplexReal(w)), z-Real2Complex(ComplexReal(w))) + Real2Complex(Square(ComplexImag(w)))"
|
||||
elif denom == 2:
|
||||
expr_denom = "ComplexMult(z-w, z-Conj(w))"
|
||||
else:
|
||||
raise NotImplementedError
|
||||
|
||||
cauchy_mult = Genred(
|
||||
f"ComplexDivide({expr_num}, {expr_denom})",
|
||||
[
|
||||
"v = Vj(2)",
|
||||
"z = Vi(2)",
|
||||
"w = Vj(2)",
|
||||
],
|
||||
reduction_op="Sum",
|
||||
axis=1,
|
||||
)
|
||||
|
||||
v, z, w = _broadcast_dims(v, z, w)
|
||||
v = _c2r(v)
|
||||
z = _c2r(z)
|
||||
w = _c2r(w)
|
||||
|
||||
r = 2 * cauchy_mult(v, z, w, backend="GPU")
|
||||
return _r2c(r)
|
||||
|
||||
|
||||
def cauchy_conj_components(v, z, w):
|
||||
"""Assumes z is pure imaginary (as in S4 with bilinear)"""
|
||||
|
||||
expr_num = "Imag2Complex(zi*vr) - Real2Complex(vr*wr + vi*wi)"
|
||||
expr_denom = (
|
||||
"Real2Complex(Square(wr)+Square(wi)-Square(zi)) - Imag2Complex(IntCst(2)*zi*wr)"
|
||||
)
|
||||
cauchy_mult = Genred(
|
||||
f"ComplexDivide({expr_num}, {expr_denom})",
|
||||
[
|
||||
"vr = Vj(1)",
|
||||
"vi = Vj(1)",
|
||||
"wr = Vj(1)",
|
||||
"wi = Vj(1)",
|
||||
"zi = Vi(1)",
|
||||
],
|
||||
reduction_op="Sum",
|
||||
axis=1,
|
||||
)
|
||||
|
||||
v, z, w = _broadcast_dims(v, z, w)
|
||||
v = v.unsqueeze(-1)
|
||||
z = z.unsqueeze(-1)
|
||||
w = w.unsqueeze(-1)
|
||||
|
||||
v_r, v_i = v.real.contiguous(), v.imag.contiguous()
|
||||
w_r, w_i = w.real.contiguous(), w.imag.contiguous()
|
||||
z_i = z.imag.contiguous()
|
||||
|
||||
r = 2 * cauchy_mult(v_r, v_i, w_r, w_i, z_i, backend="GPU")
|
||||
return _r2c(r)
|
||||
|
||||
|
||||
def cauchy_conj_components_lazy(v, z, w, type=1):
|
||||
v, z, w = _broadcast_dims(v, z, w)
|
||||
|
||||
v_r, v_i = v.real.contiguous(), v.imag.contiguous()
|
||||
w_r, w_i = w.real.contiguous(), w.imag.contiguous()
|
||||
z_i = z.imag.contiguous()
|
||||
|
||||
v_r = LazyTensor(rearrange(v_r, "... N -> ... 1 N 1"))
|
||||
v_i = LazyTensor(rearrange(v_i, "... N -> ... 1 N 1"))
|
||||
w_r = LazyTensor(rearrange(w_r, "... N -> ... 1 N 1"))
|
||||
w_i = LazyTensor(rearrange(w_i, "... N -> ... 1 N 1"))
|
||||
z_i = LazyTensor(rearrange(z_i, "... L -> ... L 1 1"))
|
||||
|
||||
if type == 1:
|
||||
num = -v_r * w_r - v_i * w_i + 1j * z_i * v_r
|
||||
denom = w_r**2 + w_i**2 - z_i**2 - 2j * w_r * z_i
|
||||
else:
|
||||
# z = torch.complex(-w_r, z_i) # Not supported
|
||||
z = -w_r + 1j * z_i
|
||||
num = v_r * z - v_i * w_i
|
||||
denom = z * z + w_i**2 # z**2 is bugged for complex
|
||||
|
||||
r = num / denom
|
||||
r = 2 * r.sum(dim=len(z_i.shape) - 1)
|
||||
return r.squeeze(-1)
|
||||
|
||||
|
||||
def cauchy_conj2(v, z, w):
|
||||
expr = "ComplexDivide(v, z-w) + ComplexDivide(Conj(v), z-Conj(w))"
|
||||
# expr = 'ComplexDivide(v, z-w)'
|
||||
cauchy_mult = Genred(
|
||||
expr,
|
||||
[
|
||||
"v = Vj(2)",
|
||||
"z = Vi(2)",
|
||||
"w = Vj(2)",
|
||||
],
|
||||
reduction_op="Sum",
|
||||
axis=1,
|
||||
)
|
||||
|
||||
v, z, w = _broadcast_dims(v, z, w)
|
||||
if complex:
|
||||
v = _c2r(v)
|
||||
z = _c2r(z)
|
||||
w = _c2r(w)
|
||||
|
||||
r = cauchy_mult(v, z, w, backend="GPU")
|
||||
return _r2c(r)
|
||||
|
||||
|
||||
def trigger_compilation():
|
||||
"""Small function to trigger the compilation of a pykeops kernel
|
||||
|
||||
Used in scenarios where we must manually control compilation, e.g. the multi-gpu case (https://github.com/getkeops/keops/issues/168)"""
|
||||
B = 2
|
||||
N = 4
|
||||
L = 16
|
||||
|
||||
w = torch.randn(B, N // 2, dtype=torch.cfloat, device="cuda")
|
||||
v = torch.randn(B, N // 2, dtype=torch.cfloat, device="cuda")
|
||||
z = torch.randn(B, L, dtype=torch.cfloat, device="cuda")
|
||||
w.requires_grad = True
|
||||
v.requires_grad = True
|
||||
|
||||
cauchy_conj(v, z, w)
|
||||
+1438
File diff suppressed because one or more lines are too long
@@ -0,0 +1,218 @@
|
||||
import logging
|
||||
|
||||
import opt_einsum as oe
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
import torch.nn.functional as F
|
||||
from einops import rearrange
|
||||
|
||||
from components import Activation, LinearActivation, Normalization
|
||||
from kernel import HippoSSKernel
|
||||
|
||||
optimized = True
|
||||
|
||||
if optimized:
|
||||
contract = oe.contract
|
||||
else:
|
||||
contract = torch.einsum
|
||||
|
||||
|
||||
class S4(nn.Module):
|
||||
requires_length = True
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
d_model,
|
||||
d_state=64,
|
||||
l_max=1, # Maximum length of sequence. Fine if not provided: the kernel will keep doubling in length until longer than sequence. However, this can be marginally slower if the true length is not a power of 2
|
||||
channels=1, # maps 1-dim to C-dim
|
||||
bidirectional=False,
|
||||
# Arguments for FF
|
||||
activation="gelu", # activation in between SS and FF
|
||||
ln=False, # Extra normalization
|
||||
postact=None, # activation after FF
|
||||
initializer=None, # initializer on FF
|
||||
weight_norm=False, # weight normalization on FF
|
||||
hyper_act=None, # Use a "hypernetwork" multiplication
|
||||
dropout=0.0,
|
||||
transposed=True, # axis ordering (B, L, D) or (B, D, L)
|
||||
verbose=False,
|
||||
shift=False,
|
||||
linear=False,
|
||||
# SSM Kernel arguments
|
||||
**kernel_args,
|
||||
):
|
||||
"""
|
||||
d_state: the dimension of the state, also denoted by N
|
||||
l_max: the maximum sequence length, also denoted by L
|
||||
if this is not known at model creation, set l_max=1
|
||||
channels: can be interpreted as a number of "heads"
|
||||
bidirectional: bidirectional
|
||||
dropout: standard dropout argument
|
||||
transposed: choose backbone axis ordering of (B, L, H) or (B, H, L) [B=batch size, L=sequence length, H=hidden dimension]
|
||||
|
||||
Other options are all experimental and should not need to be configured
|
||||
"""
|
||||
|
||||
super().__init__()
|
||||
if verbose:
|
||||
from logger import get_logger
|
||||
|
||||
log = get_logger(__name__)
|
||||
log.info(f"Constructing S4 (H, N, L) = ({d_model}, {d_state}, {l_max})")
|
||||
|
||||
self.h = d_model
|
||||
self.n = d_state
|
||||
self.bidirectional = bidirectional
|
||||
self.ln = ln
|
||||
self.channels = channels
|
||||
self.transposed = transposed
|
||||
self.shift = shift
|
||||
self.linear = linear
|
||||
|
||||
# optional multiplicative modulation GLU-style
|
||||
# https://arxiv.org/abs/2002.05202
|
||||
self.hyper = hyper_act is not None
|
||||
if self.hyper:
|
||||
channels *= 2
|
||||
self.hyper_activation = Activation(hyper_act)
|
||||
|
||||
self.D = nn.Parameter(torch.randn(channels, self.h))
|
||||
|
||||
if self.bidirectional:
|
||||
channels *= 2
|
||||
|
||||
# SSM Kernel
|
||||
self.kernel = HippoSSKernel(
|
||||
self.h, N=self.n, L=l_max, channels=channels, verbose=verbose, **kernel_args
|
||||
)
|
||||
|
||||
# Pointwise
|
||||
if not self.linear:
|
||||
self.activation = Activation(activation)
|
||||
dropout_fn = nn.Dropout2d if self.transposed else nn.Dropout
|
||||
self.dropout = dropout_fn(dropout) if dropout > 0.0 else nn.Identity()
|
||||
if self.ln:
|
||||
self.norm = Normalization(self.h * self.channels, transposed=transposed)
|
||||
else:
|
||||
self.norm = nn.Identity()
|
||||
|
||||
# position-wise output transform to mix features
|
||||
if not self.linear:
|
||||
self.output_linear = LinearActivation(
|
||||
self.h * self.channels,
|
||||
self.h,
|
||||
transposed=self.transposed,
|
||||
initializer=initializer,
|
||||
activation=postact,
|
||||
activate=True,
|
||||
weight_norm=weight_norm,
|
||||
)
|
||||
|
||||
def forward(
|
||||
self, u, state=None, **kwargs
|
||||
): # absorbs return_output and transformer src mask
|
||||
"""
|
||||
u: (B H L) if self.transposed else (B L H)
|
||||
state: (H N) never needed unless you know what you're doing
|
||||
|
||||
Returns: same shape as u
|
||||
"""
|
||||
if not self.transposed:
|
||||
u = u.transpose(-1, -2)
|
||||
L = u.size(-1)
|
||||
|
||||
# Compute SS Kernel
|
||||
k, k_state = self.kernel(L=L, state=state) # (C H L) (B C H L)
|
||||
|
||||
# Convolution
|
||||
if self.bidirectional:
|
||||
k0, k1 = rearrange(k, "(s c) h l -> s c h l", s=2)
|
||||
k = F.pad(k0, (0, L)) + F.pad(k1.flip(-1), (L, 0))
|
||||
if self.shift:
|
||||
# Try flip and pad to correct for potential off-by-one
|
||||
k_f = torch.fft.rfft(F.pad(k.flip(-1), (L, 0)), n=2 * L) # (C H L)
|
||||
u_f = torch.fft.rfft(F.pad(u.flip(-1), (L, 0)), n=2 * L) # (B H L)
|
||||
y_f = contract(
|
||||
"bhl,chl->bchl", u_f, k_f
|
||||
) # k_f.unsqueeze(-4) * u_f.unsqueeze(-3) # (B C H L)
|
||||
y = torch.fft.irfft(y_f, n=2 * L)[..., L:].flip(-1) # (B C H L)
|
||||
else:
|
||||
k_f = torch.fft.rfft(k, n=2 * L) # (C H L)
|
||||
u_f = torch.fft.rfft(u, n=2 * L) # (B H L)
|
||||
y_f = contract(
|
||||
"bhl,chl->bchl", u_f, k_f
|
||||
) # k_f.unsqueeze(-4) * u_f.unsqueeze(-3) # (B C H L)
|
||||
y = torch.fft.irfft(y_f, n=2 * L)[..., :L] # (B C H L)
|
||||
|
||||
# Compute D term in state space equation - essentially a skip connection
|
||||
y = y + contract(
|
||||
"bhl,ch->bchl", u, self.D
|
||||
) # u.unsqueeze(-3) * self.D.unsqueeze(-1)
|
||||
|
||||
# Compute state update
|
||||
if state is not None:
|
||||
assert (
|
||||
not self.bidirectional
|
||||
), "Bidirectional not supported with state forwarding"
|
||||
y = y + k_state
|
||||
next_state = self.kernel.forward_state(u, state)
|
||||
else:
|
||||
next_state = None
|
||||
|
||||
# Optional hyper-network multiplication
|
||||
if self.hyper:
|
||||
y, yh = rearrange(y, "b (s c) h l -> s b c h l", s=2)
|
||||
y = self.hyper_activation(yh) * y
|
||||
|
||||
# Reshape to flatten channels
|
||||
y = rearrange(y, "... c h l -> ... (c h) l")
|
||||
|
||||
if not self.linear:
|
||||
y = self.dropout(self.activation(y))
|
||||
|
||||
if not self.transposed:
|
||||
y = y.transpose(-1, -2)
|
||||
|
||||
if not self.linear:
|
||||
y = self.norm(y)
|
||||
y = self.output_linear(y)
|
||||
|
||||
return y, next_state
|
||||
|
||||
def setup_step(self):
|
||||
self.kernel.setup_step()
|
||||
|
||||
def step(self, u, state):
|
||||
"""Step one time step as a recurrent model. Intended to be used during validation.
|
||||
|
||||
u: (B H)
|
||||
state: (B H N)
|
||||
Returns: output (B H), state (B H N)
|
||||
"""
|
||||
assert not self.training
|
||||
|
||||
y, next_state = self.kernel.step(u, state) # (B C H)
|
||||
y = y + u.unsqueeze(-2) * self.D
|
||||
y = rearrange(y, "... c h -> ... (c h)")
|
||||
y = self.activation(y)
|
||||
if self.transposed:
|
||||
y = self.output_linear(y.unsqueeze(-1)).squeeze(-1)
|
||||
else:
|
||||
y = self.output_linear(y)
|
||||
return y, next_state
|
||||
|
||||
def default_state(self, *batch_shape, device=None):
|
||||
return self.kernel.default_state(*batch_shape)
|
||||
|
||||
@property
|
||||
def d_state(self):
|
||||
return self.h * self.n
|
||||
|
||||
@property
|
||||
def d_output(self):
|
||||
return self.h
|
||||
|
||||
@property
|
||||
def state_to_tensor(self):
|
||||
return lambda state: rearrange("... h n -> ... (h n)", state)
|
||||
+300
@@ -0,0 +1,300 @@
|
||||
""" Utilities for computing convolutions.
|
||||
|
||||
There are 3 equivalent views:
|
||||
1. causal convolution
|
||||
2. multiplication of (lower) triangular Toeplitz matrices
|
||||
3. polynomial multiplication (mod x^N)
|
||||
"""
|
||||
|
||||
import torch
|
||||
|
||||
# import torch.nn as nn
|
||||
import torch.nn.functional as F
|
||||
|
||||
# from model.complex import complex_mul
|
||||
# from pytorch_memlab import profile
|
||||
|
||||
|
||||
def construct_toeplitz(v, f=0.0):
|
||||
"""Explicit construction of Krylov matrix [v A @ v A^2 @ v ... A^{n-1} @ v]
|
||||
where A = Z_f. This uses vectorized indexing and cumprod so it's much
|
||||
faster than using the Krylov function.
|
||||
Parameters:
|
||||
v: the starting vector of size n or (rank, n).
|
||||
f: real number
|
||||
Returns:
|
||||
K: Krylov matrix of size (n, n) or (rank, n, n).
|
||||
"""
|
||||
n = v.shape[-1]
|
||||
a = torch.arange(n, device=v.device)
|
||||
b = -a
|
||||
indices = a[:, None] + b[None]
|
||||
K = v[..., indices]
|
||||
K[..., indices < 0] *= f
|
||||
return K
|
||||
|
||||
|
||||
def triangular_toeplitz_multiply_(u, v, sum=None):
|
||||
n = u.shape[-1]
|
||||
u_expand = F.pad(u, (0, n))
|
||||
v_expand = F.pad(v, (0, n))
|
||||
u_f = torch.fft.rfft(u_expand, n=2 * n, dim=-1)
|
||||
v_f = torch.fft.rfft(v_expand, n=2 * n, dim=-1)
|
||||
uv_f = u_f * v_f
|
||||
if sum is not None:
|
||||
uv_f = uv_f.sum(dim=sum)
|
||||
output = torch.fft.irfft(uv_f, n=2 * n, dim=-1)[..., :n]
|
||||
return output
|
||||
|
||||
|
||||
def triangular_toeplitz_multiply_padded_(u, v):
|
||||
"""Same as triangular_toeplitz_multiply but inputs and output assume to be 0-padded already."""
|
||||
n = u.shape[-1]
|
||||
assert n % 2 == 0
|
||||
u_f = torch.fft.rfft(u, n=n, dim=-1)
|
||||
v_f = torch.fft.rfft(v, n=n, dim=-1)
|
||||
uv_f = u_f * v_f
|
||||
output = torch.fft.irfft(uv_f, n=n, dim=-1)
|
||||
output[..., n:] = 0
|
||||
return output
|
||||
|
||||
|
||||
class TriangularToeplitzMult(torch.autograd.Function):
|
||||
@staticmethod
|
||||
def forward(ctx, u, v):
|
||||
ctx.save_for_backward(u, v)
|
||||
return triangular_toeplitz_multiply_(u, v)
|
||||
|
||||
@staticmethod
|
||||
def backward(ctx, grad):
|
||||
u, v = ctx.saved_tensors
|
||||
d_u = triangular_toeplitz_multiply_(grad.flip(-1), v).flip(-1)
|
||||
d_v = triangular_toeplitz_multiply_(grad.flip(-1), u).flip(-1)
|
||||
return d_u, d_v
|
||||
|
||||
|
||||
class TriangularToeplitzMultFast(torch.autograd.Function):
|
||||
@staticmethod
|
||||
def forward(ctx, u, v):
|
||||
n = u.shape[-1]
|
||||
u_expand = F.pad(u, (0, n))
|
||||
v_expand = F.pad(v, (0, n))
|
||||
u_f = torch.fft.rfft(u_expand, n=2 * n, dim=-1)
|
||||
v_f = torch.fft.rfft(v_expand, n=2 * n, dim=-1)
|
||||
|
||||
ctx.save_for_backward(u_f, v_f)
|
||||
|
||||
uv_f = u_f * v_f
|
||||
output = torch.fft.irfft(uv_f, n=2 * n, dim=-1)[..., :n]
|
||||
return output
|
||||
|
||||
@staticmethod
|
||||
def backward(ctx, grad):
|
||||
u_f, v_f = ctx.saved_tensors
|
||||
n = grad.shape[-1]
|
||||
g_expand = F.pad(grad.flip(-1), (0, n))
|
||||
g_f = torch.fft.rfft(g_expand, n=2 * n, dim=-1)
|
||||
gu_f = g_f * u_f
|
||||
gv_f = g_f * v_f
|
||||
d_u = torch.fft.irfft(gv_f, n=2 * n, dim=-1)[..., :n]
|
||||
d_v = torch.fft.irfft(gu_f, n=2 * n, dim=-1)[..., :n]
|
||||
d_u = d_u.flip(-1)
|
||||
d_v = d_v.flip(-1)
|
||||
return d_u, d_v
|
||||
|
||||
|
||||
class TriangularToeplitzMultPadded(torch.autograd.Function):
|
||||
@staticmethod
|
||||
def forward(ctx, u, v):
|
||||
ctx.save_for_backward(u, v)
|
||||
output = triangular_toeplitz_multiply_(u, v)
|
||||
return output
|
||||
|
||||
@staticmethod
|
||||
def backward(ctx, grad):
|
||||
u, v = ctx.saved_tensors
|
||||
d_u = triangular_toeplitz_multiply_padded_(grad.flip(-1), v).flip(-1)
|
||||
d_v = triangular_toeplitz_multiply_padded_(grad.flip(-1), u).flip(-1)
|
||||
return d_u, d_v
|
||||
|
||||
|
||||
class TriangularToeplitzMultPaddedFast(torch.autograd.Function):
|
||||
"""Trade off speed (20-25% faster) for more memory (20-25%)"""
|
||||
|
||||
@staticmethod
|
||||
def forward(ctx, u, v):
|
||||
n = u.shape[-1]
|
||||
u_f = torch.fft.rfft(u, n=n, dim=-1)
|
||||
v_f = torch.fft.rfft(v, n=n, dim=-1)
|
||||
|
||||
ctx.save_for_backward(u_f, v_f)
|
||||
|
||||
uv_f = u_f * v_f
|
||||
output = torch.fft.irfft(uv_f, n=n, dim=-1)
|
||||
output[..., n // 2 :].zero_()
|
||||
return output
|
||||
|
||||
@staticmethod
|
||||
def backward(ctx, grad):
|
||||
u_f, v_f = ctx.saved_tensors
|
||||
n = grad.shape[-1]
|
||||
g_expand = F.pad(grad[..., : n // 2].flip(-1), (0, n // 2))
|
||||
g_f = torch.fft.rfft(g_expand, n=n, dim=-1)
|
||||
gu_f = g_f * u_f
|
||||
gv_f = g_f * v_f
|
||||
d_u = torch.fft.irfft(gv_f, n=n, dim=-1)
|
||||
d_v = torch.fft.irfft(gu_f, n=n, dim=-1)
|
||||
d_u[..., n // 2 :].zero_()
|
||||
d_v[..., n // 2 :].zero_()
|
||||
d_u[..., : n // 2] = d_u[..., : n // 2].flip(-1) # TODO
|
||||
d_v[..., : n // 2] = d_v[..., : n // 2].flip(-1) # TODO
|
||||
return d_u, d_v
|
||||
|
||||
|
||||
# triangular_toeplitz_multiply = triangular_toeplitz_multiply_
|
||||
triangular_toeplitz_multiply = TriangularToeplitzMult.apply
|
||||
triangular_toeplitz_multiply_fast = TriangularToeplitzMultFast.apply
|
||||
triangular_toeplitz_multiply_padded = TriangularToeplitzMultPadded.apply
|
||||
triangular_toeplitz_multiply_padded_fast = TriangularToeplitzMultPaddedFast.apply
|
||||
|
||||
|
||||
def causal_convolution(u, v, fast=True, pad=False):
|
||||
if not pad and not fast:
|
||||
return triangular_toeplitz_multiply(u, v)
|
||||
if not pad and fast:
|
||||
return triangular_toeplitz_multiply_fast(u, v)
|
||||
if pad and not fast:
|
||||
return triangular_toeplitz_multiply_padded(u, v)
|
||||
if pad and fast:
|
||||
return triangular_toeplitz_multiply_padded_fast(u, v)
|
||||
|
||||
|
||||
def _fft(x, N):
|
||||
return torch.fft.rfft(F.pad(x, (0, 2 * N - x.shape[-1])), n=2 * N, dim=-1)
|
||||
|
||||
|
||||
def _ifft(x, N):
|
||||
return torch.fft.irfft(x, n=2 * N, dim=-1)[..., :N]
|
||||
|
||||
|
||||
def causal_convolution_inverse(u):
|
||||
"""Invert the causal convolution/polynomial/triangular Toeplitz matrix represented by u.
|
||||
|
||||
This is easiest in the polynomial view:
|
||||
https://www.csa.iisc.ac.in/~chandan/courses/CNT/notes/lec5.pdf
|
||||
The idea is that
|
||||
h = g^{-1} (mod x^m) => 2h - gh^2 = g^{-1} (mod x^{2m})
|
||||
|
||||
# TODO this can be numerically unstable if input is "poorly conditioned",
|
||||
# for example if u[0] is magnitudes different from the rest of u
|
||||
"""
|
||||
N = u.shape[-1]
|
||||
v = u[..., :1].reciprocal()
|
||||
while v.shape[-1] < N:
|
||||
M = v.shape[-1]
|
||||
v_f = _fft(v, 2 * M)
|
||||
u_f = _fft(u[..., : 2 * M], 2 * M)
|
||||
_v = -_ifft(u_f * v_f**2, 2 * M)
|
||||
_v[..., :M] = _v[..., :M] + 2 * v
|
||||
v = _v
|
||||
# TODO contiguous?
|
||||
v = v[..., :N]
|
||||
return v
|
||||
|
||||
|
||||
""" Below are experimental functions for improving the stability of LSSL/S3 algorithm. Currently not used anywhere. """
|
||||
|
||||
|
||||
def causal_convolution_inverse_wrong(u, v):
|
||||
"""Solve u * x = v. Initial attempt by inverting the multiplication algorithm, which I think doesn't work."""
|
||||
n = u.shape[-1]
|
||||
u_expand = F.pad(u, (0, n))
|
||||
v_expand = F.pad(v, (0, n))
|
||||
u_f = torch.fft.rfft(u_expand, n=2 * n, dim=-1)
|
||||
v_f = torch.fft.rfft(v_expand, n=2 * n, dim=-1)
|
||||
uv_f = v_f / u_f
|
||||
x = torch.fft.irfft(uv_f, n=2 * n, dim=-1)[..., :n]
|
||||
return x
|
||||
|
||||
|
||||
def construct_toeplitz_log(v):
|
||||
n = v.shape[-1]
|
||||
a = torch.arange(n, device=v.device)
|
||||
b = -a
|
||||
indices = a[:, None] + b[None]
|
||||
K = v[..., indices]
|
||||
K[..., indices < 0] = -100.0
|
||||
return K
|
||||
|
||||
|
||||
def _logsumexp(x, dim=-1):
|
||||
"""logsumexp for complex"""
|
||||
m = torch.max(torch.real(x), dim=dim, keepdim=True)[0]
|
||||
x = x - m
|
||||
x = torch.log(torch.sum(torch.exp(x), dim=dim))
|
||||
x = x + m.squeeze(dim)
|
||||
return x
|
||||
|
||||
|
||||
def causal_convolution_inverse_log(u, N=-1):
|
||||
"""Invert the causal convolution/polynomial/triangular Toeplitz matrix represented by u.
|
||||
|
||||
This is easiest in the polynomial view:
|
||||
https://www.csa.iisc.ac.in/~chandan/courses/CNT/notes/lec5.pdf
|
||||
The idea is that
|
||||
h = g^{-1} (mod x^m) => 2h - gh^2 = g^{-1} (mod x^{2m})
|
||||
|
||||
# TODO this can be numerically unstable if input is "poorly conditioned",
|
||||
# for example if u[0] is magnitudes different from the rest of u
|
||||
"""
|
||||
if N < 0:
|
||||
N = u.shape[-1]
|
||||
v = -u[..., :1]
|
||||
while v.shape[-1] < N:
|
||||
M = v.shape[-1]
|
||||
_v = F.pad(v, (0, M), value=-100.0)
|
||||
_v_ = construct_toeplitz_log(_v)
|
||||
u_ = (
|
||||
u[..., : 2 * M]
|
||||
if u.shape[-1] >= 2 * M
|
||||
else F.pad(u, (0, 2 * M - u.shape[-1]), value=-100.0)
|
||||
)
|
||||
_u = _logsumexp(_v_ + u_, dim=-1)
|
||||
_u = _logsumexp(_v_ + _u, dim=-1)
|
||||
_u = _u + torch.log(-torch.ones_like(_u))
|
||||
_v = _v + torch.log(2.0 * torch.ones_like(_u))
|
||||
v = _logsumexp(torch.stack([_v, _u], dim=-1), dim=-1)
|
||||
# TODO contiguous?
|
||||
v = v[..., :N]
|
||||
|
||||
check = _logsumexp(
|
||||
construct_toeplitz_log(v) + F.pad(u, (0, N - u.shape[-1]), value=-100.0)
|
||||
)
|
||||
print("check", check, torch.exp(check))
|
||||
return v
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
a = torch.tensor([1.0, 2, 3, 4], requires_grad=True)
|
||||
b = torch.tensor([5.0, 6, 7, 8], requires_grad=True)
|
||||
a.retain_grad()
|
||||
b.retain_grad()
|
||||
x = triangular_toeplitz_multiply_padded(F.pad(a, (0, 4)), F.pad(b, (0, 4)))[:4]
|
||||
print(x) # [5 16 34 60]
|
||||
x = x.sum()
|
||||
x.backward()
|
||||
print(x, a.grad, b.grad) # [26 18 11 5] [10 6 3 1]
|
||||
|
||||
if __name__ == "__main__":
|
||||
N = 4
|
||||
a = torch.randn(N)
|
||||
construct_toeplitz(a)
|
||||
print(a)
|
||||
b = causal_convolution_inverse(a)
|
||||
print("inverse", b)
|
||||
print("check", causal_convolution(a, b))
|
||||
i = torch.zeros(N)
|
||||
i[0] = 1.0
|
||||
b = causal_convolution_inverse_wrong(a, i)
|
||||
print(b)
|
||||
print(causal_convolution(a, b))
|
||||
Reference in New Issue
Block a user