Files
2020-11-01 15:49:57 +08:00

29 lines
1022 B
Python

import torch
from torch import nn
from torch.nn import functional as F
class BaselineLast(nn.Module):
def __init__(self):
super().__init__()
self.std = nn.Parameter(torch.tensor(1.))
def forward(self, past_x, past_y, future_x, future_y=None):
device = next(self.parameters()).device
B, S, F = future_x.shape
mean = past_y[:, -1:].repeat(1, S, 1)
std = (self.std * 1.0).repeat(1, S, 1)
return torch.distributions.Normal(mean, std), {}
class BaselineMean(nn.Module):
"""Simple model that predicts mean with learnable constant uncertainty."""
def __init__(self):
super().__init__()
self.std = nn.Parameter(torch.tensor(1.))
def forward(self, past_x, past_y, future_x, future_y=None):
device = next(self.parameters()).device
B, S, F = future_x.shape
mean = past_y.mean(1, keepdim=True).repeat(1, S, 1)
std = (self.std * 1.0).repeat(1, S, 1)
return torch.distributions.Normal(mean, std), {}