mirror of
https://github.com/wassname/DeepTime.git
synced 2026-06-27 18:40:49 +08:00
150 lines
6.3 KiB
Python
150 lines
6.3 KiB
Python
import os
|
|
from os.path import join
|
|
from typing import Optional, List, Tuple
|
|
|
|
import gin
|
|
import numpy as np
|
|
import pandas as pd
|
|
from sklearn.preprocessing import StandardScaler
|
|
from torch.utils.data import Dataset, DataLoader
|
|
|
|
from utils.time_features import get_time_features
|
|
|
|
|
|
@gin.configurable()
|
|
class ForecastDataset(Dataset):
|
|
def __init__(self,
|
|
flag: str,
|
|
horizon_len: int,
|
|
scale: bool,
|
|
cross_learn: bool,
|
|
data_path: str,
|
|
root_path: Optional[str] = 'storage/datasets',
|
|
features: Optional[str] = 'S',
|
|
target: Optional[str] = 'OT',
|
|
lookback_len: Optional[int] = None,
|
|
lookback_aux_len: Optional[int] = 0,
|
|
lookback_mult: Optional[float] = None,
|
|
time_features: Optional = [],
|
|
normalise_time_features: Optional = True,):
|
|
"""
|
|
:param flag: train/val/test flag
|
|
:param horizon_len: number of time steps in forecast horizon
|
|
:param scale: performs standard scaling
|
|
:param data_path: relative (to root_path) path to data file (.csv)
|
|
:param cross_learn: treats multivariate time series as multiple univar time series
|
|
:param root_path: path to datasets folder
|
|
:param features: multivar (M) or univar (S) forecasting
|
|
:param target: name of target variable for univar forecasting (features=S)
|
|
:param lookback_len: number of time steps in lookback window
|
|
:param lookback_aux_len: number of time steps to append to y from the lookback window
|
|
(for models with decoders which requires initialisation)
|
|
:param lookback_mult: multiplier to decide lookback window length
|
|
"""
|
|
assert flag in ('train', 'val', 'test'), \
|
|
f"flag should be one of (train, val, test)"
|
|
assert features in ('M', 'S'), \
|
|
f"features should be one of (M: multivar, S: univar)"
|
|
assert (lookback_len is not None) ^ (lookback_mult is not None), \
|
|
f"only 'lookback_len' xor 'lookback_mult' should be specified"
|
|
|
|
self.flag = flag
|
|
self.lookback_len = int(horizon_len * lookback_mult) if lookback_mult is not None else lookback_len
|
|
self.lookback_aux_len = lookback_aux_len
|
|
self.horizon_len = horizon_len
|
|
self.scale = scale
|
|
self.cross_learn = cross_learn
|
|
self.data_path = data_path
|
|
self.root_path = root_path
|
|
self.features = features
|
|
self.target = target
|
|
self.time_features = time_features
|
|
self.normalise_time_features = normalise_time_features
|
|
|
|
self.n_dims = None
|
|
self.scaler = None
|
|
self.data_x = None
|
|
self.data_y = None
|
|
self.timestamps = None
|
|
self.n_time = None
|
|
self.n_time_samples = None
|
|
self.load_data()
|
|
|
|
def load_data(self):
|
|
df_raw = pd.read_csv(join(self.root_path, self.data_path)) # (n_obs, date + n_feats)
|
|
cols = list(df_raw.columns)
|
|
cols.remove('date')
|
|
cols.remove(self.target)
|
|
df_raw = df_raw[['date'] + cols + [self.target]]
|
|
border1s, border2s, border1, border2 = self.get_borders(df_raw)
|
|
|
|
if self.features == 'M':
|
|
df_data = df_raw[cols + [self.target]]
|
|
self.n_dims = len(cols + [self.target])
|
|
elif self.features == 'S':
|
|
df_data = df_raw[[self.target]]
|
|
self.n_dims = 1
|
|
else:
|
|
raise ValueError
|
|
|
|
self.scaler = StandardScaler()
|
|
if self.scale:
|
|
train_data = df_data[border1s[0]:border2s[0]]
|
|
self.scaler.fit(train_data.values)
|
|
data = self.scaler.transform(df_data.values)
|
|
else:
|
|
data = df_data.values
|
|
|
|
self.data_x = data[border1:border2]
|
|
self.data_y = data[border1:border2]
|
|
self.timestamps = get_time_features(pd.to_datetime(df_raw.date[border1:border2].values),
|
|
normalise=self.normalise_time_features,
|
|
features=self.time_features)
|
|
self.n_time = len(self.data_x)
|
|
self.n_time_samples = self.n_time - self.lookback_len - self.horizon_len + 1
|
|
|
|
def get_borders(self, df_raw: pd.DataFrame) -> Tuple[List[int], List[int], List[int], List[int]]:
|
|
set_type = {'train': 0, 'val': 1, 'test': 2}[self.flag]
|
|
if self.data_path.startswith('ETT-small/ETTh'):
|
|
border1s = [0, 12 * 30 * 24 - self.lookback_len, 12 * 30 * 24 + 4 * 30 * 24 - self.lookback_len]
|
|
border2s = [12 * 30 * 24, 12 * 30 * 24 + 4 * 30 * 24, 12 * 30 * 24 + 8 * 30 * 24]
|
|
elif self.data_path.startswith('ETT-small/ETTm'):
|
|
border1s = [0, 12 * 30 * 24 * 4 - self.lookback_len, 12 * 30 * 24 * 4 + 4 * 30 * 24 * 4 - self.lookback_len]
|
|
border2s = [12 * 30 * 24 * 4, 12 * 30 * 24 * 4 + 4 * 30 * 24 * 4, 12 * 30 * 24 * 4 + 8 * 30 * 24 * 4]
|
|
else:
|
|
num_train = int(len(df_raw) * 0.7)
|
|
num_test = int(len(df_raw) * 0.2)
|
|
num_val = len(df_raw) - num_train - num_test
|
|
border1s = [0, num_train - self.lookback_len, len(df_raw) - num_test - self.lookback_len]
|
|
border2s = [num_train, num_train + num_val, len(df_raw)]
|
|
border1 = border1s[set_type]
|
|
border2 = border2s[set_type]
|
|
return border1s, border2s, border1, border2
|
|
|
|
def __len__(self):
|
|
if self.cross_learn:
|
|
return self.n_time_samples * self.n_dims
|
|
return self.n_time_samples
|
|
|
|
def __getitem__(self, idx: int) -> Tuple[np.ndarray, np.ndarray, np.ndarray, np.ndarray]:
|
|
if self.cross_learn:
|
|
dim_idx = idx // self.n_time_samples
|
|
dim_slice = slice(dim_idx, dim_idx + 1)
|
|
idx = idx % self.n_time_samples
|
|
else:
|
|
dim_slice = slice(None)
|
|
|
|
x_start = idx
|
|
x_end = x_start + self.lookback_len
|
|
y_start = x_end - self.lookback_aux_len
|
|
y_end = y_start + self.lookback_aux_len + self.horizon_len
|
|
|
|
x = self.data_x[x_start:x_end, dim_slice]
|
|
y = self.data_y[y_start:y_end, dim_slice]
|
|
x_time = self.timestamps[x_start:x_end]
|
|
y_time = self.timestamps[y_start:y_end]
|
|
|
|
return x, y, x_time, y_time
|
|
|
|
def inverse_transform(self, data):
|
|
return self.scaler.inverse_transform(data) |