mirror of
https://github.com/wassname/DeepTime.git
synced 2026-06-27 19:29:06 +08:00
43 lines
1.5 KiB
Python
43 lines
1.5 KiB
Python
from typing import Optional
|
|
|
|
import torch
|
|
import torch.nn as nn
|
|
from torch import Tensor
|
|
|
|
from models.modules.feature_transforms import GaussianFourierFeatureTransform
|
|
|
|
|
|
class INRLayer(nn.Module):
|
|
def __init__(self, input_size: int, output_size: int,
|
|
dropout: Optional[float] = 0.1):
|
|
super().__init__()
|
|
self.input_size = input_size
|
|
self.output_size = output_size
|
|
self.linear = nn.Linear(input_size, output_size)
|
|
self.dropout = nn.Dropout(dropout)
|
|
self.norm = nn.LayerNorm(output_size)
|
|
|
|
def forward(self, x: Tensor) -> Tensor:
|
|
out = self._layer(x)
|
|
return self.norm(out)
|
|
|
|
def _layer(self, x: Tensor) -> Tensor:
|
|
return self.dropout(torch.relu(self.linear(x)))
|
|
|
|
|
|
class INR(nn.Module):
|
|
def __init__(self, in_feats: int, layers: int, layer_size: int, n_fourier_feats: int, scales: float,
|
|
dropout: Optional[float] = 0.1):
|
|
super().__init__()
|
|
self.features = nn.Linear(in_feats, layer_size) if n_fourier_feats == 0 \
|
|
else GaussianFourierFeatureTransform(in_feats, n_fourier_feats, scales)
|
|
in_size = layer_size if n_fourier_feats == 0 \
|
|
else n_fourier_feats
|
|
layers = [INRLayer(in_size, layer_size, dropout=dropout)] + \
|
|
[INRLayer(layer_size, layer_size, dropout=dropout) for _ in range(layers - 1)]
|
|
self.layers = nn.Sequential(*layers)
|
|
|
|
def forward(self, x: Tensor) -> Tensor:
|
|
x = self.features(x)
|
|
return self.layers(x)
|