mirror of
https://github.com/wassname/seq2seq-time.git
synced 2026-06-27 17:50:09 +08:00
all plots hv
This commit is contained in:
+907
-1530
File diff suppressed because one or more lines are too long
@@ -35,6 +35,7 @@
|
||||
# - [ ] val
|
||||
# - [ ] don't overfit
|
||||
# - [ ] TCN
|
||||
# - [ ] make overlap between past and future
|
||||
|
||||
# OPTIONAL: Load the "autoreload" extension so that code can change. But blacklist large modules
|
||||
# %load_ext autoreload
|
||||
@@ -84,9 +85,14 @@ import holoviews as hv
|
||||
from holoviews import opts
|
||||
from holoviews.operation.datashader import datashade, dynspread
|
||||
hv.extension('bokeh')
|
||||
from seq2seq_time.visualization.hv_ggplot import ggplot_theme
|
||||
hv.renderer('bokeh').theme = ggplot_theme
|
||||
|
||||
# holoview datashader timeseries options
|
||||
# %opts RGB [width=800 height=200 active_tools=["xwheel_zoom"] default_tools=["xpan","xwheel_zoom", "reset"] toolbar="right"]
|
||||
# %opts RGB [width=800 height=200 show_grid=True active_tools=["xwheel_zoom"] default_tools=["xpan","xwheel_zoom", "reset"] toolbar="right"]
|
||||
# %opts Curve [width=800 height=200 show_grid=True active_tools=["xwheel_zoom"] default_tools=["xpan","xwheel_zoom", "reset"] toolbar="right"]
|
||||
# %opts Scatter [width=800 height=200 show_grid=True active_tools=["xwheel_zoom"] default_tools=["xpan","xwheel_zoom", "reset"] toolbar="right"]
|
||||
# %opts Layout [width=800 height=200]
|
||||
# -
|
||||
|
||||
|
||||
@@ -108,95 +114,115 @@ freq = '30T'
|
||||
max_rows = 5e5
|
||||
datasets_root = Path('../data/processed/')
|
||||
window_past
|
||||
|
||||
|
||||
# -
|
||||
|
||||
# ## Plot helpers
|
||||
|
||||
|
||||
|
||||
# +
|
||||
def plot_prediction(ds_preds, i, ax=None, title='', std=False, label='pred', legend=False):
|
||||
"""Plot a prediction into the future, at a single point in time."""
|
||||
d = ds_preds.isel(t_source=i)
|
||||
def hv_plot_std(d: xr.Dataset):
|
||||
xf = d.t_target
|
||||
yp = d.y_pred
|
||||
s = d.y_pred_std
|
||||
return hv.Spread((xf, yp, s * 2),
|
||||
label='2*std').opts(alpha=0.5, line_width=0)
|
||||
|
||||
def hv_plot_pred(d: xr.Dataset):
|
||||
# Get arrays
|
||||
xf = d.t_target
|
||||
yp = d.y_pred
|
||||
s = d.y_pred_std
|
||||
yt = d.y_true
|
||||
now = d.t_source.squeeze()
|
||||
return hv.Curve({'x': xf, 'y': yp})
|
||||
|
||||
def hv_plot_true(d: xr.Dataset):
|
||||
"""Plot a prediction into the future, at a single point in time."""
|
||||
|
||||
plt.scatter(xf, yt, c='k', s=6, label='true' if legend else None)
|
||||
ylim = plt.ylim()
|
||||
# Plot true
|
||||
x = np.concatenate([d.t_past, d.t_target])
|
||||
yt = np.concatenate([d.y_past, d.y_true])
|
||||
p = hv.Scatter({
|
||||
'x': x,
|
||||
'y': yt
|
||||
}, label='true').opts(color='black')
|
||||
|
||||
# plot prediction
|
||||
if std:
|
||||
plt.fill_between(xf, yp-2*s, yp+2*s, alpha=0.25,
|
||||
facecolor="b",
|
||||
interpolate=True,
|
||||
label="2 std" if legend else None,)
|
||||
plt.plot(xf, yp, label=label)
|
||||
|
||||
# plot true
|
||||
plt.scatter(
|
||||
d.t_past,
|
||||
d.y_past,
|
||||
c='k',
|
||||
s=6
|
||||
|
||||
now=pd.Timestamp(d.t_source.squeeze().values)
|
||||
|
||||
p = p.opts(
|
||||
ylabel=ds_preds.attrs['targets'],
|
||||
xlabel=f'{now}'
|
||||
)
|
||||
|
||||
|
||||
# plot a red line for now
|
||||
plt.vlines(x=now, ymin=ylim[0], ymax=ylim[1], color='grey', ls='--')
|
||||
plt.ylim(*ylim)
|
||||
p *= hv.VLine(now, label='now').opts(color='red', framewise=True)
|
||||
|
||||
now=pd.Timestamp(now.values)
|
||||
plt.title(title or f'Prediction NLL={d.nll.mean().item():2.2g}')
|
||||
plt.xticks(rotation=0)
|
||||
if legend:
|
||||
plt.legend()
|
||||
plt.xlabel(f'{now}')
|
||||
plt.ylabel(ds_preds.attrs['targets'])
|
||||
return now
|
||||
return p
|
||||
|
||||
def plot_performance(ds_preds, full=False):
|
||||
"""Multiple plots using xr_preds"""
|
||||
plot_prediction(ds_preds, 24, std=True, legend=True)
|
||||
plt.show()
|
||||
|
||||
ds_preds.mean('t_source').plot.scatter('t_ahead_hours', 'nll') # Mean over all predictions
|
||||
n = len(ds_preds.t_source)
|
||||
plt.ylabel('NLL (lower is better)')
|
||||
plt.xlabel('Hours ahead')
|
||||
plt.title(f'NLL vs time ahead (no. samples={n})')
|
||||
plt.show()
|
||||
|
||||
# Make a plot of the NLL over time. Does this solution get worse with time?
|
||||
if full:
|
||||
d = ds_preds.mean('t_ahead').groupby('t_source').mean().plot.scatter('t_source', 'nll')
|
||||
plt.xticks(rotation=45)
|
||||
plt.title('NLL over source time (lower is better)')
|
||||
plt.show()
|
||||
|
||||
# A scatter plot is easy with xarray
|
||||
if full:
|
||||
plt.figure(figsize=(5, 5))
|
||||
ds_preds.plot.scatter('y_true', 'y_pred', s=.01)
|
||||
plt.show()
|
||||
def hv_plot_prediction(d):
|
||||
p = hv_plot_true(d)
|
||||
p *= hv_plot_pred(d)
|
||||
p *= hv_plot_std(d)
|
||||
return p
|
||||
|
||||
|
||||
# -
|
||||
|
||||
def plot_performance(ds_preds, full=False):
|
||||
"""Multiple plots using xr_preds"""
|
||||
p = hv_plot_prediction(ds_preds.isel(t_source=10))
|
||||
display(p)
|
||||
|
||||
n = len(ds_preds.t_source)
|
||||
d_ahead = ds_preds.mean(['t_source'])['nll'].groupby('t_ahead_hours').mean()
|
||||
nll_vs_tahead = (hv.Curve(
|
||||
(d_ahead.t_ahead_hours,
|
||||
d_ahead)).redim(x='hours ahead',
|
||||
y='nll').opts(
|
||||
title=f'NLL vs time ahead (no. samples={n})'))
|
||||
display(nll_vs_tahead)
|
||||
|
||||
# Make a plot of the NLL over time. Does this solution get worse with time?
|
||||
if full:
|
||||
d_source = ds_preds.mean(['t_ahead'])['nll'].groupby('t_source').mean()
|
||||
nll_vs_time = (hv.Curve(d_source).opts(
|
||||
title='Error vs time of prediction'))
|
||||
display(nll_vs_time)
|
||||
|
||||
# A scatter plot is easy with xarray
|
||||
if full:
|
||||
tlim = (ds_preds.y_true.min().item(), ds_preds.y_true.max().item())
|
||||
true_vs_pred = datashade(hv.Scatter(
|
||||
(ds_preds.y_true,
|
||||
ds_preds.y_pred))).redim(x='true', y='pred').opts(width=400,
|
||||
height=400,
|
||||
xlim=tlim,
|
||||
ylim=tlim,
|
||||
title='Scatter plot')
|
||||
true_vs_pred = dynspread(true_vs_pred)
|
||||
true_vs_pred
|
||||
display(true_vs_pred)
|
||||
def plot_hist(trainer):
|
||||
try:
|
||||
df_hist = pd.read_csv(trainer.logger.experiment.metrics_file_path)
|
||||
df_hist['epoch'] = df_hist['epoch'].ffill()
|
||||
|
||||
df_histe = df_hist.set_index('epoch').groupby('epoch').mean()
|
||||
if len(df_histe)>1:
|
||||
df_histe[['loss/train', 'loss/val']].plot(title='history')
|
||||
plt.show()
|
||||
p = hv.Curve(df_histe, kdims=['epoch'], vdims=['loss/train']).relabel('train')
|
||||
p *= hv.Curve(df_histe, kdims=['epoch'], vdims=['loss/val']).relabel('val')
|
||||
display(p.opts(ylabel='loss'))
|
||||
return df_histe
|
||||
except Exception:
|
||||
except Exception as e:
|
||||
print(e)
|
||||
pass
|
||||
|
||||
|
||||
df_hist = plot_hist(trainer)
|
||||
df_hist
|
||||
|
||||
# ## Datasets
|
||||
|
||||
|
||||
@@ -351,18 +377,7 @@ models = [
|
||||
from collections import defaultdict
|
||||
results = defaultdict(dict)
|
||||
|
||||
# +
|
||||
# tmp
|
||||
model = Transformer(input_size,
|
||||
output_size,
|
||||
attention_dropout=0.4,
|
||||
nhead=2,
|
||||
nlayers=4,
|
||||
hidden_size=16)
|
||||
|
||||
x_past, y_past, x_future, y_future = next(iter(dl_val))
|
||||
model(x_past, y_past, x_future, y_future)
|
||||
# -
|
||||
|
||||
from seq2seq_time.metrics import rmse, smape
|
||||
|
||||
@@ -398,8 +413,8 @@ for Dataset in datasets:
|
||||
# Wrap in lightning
|
||||
patience = 3
|
||||
model = PL_MODEL(pt_model,
|
||||
lr=3e-3, patience=patience,
|
||||
weight_decay=1e-5).to(device)
|
||||
lr=3e-4, patience=patience,
|
||||
weight_decay=4e-5).to(device)
|
||||
|
||||
# Trainer
|
||||
trainer = pl.Trainer(
|
||||
@@ -479,7 +494,15 @@ d.style.apply(bold_min)
|
||||
print(f'Symmetric mean absolute percentage error (SMAPE)\nover {window_future} steps')
|
||||
d=df_results.xs('smape', level=1).T.round(2)
|
||||
d.style.apply(bold_min)
|
||||
|
||||
# # Plots
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
# +
|
||||
|
||||
# # plots
|
||||
# Load saved preds
|
||||
results = defaultdict(dict)
|
||||
@@ -495,42 +518,34 @@ for Dataset in datasets:
|
||||
if len(fs)>0:
|
||||
ds_preds = xr.open_dataset(fs[-1])
|
||||
results[dataset_name][model_name] = ds_preds
|
||||
# -
|
||||
|
||||
data_i = 100
|
||||
|
||||
|
||||
|
||||
# Plot mean of predictions
|
||||
n = hv.Layout()
|
||||
for dataset in results.keys():
|
||||
d = next(iter(results[dataset].values())).isel(t_source=data_i)
|
||||
p = hv_plot_true(d)
|
||||
for model in results[dataset].keys():
|
||||
ds_preds = results[dataset][model]
|
||||
plot_prediction(ds_preds, data_i, label=f"{model}")
|
||||
plt.title(dataset)
|
||||
plt.legend()
|
||||
plt.show()
|
||||
d = ds_preds.isel(t_source=data_i)
|
||||
p *= hv_plot_pred(d).relabel(label=f"{model}")
|
||||
n += p.opts(title=dataset, legend_position='top_left')
|
||||
n.cols(1).opts(shared_axes=False)
|
||||
|
||||
# +
|
||||
dataset='BejingPM25'
|
||||
n = len(results[dataset].keys())
|
||||
|
||||
plt.figure(figsize=(8, 1.5*n))
|
||||
plt.suptitle(f'Plots with confidence for {dataset} ')
|
||||
n = hv.Layout()
|
||||
for i, model in enumerate(results[dataset].keys()):
|
||||
plt.subplot(n, 1, i+1)
|
||||
ds_preds = results[dataset][model]
|
||||
if i==n-1:
|
||||
# The last one has the legend
|
||||
plot_prediction(ds_preds, data_i, title=f"{model}", std=True, legend=True)
|
||||
else:
|
||||
plot_prediction(ds_preds, data_i, title=f"{model}", std=True, )
|
||||
|
||||
# share the x axis
|
||||
locs, _ = plt.xticks()
|
||||
plt.xticks(locs, labels=[])
|
||||
plt.xlabel(None)
|
||||
plt.subplots_adjust()
|
||||
# -
|
||||
|
||||
d = ds_preds.isel(t_source=data_i)
|
||||
p = hv_plot_true(d)
|
||||
p *= hv_plot_pred(d).relabel('pred')
|
||||
p *= hv_plot_std(d)
|
||||
n += p.opts(title=f'{dataset} {model}', legend_position='top_left')
|
||||
n.cols(1)
|
||||
|
||||
plot_performance(ds_preds, full=True)
|
||||
|
||||
|
||||
|
||||
|
||||
@@ -0,0 +1,81 @@
|
||||
#-----------------------------------------------------------------------------
|
||||
# Copyright (c) 2012 - 2020, Anaconda, Inc., and Bokeh Contributors.
|
||||
# All rights reserved.
|
||||
#
|
||||
# The full license is in the file LICENSE.txt, distributed with this software.
|
||||
#-----------------------------------------------------------------------------
|
||||
# see https://raw.githubusercontent.com/bokeh/bokeh/ffdd1114e6aace02bb6c61748390e9b62522a8d9/bokeh/themes/_ggplot.py
|
||||
# see https://github.com/bokeh/bokeh/pull/10150
|
||||
from bokeh.themes import Theme
|
||||
json = {
|
||||
"attrs": {
|
||||
"Figure" : {
|
||||
"background_fill_color": "#E5E5E5",
|
||||
"border_fill_color": "#FFFFFF",
|
||||
"outline_line_color": "#000000",
|
||||
"outline_line_alpha": 0.25
|
||||
},
|
||||
|
||||
"Grid": {
|
||||
"grid_line_color": "#FFFFFF",
|
||||
"grid_line_alpha": 1,
|
||||
},
|
||||
|
||||
"Axis": {
|
||||
"major_tick_line_alpha": 0.3,
|
||||
"major_tick_line_color": "#000000",
|
||||
|
||||
"minor_tick_line_alpha": 0.4,
|
||||
"minor_tick_line_color": "#000000",
|
||||
|
||||
"axis_line_alpha": 1,
|
||||
"axis_line_color": "#000000",
|
||||
|
||||
"major_label_text_color": "#000000",
|
||||
"major_label_text_font": "Helvetica",
|
||||
"major_label_text_font_size": "1.025em",
|
||||
|
||||
"axis_label_standoff": 10,
|
||||
"axis_label_text_color": "#000000",
|
||||
"axis_label_text_font": "Helvetica",
|
||||
"axis_label_text_font_size": "1.25em",
|
||||
"axis_label_text_font_style": "normal"
|
||||
},
|
||||
|
||||
"Legend": {
|
||||
"spacing": 8,
|
||||
"glyph_width": 15,
|
||||
|
||||
"label_standoff": 8,
|
||||
"label_text_color": "#000000",
|
||||
"label_text_font": "Arial",
|
||||
"label_text_font_size": "0.95em",
|
||||
|
||||
"border_line_alpha": 1,
|
||||
"background_fill_alpha": 0.25,
|
||||
"background_fill_color": "#000000"
|
||||
},
|
||||
|
||||
"ColorBar": {
|
||||
"title_text_color": "#E0E0E0",
|
||||
"title_text_font": "Helvetica",
|
||||
"title_text_font_size": "1.025em",
|
||||
"title_text_font_style": "normal",
|
||||
|
||||
"major_label_text_color": "#E0E0E0",
|
||||
"major_label_text_font": "Arial",
|
||||
"major_label_text_font_size": "1.025em",
|
||||
|
||||
"background_fill_color": "#000000",
|
||||
"major_tick_line_alpha": 0,
|
||||
"bar_line_alpha": 0
|
||||
},
|
||||
|
||||
"Title": {
|
||||
"text_color": "#000000",
|
||||
"text_font": "Helvetica",
|
||||
"text_font_size": "1.10em"
|
||||
}
|
||||
}
|
||||
}
|
||||
ggplot_theme = Theme(json=json)
|
||||
Reference in New Issue
Block a user