mirror of
https://github.com/wassname/DeepTime.git
synced 2026-06-27 21:20:13 +08:00
37 lines
1.4 KiB
Python
37 lines
1.4 KiB
Python
from typing import List, Optional
|
|
|
|
import math
|
|
import torch
|
|
from torch import nn
|
|
from torch import Tensor
|
|
|
|
|
|
class GaussianFourierFeatureTransform(nn.Module):
|
|
"""
|
|
https://github.com/ndahlquist/pytorch-fourier-feature-networks
|
|
Given an input of size [..., time, dim], returns a tensor of size [..., n_fourier_feats, time].
|
|
"""
|
|
def __init__(self, input_dim: int, n_fourier_feats: int, scales: List[int]):
|
|
super().__init__()
|
|
self.input_dim = input_dim
|
|
self.n_fourier_feats = n_fourier_feats
|
|
self.scales = scales
|
|
|
|
n_scale_feats = n_fourier_feats // (2 * len(scales))
|
|
assert n_scale_feats * 2 * len(scales) == n_fourier_feats, \
|
|
f"n_fourier_feats: {n_fourier_feats} must be divisible by 2 * len(scales) = {2 * len(scales)}"
|
|
B_size = (input_dim, n_scale_feats)
|
|
B = torch.cat([torch.randn(B_size) * scale for scale in scales], dim=1)
|
|
self.register_buffer('B', B)
|
|
|
|
def forward(self, x: Tensor) -> Tensor:
|
|
assert x.dim() >= 2, f"Expected 2 or more dimensional input (got {x.dim()}D input)"
|
|
time, dim = x.shape[-2], x.shape[-1]
|
|
|
|
assert dim == self.input_dim, \
|
|
f"Expected input to have {self.input_dim} channels (got {dim} channels)"
|
|
|
|
x = torch.einsum('... t n, n d -> ... t d', [x, self.B])
|
|
x = 2 * math.pi * x
|
|
return torch.cat([torch.sin(x), torch.cos(x)], dim=-1)
|