mirror of
https://github.com/wassname/seq2seq-time.git
synced 2026-06-27 16:31:46 +08:00
working
This commit is contained in:
File diff suppressed because one or more lines are too long
@@ -369,9 +369,7 @@ def plot_prediction(ds_preds, i):
|
||||
|
||||
def plot_performance(ds_preds, full=False):
|
||||
"""Multiple plots using xr_preds"""
|
||||
print(f'mean_NLL {ds_preds.nll.mean().item():2.2f}')
|
||||
plot_prediction(ds_preds, 24)
|
||||
# plot_prediction(ds_preds, 480)
|
||||
|
||||
ds_preds.mean('t_source').plot.scatter('t_ahead_hours', 'nll') # Mean over all predictions
|
||||
n = len(ds_preds.t_source)
|
||||
@@ -481,7 +479,7 @@ from seq2seq_time.models.baseline import BaselineLast
|
||||
from seq2seq_time.models.transformer import Transformer
|
||||
from seq2seq_time.models.transformer_seq2seq import TransformerSeq2Seq
|
||||
from seq2seq_time.models.transformer_seq import TransformerSeq
|
||||
from seq2seq_time.models.anp import RANP
|
||||
from seq2seq_time.models.neural_process import RANP
|
||||
# ## Plots
|
||||
# +
|
||||
models = [
|
||||
@@ -529,7 +527,11 @@ trainer = pl.Trainer(gpus=1,
|
||||
)
|
||||
trainer.fit(model, dl_train, dl_test)
|
||||
print(plot_hist(trainer))
|
||||
ds_preds = predict(model.to(device), ds_test.datasets[0], batch_size, device=device, scaler=output_scaler)
|
||||
ds_predss = predict_multi(model.to(device),
|
||||
ds_test.datasets,
|
||||
batch_size*8,
|
||||
device=device,
|
||||
scaler=output_scaler)
|
||||
print(f'baseline nll: {ds_preds.nll.mean().item():2.2g}')
|
||||
|
||||
for pt_model in models:
|
||||
@@ -560,46 +562,35 @@ for pt_model in models:
|
||||
|
||||
|
||||
|
||||
ds_preds = predict(model.to(device),
|
||||
ds_test.datasets[0],
|
||||
batch_size,
|
||||
ds_predss = predict_multi(model.to(device),
|
||||
ds_test.datasets,
|
||||
batch_size*8,
|
||||
device=device,
|
||||
scaler=output_scaler)
|
||||
|
||||
print(name)
|
||||
print(f'mean_NLL {ds_preds.nll.mean().item():2.2f}')
|
||||
print(f'mean_NLL {ds_predss.nll.mean().item():2.2f}')
|
||||
|
||||
# Performance
|
||||
ds_preds = ds_predss.isel(block=0)
|
||||
print(plot_hist(trainer))
|
||||
plot_performance(ds_preds)
|
||||
|
||||
# %debug
|
||||
|
||||
ds_preds = predict(model.to(device),q
|
||||
|
||||
ds_test.datasets[0],
|
||||
batch_size,
|
||||
device=device,
|
||||
scaler=output_scaler)
|
||||
|
||||
# +
|
||||
# ds_predss = predict_multi(model.to(device),
|
||||
# ds_test.datasets,
|
||||
# batch_size,
|
||||
# ds_preds = predict(model.to(device),
|
||||
# ds_test.datasets[0],
|
||||
# batch_size*8,
|
||||
# device=device,
|
||||
# scaler=output_scaler)
|
||||
# -
|
||||
|
||||
ds_test.datasets[0].df.index.value_counts()
|
||||
ds_predss = predict_multi(model.to(device),
|
||||
ds_test.datasets,
|
||||
batch_size*8,
|
||||
device=device,
|
||||
scaler=output_scaler)
|
||||
|
||||
# TODO why dup?
|
||||
ds_preds.sel(t_source='2013-11-11 00:30:00')
|
||||
|
||||
# TODO why duplicates?
|
||||
d = ds_preds.isel(t_ahead=0)
|
||||
d.t_source.to_series().sort_index()#.value_counts()
|
||||
# np.unique
|
||||
# d
|
||||
ds_pred_block = ds_predss.isel(block=1)
|
||||
|
||||
# # holoviews pred
|
||||
|
||||
@@ -613,9 +604,9 @@ def plot_prediction_now(t_source):
|
||||
|
||||
# Let us pass in an int
|
||||
if isinstance(t_source, int):
|
||||
t_source = ds_preds.t_source[t_source].to_pandas()
|
||||
t_source = ds_pred_block.t_source[t_source].to_pandas()
|
||||
|
||||
d = ds_preds.sel(t_source=t_source)
|
||||
d = ds_pred_block.sel(t_source=t_source)
|
||||
|
||||
# Sometimes there are duplicate times, take the first
|
||||
if len(d.t_source.shape) and d.t_source.shape[0] > 0:
|
||||
@@ -651,56 +642,58 @@ def plot_prediction_now(t_source):
|
||||
|
||||
|
||||
dmap_pred = (hv.DynamicMap(plot_prediction_now, kdims=['t_source'])
|
||||
.redim.values(t_source=ds_preds.t_source.to_pandas())
|
||||
.redim.values(t_source=ds_pred_block.t_source.to_pandas())
|
||||
.opts(width=800,
|
||||
height=300,
|
||||
))
|
||||
dmap_pred
|
||||
|
||||
|
||||
# +
|
||||
def plot_predictions_vs_time(it_ahead):
|
||||
"""Plot predictions vs time with holoviews"""
|
||||
|
||||
d = ds_preds.isel(t_ahead=it_ahead).groupby('t_source').first()
|
||||
print(d)
|
||||
|
||||
p = hv.Scatter({
|
||||
'x': d.t_source,
|
||||
'y': d.y_true
|
||||
}, label='true').opts(color='black')
|
||||
|
||||
# Get arrays
|
||||
xf = d.t_source.values
|
||||
yp = d.y_pred
|
||||
s = d.y_pred_std
|
||||
p *= hv.Curve({
|
||||
'x': xf,
|
||||
'y': yp
|
||||
}, label='pred').opts(color='blue')
|
||||
p *= hv.Area((xf, yp - 2 * s, yp + 2 * s),
|
||||
vdims=['y', 'y2'],
|
||||
label='2*std').opts(alpha=0.5, line_width=0)
|
||||
|
||||
|
||||
return p.opts(title=f'Prediction at {it_ahead * pd.Timedelta(freq)} ahead. NLL={d.nll.mean().item():2.2f}')
|
||||
|
||||
|
||||
dmap_preds = (hv.DynamicMap(plot_predictions_vs_time, kdims=['it_ahead'])
|
||||
.redim.values(it_ahead=range(ds_preds.t_ahead.shape[0]))
|
||||
.opts(width=800,
|
||||
height=300,
|
||||
))
|
||||
dmap_preds
|
||||
# TODO fixme
|
||||
# -
|
||||
|
||||
d = ds_preds.mean('t_source')['nll'].groupby('t_ahead_hours').mean()
|
||||
d = ds_preds.mean(['t_source', 'block'])['nll'].groupby('t_ahead_hours').mean()
|
||||
nll_vs_tahead = hv.Curve((d.t_ahead_hours, d)).redim(x='hours ahead', y='nll').opts(width=800)
|
||||
nll_vs_tahead
|
||||
|
||||
# +
|
||||
# d = ds_preds.mean('t_ahead')['nll'].groupby('t_source').mean()
|
||||
# def plot_predictions_vs_time(it_ahead):
|
||||
# """Plot predictions vs time with holoviews"""
|
||||
|
||||
# d = ds_pred_block.isel(t_ahead=it_ahead).groupby('t_source').first()
|
||||
# # print(d)
|
||||
|
||||
# p = hv.Scatter({
|
||||
# 'x': d.t_source,
|
||||
# 'y': d.y_true
|
||||
# }, label='true').opts(color='black')
|
||||
|
||||
# # Get arrays
|
||||
# xf = d.t_source.values
|
||||
# yp = d.y_pred
|
||||
# s = d.y_pred_std
|
||||
# p *= hv.Curve({
|
||||
# 'x': xf,
|
||||
# 'y': yp
|
||||
# }, label='pred').opts(color='blue')
|
||||
# p *= hv.Area((xf, yp - 2 * s, yp + 2 * s),
|
||||
# vdims=['y', 'y2'],
|
||||
# label='2*std').opts(alpha=0.5, line_width=0)
|
||||
|
||||
|
||||
# return p.opts(title=f'Prediction at {it_ahead * pd.Timedelta(freq)} ahead. NLL={d.nll.mean().item():2.2f}')
|
||||
|
||||
|
||||
# dmap_preds = (hv.DynamicMap(plot_predictions_vs_time, kdims=['it_ahead'])
|
||||
# .redim.values(it_ahead=range(ds_pred_block.t_ahead.shape[0]))
|
||||
# .opts(width=800,
|
||||
# height=300,
|
||||
# ))
|
||||
# dmap_preds
|
||||
# # TODO fixme
|
||||
# -
|
||||
|
||||
|
||||
|
||||
# +
|
||||
# d = ds_preds.mean(['t_ahead', 'block'])['nll'].groupby('t_source').mean()
|
||||
# nll_vs_time = hv.Curve(d).opts(width=800)
|
||||
# nll_vs_time
|
||||
|
||||
|
||||
@@ -36,4 +36,4 @@ class LSTMSeq2Seq(nn.Module):
|
||||
log_sigma = self.std(outputs)
|
||||
sigma = self._min_std + (1 - self._min_std) * F.softplus(log_sigma)
|
||||
y_dist = torch.distributions.Normal(mean, sigma)
|
||||
return y_dist
|
||||
return y_dist, {}
|
||||
|
||||
@@ -7,6 +7,7 @@ import math
|
||||
|
||||
|
||||
class LSTMBlock(nn.Module):
|
||||
"""Wrapper to return only lstm output."""
|
||||
def __init__(
|
||||
self,
|
||||
in_channels,
|
||||
@@ -437,14 +438,21 @@ class RANP(nn.Module):
|
||||
if self._use_rnn:
|
||||
# see https://arxiv.org/abs/1910.09323 where x is substituted with h = RNN(x)
|
||||
# x need to be provided as [B, T, H]
|
||||
future_x, _ = self._lstm(future_x)
|
||||
past_x, _ = self._lstm(past_x)
|
||||
S = past_x.shape[1]
|
||||
x = torch.cat([past_x, future_x], 1)
|
||||
x, _ = self._lstm(x)
|
||||
past_x = x[:, :S]
|
||||
future_x = x[:, S:]
|
||||
# future_x, _ = self._lstm(future_x)
|
||||
# past_x, _ = self._lstm(past_x)
|
||||
|
||||
dist_prior, log_var_prior = self._latent_encoder(past_x, past_y)
|
||||
|
||||
if future_y is not None:
|
||||
if (future_y is not None):
|
||||
dist_post, log_var_post = self._latent_encoder(future_x, future_y)
|
||||
z = dist_post.loc
|
||||
|
||||
if self.training:
|
||||
z = dist_prior.rsample()
|
||||
else:
|
||||
z = dist_prior.loc
|
||||
|
||||
@@ -471,3 +479,4 @@ class RANP(nn.Module):
|
||||
].mean()
|
||||
loss = (kl_loss - log_p).mean()
|
||||
return dist, {'loss':loss}
|
||||
|
||||
@@ -35,9 +35,10 @@ def predict(model, ds_test, batch_size, device='cpu', scaler=None):
|
||||
|
||||
# Make an xarray.Dataset for the data
|
||||
bs = y_future.shape[0]
|
||||
t_source = ds_test.df.index[i:i+bs].values
|
||||
t_ahead = pd.timedelta_range(0, periods=ds_test.window_future, freq=freq).values
|
||||
t_behind = pd.timedelta_range(end=-pd.Timedelta(freq), periods=ds_test.window_past, freq=freq)
|
||||
wp = ds_test.window_past
|
||||
t_source = ds_test.df.index[wp + i*bs -1:wp+ i*bs+bs -1].values
|
||||
t_ahead = pd.timedelta_range(1, periods=ds_test.window_future, freq=freq).values
|
||||
t_behind = pd.timedelta_range(end=0, periods=ds_test.window_past, freq=freq)
|
||||
xr_out = xr.Dataset(
|
||||
{
|
||||
# Format> name: ([dimensions,...], array),
|
||||
@@ -77,5 +78,5 @@ def predict_multi(model, datasets, batch_size, device='cpu', scaler=None):
|
||||
d,
|
||||
batch_size,
|
||||
device=device,
|
||||
scaler=output_scaler) for d in tqdm(datasets)]
|
||||
scaler=scaler) for d in tqdm(datasets, desc='predict_multi')]
|
||||
return xr.concat(ds_preds, dim='block')
|
||||
|
||||
Reference in New Issue
Block a user