mean baseline

This commit is contained in:
wassname
2020-11-01 15:49:57 +08:00
parent 27d4cde5bd
commit 4aa8b3a52e
2 changed files with 15 additions and 3 deletions
+13
View File
@@ -13,3 +13,16 @@ class BaselineLast(nn.Module):
mean = past_y[:, -1:].repeat(1, S, 1)
std = (self.std * 1.0).repeat(1, S, 1)
return torch.distributions.Normal(mean, std), {}
class BaselineMean(nn.Module):
"""Simple model that predicts mean with learnable constant uncertainty."""
def __init__(self):
super().__init__()
self.std = nn.Parameter(torch.tensor(1.))
def forward(self, past_x, past_y, future_x, future_y=None):
device = next(self.parameters()).device
B, S, F = future_x.shape
mean = past_y.mean(1, keepdim=True).repeat(1, S, 1)
std = (self.std * 1.0).repeat(1, S, 1)
return torch.distributions.Normal(mean, std), {}
+2 -3
View File
@@ -64,9 +64,8 @@ class TransformerSeq2Seq(nn.Module):
# In transformers the memory and future_x need to be the same length. Lets use a permutation invariant agg on the context
# Then expand it, so it's available as we decode, conditional on future_x
# (C, B, emb_dim) -> (B, emb_dim) -> (T, B, emb_dim)
# In transformers the memory and future_x need to be the same length. Lets use a permutation invariant agg on the context
# Then expand it, so it's available as we decode, conditional on future_x
memory = memory.max(dim=0, keepdim=True)[0].expand_as(future_x)
S, B, H = future_x.shape
memory = memory.max(dim=0, keepdim=True)[0].repeat(1, S, 1)
outputs = self.decoder(future_x, memory, tgt_key_padding_mask=tgt_key_padding_mask)
# [T, B, emb_dim] -> [B, T, emb_dim]