mirror of
https://github.com/wassname/seq2seq-time.git
synced 2026-06-27 18:06:49 +08:00
plots
This commit is contained in:
+505
-180
File diff suppressed because one or more lines are too long
@@ -31,6 +31,7 @@ class Seq2SeqDataSet(torch.utils.data.Dataset):
|
||||
assert df.index.freq is not None, 'should have freq'
|
||||
assert_no_objects(df)
|
||||
|
||||
self.freq = self.df.index.freq
|
||||
self.df = df.dropna(subset=columns_target).ffill()
|
||||
|
||||
self.window_past = window_past
|
||||
@@ -93,7 +94,8 @@ class Seq2SeqDataSet(torch.utils.data.Dataset):
|
||||
y_future = pd.DataFrame(y_future, columns=self.columns_target, index=t_future)
|
||||
return x_past, y_past, x_future, y_future
|
||||
|
||||
|
||||
def show_batches(self, i=0):
|
||||
raise Exception('not implemented')
|
||||
|
||||
def __len__(self):
|
||||
return len(self._x) - (self.window_past + self.window_future)
|
||||
|
||||
@@ -49,6 +49,7 @@ def predict(model, ds_test, batch_size, device='cpu', scaler=None):
|
||||
"y_true": (["t_source", "t_ahead",], y_future),
|
||||
},
|
||||
coords={"t_source": t_source, "t_ahead": t_ahead, "t_behind": t_behind},
|
||||
attrs={'freq': ds_test.freq, "model": str(model), "targets": ds_test.columns_target}
|
||||
)
|
||||
xrs.append(xr_out)
|
||||
|
||||
|
||||
Reference in New Issue
Block a user