From 04229f8bf64738536e3576290c32783c57b567fe Mon Sep 17 00:00:00 2001 From: wassname Date: Sun, 8 Nov 2020 10:53:49 +0800 Subject: [PATCH] randomize indexing --- seq2seq_time/data/dataset.py | 21 ++++++++++++++------- 1 file changed, 14 insertions(+), 7 deletions(-) diff --git a/seq2seq_time/data/dataset.py b/seq2seq_time/data/dataset.py index 7f2ff7c..d344917 100644 --- a/seq2seq_time/data/dataset.py +++ b/seq2seq_time/data/dataset.py @@ -32,7 +32,8 @@ class Seq2SeqDataSet(torch.utils.data.Dataset): assert_no_objects(df) self.freq = df.index.freq.freqstr - self.df = df.dropna(subset=columns_target).ffill() + self.df = df.dropna(subset = columns_target).ffill() + self.window_past = window_past self.window_future = window_future @@ -43,6 +44,10 @@ class Seq2SeqDataSet(torch.utils.data.Dataset): self._icol_blank = [df.drop(columns = columns_target).columns.tolist().index(n) for n in columns_past] self._x = self.df.drop(columns = self.columns_target).values self._y = self.df[columns_target].values + + # Sometimes we want to have it shuffled, but the same each time + np.random.seed(42) + self._rand_index = np.random.permutation(len(self)) def get_components(self, i): """Get past and future rows.""" @@ -71,23 +76,25 @@ class Seq2SeqDataSet(torch.utils.data.Dataset): return x_past, y_past, x_future, y_future - def __getitem__(self, i): + def __getitem__(self, j): """This is how python implements square brackets""" - if i<0: + if j<0: # Handle negative integers - i = len(self)+i + j = len(self)+j + i = self._rand_index[j] data = self.get_components(i) # From dataframe to torch return [d.astype(np.float32) for d in data] - def get_rows(self, i): + def get_rows(self, j): """ Output pandas dataframes for display purposes. """ - if i<0: + if j<0: # Handle negative integers - i = len(self)+i + j = len(self) + j + i = self._rand_index[j] x_cols = list(self.df.drop(columns=self.columns_target).columns) + ['tsp_days', 'is_past'] x_past, y_past, x_future, y_future = self.get_components(i) t_past = self.df.index[i:i+self.window_past]