Files
DeepTime/models/DeepTIMe2.py
T
2022-11-20 14:29:25 +08:00

64 lines
2.5 KiB
Python

# Copyright (c) 2022, salesforce.com, inc.
# All rights reserved.
# SPDX-License-Identifier: BSD-3-Clause
# For full license text, see the LICENSE file in the repo root or https://opensource.org/licenses/BSD-3-Clause
from typing import Optional
import gin
import torch
import torch.nn as nn
from torch import Tensor
from einops import rearrange, repeat, reduce
from models.modules.inrplus2 import INRPlus2
from models.modules.regressors import RidgeRegressor
@gin.configurable()
def deeptime2(datetime_feats: int, layer_size: int, inr_layers: int, n_fourier_feats: int, scales: float):
return DeepTIMe2(datetime_feats, layer_size, inr_layers, n_fourier_feats, scales)
class DeepTIMe2(nn.Module):
def __init__(self, datetime_feats: int, layer_size: int, inr_layers: int, n_fourier_feats: int, scales: float, dropout: float=0.3):
super().__init__()
in_feats=datetime_feats
if n_fourier_feats:
in_feats += 1
self.inr = INRPlus2(in_feats=in_feats, layers=inr_layers, layer_size=layer_size,
n_fourier_feats=n_fourier_feats, scales=scales, dropout=dropout)
self.adaptive_weights = RidgeRegressor()
self.datetime_feats = datetime_feats
self.inr_layers = inr_layers
self.layer_size = layer_size
self.n_fourier_feats = n_fourier_feats
self.scales = scales
def forward(self, x: Tensor, x_time: Tensor, y_time: Tensor) -> Tensor:
tgt_horizon_len = y_time.shape[1]
batch_size, lookback_len, _ = x.shape
coords = self.get_coords(lookback_len, tgt_horizon_len).to(x.device)
if y_time.shape[-1] != 0:
time = torch.cat([x_time, y_time], dim=1)
coords = repeat(coords, '1 t 1 -> b t 1', b=time.shape[0])
coords = torch.cat([coords, time], dim=-1)
time_reprs = self.inr(coords)
else:
time_reprs = repeat(self.inr(coords), '1 t d -> b t d', b=batch_size)
lookback_reprs = time_reprs[:, :-tgt_horizon_len]
horizon_reprs = time_reprs[:, -tgt_horizon_len:]
w, b = self.adaptive_weights(lookback_reprs, x)
preds = self.forecast(horizon_reprs, w, b)
return preds
def forecast(self, inp: Tensor, w: Tensor, b: Tensor) -> Tensor:
return torch.einsum('... d o, ... t d -> ... t o', [w, inp]) + b
def get_coords(self, lookback_len: int, horizon_len: int) -> Tensor:
coords = torch.linspace(0, 1, lookback_len + horizon_len)
return rearrange(coords, 't -> 1 t 1')