# Copyright (c) 2022, salesforce.com, inc. # All rights reserved. # SPDX-License-Identifier: BSD-3-Clause # For full license text, see the LICENSE file in the repo root or https://opensource.org/licenses/BSD-3-Clause from typing import Optional import gin import torch import torch.nn as nn from torch import Tensor from einops import rearrange, repeat, reduce from models.modules.inr import INR from models.modules.regressors import RidgeRegressor @gin.configurable() def deeptime(datetime_feats: int, layer_size: int, inr_layers: int, n_fourier_feats: int, scales: float): return DeepTIMe(datetime_feats, layer_size, inr_layers, n_fourier_feats, scales) class DeepTIMe(nn.Module): def __init__(self, datetime_feats: int, layer_size: int, inr_layers: int, n_fourier_feats: int, scales: float): super().__init__() self.inr = INR(in_feats=datetime_feats + 1, layers=inr_layers, layer_size=layer_size, n_fourier_feats=n_fourier_feats, scales=scales) self.adaptive_weights = RidgeRegressor() self.datetime_feats = datetime_feats self.inr_layers = inr_layers self.layer_size = layer_size self.n_fourier_feats = n_fourier_feats self.scales = scales def forward(self, x: Tensor, x_time: Tensor, y_time: Tensor) -> Tensor: tgt_horizon_len = y_time.shape[1] batch_size, lookback_len, _ = x.shape coords = self.get_coords(lookback_len, tgt_horizon_len).to(x.device) if y_time.shape[-1] != 0: time = torch.cat([x_time, y_time], dim=1) coords = repeat(coords, '1 t 1 -> b t 1', b=time.shape[0]) coords = torch.cat([coords, time], dim=-1) time_reprs = self.inr(coords) else: time_reprs = repeat(self.inr(coords), '1 t d -> b t d', b=batch_size) lookback_reprs = time_reprs[:, :-tgt_horizon_len] horizon_reprs = time_reprs[:, -tgt_horizon_len:] w, b = self.adaptive_weights(lookback_reprs, x) preds = self.forecast(horizon_reprs, w, b) return preds def forecast(self, inp: Tensor, w: Tensor, b: Tensor) -> Tensor: return torch.einsum('... d o, ... t d -> ... t o', [w, inp]) + b def get_coords(self, lookback_len: int, horizon_len: int) -> Tensor: coords = torch.linspace(0, 1, lookback_len + horizon_len) return rearrange(coords, 't -> 1 t 1')