# Copyright (c) 2022, salesforce.com, inc. # All rights reserved. # SPDX-License-Identifier: BSD-3-Clause # For full license text, see the LICENSE file in the repo root or https://opensource.org/licenses/BSD-3-Clause from typing import Optional import torch import torch.nn as nn from torch import Tensor from models.modules.feature_transforms import GaussianFourierFeatureTransform class INRLayer(nn.Module): def __init__(self, input_size: int, output_size: int, dropout: Optional[float] = 0.1): super().__init__() self.input_size = input_size self.output_size = output_size self.linear = nn.Linear(input_size, output_size) self.dropout = nn.Dropout(dropout) self.norm = nn.LayerNorm(output_size) def forward(self, x: Tensor) -> Tensor: out = self._layer(x) return self.norm(out) def _layer(self, x: Tensor) -> Tensor: return self.dropout(torch.relu(self.linear(x))) class INR(nn.Module): def __init__(self, in_feats: int, layers: int, layer_size: int, n_fourier_feats: int, scales: float, dropout: Optional[float] = 0.1): super().__init__() self.features = nn.Linear(in_feats, layer_size) if n_fourier_feats == 0 \ else GaussianFourierFeatureTransform(in_feats, n_fourier_feats, scales) in_size = layer_size if n_fourier_feats == 0 \ else n_fourier_feats layers = [INRLayer(in_size, layer_size, dropout=dropout)] + \ [INRLayer(layer_size, layer_size, dropout=dropout) for _ in range(layers - 1)] self.layers = nn.Sequential(*layers) def forward(self, x: Tensor) -> Tensor: x = self.features(x) return self.layers(x)