mirror of
https://github.com/wassname/seq2seq-time.git
synced 2026-06-27 16:31:46 +08:00
multi datasets
This commit is contained in:
+1243
-1730
File diff suppressed because one or more lines are too long
@@ -55,6 +55,8 @@ import torch.utils.data
|
||||
import pandas as pd
|
||||
import numpy as np
|
||||
import matplotlib.pyplot as plt
|
||||
plt.rcParams['figure.figsize'] = (12.0, 4.0)
|
||||
plt.style.use('ggplot')
|
||||
|
||||
from pathlib import Path
|
||||
from tqdm.auto import tqdm
|
||||
@@ -158,11 +160,11 @@ dfs = list(dfs)
|
||||
# df = df.resample(freq).first().dropna() # Where empty we will backfill, this will respect causality, and mostly maintain the mean
|
||||
|
||||
# df = df.tail(int(max_rows)).copy() # Just use last X rows
|
||||
# df = pd.concat(dfs, 0)
|
||||
df = dfs[0]
|
||||
df = pd.concat(dfs[:6], 0)
|
||||
# df = dfs[0]
|
||||
# -
|
||||
|
||||
|
||||
df.LCLid.unique()
|
||||
|
||||
|
||||
|
||||
@@ -188,22 +190,22 @@ df_norm
|
||||
output_scaler = next(filter(lambda r:r[0][0] in columns_target, scaler.features))[-1]
|
||||
output_scaler
|
||||
|
||||
# # Resample
|
||||
df_norm = df_norm.resample(freq).first().fillna(0)
|
||||
dfs_norm = [d.resample(freq).first().ffill().dropna() for _, d in df_norm.groupby('LCLid')]
|
||||
len(dfs_norm)
|
||||
|
||||
# +
|
||||
# split data, with the test in the future
|
||||
n_split = -int(len(df)*0.2)
|
||||
df_train = df_norm[:n_split]
|
||||
df_test = df_norm[n_split:]
|
||||
n_split = -int(len(dfs_norm)*0.2)
|
||||
df_train = dfs_norm[:n_split]
|
||||
df_test = dfs_norm[n_split:]
|
||||
|
||||
# Show split
|
||||
df_train['energy(kWh/hh)'].plot(label='train')
|
||||
df_test['energy(kWh/hh)'].plot(label='test')
|
||||
pd.concat(df_train)['energy(kWh/hh)'].plot(label='train')
|
||||
pd.concat(df_test)['energy(kWh/hh)'].plot(label='test')
|
||||
plt.ylabel('energy(kWh/hh)')
|
||||
plt.legend()
|
||||
# -
|
||||
df_norm
|
||||
|
||||
|
||||
|
||||
# ### Dataset
|
||||
@@ -214,11 +216,11 @@ columns_blank=['visibility',
|
||||
'windBearing', 'temperature', 'dewPoint', 'pressure',
|
||||
'apparentTemperature', 'windSpeed', 'humidity']
|
||||
|
||||
ds_train = Seq2SeqDataSet(df_train,
|
||||
ds_train = Seq2SeqDataSets(df_train,
|
||||
window_past=window_past,
|
||||
window_future=window_future,
|
||||
columns_blank=columns_blank)
|
||||
ds_test = Seq2SeqDataSet(df_test,
|
||||
ds_test = Seq2SeqDataSets(df_test,
|
||||
window_past=window_past,
|
||||
window_future=window_future,
|
||||
columns_blank=columns_blank)
|
||||
@@ -228,7 +230,7 @@ print(ds_test)
|
||||
# we can treat it like an array
|
||||
ds_train[0]
|
||||
len(ds_train)
|
||||
ds_train[0]
|
||||
ds_train[-1]
|
||||
|
||||
# +
|
||||
# We can get rows
|
||||
@@ -325,7 +327,6 @@ class PL_Seq2Seq(pl.LightningModule):
|
||||
return torch.optim.Adam(self.parameters(), lr=1e-4)
|
||||
|
||||
|
||||
|
||||
# -
|
||||
|
||||
from torch.utils.data import DataLoader, random_split
|
||||
@@ -341,9 +342,9 @@ output_size = y_future.shape[-1]
|
||||
model = PL_Seq2Seq(input_size=input_size,
|
||||
input_size_decoder=input_size,
|
||||
output_size=output_size,
|
||||
hidden_size=16,
|
||||
lstm_layers=1,
|
||||
lstm_dropout=0.5).to(device)
|
||||
hidden_size=32,
|
||||
lstm_layers=2,
|
||||
lstm_dropout=0.25).to(device)
|
||||
|
||||
logger = CSVLogger("logs", name="seq2seq")
|
||||
trainer = pl.Trainer(gpus=1,
|
||||
@@ -351,7 +352,7 @@ trainer = pl.Trainer(gpus=1,
|
||||
dl_train = DataLoader(ds_train,
|
||||
batch_size=batch_size,
|
||||
shuffle=True,
|
||||
num_workers=4)
|
||||
num_workers=8)
|
||||
dl_test = DataLoader(ds_test, batch_size=batch_size, num_workers=4)
|
||||
trainer.fit(model, dl_train, dl_test)
|
||||
# -
|
||||
@@ -364,7 +365,7 @@ df_histe
|
||||
# ## Predict
|
||||
#
|
||||
|
||||
ds_preds = predict(model.to(device), ds_test, batch_size, device=device, scaler=output_scaler)
|
||||
ds_preds = predict(model.to(device), ds_test.datasets[0], batch_size, device=device, scaler=output_scaler)
|
||||
ds_preds
|
||||
|
||||
|
||||
|
||||
@@ -115,7 +115,7 @@ class Seq2SeqDataSets(torch.utils.data.Dataset):
|
||||
for d in self.datasets:
|
||||
l += len(d)
|
||||
if i < l:
|
||||
return d[i]
|
||||
return d[l-i]
|
||||
raise IndexError
|
||||
|
||||
def get_rows(self, i):
|
||||
@@ -123,7 +123,7 @@ class Seq2SeqDataSets(torch.utils.data.Dataset):
|
||||
for d in self.datasets:
|
||||
l += len(d)
|
||||
if i < l:
|
||||
return d.get_rows(i)
|
||||
return d.get_rows(l-i)
|
||||
raise IndexError
|
||||
|
||||
def __len__(self):
|
||||
|
||||
Reference in New Issue
Block a user