This commit is contained in:
wassname
2020-10-20 06:49:15 +08:00
parent 7b6c729db5
commit f9851e123b
5 changed files with 1131 additions and 1475 deletions
File diff suppressed because one or more lines are too long
+64 -71
View File
@@ -369,9 +369,7 @@ def plot_prediction(ds_preds, i):
def plot_performance(ds_preds, full=False):
"""Multiple plots using xr_preds"""
print(f'mean_NLL {ds_preds.nll.mean().item():2.2f}')
plot_prediction(ds_preds, 24)
# plot_prediction(ds_preds, 480)
ds_preds.mean('t_source').plot.scatter('t_ahead_hours', 'nll') # Mean over all predictions
n = len(ds_preds.t_source)
@@ -481,7 +479,7 @@ from seq2seq_time.models.baseline import BaselineLast
from seq2seq_time.models.transformer import Transformer
from seq2seq_time.models.transformer_seq2seq import TransformerSeq2Seq
from seq2seq_time.models.transformer_seq import TransformerSeq
from seq2seq_time.models.anp import RANP
from seq2seq_time.models.neural_process import RANP
# ## Plots
# +
models = [
@@ -529,7 +527,11 @@ trainer = pl.Trainer(gpus=1,
)
trainer.fit(model, dl_train, dl_test)
print(plot_hist(trainer))
ds_preds = predict(model.to(device), ds_test.datasets[0], batch_size, device=device, scaler=output_scaler)
ds_predss = predict_multi(model.to(device),
ds_test.datasets,
batch_size*8,
device=device,
scaler=output_scaler)
print(f'baseline nll: {ds_preds.nll.mean().item():2.2g}')
for pt_model in models:
@@ -560,46 +562,35 @@ for pt_model in models:
ds_preds = predict(model.to(device),
ds_test.datasets[0],
batch_size,
ds_predss = predict_multi(model.to(device),
ds_test.datasets,
batch_size*8,
device=device,
scaler=output_scaler)
print(name)
print(f'mean_NLL {ds_preds.nll.mean().item():2.2f}')
print(f'mean_NLL {ds_predss.nll.mean().item():2.2f}')
# Performance
ds_preds = ds_predss.isel(block=0)
print(plot_hist(trainer))
plot_performance(ds_preds)
# %debug
ds_preds = predict(model.to(device),q
ds_test.datasets[0],
batch_size,
device=device,
scaler=output_scaler)
# +
# ds_predss = predict_multi(model.to(device),
# ds_test.datasets,
# batch_size,
# ds_preds = predict(model.to(device),
# ds_test.datasets[0],
# batch_size*8,
# device=device,
# scaler=output_scaler)
# -
ds_test.datasets[0].df.index.value_counts()
ds_predss = predict_multi(model.to(device),
ds_test.datasets,
batch_size*8,
device=device,
scaler=output_scaler)
# TODO why dup?
ds_preds.sel(t_source='2013-11-11 00:30:00')
# TODO why duplicates?
d = ds_preds.isel(t_ahead=0)
d.t_source.to_series().sort_index()#.value_counts()
# np.unique
# d
ds_pred_block = ds_predss.isel(block=1)
# # holoviews pred
@@ -613,9 +604,9 @@ def plot_prediction_now(t_source):
# Let us pass in an int
if isinstance(t_source, int):
t_source = ds_preds.t_source[t_source].to_pandas()
t_source = ds_pred_block.t_source[t_source].to_pandas()
d = ds_preds.sel(t_source=t_source)
d = ds_pred_block.sel(t_source=t_source)
# Sometimes there are duplicate times, take the first
if len(d.t_source.shape) and d.t_source.shape[0] > 0:
@@ -651,56 +642,58 @@ def plot_prediction_now(t_source):
dmap_pred = (hv.DynamicMap(plot_prediction_now, kdims=['t_source'])
.redim.values(t_source=ds_preds.t_source.to_pandas())
.redim.values(t_source=ds_pred_block.t_source.to_pandas())
.opts(width=800,
height=300,
))
dmap_pred
# +
def plot_predictions_vs_time(it_ahead):
"""Plot predictions vs time with holoviews"""
d = ds_preds.isel(t_ahead=it_ahead).groupby('t_source').first()
print(d)
p = hv.Scatter({
'x': d.t_source,
'y': d.y_true
}, label='true').opts(color='black')
# Get arrays
xf = d.t_source.values
yp = d.y_pred
s = d.y_pred_std
p *= hv.Curve({
'x': xf,
'y': yp
}, label='pred').opts(color='blue')
p *= hv.Area((xf, yp - 2 * s, yp + 2 * s),
vdims=['y', 'y2'],
label='2*std').opts(alpha=0.5, line_width=0)
return p.opts(title=f'Prediction at {it_ahead * pd.Timedelta(freq)} ahead. NLL={d.nll.mean().item():2.2f}')
dmap_preds = (hv.DynamicMap(plot_predictions_vs_time, kdims=['it_ahead'])
.redim.values(it_ahead=range(ds_preds.t_ahead.shape[0]))
.opts(width=800,
height=300,
))
dmap_preds
# TODO fixme
# -
d = ds_preds.mean('t_source')['nll'].groupby('t_ahead_hours').mean()
d = ds_preds.mean(['t_source', 'block'])['nll'].groupby('t_ahead_hours').mean()
nll_vs_tahead = hv.Curve((d.t_ahead_hours, d)).redim(x='hours ahead', y='nll').opts(width=800)
nll_vs_tahead
# +
# d = ds_preds.mean('t_ahead')['nll'].groupby('t_source').mean()
# def plot_predictions_vs_time(it_ahead):
# """Plot predictions vs time with holoviews"""
# d = ds_pred_block.isel(t_ahead=it_ahead).groupby('t_source').first()
# # print(d)
# p = hv.Scatter({
# 'x': d.t_source,
# 'y': d.y_true
# }, label='true').opts(color='black')
# # Get arrays
# xf = d.t_source.values
# yp = d.y_pred
# s = d.y_pred_std
# p *= hv.Curve({
# 'x': xf,
# 'y': yp
# }, label='pred').opts(color='blue')
# p *= hv.Area((xf, yp - 2 * s, yp + 2 * s),
# vdims=['y', 'y2'],
# label='2*std').opts(alpha=0.5, line_width=0)
# return p.opts(title=f'Prediction at {it_ahead * pd.Timedelta(freq)} ahead. NLL={d.nll.mean().item():2.2f}')
# dmap_preds = (hv.DynamicMap(plot_predictions_vs_time, kdims=['it_ahead'])
# .redim.values(it_ahead=range(ds_pred_block.t_ahead.shape[0]))
# .opts(width=800,
# height=300,
# ))
# dmap_preds
# # TODO fixme
# -
# +
# d = ds_preds.mean(['t_ahead', 'block'])['nll'].groupby('t_source').mean()
# nll_vs_time = hv.Curve(d).opts(width=800)
# nll_vs_time
+1 -1
View File
@@ -36,4 +36,4 @@ class LSTMSeq2Seq(nn.Module):
log_sigma = self.std(outputs)
sigma = self._min_std + (1 - self._min_std) * F.softplus(log_sigma)
y_dist = torch.distributions.Normal(mean, sigma)
return y_dist
return y_dist, {}
@@ -7,6 +7,7 @@ import math
class LSTMBlock(nn.Module):
"""Wrapper to return only lstm output."""
def __init__(
self,
in_channels,
@@ -437,14 +438,21 @@ class RANP(nn.Module):
if self._use_rnn:
# see https://arxiv.org/abs/1910.09323 where x is substituted with h = RNN(x)
# x need to be provided as [B, T, H]
future_x, _ = self._lstm(future_x)
past_x, _ = self._lstm(past_x)
S = past_x.shape[1]
x = torch.cat([past_x, future_x], 1)
x, _ = self._lstm(x)
past_x = x[:, :S]
future_x = x[:, S:]
# future_x, _ = self._lstm(future_x)
# past_x, _ = self._lstm(past_x)
dist_prior, log_var_prior = self._latent_encoder(past_x, past_y)
if future_y is not None:
if (future_y is not None):
dist_post, log_var_post = self._latent_encoder(future_x, future_y)
z = dist_post.loc
if self.training:
z = dist_prior.rsample()
else:
z = dist_prior.loc
@@ -471,3 +479,4 @@ class RANP(nn.Module):
].mean()
loss = (kl_loss - log_p).mean()
return dist, {'loss':loss}
+5 -4
View File
@@ -35,9 +35,10 @@ def predict(model, ds_test, batch_size, device='cpu', scaler=None):
# Make an xarray.Dataset for the data
bs = y_future.shape[0]
t_source = ds_test.df.index[i:i+bs].values
t_ahead = pd.timedelta_range(0, periods=ds_test.window_future, freq=freq).values
t_behind = pd.timedelta_range(end=-pd.Timedelta(freq), periods=ds_test.window_past, freq=freq)
wp = ds_test.window_past
t_source = ds_test.df.index[wp + i*bs -1:wp+ i*bs+bs -1].values
t_ahead = pd.timedelta_range(1, periods=ds_test.window_future, freq=freq).values
t_behind = pd.timedelta_range(end=0, periods=ds_test.window_past, freq=freq)
xr_out = xr.Dataset(
{
# Format> name: ([dimensions,...], array),
@@ -77,5 +78,5 @@ def predict_multi(model, datasets, batch_size, device='cpu', scaler=None):
d,
batch_size,
device=device,
scaler=output_scaler) for d in tqdm(datasets)]
scaler=scaler) for d in tqdm(datasets, desc='predict_multi')]
return xr.concat(ds_preds, dim='block')