mirror of
https://github.com/wassname/seq2seq-time.git
synced 2026-07-02 01:03:38 +08:00
misc
This commit is contained in:
@@ -28,6 +28,10 @@ class RegressionForecastData:
|
||||
|
||||
# Check processing
|
||||
self.check()
|
||||
|
||||
@property
|
||||
def columns_past(self):
|
||||
return set(self.df.columns)-set(self.columns_forecast)-set(self.columns_target)
|
||||
|
||||
def download(self) -> pd.DataFrame:
|
||||
"""Implement this method to download data and return raw df"""
|
||||
@@ -54,8 +58,8 @@ class RegressionForecastData:
|
||||
|
||||
def to_datasets(self, window_past: int, window_future: int, valid:bool=False) -> Tuple[Seq2SeqDataSet, Seq2SeqDataSet]:
|
||||
"""Convert to torch datasets"""
|
||||
ds_train = Seq2SeqDataSet(df_train, window_past=window_past, window_future=window_future, columns_target=self.columns_target, columns_past=self.columns_past)
|
||||
ds_test = Seq2SeqDataSet(df_test, window_past=window_past, window_future=window_future, columns_target=self.columns_target, columns_past=self.columns_past)
|
||||
ds_train = Seq2SeqDataSet(self.df_train, window_past=window_past, window_future=window_future, columns_target=self.columns_target, columns_past=self.columns_past)
|
||||
ds_test = Seq2SeqDataSet(self.df_test, window_past=window_past, window_future=window_future, columns_target=self.columns_target, columns_past=self.columns_past)
|
||||
return ds_train, ds_test
|
||||
|
||||
def __repr__(self):
|
||||
|
||||
Reference in New Issue
Block a user