mirror of
https://github.com/wassname/DeepTime.git
synced 2026-06-27 16:28:35 +08:00
61 lines
2.4 KiB
Python
61 lines
2.4 KiB
Python
# Copyright (c) 2022, salesforce.com, inc.
|
|
# All rights reserved.
|
|
# SPDX-License-Identifier: BSD-3-Clause
|
|
# For full license text, see the LICENSE file in the repo root or https://opensource.org/licenses/BSD-3-Clause
|
|
|
|
from typing import Optional
|
|
|
|
import gin
|
|
import torch
|
|
import torch.nn as nn
|
|
from torch import Tensor
|
|
from einops import rearrange, repeat, reduce
|
|
|
|
from models.modules.inrplus2 import INRPlus2
|
|
from models.modules.regressors import RidgeRegressor
|
|
|
|
|
|
@gin.configurable()
|
|
def deeptime2(datetime_feats: int, layer_size: int, inr_layers: int, n_fourier_feats: int, scales: float):
|
|
return DeepTIMe2(datetime_feats, layer_size, inr_layers, n_fourier_feats, scales)
|
|
|
|
|
|
class DeepTIMe2(nn.Module):
|
|
def __init__(self, datetime_feats: int, layer_size: int, inr_layers: int, n_fourier_feats: int, scales: float):
|
|
super().__init__()
|
|
self.inr = INRPlus2(in_feats=datetime_feats + 1, layers=inr_layers, layer_size=layer_size,
|
|
n_fourier_feats=n_fourier_feats, scales=scales)
|
|
self.adaptive_weights = RidgeRegressor()
|
|
|
|
self.datetime_feats = datetime_feats
|
|
self.inr_layers = inr_layers
|
|
self.layer_size = layer_size
|
|
self.n_fourier_feats = n_fourier_feats
|
|
self.scales = scales
|
|
|
|
def forward(self, x: Tensor, x_time: Tensor, y_time: Tensor) -> Tensor:
|
|
tgt_horizon_len = y_time.shape[1]
|
|
batch_size, lookback_len, _ = x.shape
|
|
coords = self.get_coords(lookback_len, tgt_horizon_len).to(x.device)
|
|
|
|
if y_time.shape[-1] != 0:
|
|
time = torch.cat([x_time, y_time], dim=1)
|
|
coords = repeat(coords, '1 t 1 -> b t 1', b=time.shape[0])
|
|
coords = torch.cat([coords, time], dim=-1)
|
|
time_reprs = self.inr(coords)
|
|
else:
|
|
time_reprs = repeat(self.inr(coords), '1 t d -> b t d', b=batch_size)
|
|
|
|
lookback_reprs = time_reprs[:, :-tgt_horizon_len]
|
|
horizon_reprs = time_reprs[:, -tgt_horizon_len:]
|
|
w, b = self.adaptive_weights(lookback_reprs, x)
|
|
preds = self.forecast(horizon_reprs, w, b)
|
|
return preds
|
|
|
|
def forecast(self, inp: Tensor, w: Tensor, b: Tensor) -> Tensor:
|
|
return torch.einsum('... d o, ... t d -> ... t o', [w, inp]) + b
|
|
|
|
def get_coords(self, lookback_len: int, horizon_len: int) -> Tensor:
|
|
coords = torch.linspace(0, 1, lookback_len + horizon_len)
|
|
return rearrange(coords, 't -> 1 t 1')
|