Files
DeepTime/models/modules/inrplus2.py
T
2022-11-23 12:02:22 +08:00

46 lines
2.0 KiB
Python

# Copyright (c) 2022, salesforce.com, inc.
# All rights reserved.
# SPDX-License-Identifier: BSD-3-Clause
# For full license text, see the LICENSE file in the repo root or https://opensource.org/licenses/BSD-3-Clause
from typing import Optional
import torch
import torch.nn as nn
from torch import Tensor
from models.modules.feature_transforms import GaussianFourierFeatureTransform
from tsai.models.InceptionTimePlus import InceptionTimePlus
from .causalinception import CausalInceptionTimePlus, CausalConv1d, Conv
def custom_head(head_nf, c_out, seq_len):
return nn.Sequential(
# CausalConv1d(head_nf, c_out, 1, bias=False, norm="Spectral")
Conv(head_nf, c_out, 1, bias=False, norm="Spectral"),
)
class INRPlus2(nn.Module):
def __init__(self, in_feats: int, out_feats:int ,layers: int, layer_size: int, n_fourier_feats: int, scales: float,
dropout: Optional[float] = 0.5, bn=False, *args, **kwargs):
super().__init__()
self.n_fourier_feats = n_fourier_feats
self.features = nn.Linear(in_feats, in_feats) if n_fourier_feats == 0 \
else GaussianFourierFeatureTransform(in_feats, n_fourier_feats, scales)
in_size = in_feats if n_fourier_feats == 0 \
else n_fourier_feats+in_feats
self.layers = CausalInceptionTimePlus(
in_size, out_feats, seq_len=None, nf=layer_size, depth=layers,
flatten=False, concat_pool=False, fc_dropout=dropout, conv_dropout=dropout/4, bn=bn, y_range=None, custom_head=custom_head, ks=[139, 19, 3], dilation=2, *args, **kwargs
)
# layers = [INRPlusLayer(in_size, layer_size, dropout=dropout)] + \
# [INRPlusLayer(layer_size, layer_size, dropout=dropout) for _ in range(layers - 1)]
# self.layers = nn.Sequential(*layers)
def forward(self, x: Tensor) -> Tensor:
f = self.features(x)
if self.n_fourier_feats>0:
f = torch.concat([f, x], -1)
return self.layers(f.permute((0, 2, 1))).permute((0, 2, 1))