# Copyright (c) 2022, salesforce.com, inc. # All rights reserved. # SPDX-License-Identifier: BSD-3-Clause # For full license text, see the LICENSE file in the repo root or https://opensource.org/licenses/BSD-3-Clause from typing import Optional import gin import torch import torch.nn as nn from torch import Tensor from einops import rearrange, repeat, reduce from models.modules.inrplus2 import INRPlus2 from models.modules.regressors import RidgeRegressor @gin.configurable() def deeptime2(datetime_feats: int, layer_size: int, inr_layers: int, n_fourier_feats: int, scales: float): return DeepTIMe2(datetime_feats, layer_size, inr_layers, n_fourier_feats, scales) class DeepTIMe2(nn.Module): def __init__(self, datetime_feats: int, layer_size: int, inr_layers: int, n_fourier_feats: int, scales: float): super().__init__() self.inr = INRPlus2(in_feats=datetime_feats + 1, layers=inr_layers, layer_size=layer_size, n_fourier_feats=n_fourier_feats, scales=scales) self.adaptive_weights = RidgeRegressor() self.datetime_feats = datetime_feats self.inr_layers = inr_layers self.layer_size = layer_size self.n_fourier_feats = n_fourier_feats self.scales = scales def forward(self, x: Tensor, x_time: Tensor, y_time: Tensor) -> Tensor: tgt_horizon_len = y_time.shape[1] batch_size, lookback_len, _ = x.shape coords = self.get_coords(lookback_len, tgt_horizon_len).to(x.device) if y_time.shape[-1] != 0: time = torch.cat([x_time, y_time], dim=1) coords = repeat(coords, '1 t 1 -> b t 1', b=time.shape[0]) coords = torch.cat([coords, time], dim=-1) time_reprs = self.inr(coords) else: time_reprs = repeat(self.inr(coords), '1 t d -> b t d', b=batch_size) lookback_reprs = time_reprs[:, :-tgt_horizon_len] horizon_reprs = time_reprs[:, -tgt_horizon_len:] w, b = self.adaptive_weights(lookback_reprs, x) preds = self.forecast(horizon_reprs, w, b) return preds def forecast(self, inp: Tensor, w: Tensor, b: Tensor) -> Tensor: return torch.einsum('... d o, ... t d -> ... t o', [w, inp]) + b def get_coords(self, lookback_len: int, horizon_len: int) -> Tensor: coords = torch.linspace(0, 1, lookback_len + horizon_len) return rearrange(coords, 't -> 1 t 1')