fix indexing bugs

This commit is contained in:
wassname
2023-05-08 20:51:12 +08:00
committed by GitHub
parent 075ccf7c3e
commit d38636419b
+5 -3
View File
@@ -41,13 +41,15 @@ class Seq2SeqDataSet(torch.utils.data.Dataset):
self.columns_past = columns_past
# For speed
self._icol_blank = [df.drop(columns = columns_target).columns.tolist().index(n) for n in columns_past]
self._icol_blank = [df.drop(columns = columns_target).columns.tolist().index(n) for n in columns_past] # the columns to blank for future x
self._x = self.df.drop(columns = self.columns_target).values
self._y = self.df[columns_target].values
# Sometimes we want to have it shuffled, but the same each time
np.random.seed(42)
self._rand_index = np.random.permutation(len(self))
self.index = self.df.index[self.window_past:-1]
def get_components(self, i):
"""Get past and future rows."""
@@ -70,7 +72,7 @@ class Seq2SeqDataSet(torch.utils.data.Dataset):
y_future = y[self.window_past:]
# Stop it cheating by using future weather measurements. Fill in with last value
x_future[:, self._icol_blank] = x_past[0, self._icol_blank]
x_future[:, self._icol_blank] = x_past[-1, self._icol_blank]
# x_future[:, self._icol_blank] = 0
return x_past, y_past, x_future, y_future
@@ -113,7 +115,7 @@ class Seq2SeqDataSet(torch.utils.data.Dataset):
def __repr__(self):
t = self.df.index
return f'<{type(self).__name__}(shape={self.df.shape}, times={t[0]} to {t[1]})>'
return f'<{type(self).__name__}(shape={self.df.shape}, times={t[0]} to {t[-1]})>'
class Seq2SeqDataSets(torch.utils.data.Dataset):