This commit is contained in:
wassname
2020-10-29 21:29:43 +08:00
parent 0326f6d33c
commit b7a80f187c
3 changed files with 509 additions and 181 deletions
File diff suppressed because one or more lines are too long
+3 -1
View File
@@ -31,6 +31,7 @@ class Seq2SeqDataSet(torch.utils.data.Dataset):
assert df.index.freq is not None, 'should have freq'
assert_no_objects(df)
self.freq = self.df.index.freq
self.df = df.dropna(subset=columns_target).ffill()
self.window_past = window_past
@@ -93,7 +94,8 @@ class Seq2SeqDataSet(torch.utils.data.Dataset):
y_future = pd.DataFrame(y_future, columns=self.columns_target, index=t_future)
return x_past, y_past, x_future, y_future
def show_batches(self, i=0):
raise Exception('not implemented')
def __len__(self):
return len(self._x) - (self.window_past + self.window_future)
+1
View File
@@ -49,6 +49,7 @@ def predict(model, ds_test, batch_size, device='cpu', scaler=None):
"y_true": (["t_source", "t_ahead",], y_future),
},
coords={"t_source": t_source, "t_ahead": t_ahead, "t_behind": t_behind},
attrs={'freq': ds_test.freq, "model": str(model), "targets": ds_test.columns_target}
)
xrs.append(xr_out)