mirror of
https://github.com/wassname/seq2seq-time.git
synced 2026-06-27 18:06:49 +08:00
model size
This commit is contained in:
+28159
-1581
File diff suppressed because one or more lines are too long
@@ -70,8 +70,12 @@ from tqdm.auto import tqdm
|
||||
import pytorch_lightning as pl
|
||||
# -
|
||||
|
||||
|
||||
|
||||
import warnings
|
||||
warnings.simplefilter('once')
|
||||
warnings.simplefilter(action='ignore', category=FutureWarning)
|
||||
warnings.simplefilter(action='ignore', category=DeprecationWarning)
|
||||
|
||||
from seq2seq_time.data.dataset import Seq2SeqDataSet, Seq2SeqDataSets
|
||||
from seq2seq_time.predict import predict, predict_multi
|
||||
@@ -84,7 +88,7 @@ import logging, sys
|
||||
import holoviews as hv
|
||||
from holoviews import opts
|
||||
from holoviews.operation.datashader import datashade, dynspread
|
||||
hv.extension('bokeh')
|
||||
hv.extension('bokeh', inline=False)
|
||||
from seq2seq_time.visualization.hv_ggplot import ggplot_theme
|
||||
hv.renderer('bokeh').theme = ggplot_theme
|
||||
|
||||
@@ -96,9 +100,6 @@ hv.renderer('bokeh').theme = ggplot_theme
|
||||
# -
|
||||
|
||||
|
||||
import warnings
|
||||
warnings.filterwarnings("ignore")
|
||||
|
||||
# ## Parameters
|
||||
|
||||
# +
|
||||
@@ -151,7 +152,7 @@ def hv_plot_true(d: xr.Dataset):
|
||||
now=pd.Timestamp(d.t_source.squeeze().values)
|
||||
|
||||
p = p.opts(
|
||||
ylabel=ds_preds.attrs['targets'],
|
||||
ylabel=str(ds_preds.attrs['targets']),
|
||||
xlabel=f'{now}'
|
||||
)
|
||||
|
||||
@@ -220,13 +221,43 @@ def plot_hist(trainer):
|
||||
pass
|
||||
|
||||
|
||||
df_hist = plot_hist(trainer)
|
||||
df_hist
|
||||
# +
|
||||
def df_bold_min(data):
|
||||
'''
|
||||
highlight the maximum in a Series or DataFrame
|
||||
|
||||
|
||||
Usage:
|
||||
`df.style.apply(df_bold_min)`
|
||||
'''
|
||||
attr = 'font-weight: bold'
|
||||
#remove % and cast to float
|
||||
data = data.replace('%','', regex=True).astype(float)
|
||||
if data.ndim == 1: # Series from .apply(axis=0) or axis=1
|
||||
is_min = data == data.min()
|
||||
return [attr if v else '' for v in is_min]
|
||||
else: # from .apply(axis=None)
|
||||
is_min = data == data.min().min()
|
||||
return pd.DataFrame(np.where(is_min, attr, ''),
|
||||
index=data.index, columns=data.columns)
|
||||
|
||||
def display_results(results, metric='nll', strformat="{:2.2f}"):
|
||||
df_results = pd.concat({k:pd.DataFrame(v) for k,v in results.items()}).T
|
||||
df_results = df_results.rename_axis(index='models', columns=metric)
|
||||
|
||||
# display metric
|
||||
display(df_results
|
||||
.xs(metric, axis=1, level=1)
|
||||
.style.format(strformat)
|
||||
.apply(df_bold_min)
|
||||
)
|
||||
return df_results
|
||||
|
||||
|
||||
# -
|
||||
|
||||
# ## Datasets
|
||||
|
||||
|
||||
|
||||
# +
|
||||
from seq2seq_time.data.data import IMOSCurrentsVel, AppliancesEnergyPrediction, BejingPM25, GasSensor, MetroInterstateTraffic
|
||||
|
||||
@@ -234,6 +265,12 @@ datasets = [BejingPM25, GasSensor, AppliancesEnergyPrediction, MetroInterstateTr
|
||||
datasets
|
||||
# -
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
# View train, test, val splits
|
||||
l = hv.Layout()
|
||||
for dataset in datasets:
|
||||
@@ -249,8 +286,7 @@ for dataset in datasets:
|
||||
datashade(hv.Scatter(d.df_test[d.columns_target[0]]),
|
||||
cmap='blue'))
|
||||
p = p.opts(title=f"{dataset}")
|
||||
l += p
|
||||
l.cols(1)
|
||||
display(p)
|
||||
|
||||
# ## Lightning
|
||||
|
||||
@@ -316,7 +352,7 @@ from seq2seq_time.models.transformer_seq2seq import TransformerSeq2Seq
|
||||
from seq2seq_time.models.transformer_seq import TransformerSeq
|
||||
from seq2seq_time.models.neural_process import RANP
|
||||
from seq2seq_time.models.transformer_process import TransformerProcess
|
||||
from seq2seq_time.models.tcn import TemporalConvNet
|
||||
from seq2seq_time.models.tcn import TCNSeq2Seq
|
||||
# ## Plots
|
||||
# +
|
||||
import gc
|
||||
@@ -327,56 +363,80 @@ def free_mem():
|
||||
gc.collect()
|
||||
|
||||
|
||||
# -
|
||||
# +
|
||||
hidden_size = 32
|
||||
dropout=0.25
|
||||
layers=6
|
||||
nhead=8
|
||||
|
||||
models = [
|
||||
lambda: BaselineLast(),
|
||||
# lambda: TransformerAutoR(input_size,
|
||||
# output_size, hidden_out_size=32),
|
||||
lambda: RANP(input_size,
|
||||
output_size, hidden_dim=64, dropout=0.5,
|
||||
latent_dim=32, n_decoder_layers=4),
|
||||
lambda: LSTM(input_size,
|
||||
output_size,
|
||||
hidden_size=32,
|
||||
lstm_layers=3,
|
||||
lstm_dropout=0.4),
|
||||
lambda: LSTMSeq2Seq(input_size,
|
||||
output_size,
|
||||
hidden_size=64,
|
||||
lstm_layers=2,
|
||||
lstm_dropout=0.4),
|
||||
lambda: TransformerSeq2Seq(input_size,
|
||||
output_size,
|
||||
hidden_size=64,
|
||||
nhead=8,
|
||||
nlayers=4,
|
||||
attention_dropout=0.4),
|
||||
lambda: Transformer(input_size,
|
||||
output_size,
|
||||
attention_dropout=0.4,
|
||||
nhead=8,
|
||||
nlayers=6,
|
||||
hidden_size=64),
|
||||
lambda :TransformerProcess(input_size,
|
||||
output_size, hidden_size=16,
|
||||
latent_dim=8, dropout=0.5,
|
||||
nlayers=4,)
|
||||
# lambda :TemporalConvNet()
|
||||
lambda xs, ys: BaselineLast(),
|
||||
# lambda xs, ys: TransformerAutoR(xs,
|
||||
# ys, hidden_out_size=hidden_size),
|
||||
lambda xs, ys: RANP(xs,
|
||||
ys, hidden_dim=hidden_size, dropout=dropout,
|
||||
latent_dim=hidden_size//4, n_decoder_layers=layers),
|
||||
# lambda xs, ys: LSTM(xs,
|
||||
# ys,
|
||||
# hidden_size=hidden_size,
|
||||
# lstm_layers=layers,
|
||||
# lstm_dropout=dropout),
|
||||
# lambda xs, ys: LSTMSeq2Seq(xs,
|
||||
# ys,
|
||||
# hidden_size=hidden_size,
|
||||
# lstm_layers=layers,
|
||||
# lstm_dropout=dropout),
|
||||
# lambda xs, ys: TransformerSeq2Seq(xs,
|
||||
# ys,
|
||||
# hidden_size=hidden_size,
|
||||
# nhead=nhead,
|
||||
# nlayers=layers,
|
||||
# attention_dropout=dropout),
|
||||
lambda xs, ys: Transformer(xs,
|
||||
ys,
|
||||
attention_dropout=dropout,
|
||||
nhead=nhead,
|
||||
nlayers=layers,
|
||||
hidden_size=hidden_size),
|
||||
# lambda xs, ys:TransformerProcess(xs,
|
||||
# ys, hidden_size=hidden_size,
|
||||
# latent_dim=hidden_size//4, dropout=dropout,
|
||||
# nlayers=layers,)
|
||||
lambda xs, ys:TCNSeq2Seq(xs, ys, hidden_size=hidden_size, nlayers=layers, dropout=dropout)
|
||||
]
|
||||
# models
|
||||
|
||||
|
||||
|
||||
# +
|
||||
# GasSensor(datasets_root)
|
||||
# -
|
||||
|
||||
|
||||
|
||||
# ## Train
|
||||
|
||||
from collections import defaultdict
|
||||
results = defaultdict(dict)
|
||||
|
||||
# +
|
||||
# Summarize each models shape and weights
|
||||
Dataset = datasets[0]
|
||||
dataset = Dataset(datasets_root)
|
||||
ds_train, ds_val, ds_test = dataset.to_datasets(window_past=window_past,
|
||||
window_future=window_future)
|
||||
dl_val = DataLoader(ds_val, batch_size=batch_size)
|
||||
x_past, y_past, x_future, y_future = next(iter(dl_val))
|
||||
xs = x_past.shape[-1]
|
||||
ys = y_future.shape[-1]
|
||||
|
||||
from seq2seq_time.torchsummaryX import summary
|
||||
sizes=[]
|
||||
for m_fn in models:
|
||||
pt_model = m_fn(xs, ys)
|
||||
model_name = type(pt_model).__name__
|
||||
with torch.no_grad():
|
||||
df_summary, df_total = summary(pt_model, x_past, y_past, x_future, y_future, print_summary=False)
|
||||
sizes.append(df_total.rename(columns={'Totals':model_name}))
|
||||
df_model_sizes = pd.concat(sizes, 1)
|
||||
df_model_sizes
|
||||
# -
|
||||
|
||||
|
||||
|
||||
from seq2seq_time.metrics import rmse, smape
|
||||
@@ -400,6 +460,7 @@ for Dataset in datasets:
|
||||
pin_memory=num_workers == 0,
|
||||
num_workers=num_workers)
|
||||
dl_val = DataLoader(ds_val,
|
||||
shuffle=True,
|
||||
batch_size=batch_size,
|
||||
num_workers=num_workers)
|
||||
|
||||
@@ -411,10 +472,11 @@ for Dataset in datasets:
|
||||
print(dataset_name, model_name)
|
||||
|
||||
# Wrap in lightning
|
||||
patience = 3
|
||||
patience = 5
|
||||
model = PL_MODEL(pt_model,
|
||||
lr=3e-4, patience=patience,
|
||||
weight_decay=4e-5).to(device)
|
||||
# weight_decay=4e-5
|
||||
).to(device)
|
||||
|
||||
# Trainer
|
||||
trainer = pl.Trainer(
|
||||
@@ -424,11 +486,11 @@ for Dataset in datasets:
|
||||
amp_level='O1',
|
||||
precision=16,
|
||||
|
||||
limit_train_batches=300,
|
||||
limit_val_batches=30,
|
||||
limit_train_batches=500,
|
||||
limit_val_batches=150,
|
||||
logger=CSVLogger("../outputs", name=f'{dataset_name}_{model_name}'),
|
||||
callbacks=[
|
||||
EarlyStopping(monitor='loss/val', patience=patience * 2, verbose=True),
|
||||
EarlyStopping(monitor='loss/val', patience=patience * 2),
|
||||
],
|
||||
)
|
||||
|
||||
@@ -455,8 +517,7 @@ for Dataset in datasets:
|
||||
nll=ds_preds.nll.mean().item()
|
||||
)
|
||||
results[dataset_name][model_name] = metrics
|
||||
df_results = pd.concat({k:pd.DataFrame(v) for k,v in results.items()})
|
||||
display(df_results)
|
||||
display_results(results, 'nll')
|
||||
|
||||
dset_to_nc(ds_preds, Path(trainer.logger.experiment.log_dir)/'ds_preds.nc')
|
||||
model.cpu()
|
||||
@@ -465,36 +526,23 @@ for Dataset in datasets:
|
||||
|
||||
df_results = pd.concat({k:pd.DataFrame(v) for k,v in results.items()})
|
||||
display(df_results)
|
||||
|
||||
|
||||
# -
|
||||
|
||||
# # Leaderboard
|
||||
|
||||
def bold_min(data):
|
||||
'''
|
||||
highlight the maximum in a Series or DataFrame
|
||||
'''
|
||||
attr = 'font-weight: bold'
|
||||
#remove % and cast to float
|
||||
data = data.replace('%','', regex=True).astype(float)
|
||||
if data.ndim == 1: # Series from .apply(axis=0) or axis=1
|
||||
is_min = data == data.min()
|
||||
return [attr if v else '' for v in is_min]
|
||||
else: # from .apply(axis=None)
|
||||
is_min = data == data.min().min()
|
||||
return pd.DataFrame(np.where(is_min, attr, ''),
|
||||
index=data.index, columns=data.columns)
|
||||
|
||||
|
||||
print(f'Negative Log-Likelihood (NLL).\nover {window_future} steps')
|
||||
d=df_results.xs('nll', level=1).T.round(2)
|
||||
d.style.apply(bold_min)
|
||||
# df_results = pd.concat({k:pd.DataFrame(v) for k,v in results.items()})
|
||||
df_results
|
||||
results
|
||||
|
||||
# # Leaderboard
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
print(f'Symmetric mean absolute percentage error (SMAPE)\nover {window_future} steps')
|
||||
d=df_results.xs('smape', level=1).T.round(2)
|
||||
d.style.apply(bold_min)
|
||||
|
||||
display_results(results, 'nll')
|
||||
# # Plots
|
||||
|
||||
|
||||
@@ -505,7 +553,7 @@ d.style.apply(bold_min)
|
||||
|
||||
# # plots
|
||||
# Load saved preds
|
||||
results = defaultdict(dict)
|
||||
ds_predss = defaultdict(dict)
|
||||
for Dataset in datasets:
|
||||
dataset_name = Dataset.__name__
|
||||
for m_fn in models:
|
||||
@@ -517,18 +565,18 @@ for Dataset in datasets:
|
||||
fs = sorted(save_dir.glob("**/ds_preds.nc"))
|
||||
if len(fs)>0:
|
||||
ds_preds = xr.open_dataset(fs[-1])
|
||||
results[dataset_name][model_name] = ds_preds
|
||||
ds_predss[dataset_name][model_name] = ds_preds
|
||||
# -
|
||||
|
||||
data_i = 100
|
||||
|
||||
# Plot mean of predictions
|
||||
n = hv.Layout()
|
||||
for dataset in results.keys():
|
||||
d = next(iter(results[dataset].values())).isel(t_source=data_i)
|
||||
for dataset in ds_predss.keys():
|
||||
d = next(iter(ds_predss[dataset].values())).isel(t_source=data_i)
|
||||
p = hv_plot_true(d)
|
||||
for model in results[dataset].keys():
|
||||
ds_preds = results[dataset][model]
|
||||
ds_preds = ds_predss[dataset][model]
|
||||
d = ds_preds.isel(t_source=data_i)
|
||||
p *= hv_plot_pred(d).relabel(label=f"{model}")
|
||||
n += p.opts(title=dataset, legend_position='top_left')
|
||||
@@ -536,8 +584,8 @@ n.cols(1).opts(shared_axes=False)
|
||||
|
||||
dataset='BejingPM25'
|
||||
n = hv.Layout()
|
||||
for i, model in enumerate(results[dataset].keys()):
|
||||
ds_preds = results[dataset][model]
|
||||
for i, model in enumerate(ds_predss[dataset].keys()):
|
||||
ds_preds = ds_predss[dataset][model]
|
||||
d = ds_preds.isel(t_source=data_i)
|
||||
p = hv_plot_true(d)
|
||||
p *= hv_plot_pred(d).relabel('pred')
|
||||
@@ -551,3 +599,7 @@ plot_performance(ds_preds, full=True)
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
Reference in New Issue
Block a user