diff --git a/notebooks/05.0-mc-leaderboard.ipynb b/notebooks/05.0-mc-leaderboard.ipynb index b37cbf5..f28c616 100644 --- a/notebooks/05.0-mc-leaderboard.ipynb +++ b/notebooks/05.0-mc-leaderboard.ipynb @@ -38,7 +38,8 @@ "- [ ] show test train\n", "- [ ] val\n", "- [ ] don't overfit\n", - "- [ ] TCN" + "- [ ] TCN\n", + "- [ ] make overlap between past and future" ] }, { @@ -46,8 +47,8 @@ "execution_count": 1, "metadata": { "ExecuteTime": { - "end_time": "2020-10-30T06:04:56.455187Z", - "start_time": "2020-10-30T06:04:56.043007Z" + "end_time": "2020-10-30T06:47:04.839427Z", + "start_time": "2020-10-30T06:47:04.401726Z" } }, "outputs": [], @@ -69,8 +70,8 @@ "execution_count": 2, "metadata": { "ExecuteTime": { - "end_time": "2020-10-30T06:04:58.117548Z", - "start_time": "2020-10-30T06:04:56.459312Z" + "end_time": "2020-10-30T06:47:06.443003Z", + "start_time": "2020-10-30T06:47:04.843915Z" } }, "outputs": [], @@ -101,8 +102,8 @@ "execution_count": 3, "metadata": { "ExecuteTime": { - "end_time": "2020-10-30T06:04:58.159177Z", - "start_time": "2020-10-30T06:04:58.123924Z" + "end_time": "2020-10-30T06:47:06.477685Z", + "start_time": "2020-10-30T06:47:06.448612Z" } }, "outputs": [], @@ -116,8 +117,8 @@ "execution_count": 4, "metadata": { "ExecuteTime": { - "end_time": "2020-10-30T06:04:58.207947Z", - "start_time": "2020-10-30T06:04:58.162906Z" + "end_time": "2020-10-30T06:47:06.512676Z", + "start_time": "2020-10-30T06:47:06.480950Z" } }, "outputs": [ @@ -141,8 +142,8 @@ "execution_count": 5, "metadata": { "ExecuteTime": { - "end_time": "2020-10-30T06:04:58.247847Z", - "start_time": "2020-10-30T06:04:58.212316Z" + "end_time": "2020-10-30T06:47:06.561982Z", + "start_time": "2020-10-30T06:47:06.516605Z" } }, "outputs": [], @@ -153,23 +154,15 @@ }, { "cell_type": "code", - "execution_count": 6, + "execution_count": 55, "metadata": { "ExecuteTime": { - "end_time": "2020-10-30T06:05:01.478145Z", - "start_time": "2020-10-30T06:04:58.251662Z" + "end_time": "2020-10-30T13:23:43.253896Z", + "start_time": "2020-10-30T13:23:42.642309Z" }, "lines_to_next_cell": 2 }, "outputs": [ - { - "name": "stderr", - "output_type": "stream", - "text": [ - "/home/wassname/anaconda/envs/seq2seq-time/lib/python3.7/site-packages/holoviews/operation/datashader.py:5: DeprecationWarning: Using or importing the ABCs from 'collections' instead of from 'collections.abc' is deprecated since Python 3.3,and in 3.9 it will stop working\n", - " from collections import Callable\n" - ] - }, { "data": { "application/javascript": [ @@ -1644,9 +1637,14 @@ "from holoviews import opts\n", "from holoviews.operation.datashader import datashade, dynspread\n", "hv.extension('bokeh')\n", + "from seq2seq_time.visualization.hv_ggplot import ggplot_theme\n", + "hv.renderer('bokeh').theme = ggplot_theme\n", "\n", "# holoview datashader timeseries options\n", - "%opts RGB [width=800 height=200 active_tools=[\"xwheel_zoom\"] default_tools=[\"xpan\",\"xwheel_zoom\", \"reset\"] toolbar=\"right\"]" + "%opts RGB [width=800 height=200 show_grid=True active_tools=[\"xwheel_zoom\"] default_tools=[\"xpan\",\"xwheel_zoom\", \"reset\"] toolbar=\"right\"]\n", + "%opts Curve [width=800 height=200 show_grid=True active_tools=[\"xwheel_zoom\"] default_tools=[\"xpan\",\"xwheel_zoom\", \"reset\"] toolbar=\"right\"]\n", + "%opts Scatter [width=800 height=200 show_grid=True active_tools=[\"xwheel_zoom\"] default_tools=[\"xpan\",\"xwheel_zoom\", \"reset\"] toolbar=\"right\"]\n", + "%opts Layout [width=800 height=200]" ] }, { @@ -1654,8 +1652,8 @@ "execution_count": 7, "metadata": { "ExecuteTime": { - "end_time": "2020-10-30T06:05:01.534412Z", - "start_time": "2020-10-30T06:05:01.482602Z" + "end_time": "2020-10-30T06:47:09.979298Z", + "start_time": "2020-10-30T06:47:09.927575Z" } }, "outputs": [ @@ -1690,8 +1688,8 @@ "execution_count": 8, "metadata": { "ExecuteTime": { - "end_time": "2020-10-30T06:05:01.619419Z", - "start_time": "2020-10-30T06:05:01.539179Z" + "end_time": "2020-10-30T06:47:10.070962Z", + "start_time": "2020-10-30T06:47:09.983029Z" } }, "outputs": [ @@ -1737,105 +1735,118 @@ }, { "cell_type": "code", - "execution_count": null, + "execution_count": 191, "metadata": { "ExecuteTime": { - "end_time": "2020-10-29T13:29:14.460346Z", - "start_time": "2020-10-29T13:29:14.371348Z" + "end_time": "2020-10-30T14:08:39.113905Z", + "start_time": "2020-10-30T14:08:39.047075Z" } }, "outputs": [], - "source": [] + "source": [ + "def hv_plot_std(d: xr.Dataset):\n", + " xf = d.t_target\n", + " yp = d.y_pred\n", + " s = d.y_pred_std\n", + " return hv.Spread((xf, yp, s * 2),\n", + " label='2*std').opts(alpha=0.5, line_width=0)\n", + "\n", + "def hv_plot_pred(d: xr.Dataset):\n", + " # Get arrays\n", + " xf = d.t_target\n", + " yp = d.y_pred\n", + " s = d.y_pred_std\n", + " return hv.Curve({'x': xf, 'y': yp})\n", + "\n", + "def hv_plot_true(d: xr.Dataset):\n", + " \"\"\"Plot a prediction into the future, at a single point in time.\"\"\" \n", + " \n", + " # Plot true\n", + " x = np.concatenate([d.t_past, d.t_target])\n", + " yt = np.concatenate([d.y_past, d.y_true])\n", + " p = hv.Scatter({\n", + " 'x': x,\n", + " 'y': yt\n", + " }, label='true').opts(color='black')\n", + "\n", + "\n", + " \n", + " now=pd.Timestamp(d.t_source.squeeze().values)\n", + " \n", + " p = p.opts(\n", + " ylabel=ds_preds.attrs['targets'],\n", + " xlabel=f'{now}'\n", + " )\n", + "\n", + " \n", + " # plot a red line for now\n", + " p *= hv.VLine(now, label='now').opts(color='red', framewise=True)\n", + "\n", + " return p\n", + "\n", + "def hv_plot_prediction(d):\n", + " p = hv_plot_true(d)\n", + " p *= hv_plot_pred(d)\n", + " p *= hv_plot_std(d)\n", + " return p" + ] }, { "cell_type": "code", - "execution_count": 9, + "execution_count": 220, "metadata": { "ExecuteTime": { - "end_time": "2020-10-30T06:05:01.703395Z", - "start_time": "2020-10-30T06:05:01.629810Z" + "end_time": "2020-10-30T14:13:40.986924Z", + "start_time": "2020-10-30T14:13:40.901669Z" }, "lines_to_end_of_cell_marker": 2, "lines_to_next_cell": 0 }, "outputs": [], "source": [ - "def plot_prediction(ds_preds, i, ax=None, title='', std=False, label='pred', legend=False):\n", - " \"\"\"Plot a prediction into the future, at a single point in time.\"\"\" \n", - " d = ds_preds.isel(t_source=i)\n", - "\n", - " # Get arrays\n", - " xf = d.t_target\n", - " yp = d.y_pred\n", - " s = d.y_pred_std\n", - " yt = d.y_true\n", - " now = d.t_source.squeeze()\n", - " \n", - " plt.scatter(xf, yt, c='k', s=6, label='true' if legend else None)\n", - " ylim = plt.ylim()\n", - "\n", - " # plot prediction\n", - " if std:\n", - " plt.fill_between(xf, yp-2*s, yp+2*s, alpha=0.25,\n", - " facecolor=\"b\",\n", - " interpolate=True,\n", - " label=\"2 std\" if legend else None,)\n", - " plt.plot(xf, yp, label=label)\n", - "\n", - " # plot true\n", - " plt.scatter(\n", - " d.t_past,\n", - " d.y_past,\n", - " c='k',\n", - " s=6\n", - " )\n", - " \n", - " # plot a red line for now\n", - " plt.vlines(x=now, ymin=ylim[0], ymax=ylim[1], color='grey', ls='--')\n", - " plt.ylim(*ylim)\n", - "\n", - " now=pd.Timestamp(now.values)\n", - " plt.title(title or f'Prediction NLL={d.nll.mean().item():2.2g}')\n", - " plt.xticks(rotation=0) \n", - " if legend:\n", - " plt.legend()\n", - " plt.xlabel(f'{now}')\n", - " plt.ylabel(ds_preds.attrs['targets'])\n", - " return now\n", - "\n", "def plot_performance(ds_preds, full=False):\n", " \"\"\"Multiple plots using xr_preds\"\"\"\n", - " plot_prediction(ds_preds, 24, std=True, legend=True)\n", - " plt.show()\n", + " p = hv_plot_prediction(ds_preds.isel(t_source=10))\n", + " display(p)\n", "\n", - " ds_preds.mean('t_source').plot.scatter('t_ahead_hours', 'nll') # Mean over all predictions\n", " n = len(ds_preds.t_source)\n", - " plt.ylabel('NLL (lower is better)')\n", - " plt.xlabel('Hours ahead')\n", - " plt.title(f'NLL vs time ahead (no. samples={n})')\n", - " plt.show()\n", + " d_ahead = ds_preds.mean(['t_source'])['nll'].groupby('t_ahead_hours').mean()\n", + " nll_vs_tahead = (hv.Curve(\n", + " (d_ahead.t_ahead_hours,\n", + " d_ahead)).redim(x='hours ahead',\n", + " y='nll').opts(\n", + " title=f'NLL vs time ahead (no. samples={n})'))\n", + " display(nll_vs_tahead)\n", "\n", " # Make a plot of the NLL over time. Does this solution get worse with time?\n", " if full:\n", - " d = ds_preds.mean('t_ahead').groupby('t_source').mean().plot.scatter('t_source', 'nll')\n", - " plt.xticks(rotation=45)\n", - " plt.title('NLL over source time (lower is better)')\n", - " plt.show()\n", + " d_source = ds_preds.mean(['t_ahead'])['nll'].groupby('t_source').mean()\n", + " nll_vs_time = (hv.Curve(d_source).opts(\n", + " title='Error vs time of prediction'))\n", + " display(nll_vs_time)\n", "\n", " # A scatter plot is easy with xarray\n", " if full:\n", - " plt.figure(figsize=(5, 5))\n", - " ds_preds.plot.scatter('y_true', 'y_pred', s=.01)\n", - " plt.show()" + " tlim = (ds_preds.y_true.min().item(), ds_preds.y_true.max().item())\n", + " true_vs_pred = datashade(hv.Scatter(\n", + " (ds_preds.y_true,\n", + " ds_preds.y_pred))).redim(x='true', y='pred').opts(width=400,\n", + " height=400,\n", + " xlim=tlim,\n", + " ylim=tlim,\n", + " title='Scatter plot')\n", + " true_vs_pred = dynspread(true_vs_pred)\n", + " true_vs_pred\n", + " display(true_vs_pred)" ] }, { "cell_type": "code", - "execution_count": 10, + "execution_count": 279, "metadata": { "ExecuteTime": { - "end_time": "2020-10-30T06:05:01.757797Z", - "start_time": "2020-10-30T06:05:01.707827Z" + "end_time": "2020-10-30T14:20:32.513001Z", + "start_time": "2020-10-30T14:20:32.427238Z" } }, "outputs": [], @@ -1844,15 +1855,193 @@ " try:\n", " df_hist = pd.read_csv(trainer.logger.experiment.metrics_file_path)\n", " df_hist['epoch'] = df_hist['epoch'].ffill()\n", + " \n", " df_histe = df_hist.set_index('epoch').groupby('epoch').mean()\n", " if len(df_histe)>1:\n", - " df_histe[['loss/train', 'loss/val']].plot(title='history')\n", - " plt.show()\n", + " p = hv.Curve(df_histe, kdims=['epoch'], vdims=['loss/train']).relabel('train')\n", + " p *= hv.Curve(df_histe, kdims=['epoch'], vdims=['loss/val']).relabel('val')\n", + " display(p.opts(ylabel='loss'))\n", " return df_histe\n", - " except Exception:\n", + " except Exception as e:\n", + " print(e)\n", " pass" ] }, + { + "cell_type": "code", + "execution_count": 280, + "metadata": { + "ExecuteTime": { + "end_time": "2020-10-30T14:20:32.934663Z", + "start_time": "2020-10-30T14:20:32.670013Z" + } + }, + "outputs": [ + { + "data": {}, + "metadata": {}, + "output_type": "display_data" + }, + { + "data": { + "application/vnd.holoviews_exec.v0+json": "", + "text/html": [ + "
\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "
\n", + "
\n", + "" + ], + "text/plain": [ + ":Overlay\n", + " .Curve.Train :Curve [epoch] (loss/train)\n", + " .Curve.Val :Curve [epoch] (loss/val)" + ] + }, + "metadata": { + "application/vnd.holoviews_exec.v0+json": { + "id": "134882" + } + }, + "output_type": "display_data" + }, + { + "data": { + "text/html": [ + "
\n", + "\n", + "\n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + "
loss/trainmodel_loss/trainsteploss/val
epoch
0.01.3060881.348648191.8571431.703918
1.01.1663881.201709491.8571432.081715
2.01.1473061.177450791.8571432.530396
3.01.0543491.0838071091.8571432.719204
4.01.0546081.0824191391.8571432.729672
5.01.0130831.0405251691.8571433.024244
6.01.0031891.0330431991.8571433.000752
\n", + "
" + ], + "text/plain": [ + " loss/train model_loss/train step loss/val\n", + "epoch \n", + "0.0 1.306088 1.348648 191.857143 1.703918\n", + "1.0 1.166388 1.201709 491.857143 2.081715\n", + "2.0 1.147306 1.177450 791.857143 2.530396\n", + "3.0 1.054349 1.083807 1091.857143 2.719204\n", + "4.0 1.054608 1.082419 1391.857143 2.729672\n", + "5.0 1.013083 1.040525 1691.857143 3.024244\n", + "6.0 1.003189 1.033043 1991.857143 3.000752" + ] + }, + "execution_count": 280, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "df_hist = plot_hist(trainer)\n", + "df_hist" + ] + }, { "cell_type": "markdown", "metadata": {}, @@ -1872,8 +2061,8 @@ "execution_count": 11, "metadata": { "ExecuteTime": { - "end_time": "2020-10-30T06:05:02.099220Z", - "start_time": "2020-10-30T06:05:01.763911Z" + "end_time": "2020-10-30T06:47:10.603240Z", + "start_time": "2020-10-30T06:47:10.241440Z" } }, "outputs": [ @@ -1901,11 +2090,11 @@ }, { "cell_type": "code", - "execution_count": 12, + "execution_count": 50, "metadata": { "ExecuteTime": { - "end_time": "2020-10-30T06:05:09.451517Z", - "start_time": "2020-10-30T06:05:02.104070Z" + "end_time": "2020-10-30T13:09:41.533230Z", + "start_time": "2020-10-30T13:09:36.631759Z" } }, "outputs": [ @@ -1925,18 +2114,18 @@ "data": { "application/vnd.holoviews_exec.v0+json": "", "text/html": [ - "
\n", + "
\n", "\n", "\n", "\n", "\n", "\n", - "
\n", + "
\n", "
\n", "" + ], "text/plain": [ - "
" + ":Layout\n", + " .Overlay.I :Overlay\n", + " .Scatter.True :Scatter [x] (y)\n", + " .VLine.Now :VLine [x,y]\n", + " .Curve.BaselineLast :Curve [x] (y)\n", + " .Curve.RANP :Curve [x] (y)\n", + " .Curve.LSTM :Curve [x] (y)\n", + " .Curve.LSTMSeq2Seq :Curve [x] (y)\n", + " .Curve.TransformerSeq2Seq :Curve [x] (y)\n", + " .Curve.Transformer :Curve [x] (y)\n", + " .Curve.TransformerProcess :Curve [x] (y)\n", + " .Overlay.II :Overlay\n", + " .Scatter.True :Scatter [x] (y)\n", + " .VLine.Now :VLine [x,y]\n", + " .Curve.BaselineLast :Curve [x] (y)\n", + " .Curve.RANP :Curve [x] (y)\n", + " .Curve.LSTM :Curve [x] (y)\n", + " .Curve.LSTMSeq2Seq :Curve [x] (y)\n", + " .Curve.TransformerSeq2Seq :Curve [x] (y)\n", + " .Curve.Transformer :Curve [x] (y)\n", + " .Curve.TransformerProcess :Curve [x] (y)\n", + " .Overlay.III :Overlay\n", + " .Scatter.True :Scatter [x] (y)\n", + " .VLine.Now :VLine [x,y]\n", + " .Curve.BaselineLast :Curve [x] (y)\n", + " .Curve.RANP :Curve [x] (y)\n", + " .Curve.LSTM :Curve [x] (y)\n", + " .Curve.LSTMSeq2Seq :Curve [x] (y)\n", + " .Curve.TransformerSeq2Seq :Curve [x] (y)\n", + " .Curve.Transformer :Curve [x] (y)\n", + " .Curve.TransformerProcess :Curve [x] (y)\n", + " .Overlay.IV :Overlay\n", + " .Scatter.True :Scatter [x] (y)\n", + " .VLine.Now :VLine [x,y]\n", + " .Curve.BaselineLast :Curve [x] (y)\n", + " .Curve.RANP :Curve [x] (y)\n", + " .Curve.LSTM :Curve [x] (y)\n", + " .Curve.LSTMSeq2Seq :Curve [x] (y)\n", + " .Curve.TransformerSeq2Seq :Curve [x] (y)\n", + " .Curve.Transformer :Curve [x] (y)\n", + " .Curve.TransformerProcess :Curve [x] (y)\n", + " .Overlay.V :Overlay\n", + " .Scatter.True :Scatter [x] (y)\n", + " .VLine.Now :VLine [x,y]\n", + " .Curve.BaselineLast :Curve [x] (y)\n", + " .Curve.RANP :Curve [x] (y)\n", + " .Curve.LSTM :Curve [x] (y)\n", + " .Curve.LSTMSeq2Seq :Curve [x] (y)\n", + " .Curve.TransformerSeq2Seq :Curve [x] (y)\n", + " .Curve.Transformer :Curve [x] (y)\n", + " .Curve.TransformerProcess :Curve [x] (y)" ] }, - "metadata": {}, - "output_type": "display_data" + "execution_count": 174, + "metadata": { + "application/vnd.holoviews_exec.v0+json": { + "id": "99955" + } + }, + "output_type": "execute_result" } ], "source": [ "# Plot mean of predictions\n", + "n = hv.Layout()\n", "for dataset in results.keys():\n", + " d = next(iter(results[dataset].values())).isel(t_source=data_i)\n", + " p = hv_plot_true(d)\n", " for model in results[dataset].keys():\n", " ds_preds = results[dataset][model]\n", - " plot_prediction(ds_preds, data_i, label=f\"{model}\")\n", - " plt.title(dataset)\n", - " plt.legend()\n", - " plt.show()" + " d = ds_preds.isel(t_source=data_i)\n", + " p *= hv_plot_pred(d).relabel(label=f\"{model}\")\n", + " n += p.opts(title=dataset, legend_position='top_left')\n", + "n.cols(1).opts(shared_axes=False)" ] }, { "cell_type": "code", - "execution_count": null, + "execution_count": 190, "metadata": { "ExecuteTime": { - "end_time": "2020-10-30T06:25:32.323248Z", - "start_time": "2020-10-30T06:25:31.817Z" - }, - "lines_to_next_cell": 0, - "scrolled": false - }, - "outputs": [], - "source": [ - "dataset='BejingPM25'\n", - "n = len(results[dataset].keys())\n", - "\n", - "plt.figure(figsize=(8, 1.5*n))\n", - "plt.suptitle(f'Plots with confidence for {dataset} ')\n", - "for i, model in enumerate(results[dataset].keys()):\n", - " plt.subplot(n, 1, i+1)\n", - " ds_preds = results[dataset][model]\n", - " if i==n-1:\n", - " # The last one has the legend\n", - " plot_prediction(ds_preds, data_i, title=f\"{model}\", std=True, legend=True)\n", - " else:\n", - " plot_prediction(ds_preds, data_i, title=f\"{model}\", std=True, )\n", - " \n", - " # share the x axis\n", - " locs, _ = plt.xticks()\n", - " plt.xticks(locs, labels=[])\n", - " plt.xlabel(None)\n", - "plt.subplots_adjust()" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": { - "ExecuteTime": { - "end_time": "2020-10-29T13:27:13.984227Z", - "start_time": "2020-10-29T13:27:13.901770Z" + "end_time": "2020-10-30T14:07:27.883279Z", + "start_time": "2020-10-30T14:07:24.819203Z" } }, - "outputs": [], - "source": [] + "outputs": [ + { + "data": {}, + "metadata": {}, + "output_type": "display_data" + }, + { + "data": { + "application/vnd.holoviews_exec.v0+json": "", + "text/html": [ + "
\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "
\n", + "
\n", + "" + ], + "text/plain": [ + ":Layout\n", + " .Overlay.I :Overlay\n", + " .Scatter.True :Scatter [x] (y)\n", + " .VLine.Now :VLine [x,y]\n", + " .Curve.Pred :Curve [x] (y)\n", + " .Spread.A_2_times_std :Spread [x] (y,yerror)\n", + " .Overlay.II :Overlay\n", + " .Scatter.True :Scatter [x] (y)\n", + " .VLine.Now :VLine [x,y]\n", + " .Curve.Pred :Curve [x] (y)\n", + " .Spread.A_2_times_std :Spread [x] (y,yerror)\n", + " .Overlay.III :Overlay\n", + " .Scatter.True :Scatter [x] (y)\n", + " .VLine.Now :VLine [x,y]\n", + " .Curve.Pred :Curve [x] (y)\n", + " .Spread.A_2_times_std :Spread [x] (y,yerror)\n", + " .Overlay.IV :Overlay\n", + " .Scatter.True :Scatter [x] (y)\n", + " .VLine.Now :VLine [x,y]\n", + " .Curve.Pred :Curve [x] (y)\n", + " .Spread.A_2_times_std :Spread [x] (y,yerror)\n", + " .Overlay.V :Overlay\n", + " .Scatter.True :Scatter [x] (y)\n", + " .VLine.Now :VLine [x,y]\n", + " .Curve.Pred :Curve [x] (y)\n", + " .Spread.A_2_times_std :Spread [x] (y,yerror)\n", + " .Overlay.VI :Overlay\n", + " .Scatter.True :Scatter [x] (y)\n", + " .VLine.Now :VLine [x,y]\n", + " .Curve.Pred :Curve [x] (y)\n", + " .Spread.A_2_times_std :Spread [x] (y,yerror)\n", + " .Overlay.VII :Overlay\n", + " .Scatter.True :Scatter [x] (y)\n", + " .VLine.Now :VLine [x,y]\n", + " .Curve.Pred :Curve [x] (y)\n", + " .Spread.A_2_times_std :Spread [x] (y,yerror)" + ] + }, + "execution_count": 190, + "metadata": { + "application/vnd.holoviews_exec.v0+json": { + "id": "119739" + } + }, + "output_type": "execute_result" + } + ], + "source": [ + "dataset='BejingPM25'\n", + "n = hv.Layout()\n", + "for i, model in enumerate(results[dataset].keys()):\n", + " ds_preds = results[dataset][model]\n", + " d = ds_preds.isel(t_source=data_i)\n", + " p = hv_plot_true(d)\n", + " p *= hv_plot_pred(d).relabel('pred')\n", + " p *= hv_plot_std(d)\n", + " n += p.opts(title=f'{dataset} {model}', legend_position='top_left')\n", + "n.cols(1)" + ] + }, + { + "cell_type": "code", + "execution_count": 222, + "metadata": { + "ExecuteTime": { + "end_time": "2020-10-30T14:14:24.464233Z", + "start_time": "2020-10-30T14:14:22.441273Z" + } + }, + "outputs": [ + { + "data": {}, + "metadata": {}, + "output_type": "display_data" + }, + { + "data": { + "application/vnd.holoviews_exec.v0+json": "", + "text/html": [ + "
\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "
\n", + "
\n", + "" + ], + "text/plain": [ + ":Overlay\n", + " .Scatter.True :Scatter [x] (y)\n", + " .VLine.Now :VLine [x,y]\n", + " .Curve.I :Curve [x] (y)\n", + " .Spread.A_2_times_std :Spread [x] (y,yerror)" + ] + }, + "metadata": { + "application/vnd.holoviews_exec.v0+json": { + "id": "129752" + } + }, + "output_type": "display_data" + }, + { + "data": {}, + "metadata": {}, + "output_type": "display_data" + }, + { + "data": { + "application/vnd.holoviews_exec.v0+json": "", + "text/html": [ + "
\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "
\n", + "
\n", + "" + ], + "text/plain": [ + ":Curve [hours ahead] (nll)" + ] + }, + "metadata": { + "application/vnd.holoviews_exec.v0+json": { + "id": "130205" + } + }, + "output_type": "display_data" + }, + { + "data": {}, + "metadata": {}, + "output_type": "display_data" + }, + { + "data": { + "application/vnd.holoviews_exec.v0+json": "", + "text/html": [ + "
\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "
\n", + "
\n", + "" + ], + "text/plain": [ + ":Curve [t_source] (nll)" + ] + }, + "metadata": { + "application/vnd.holoviews_exec.v0+json": { + "id": "130300" + } + }, + "output_type": "display_data" + }, + { + "data": {}, + "metadata": {}, + "output_type": "display_data" + }, + { + "data": { + "application/vnd.holoviews_exec.v0+json": "", + "text/html": [ + "
\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "
\n", + "
\n", + "" + ], + "text/plain": [ + ":DynamicMap []\n", + " :RGB [true,pred] (R,G,B,A)" + ] + }, + "metadata": { + "application/vnd.holoviews_exec.v0+json": { + "id": "130479" + } + }, + "output_type": "display_data" + } + ], + "source": [ + "plot_performance(ds_preds, full=True)" + ] }, { "cell_type": "code", @@ -4584,7 +3961,7 @@ "height": "calc(100% - 180px)", "left": "10px", "top": "150px", - "width": "209.197px" + "width": "209.2px" }, "toc_section_display": true, "toc_window_display": true diff --git a/notebooks/05.0-mc-leaderboard.py b/notebooks/05.0-mc-leaderboard.py index b12dd9c..897755b 100644 --- a/notebooks/05.0-mc-leaderboard.py +++ b/notebooks/05.0-mc-leaderboard.py @@ -35,6 +35,7 @@ # - [ ] val # - [ ] don't overfit # - [ ] TCN +# - [ ] make overlap between past and future # OPTIONAL: Load the "autoreload" extension so that code can change. But blacklist large modules # %load_ext autoreload @@ -84,9 +85,14 @@ import holoviews as hv from holoviews import opts from holoviews.operation.datashader import datashade, dynspread hv.extension('bokeh') +from seq2seq_time.visualization.hv_ggplot import ggplot_theme +hv.renderer('bokeh').theme = ggplot_theme # holoview datashader timeseries options -# %opts RGB [width=800 height=200 active_tools=["xwheel_zoom"] default_tools=["xpan","xwheel_zoom", "reset"] toolbar="right"] +# %opts RGB [width=800 height=200 show_grid=True active_tools=["xwheel_zoom"] default_tools=["xpan","xwheel_zoom", "reset"] toolbar="right"] +# %opts Curve [width=800 height=200 show_grid=True active_tools=["xwheel_zoom"] default_tools=["xpan","xwheel_zoom", "reset"] toolbar="right"] +# %opts Scatter [width=800 height=200 show_grid=True active_tools=["xwheel_zoom"] default_tools=["xpan","xwheel_zoom", "reset"] toolbar="right"] +# %opts Layout [width=800 height=200] # - @@ -108,95 +114,115 @@ freq = '30T' max_rows = 5e5 datasets_root = Path('../data/processed/') window_past + + # - # ## Plot helpers - - # + -def plot_prediction(ds_preds, i, ax=None, title='', std=False, label='pred', legend=False): - """Plot a prediction into the future, at a single point in time.""" - d = ds_preds.isel(t_source=i) +def hv_plot_std(d: xr.Dataset): + xf = d.t_target + yp = d.y_pred + s = d.y_pred_std + return hv.Spread((xf, yp, s * 2), + label='2*std').opts(alpha=0.5, line_width=0) +def hv_plot_pred(d: xr.Dataset): # Get arrays xf = d.t_target yp = d.y_pred s = d.y_pred_std - yt = d.y_true - now = d.t_source.squeeze() + return hv.Curve({'x': xf, 'y': yp}) + +def hv_plot_true(d: xr.Dataset): + """Plot a prediction into the future, at a single point in time.""" - plt.scatter(xf, yt, c='k', s=6, label='true' if legend else None) - ylim = plt.ylim() + # Plot true + x = np.concatenate([d.t_past, d.t_target]) + yt = np.concatenate([d.y_past, d.y_true]) + p = hv.Scatter({ + 'x': x, + 'y': yt + }, label='true').opts(color='black') - # plot prediction - if std: - plt.fill_between(xf, yp-2*s, yp+2*s, alpha=0.25, - facecolor="b", - interpolate=True, - label="2 std" if legend else None,) - plt.plot(xf, yp, label=label) - # plot true - plt.scatter( - d.t_past, - d.y_past, - c='k', - s=6 + + now=pd.Timestamp(d.t_source.squeeze().values) + + p = p.opts( + ylabel=ds_preds.attrs['targets'], + xlabel=f'{now}' ) + # plot a red line for now - plt.vlines(x=now, ymin=ylim[0], ymax=ylim[1], color='grey', ls='--') - plt.ylim(*ylim) + p *= hv.VLine(now, label='now').opts(color='red', framewise=True) - now=pd.Timestamp(now.values) - plt.title(title or f'Prediction NLL={d.nll.mean().item():2.2g}') - plt.xticks(rotation=0) - if legend: - plt.legend() - plt.xlabel(f'{now}') - plt.ylabel(ds_preds.attrs['targets']) - return now + return p -def plot_performance(ds_preds, full=False): - """Multiple plots using xr_preds""" - plot_prediction(ds_preds, 24, std=True, legend=True) - plt.show() - - ds_preds.mean('t_source').plot.scatter('t_ahead_hours', 'nll') # Mean over all predictions - n = len(ds_preds.t_source) - plt.ylabel('NLL (lower is better)') - plt.xlabel('Hours ahead') - plt.title(f'NLL vs time ahead (no. samples={n})') - plt.show() - - # Make a plot of the NLL over time. Does this solution get worse with time? - if full: - d = ds_preds.mean('t_ahead').groupby('t_source').mean().plot.scatter('t_source', 'nll') - plt.xticks(rotation=45) - plt.title('NLL over source time (lower is better)') - plt.show() - - # A scatter plot is easy with xarray - if full: - plt.figure(figsize=(5, 5)) - ds_preds.plot.scatter('y_true', 'y_pred', s=.01) - plt.show() +def hv_plot_prediction(d): + p = hv_plot_true(d) + p *= hv_plot_pred(d) + p *= hv_plot_std(d) + return p # - + +def plot_performance(ds_preds, full=False): + """Multiple plots using xr_preds""" + p = hv_plot_prediction(ds_preds.isel(t_source=10)) + display(p) + + n = len(ds_preds.t_source) + d_ahead = ds_preds.mean(['t_source'])['nll'].groupby('t_ahead_hours').mean() + nll_vs_tahead = (hv.Curve( + (d_ahead.t_ahead_hours, + d_ahead)).redim(x='hours ahead', + y='nll').opts( + title=f'NLL vs time ahead (no. samples={n})')) + display(nll_vs_tahead) + + # Make a plot of the NLL over time. Does this solution get worse with time? + if full: + d_source = ds_preds.mean(['t_ahead'])['nll'].groupby('t_source').mean() + nll_vs_time = (hv.Curve(d_source).opts( + title='Error vs time of prediction')) + display(nll_vs_time) + + # A scatter plot is easy with xarray + if full: + tlim = (ds_preds.y_true.min().item(), ds_preds.y_true.max().item()) + true_vs_pred = datashade(hv.Scatter( + (ds_preds.y_true, + ds_preds.y_pred))).redim(x='true', y='pred').opts(width=400, + height=400, + xlim=tlim, + ylim=tlim, + title='Scatter plot') + true_vs_pred = dynspread(true_vs_pred) + true_vs_pred + display(true_vs_pred) def plot_hist(trainer): try: df_hist = pd.read_csv(trainer.logger.experiment.metrics_file_path) df_hist['epoch'] = df_hist['epoch'].ffill() + df_histe = df_hist.set_index('epoch').groupby('epoch').mean() if len(df_histe)>1: - df_histe[['loss/train', 'loss/val']].plot(title='history') - plt.show() + p = hv.Curve(df_histe, kdims=['epoch'], vdims=['loss/train']).relabel('train') + p *= hv.Curve(df_histe, kdims=['epoch'], vdims=['loss/val']).relabel('val') + display(p.opts(ylabel='loss')) return df_histe - except Exception: + except Exception as e: + print(e) pass + +df_hist = plot_hist(trainer) +df_hist + # ## Datasets @@ -351,18 +377,7 @@ models = [ from collections import defaultdict results = defaultdict(dict) -# + -# tmp -model = Transformer(input_size, - output_size, - attention_dropout=0.4, - nhead=2, - nlayers=4, - hidden_size=16) -x_past, y_past, x_future, y_future = next(iter(dl_val)) -model(x_past, y_past, x_future, y_future) -# - from seq2seq_time.metrics import rmse, smape @@ -398,8 +413,8 @@ for Dataset in datasets: # Wrap in lightning patience = 3 model = PL_MODEL(pt_model, - lr=3e-3, patience=patience, - weight_decay=1e-5).to(device) + lr=3e-4, patience=patience, + weight_decay=4e-5).to(device) # Trainer trainer = pl.Trainer( @@ -479,7 +494,15 @@ d.style.apply(bold_min) print(f'Symmetric mean absolute percentage error (SMAPE)\nover {window_future} steps') d=df_results.xs('smape', level=1).T.round(2) d.style.apply(bold_min) + # # Plots + + + + + +# + + # # plots # Load saved preds results = defaultdict(dict) @@ -495,42 +518,34 @@ for Dataset in datasets: if len(fs)>0: ds_preds = xr.open_dataset(fs[-1]) results[dataset_name][model_name] = ds_preds +# - data_i = 100 - - # Plot mean of predictions +n = hv.Layout() for dataset in results.keys(): + d = next(iter(results[dataset].values())).isel(t_source=data_i) + p = hv_plot_true(d) for model in results[dataset].keys(): ds_preds = results[dataset][model] - plot_prediction(ds_preds, data_i, label=f"{model}") - plt.title(dataset) - plt.legend() - plt.show() + d = ds_preds.isel(t_source=data_i) + p *= hv_plot_pred(d).relabel(label=f"{model}") + n += p.opts(title=dataset, legend_position='top_left') +n.cols(1).opts(shared_axes=False) -# + dataset='BejingPM25' -n = len(results[dataset].keys()) - -plt.figure(figsize=(8, 1.5*n)) -plt.suptitle(f'Plots with confidence for {dataset} ') +n = hv.Layout() for i, model in enumerate(results[dataset].keys()): - plt.subplot(n, 1, i+1) ds_preds = results[dataset][model] - if i==n-1: - # The last one has the legend - plot_prediction(ds_preds, data_i, title=f"{model}", std=True, legend=True) - else: - plot_prediction(ds_preds, data_i, title=f"{model}", std=True, ) - - # share the x axis - locs, _ = plt.xticks() - plt.xticks(locs, labels=[]) - plt.xlabel(None) -plt.subplots_adjust() -# - - + d = ds_preds.isel(t_source=data_i) + p = hv_plot_true(d) + p *= hv_plot_pred(d).relabel('pred') + p *= hv_plot_std(d) + n += p.opts(title=f'{dataset} {model}', legend_position='top_left') +n.cols(1) + +plot_performance(ds_preds, full=True) diff --git a/seq2seq_time/visualization/hv_ggplot.py b/seq2seq_time/visualization/hv_ggplot.py new file mode 100644 index 0000000..6f8fe95 --- /dev/null +++ b/seq2seq_time/visualization/hv_ggplot.py @@ -0,0 +1,81 @@ +#----------------------------------------------------------------------------- +# Copyright (c) 2012 - 2020, Anaconda, Inc., and Bokeh Contributors. +# All rights reserved. +# +# The full license is in the file LICENSE.txt, distributed with this software. +#----------------------------------------------------------------------------- +# see https://raw.githubusercontent.com/bokeh/bokeh/ffdd1114e6aace02bb6c61748390e9b62522a8d9/bokeh/themes/_ggplot.py +# see https://github.com/bokeh/bokeh/pull/10150 +from bokeh.themes import Theme +json = { + "attrs": { + "Figure" : { + "background_fill_color": "#E5E5E5", + "border_fill_color": "#FFFFFF", + "outline_line_color": "#000000", + "outline_line_alpha": 0.25 + }, + + "Grid": { + "grid_line_color": "#FFFFFF", + "grid_line_alpha": 1, + }, + + "Axis": { + "major_tick_line_alpha": 0.3, + "major_tick_line_color": "#000000", + + "minor_tick_line_alpha": 0.4, + "minor_tick_line_color": "#000000", + + "axis_line_alpha": 1, + "axis_line_color": "#000000", + + "major_label_text_color": "#000000", + "major_label_text_font": "Helvetica", + "major_label_text_font_size": "1.025em", + + "axis_label_standoff": 10, + "axis_label_text_color": "#000000", + "axis_label_text_font": "Helvetica", + "axis_label_text_font_size": "1.25em", + "axis_label_text_font_style": "normal" + }, + + "Legend": { + "spacing": 8, + "glyph_width": 15, + + "label_standoff": 8, + "label_text_color": "#000000", + "label_text_font": "Arial", + "label_text_font_size": "0.95em", + + "border_line_alpha": 1, + "background_fill_alpha": 0.25, + "background_fill_color": "#000000" + }, + + "ColorBar": { + "title_text_color": "#E0E0E0", + "title_text_font": "Helvetica", + "title_text_font_size": "1.025em", + "title_text_font_style": "normal", + + "major_label_text_color": "#E0E0E0", + "major_label_text_font": "Arial", + "major_label_text_font_size": "1.025em", + + "background_fill_color": "#000000", + "major_tick_line_alpha": 0, + "bar_line_alpha": 0 + }, + + "Title": { + "text_color": "#000000", + "text_font": "Helvetica", + "text_font_size": "1.10em" + } + } +} +ggplot_theme = Theme(json=json)