diff --git a/seq2seq_time/models/baseline.py b/seq2seq_time/models/baseline.py index bbff483..42c399c 100644 --- a/seq2seq_time/models/baseline.py +++ b/seq2seq_time/models/baseline.py @@ -13,3 +13,16 @@ class BaselineLast(nn.Module): mean = past_y[:, -1:].repeat(1, S, 1) std = (self.std * 1.0).repeat(1, S, 1) return torch.distributions.Normal(mean, std), {} + +class BaselineMean(nn.Module): + """Simple model that predicts mean with learnable constant uncertainty.""" + def __init__(self): + super().__init__() + self.std = nn.Parameter(torch.tensor(1.)) + + def forward(self, past_x, past_y, future_x, future_y=None): + device = next(self.parameters()).device + B, S, F = future_x.shape + mean = past_y.mean(1, keepdim=True).repeat(1, S, 1) + std = (self.std * 1.0).repeat(1, S, 1) + return torch.distributions.Normal(mean, std), {} diff --git a/seq2seq_time/models/transformer_seq2seq.py b/seq2seq_time/models/transformer_seq2seq.py index d348d5c..76bd7c1 100644 --- a/seq2seq_time/models/transformer_seq2seq.py +++ b/seq2seq_time/models/transformer_seq2seq.py @@ -64,9 +64,8 @@ class TransformerSeq2Seq(nn.Module): # In transformers the memory and future_x need to be the same length. Lets use a permutation invariant agg on the context # Then expand it, so it's available as we decode, conditional on future_x # (C, B, emb_dim) -> (B, emb_dim) -> (T, B, emb_dim) - # In transformers the memory and future_x need to be the same length. Lets use a permutation invariant agg on the context - # Then expand it, so it's available as we decode, conditional on future_x - memory = memory.max(dim=0, keepdim=True)[0].expand_as(future_x) + S, B, H = future_x.shape + memory = memory.max(dim=0, keepdim=True)[0].repeat(1, S, 1) outputs = self.decoder(future_x, memory, tgt_key_padding_mask=tgt_key_padding_mask) # [T, B, emb_dim] -> [B, T, emb_dim]