This commit is contained in:
wassname
2020-10-26 15:33:32 +08:00
parent b40d311e0b
commit 6eda47b76f
7 changed files with 8 additions and 15866 deletions
+6 -2
View File
@@ -28,6 +28,10 @@ class RegressionForecastData:
# Check processing
self.check()
@property
def columns_past(self):
return set(self.df.columns)-set(self.columns_forecast)-set(self.columns_target)
def download(self) -> pd.DataFrame:
"""Implement this method to download data and return raw df"""
@@ -54,8 +58,8 @@ class RegressionForecastData:
def to_datasets(self, window_past: int, window_future: int, valid:bool=False) -> Tuple[Seq2SeqDataSet, Seq2SeqDataSet]:
"""Convert to torch datasets"""
ds_train = Seq2SeqDataSet(df_train, window_past=window_past, window_future=window_future, columns_target=self.columns_target, columns_past=self.columns_past)
ds_test = Seq2SeqDataSet(df_test, window_past=window_past, window_future=window_future, columns_target=self.columns_target, columns_past=self.columns_past)
ds_train = Seq2SeqDataSet(self.df_train, window_past=window_past, window_future=window_future, columns_target=self.columns_target, columns_past=self.columns_past)
ds_test = Seq2SeqDataSet(self.df_test, window_past=window_past, window_future=window_future, columns_target=self.columns_target, columns_past=self.columns_past)
return ds_train, ds_test
def __repr__(self):