Files
DeepTime/models/DeepTIMe.py
wassname 00ae1d8b8f misc
2022-11-22 15:07:42 +08:00

66 lines
2.6 KiB
Python

# Copyright (c) 2022, salesforce.com, inc.
# All rights reserved.
# SPDX-License-Identifier: BSD-3-Clause
# For full license text, see the LICENSE file in the repo root or https://opensource.org/licenses/BSD-3-Clause
from typing import Optional
import gin
import torch
import torch.nn as nn
from torch import Tensor
from einops import rearrange, repeat, reduce
from models.modules.inr import INR
from models.modules.regressors import RidgeRegressor
@gin.configurable()
def deeptime(datetime_feats: int, layer_size: int, inr_layers: int, n_fourier_feats: int, scales: float):
return DeepTIMe(datetime_feats, layer_size, inr_layers, n_fourier_feats, scales)
class DeepTIMe(nn.Module):
def __init__(self, datetime_feats: int, layer_size: int, inr_layers: int, n_fourier_feats: int, scales: float):
super().__init__()
self.inr = INR(in_feats=datetime_feats + 1, layers=inr_layers, layer_size=layer_size,
n_fourier_feats=n_fourier_feats, scales=scales)
self.adaptive_weights = RidgeRegressor()
self.datetime_feats = datetime_feats
self.inr_layers = inr_layers
self.layer_size = layer_size
self.n_fourier_feats = n_fourier_feats
self.scales = scales
def forward(self, x: Tensor, x_time: Tensor, y_time: Tensor) -> Tensor:
tgt_horizon_len = y_time.shape[1]
batch_size, lookback_len, _ = x.shape
# relative coordinates are the same for each batch, so we make them once and repeat them
coords = self.get_coords(lookback_len, tgt_horizon_len).to(x.device)
if y_time.shape[-1] != 0: # if y_time is not empty of features
time = torch.cat([x_time, y_time], dim=1)
coords = repeat(coords, '1 t 1 -> b t 1', b=time.shape[0])
coords = torch.cat([coords, time], dim=-1)
time_reprs = self.inr(coords)
else:
a = self.inr(coords)
time_reprs = repeat(a, '1 t d -> b t d', b=batch_size)
print(coords.shape, a.shape, batch_size, time_reprs.shape)
lookback_reprs = time_reprs[:, :-tgt_horizon_len]
horizon_reprs = time_reprs[:, -tgt_horizon_len:]
w, b = self.adaptive_weights(lookback_reprs, x)
preds = self.forecast(horizon_reprs, w, b)
return preds
def forecast(self, inp: Tensor, w: Tensor, b: Tensor) -> Tensor:
return torch.einsum('... d o, ... t d -> ... t o', [w, inp]) + b
def get_coords(self, lookback_len: int, horizon_len: int) -> Tensor:
coords = torch.linspace(0, 1, lookback_len + horizon_len)
return rearrange(coords, 't -> 1 t 1')