mirror of
https://github.com/wassname/seq2seq-time.git
synced 2026-06-27 16:31:46 +08:00
full run
This commit is contained in:
+15846
-754
File diff suppressed because one or more lines are too long
+98
-58
@@ -376,15 +376,15 @@ def free_mem():
|
||||
# +
|
||||
# PARAMS: model
|
||||
## Some datasets are easier, so we will vary the hidden size to predict overfitting
|
||||
hidden_size={'IMOSCurrentsVel': 6, #?
|
||||
'AppliancesEnergyPrediction': 6, # ?
|
||||
hidden_size={'IMOSCurrentsVel': 8, #?
|
||||
'AppliancesEnergyPrediction': 8, # ?
|
||||
'BejingPM25': 8, # OK
|
||||
'GasSensor': 8, # OK
|
||||
'MetroInterstateTraffic': 16 # OK
|
||||
}
|
||||
dropout=0.0
|
||||
layers=6
|
||||
nhead=2
|
||||
nhead=4
|
||||
|
||||
models = [
|
||||
# lambda xs, ys: BaselineLast(),
|
||||
@@ -392,7 +392,7 @@ models = [
|
||||
lambda xs, ys, hidden_size: Transformer(xs,
|
||||
ys,
|
||||
attention_dropout=dropout,
|
||||
nhead=nhead*2,
|
||||
nhead=nhead,
|
||||
nlayers=layers,
|
||||
hidden_size=hidden_size),
|
||||
|
||||
@@ -403,7 +403,7 @@ models = [
|
||||
lambda xs, ys, hidden_size:TCNSeq(xs, ys, hidden_size=hidden_size, nlayers=layers, dropout=dropout, kernel_size=2),
|
||||
lambda xs, ys, hidden_size: RANP(xs,
|
||||
ys, hidden_dim=hidden_size, dropout=dropout,
|
||||
latent_dim=hidden_size//2, n_decoder_layers=layers),
|
||||
latent_dim=hidden_size//2, n_decoder_layers=layers, n_latent_encoder_layers=layers, n_det_encoder_layers=layers),
|
||||
lambda xs, ys, hidden_size: TransformerSeq2Seq(xs,
|
||||
ys,
|
||||
hidden_size=hidden_size,
|
||||
@@ -419,7 +419,7 @@ models = [
|
||||
lambda xs, ys, hidden_size: LSTMSeq2Seq(xs,
|
||||
ys,
|
||||
hidden_size=hidden_size,
|
||||
lstm_layers=layers,
|
||||
lstm_layers=layers//2,
|
||||
lstm_dropout=dropout),
|
||||
lambda xs, ys, hidden_size: CrossAttention(xs,
|
||||
ys,
|
||||
@@ -480,6 +480,49 @@ max_iters=20000
|
||||
tensorboard_dir = Path(f"../outputs/{timestamp}").resolve()
|
||||
print(f'For tensorboard run:\ntensorboard --logdir="{tensorboard_dir}"')
|
||||
|
||||
# +
|
||||
# DEBUG: sanity check
|
||||
|
||||
for Dataset in datasets:
|
||||
dataset_name = Dataset.__name__
|
||||
dataset = Dataset(datasets_root)
|
||||
ds_train, ds_val, ds_test = dataset.to_datasets(window_past=window_past,
|
||||
window_future=window_future)
|
||||
|
||||
# Init data
|
||||
x_past, y_past, x_future, y_future = ds_train.get_rows(10)
|
||||
xs = x_past.shape[-1]
|
||||
ys = y_future.shape[-1]
|
||||
|
||||
# Loaders
|
||||
dl_train = DataLoader(ds_train,
|
||||
batch_size=batch_size,
|
||||
shuffle=True,
|
||||
pin_memory=num_workers == 0,
|
||||
num_workers=num_workers)
|
||||
dl_val = DataLoader(ds_val,
|
||||
shuffle=True,
|
||||
batch_size=batch_size,
|
||||
num_workers=num_workers)
|
||||
|
||||
for m_fn in models:
|
||||
free_mem()
|
||||
pt_model = m_fn(xs, ys, hidden_size[dataset_name])
|
||||
model_name = type(pt_model).__name__
|
||||
print(timestamp, dataset_name, model_name)
|
||||
|
||||
# Wrap in lightning
|
||||
model = PL_MODEL(pt_model,
|
||||
lr=3e-4
|
||||
).to(device)
|
||||
trainer = pl.Trainer(
|
||||
fast_dev_run=True,
|
||||
# GPU
|
||||
gpus=1,
|
||||
amp_level='O1',
|
||||
precision=16,
|
||||
)
|
||||
|
||||
# +
|
||||
|
||||
results = defaultdict(dict)
|
||||
@@ -630,7 +673,7 @@ for dataset in ds_predss.keys():
|
||||
n += p.opts(title=dataset, legend_position='top_left')
|
||||
n.cols(1).opts(shared_axes=False)
|
||||
|
||||
1/0
|
||||
|
||||
|
||||
dataset='IMOSCurrentsVel'
|
||||
data_i=844
|
||||
@@ -646,26 +689,7 @@ n.cols(1)
|
||||
|
||||
|
||||
# +
|
||||
# plot_performance(ds_preds, full=True)
|
||||
|
||||
# +
|
||||
def plot_at_i(time_i, dataset, model):
|
||||
d = ds_predss[dataset][model].isel(t_source=time_i)
|
||||
return hv_plot_prediction(d).relabel(label=f"{model}")
|
||||
|
||||
dmap = hv.DynamicMap(plot_at_i, kdims=['t_source', 'dataset', 'model'])
|
||||
t = ds_preds.t_source.values
|
||||
models = list(next(iter(ds_predss.values())).keys())
|
||||
dmap = dmap.redim.values(
|
||||
t_source=range(len(t)),
|
||||
dataset=list(ds_predss.keys()),
|
||||
model=models,
|
||||
)
|
||||
dmap.opts(framewise=True)
|
||||
# -
|
||||
|
||||
1/0
|
||||
|
||||
# 1/0
|
||||
|
||||
# +
|
||||
# Explore predictions with dynamic map
|
||||
@@ -696,36 +720,6 @@ dmap
|
||||
1/0
|
||||
|
||||
|
||||
# +
|
||||
# Explore predictions with dynamic map
|
||||
|
||||
def plot_predictions_ahead(dataset='IMOSCurrentsVel', model='', t_ahead_i=6, start=0, window_steps=1800):
|
||||
d = next(iter(ds_predss[dataset].values())).isel(t_ahead=t_ahead_i).isel(t_source=slice(start, start+window_steps))
|
||||
|
||||
p = hv.Scatter({
|
||||
'x': d.t_target,
|
||||
'y': d.y_true
|
||||
}, label='true').opts(color='black', framewise=True)
|
||||
|
||||
ds_preds = ds_predss[dataset][model]
|
||||
d = ds_preds.isel(t_ahead=t_ahead_i).isel(t_source=slice(start, start+window_steps))
|
||||
x = d.t_target
|
||||
y = d.y_pred
|
||||
s = d.y_pred_std
|
||||
p *= hv.Curve({'x': x, 'y':y}, label=model).relabel(label=f"{model}")
|
||||
p *= hv.Spread((x, y, s * 2),
|
||||
label='2*std').opts(alpha=0.5, line_width=0)
|
||||
|
||||
p = p.opts(title=f"Dataset: {dataset}, model={model}, {d.freq}*{t_ahead_i} ahead", height=250, legend_position='top', ylabel=d.targets)
|
||||
return p.opts(framewise=True)
|
||||
|
||||
dmap = hv.DynamicMap(plot_predictions_ahead, kdims=['dataset', 'model', 't_ahead_i', 'start', 'window_steps'])
|
||||
dmap = dmap.redim.values(dataset=list(ds_predss.keys()), model=models)
|
||||
dmap = dmap.redim.range(t_ahead_i=(0, window_future), start=(0, 5000), window_steps=(10, 5000))
|
||||
dmap = dmap.redim.default(t_ahead_i=10, window_steps=1000)
|
||||
dmap
|
||||
# -
|
||||
|
||||
|
||||
|
||||
# +
|
||||
@@ -752,8 +746,54 @@ def plot_predictions_ahead(dataset='IMOSCurrentsVel', t_ahead_i=6, start=0, wind
|
||||
dmap = hv.DynamicMap(plot_predictions_ahead, kdims=['dataset', 't_ahead_i', 'start', 'window_steps'])
|
||||
dmap = dmap.redim.values(dataset=list(ds_predss.keys()))
|
||||
dmap = dmap.redim.range(t_ahead_i=(0, window_future), start=(0, 5000), window_steps=(10, 5000))
|
||||
dmap = dmap.redim.default(t_ahead_i=10, window_steps=400)
|
||||
dmap = dmap.redim.default(t_ahead_i=10, window_steps=400, dataset='IMOSCurrentsVel')
|
||||
dmap
|
||||
# +
|
||||
# def plot_at_i(time_i, dataset, model):
|
||||
# d = ds_predss[dataset][model].isel(t_source=time_i)
|
||||
# return hv_plot_prediction(d).relabel(label=f"{model}")
|
||||
|
||||
# dmap = hv.DynamicMap(plot_at_i, kdims=['t_source', 'dataset', 'model'])
|
||||
# t = ds_preds.t_source.values
|
||||
# models = list(next(iter(ds_predss.values())).keys())
|
||||
# dmap = dmap.redim.values(
|
||||
# t_source=range(len(t)),
|
||||
# dataset=list(ds_predss.keys()),
|
||||
# model=models,
|
||||
# )
|
||||
# dmap.opts(framewise=True)
|
||||
|
||||
# +
|
||||
# plot_performance(ds_preds, full=True)
|
||||
|
||||
# +
|
||||
# # Explore predictions with dynamic map
|
||||
|
||||
# def plot_predictions_ahead(dataset='IMOSCurrentsVel', model='', t_ahead_i=6, start=0, window_steps=1800):
|
||||
# d = next(iter(ds_predss[dataset].values())).isel(t_ahead=t_ahead_i).isel(t_source=slice(start, start+window_steps))
|
||||
|
||||
# p = hv.Scatter({
|
||||
# 'x': d.t_target,
|
||||
# 'y': d.y_true
|
||||
# }, label='true').opts(color='black', framewise=True)
|
||||
|
||||
# ds_preds = ds_predss[dataset][model]
|
||||
# d = ds_preds.isel(t_ahead=t_ahead_i).isel(t_source=slice(start, start+window_steps))
|
||||
# x = d.t_target
|
||||
# y = d.y_pred
|
||||
# s = d.y_pred_std
|
||||
# p *= hv.Curve({'x': x, 'y':y}, label=model).relabel(label=f"{model}")
|
||||
# p *= hv.Spread((x, y, s * 2),
|
||||
# label='2*std').opts(alpha=0.5, line_width=0)
|
||||
|
||||
# p = p.opts(title=f"Dataset: {dataset}, model={model}, {d.freq}*{t_ahead_i} ahead", height=250, legend_position='top', ylabel=d.targets)
|
||||
# return p.opts(framewise=True)
|
||||
|
||||
# dmap = hv.DynamicMap(plot_predictions_ahead, kdims=['dataset', 'model', 't_ahead_i', 'start', 'window_steps'])
|
||||
# dmap = dmap.redim.values(dataset=list(ds_predss.keys()), model=models)
|
||||
# dmap = dmap.redim.range(t_ahead_i=(0, window_future), start=(0, 5000), window_steps=(10, 5000))
|
||||
# dmap = dmap.redim.default(t_ahead_i=10, window_steps=1000)
|
||||
# dmap
|
||||
# -
|
||||
|
||||
|
||||
@@ -162,6 +162,7 @@ class LatentEncoder(nn.Module):
|
||||
min_std=0.01,
|
||||
batchnorm=False,
|
||||
dropout=0,
|
||||
nhead=8,
|
||||
attention_dropout=0,
|
||||
attention_layers=2,
|
||||
):
|
||||
@@ -178,6 +179,7 @@ class LatentEncoder(nn.Module):
|
||||
self._self_attention = Attention(
|
||||
hidden_dim,
|
||||
attention_layers,
|
||||
n_heads=nhead,
|
||||
rep="identity",
|
||||
dropout=attention_dropout,
|
||||
)
|
||||
@@ -218,6 +220,7 @@ class DeterministicEncoder(nn.Module):
|
||||
attention_layers=2,
|
||||
batchnorm=False,
|
||||
dropout=0,
|
||||
nhead=8,
|
||||
attention_dropout=0,
|
||||
):
|
||||
super().__init__()
|
||||
@@ -232,12 +235,14 @@ class DeterministicEncoder(nn.Module):
|
||||
self._self_attention = Attention(
|
||||
hidden_dim,
|
||||
attention_layers,
|
||||
n_heads=nhead,
|
||||
rep="identity",
|
||||
dropout=attention_dropout,
|
||||
)
|
||||
self._cross_attention = Attention(
|
||||
hidden_dim,
|
||||
x_dim=x_dim,
|
||||
n_heads=nhead,
|
||||
attention_layers=attention_layers,
|
||||
)
|
||||
|
||||
@@ -325,6 +330,7 @@ class RANP(nn.Module):
|
||||
use_deterministic_path=True,
|
||||
min_std=0.01, # To avoid collapse use a minimum standard deviation, should be much smaller than variation in labels
|
||||
dropout=0,
|
||||
nhead=8,
|
||||
attention_dropout=0,
|
||||
batchnorm=False,
|
||||
attention_layers=2,
|
||||
@@ -353,6 +359,7 @@ class RANP(nn.Module):
|
||||
n_encoder_layers=n_latent_encoder_layers,
|
||||
attention_layers=attention_layers,
|
||||
dropout=dropout,
|
||||
nhead=nhead,
|
||||
attention_dropout=attention_dropout,
|
||||
batchnorm=batchnorm,
|
||||
min_std=min_std,
|
||||
@@ -365,6 +372,7 @@ class RANP(nn.Module):
|
||||
n_d_encoder_layers=n_det_encoder_layers,
|
||||
attention_layers=attention_layers,
|
||||
dropout=dropout,
|
||||
nhead=nhead,
|
||||
batchnorm=batchnorm,
|
||||
attention_dropout=attention_dropout,
|
||||
)
|
||||
|
||||
Reference in New Issue
Block a user