plotting in scratch.py and trying TCN

This commit is contained in:
wassname
2022-11-19 20:16:20 +08:00
parent 0a315b3d48
commit 90ee9c3298
11 changed files with 550 additions and 6 deletions
+60
View File
@@ -0,0 +1,60 @@
# Copyright (c) 2022, salesforce.com, inc.
# All rights reserved.
# SPDX-License-Identifier: BSD-3-Clause
# For full license text, see the LICENSE file in the repo root or https://opensource.org/licenses/BSD-3-Clause
from typing import Optional
import gin
import torch
import torch.nn as nn
from torch import Tensor
from einops import rearrange, repeat, reduce
from models.modules.inrplus2 import INRPlus2
from models.modules.regressors import RidgeRegressor
@gin.configurable()
def deeptime2(datetime_feats: int, layer_size: int, inr_layers: int, n_fourier_feats: int, scales: float):
return DeepTIMe2(datetime_feats, layer_size, inr_layers, n_fourier_feats, scales)
class DeepTIMe2(nn.Module):
def __init__(self, datetime_feats: int, layer_size: int, inr_layers: int, n_fourier_feats: int, scales: float):
super().__init__()
self.inr = INRPlus2(in_feats=datetime_feats + 1, layers=inr_layers, layer_size=layer_size,
n_fourier_feats=n_fourier_feats, scales=scales)
self.adaptive_weights = RidgeRegressor()
self.datetime_feats = datetime_feats
self.inr_layers = inr_layers
self.layer_size = layer_size
self.n_fourier_feats = n_fourier_feats
self.scales = scales
def forward(self, x: Tensor, x_time: Tensor, y_time: Tensor) -> Tensor:
tgt_horizon_len = y_time.shape[1]
batch_size, lookback_len, _ = x.shape
coords = self.get_coords(lookback_len, tgt_horizon_len).to(x.device)
if y_time.shape[-1] != 0:
time = torch.cat([x_time, y_time], dim=1)
coords = repeat(coords, '1 t 1 -> b t 1', b=time.shape[0])
coords = torch.cat([coords, time], dim=-1)
time_reprs = self.inr(coords)
else:
time_reprs = repeat(self.inr(coords), '1 t d -> b t d', b=batch_size)
lookback_reprs = time_reprs[:, :-tgt_horizon_len]
horizon_reprs = time_reprs[:, -tgt_horizon_len:]
w, b = self.adaptive_weights(lookback_reprs, x)
preds = self.forecast(horizon_reprs, w, b)
return preds
def forecast(self, inp: Tensor, w: Tensor, b: Tensor) -> Tensor:
return torch.einsum('... d o, ... t d -> ... t o', [w, inp]) + b
def get_coords(self, lookback_len: int, horizon_len: int) -> Tensor:
coords = torch.linspace(0, 1, lookback_len + horizon_len)
return rearrange(coords, 't -> 1 t 1')