mirror of
https://github.com/wassname/seq2seq-time.git
synced 2026-06-27 21:38:38 +08:00
29 lines
1022 B
Python
29 lines
1022 B
Python
import torch
|
|
from torch import nn
|
|
from torch.nn import functional as F
|
|
|
|
class BaselineLast(nn.Module):
|
|
def __init__(self):
|
|
super().__init__()
|
|
self.std = nn.Parameter(torch.tensor(1.))
|
|
|
|
def forward(self, past_x, past_y, future_x, future_y=None):
|
|
device = next(self.parameters()).device
|
|
B, S, F = future_x.shape
|
|
mean = past_y[:, -1:].repeat(1, S, 1)
|
|
std = (self.std * 1.0).repeat(1, S, 1)
|
|
return torch.distributions.Normal(mean, std), {}
|
|
|
|
class BaselineMean(nn.Module):
|
|
"""Simple model that predicts mean with learnable constant uncertainty."""
|
|
def __init__(self):
|
|
super().__init__()
|
|
self.std = nn.Parameter(torch.tensor(1.))
|
|
|
|
def forward(self, past_x, past_y, future_x, future_y=None):
|
|
device = next(self.parameters()).device
|
|
B, S, F = future_x.shape
|
|
mean = past_y.mean(1, keepdim=True).repeat(1, S, 1)
|
|
std = (self.std * 1.0).repeat(1, S, 1)
|
|
return torch.distributions.Normal(mean, std), {}
|