mirror of
https://github.com/wassname/seq2seq-time.git
synced 2026-06-27 16:31:46 +08:00
mean baseline
This commit is contained in:
@@ -13,3 +13,16 @@ class BaselineLast(nn.Module):
|
||||
mean = past_y[:, -1:].repeat(1, S, 1)
|
||||
std = (self.std * 1.0).repeat(1, S, 1)
|
||||
return torch.distributions.Normal(mean, std), {}
|
||||
|
||||
class BaselineMean(nn.Module):
|
||||
"""Simple model that predicts mean with learnable constant uncertainty."""
|
||||
def __init__(self):
|
||||
super().__init__()
|
||||
self.std = nn.Parameter(torch.tensor(1.))
|
||||
|
||||
def forward(self, past_x, past_y, future_x, future_y=None):
|
||||
device = next(self.parameters()).device
|
||||
B, S, F = future_x.shape
|
||||
mean = past_y.mean(1, keepdim=True).repeat(1, S, 1)
|
||||
std = (self.std * 1.0).repeat(1, S, 1)
|
||||
return torch.distributions.Normal(mean, std), {}
|
||||
|
||||
@@ -64,9 +64,8 @@ class TransformerSeq2Seq(nn.Module):
|
||||
# In transformers the memory and future_x need to be the same length. Lets use a permutation invariant agg on the context
|
||||
# Then expand it, so it's available as we decode, conditional on future_x
|
||||
# (C, B, emb_dim) -> (B, emb_dim) -> (T, B, emb_dim)
|
||||
# In transformers the memory and future_x need to be the same length. Lets use a permutation invariant agg on the context
|
||||
# Then expand it, so it's available as we decode, conditional on future_x
|
||||
memory = memory.max(dim=0, keepdim=True)[0].expand_as(future_x)
|
||||
S, B, H = future_x.shape
|
||||
memory = memory.max(dim=0, keepdim=True)[0].repeat(1, S, 1)
|
||||
outputs = self.decoder(future_x, memory, tgt_key_padding_mask=tgt_key_padding_mask)
|
||||
|
||||
# [T, B, emb_dim] -> [B, T, emb_dim]
|
||||
|
||||
Reference in New Issue
Block a user