from typing import Optional import gin import torch import torch.nn as nn from torch import Tensor from einops import rearrange, repeat, reduce from models.modules.inr import INR from models.modules.regressors import RidgeRegressor @gin.configurable() def deeptime(datetime_feats: int, layer_size: int, inr_layers: int, n_fourier_feats: int, scales: float): return DeepTIMe(datetime_feats, layer_size, inr_layers, n_fourier_feats, scales) class DeepTIMe(nn.Module): def __init__(self, datetime_feats: int, layer_size: int, inr_layers: int, n_fourier_feats: int, scales: float): super().__init__() self.inr = INR(in_feats=datetime_feats + 1, layers=inr_layers, layer_size=layer_size, n_fourier_feats=n_fourier_feats, scales=scales) self.adaptive_weights = RidgeRegressor() self.datetime_feats = datetime_feats self.inr_layers = inr_layers self.layer_size = layer_size self.n_fourier_feats = n_fourier_feats self.scales = scales def forward(self, x: Tensor, x_time: Tensor, y_time: Tensor) -> Tensor: tgt_horizon_len = y_time.shape[1] batch_size, lookback_len, _ = x.shape coords = self.get_coords(lookback_len, tgt_horizon_len).to(x.device) if y_time.shape[-1] != 0: time = torch.cat([x_time, y_time], dim=1) coords = repeat(coords, '1 t 1 -> b t 1', b=time.shape[0]) coords = torch.cat([coords, time], dim=-1) time_reprs = self.inr(coords) else: time_reprs = repeat(self.inr(coords), '1 t d -> b t d', b=batch_size) lookback_reprs = time_reprs[:, :-tgt_horizon_len] horizon_reprs = time_reprs[:, -tgt_horizon_len:] w, b = self.adaptive_weights(lookback_reprs, x) preds = self.forecast(horizon_reprs, w, b) return preds def forecast(self, inp: Tensor, w: Tensor, b: Tensor) -> Tensor: return torch.einsum('... d o, ... t d -> ... t o', [w, inp]) + b def get_coords(self, lookback_len: int, horizon_len: int) -> Tensor: coords = torch.linspace(0, 1, lookback_len + horizon_len) return rearrange(coords, 't -> 1 t 1')