This commit is contained in:
wassname
2020-10-26 15:33:32 +08:00
parent b40d311e0b
commit 6eda47b76f
7 changed files with 8 additions and 15866 deletions
+6 -2
View File
@@ -28,6 +28,10 @@ class RegressionForecastData:
# Check processing
self.check()
@property
def columns_past(self):
return set(self.df.columns)-set(self.columns_forecast)-set(self.columns_target)
def download(self) -> pd.DataFrame:
"""Implement this method to download data and return raw df"""
@@ -54,8 +58,8 @@ class RegressionForecastData:
def to_datasets(self, window_past: int, window_future: int, valid:bool=False) -> Tuple[Seq2SeqDataSet, Seq2SeqDataSet]:
"""Convert to torch datasets"""
ds_train = Seq2SeqDataSet(df_train, window_past=window_past, window_future=window_future, columns_target=self.columns_target, columns_past=self.columns_past)
ds_test = Seq2SeqDataSet(df_test, window_past=window_past, window_future=window_future, columns_target=self.columns_target, columns_past=self.columns_past)
ds_train = Seq2SeqDataSet(self.df_train, window_past=window_past, window_future=window_future, columns_target=self.columns_target, columns_past=self.columns_past)
ds_test = Seq2SeqDataSet(self.df_test, window_past=window_past, window_future=window_future, columns_target=self.columns_target, columns_past=self.columns_past)
return ds_train, ds_test
def __repr__(self):
+1 -1
View File
@@ -31,7 +31,7 @@ class Seq2SeqDataSet(torch.utils.data.Dataset):
assert df.index.freq is not None, 'should have freq'
assert_no_objects(df)
self.df = df
self.df = df.dropna(subset=columns_target)
self.window_past = window_past
self.window_future = window_future
+1
View File
@@ -1,5 +1,6 @@
import uptide
import pandas as pd
import numpy as np
# https://en.wikipedia.org/wiki/Theory_of_tides#Harmonic_analysis
default_tidal_constituents = [