diff --git a/seq2seq_time/models/tcn.py b/seq2seq_time/models/tcn.py index baceed3..4ab5a45 100644 --- a/seq2seq_time/models/tcn.py +++ b/seq2seq_time/models/tcn.py @@ -159,11 +159,11 @@ class TCNSeq2Seq(nn.Module): nlayers=6, kernel_size=2, dropout=0.2, - embedding_dim=2, ): super().__init__() self.tcn = TemporalConvNet( - num_inputs=x_dim+y_dim, + num_inputs=x_dim + y_dim, + kernel_size=kernel_size, num_channels=[hidden_size] * nlayers, dropout=dropout) self._min_std = 0.01