diff --git a/seq2seq_time/data/dataset.py b/seq2seq_time/data/dataset.py index d344917..1550dca 100644 --- a/seq2seq_time/data/dataset.py +++ b/seq2seq_time/data/dataset.py @@ -41,13 +41,15 @@ class Seq2SeqDataSet(torch.utils.data.Dataset): self.columns_past = columns_past # For speed - self._icol_blank = [df.drop(columns = columns_target).columns.tolist().index(n) for n in columns_past] + self._icol_blank = [df.drop(columns = columns_target).columns.tolist().index(n) for n in columns_past] # the columns to blank for future x self._x = self.df.drop(columns = self.columns_target).values self._y = self.df[columns_target].values # Sometimes we want to have it shuffled, but the same each time np.random.seed(42) self._rand_index = np.random.permutation(len(self)) + + self.index = self.df.index[self.window_past:-1] def get_components(self, i): """Get past and future rows.""" @@ -70,7 +72,7 @@ class Seq2SeqDataSet(torch.utils.data.Dataset): y_future = y[self.window_past:] # Stop it cheating by using future weather measurements. Fill in with last value - x_future[:, self._icol_blank] = x_past[0, self._icol_blank] + x_future[:, self._icol_blank] = x_past[-1, self._icol_blank] # x_future[:, self._icol_blank] = 0 return x_past, y_past, x_future, y_future @@ -113,7 +115,7 @@ class Seq2SeqDataSet(torch.utils.data.Dataset): def __repr__(self): t = self.df.index - return f'<{type(self).__name__}(shape={self.df.shape}, times={t[0]} to {t[1]})>' + return f'<{type(self).__name__}(shape={self.df.shape}, times={t[0]} to {t[-1]})>' class Seq2SeqDataSets(torch.utils.data.Dataset):