diff --git a/notebooks/05.0-mc-leaderboard.ipynb b/notebooks/05.0-mc-leaderboard.ipynb index b37cbf5..f28c616 100644 --- a/notebooks/05.0-mc-leaderboard.ipynb +++ b/notebooks/05.0-mc-leaderboard.ipynb @@ -38,7 +38,8 @@ "- [ ] show test train\n", "- [ ] val\n", "- [ ] don't overfit\n", - "- [ ] TCN" + "- [ ] TCN\n", + "- [ ] make overlap between past and future" ] }, { @@ -46,8 +47,8 @@ "execution_count": 1, "metadata": { "ExecuteTime": { - "end_time": "2020-10-30T06:04:56.455187Z", - "start_time": "2020-10-30T06:04:56.043007Z" + "end_time": "2020-10-30T06:47:04.839427Z", + "start_time": "2020-10-30T06:47:04.401726Z" } }, "outputs": [], @@ -69,8 +70,8 @@ "execution_count": 2, "metadata": { "ExecuteTime": { - "end_time": "2020-10-30T06:04:58.117548Z", - "start_time": "2020-10-30T06:04:56.459312Z" + "end_time": "2020-10-30T06:47:06.443003Z", + "start_time": "2020-10-30T06:47:04.843915Z" } }, "outputs": [], @@ -101,8 +102,8 @@ "execution_count": 3, "metadata": { "ExecuteTime": { - "end_time": "2020-10-30T06:04:58.159177Z", - "start_time": "2020-10-30T06:04:58.123924Z" + "end_time": "2020-10-30T06:47:06.477685Z", + "start_time": "2020-10-30T06:47:06.448612Z" } }, "outputs": [], @@ -116,8 +117,8 @@ "execution_count": 4, "metadata": { "ExecuteTime": { - "end_time": "2020-10-30T06:04:58.207947Z", - "start_time": "2020-10-30T06:04:58.162906Z" + "end_time": "2020-10-30T06:47:06.512676Z", + "start_time": "2020-10-30T06:47:06.480950Z" } }, "outputs": [ @@ -141,8 +142,8 @@ "execution_count": 5, "metadata": { "ExecuteTime": { - "end_time": "2020-10-30T06:04:58.247847Z", - "start_time": "2020-10-30T06:04:58.212316Z" + "end_time": "2020-10-30T06:47:06.561982Z", + "start_time": "2020-10-30T06:47:06.516605Z" } }, "outputs": [], @@ -153,23 +154,15 @@ }, { "cell_type": "code", - "execution_count": 6, + "execution_count": 55, "metadata": { "ExecuteTime": { - "end_time": "2020-10-30T06:05:01.478145Z", - "start_time": "2020-10-30T06:04:58.251662Z" + "end_time": "2020-10-30T13:23:43.253896Z", + "start_time": "2020-10-30T13:23:42.642309Z" }, "lines_to_next_cell": 2 }, "outputs": [ - { - "name": "stderr", - "output_type": "stream", - "text": [ - "/home/wassname/anaconda/envs/seq2seq-time/lib/python3.7/site-packages/holoviews/operation/datashader.py:5: DeprecationWarning: Using or importing the ABCs from 'collections' instead of from 'collections.abc' is deprecated since Python 3.3,and in 3.9 it will stop working\n", - " from collections import Callable\n" - ] - }, { "data": { "application/javascript": [ @@ -1644,9 +1637,14 @@ "from holoviews import opts\n", "from holoviews.operation.datashader import datashade, dynspread\n", "hv.extension('bokeh')\n", + "from seq2seq_time.visualization.hv_ggplot import ggplot_theme\n", + "hv.renderer('bokeh').theme = ggplot_theme\n", "\n", "# holoview datashader timeseries options\n", - "%opts RGB [width=800 height=200 active_tools=[\"xwheel_zoom\"] default_tools=[\"xpan\",\"xwheel_zoom\", \"reset\"] toolbar=\"right\"]" + "%opts RGB [width=800 height=200 show_grid=True active_tools=[\"xwheel_zoom\"] default_tools=[\"xpan\",\"xwheel_zoom\", \"reset\"] toolbar=\"right\"]\n", + "%opts Curve [width=800 height=200 show_grid=True active_tools=[\"xwheel_zoom\"] default_tools=[\"xpan\",\"xwheel_zoom\", \"reset\"] toolbar=\"right\"]\n", + "%opts Scatter [width=800 height=200 show_grid=True active_tools=[\"xwheel_zoom\"] default_tools=[\"xpan\",\"xwheel_zoom\", \"reset\"] toolbar=\"right\"]\n", + "%opts Layout [width=800 height=200]" ] }, { @@ -1654,8 +1652,8 @@ "execution_count": 7, "metadata": { "ExecuteTime": { - "end_time": "2020-10-30T06:05:01.534412Z", - "start_time": "2020-10-30T06:05:01.482602Z" + "end_time": "2020-10-30T06:47:09.979298Z", + "start_time": "2020-10-30T06:47:09.927575Z" } }, "outputs": [ @@ -1690,8 +1688,8 @@ "execution_count": 8, "metadata": { "ExecuteTime": { - "end_time": "2020-10-30T06:05:01.619419Z", - "start_time": "2020-10-30T06:05:01.539179Z" + "end_time": "2020-10-30T06:47:10.070962Z", + "start_time": "2020-10-30T06:47:09.983029Z" } }, "outputs": [ @@ -1737,105 +1735,118 @@ }, { "cell_type": "code", - "execution_count": null, + "execution_count": 191, "metadata": { "ExecuteTime": { - "end_time": "2020-10-29T13:29:14.460346Z", - "start_time": "2020-10-29T13:29:14.371348Z" + "end_time": "2020-10-30T14:08:39.113905Z", + "start_time": "2020-10-30T14:08:39.047075Z" } }, "outputs": [], - "source": [] + "source": [ + "def hv_plot_std(d: xr.Dataset):\n", + " xf = d.t_target\n", + " yp = d.y_pred\n", + " s = d.y_pred_std\n", + " return hv.Spread((xf, yp, s * 2),\n", + " label='2*std').opts(alpha=0.5, line_width=0)\n", + "\n", + "def hv_plot_pred(d: xr.Dataset):\n", + " # Get arrays\n", + " xf = d.t_target\n", + " yp = d.y_pred\n", + " s = d.y_pred_std\n", + " return hv.Curve({'x': xf, 'y': yp})\n", + "\n", + "def hv_plot_true(d: xr.Dataset):\n", + " \"\"\"Plot a prediction into the future, at a single point in time.\"\"\" \n", + " \n", + " # Plot true\n", + " x = np.concatenate([d.t_past, d.t_target])\n", + " yt = np.concatenate([d.y_past, d.y_true])\n", + " p = hv.Scatter({\n", + " 'x': x,\n", + " 'y': yt\n", + " }, label='true').opts(color='black')\n", + "\n", + "\n", + " \n", + " now=pd.Timestamp(d.t_source.squeeze().values)\n", + " \n", + " p = p.opts(\n", + " ylabel=ds_preds.attrs['targets'],\n", + " xlabel=f'{now}'\n", + " )\n", + "\n", + " \n", + " # plot a red line for now\n", + " p *= hv.VLine(now, label='now').opts(color='red', framewise=True)\n", + "\n", + " return p\n", + "\n", + "def hv_plot_prediction(d):\n", + " p = hv_plot_true(d)\n", + " p *= hv_plot_pred(d)\n", + " p *= hv_plot_std(d)\n", + " return p" + ] }, { "cell_type": "code", - "execution_count": 9, + "execution_count": 220, "metadata": { "ExecuteTime": { - "end_time": "2020-10-30T06:05:01.703395Z", - "start_time": "2020-10-30T06:05:01.629810Z" + "end_time": "2020-10-30T14:13:40.986924Z", + "start_time": "2020-10-30T14:13:40.901669Z" }, "lines_to_end_of_cell_marker": 2, "lines_to_next_cell": 0 }, "outputs": [], "source": [ - "def plot_prediction(ds_preds, i, ax=None, title='', std=False, label='pred', legend=False):\n", - " \"\"\"Plot a prediction into the future, at a single point in time.\"\"\" \n", - " d = ds_preds.isel(t_source=i)\n", - "\n", - " # Get arrays\n", - " xf = d.t_target\n", - " yp = d.y_pred\n", - " s = d.y_pred_std\n", - " yt = d.y_true\n", - " now = d.t_source.squeeze()\n", - " \n", - " plt.scatter(xf, yt, c='k', s=6, label='true' if legend else None)\n", - " ylim = plt.ylim()\n", - "\n", - " # plot prediction\n", - " if std:\n", - " plt.fill_between(xf, yp-2*s, yp+2*s, alpha=0.25,\n", - " facecolor=\"b\",\n", - " interpolate=True,\n", - " label=\"2 std\" if legend else None,)\n", - " plt.plot(xf, yp, label=label)\n", - "\n", - " # plot true\n", - " plt.scatter(\n", - " d.t_past,\n", - " d.y_past,\n", - " c='k',\n", - " s=6\n", - " )\n", - " \n", - " # plot a red line for now\n", - " plt.vlines(x=now, ymin=ylim[0], ymax=ylim[1], color='grey', ls='--')\n", - " plt.ylim(*ylim)\n", - "\n", - " now=pd.Timestamp(now.values)\n", - " plt.title(title or f'Prediction NLL={d.nll.mean().item():2.2g}')\n", - " plt.xticks(rotation=0) \n", - " if legend:\n", - " plt.legend()\n", - " plt.xlabel(f'{now}')\n", - " plt.ylabel(ds_preds.attrs['targets'])\n", - " return now\n", - "\n", "def plot_performance(ds_preds, full=False):\n", " \"\"\"Multiple plots using xr_preds\"\"\"\n", - " plot_prediction(ds_preds, 24, std=True, legend=True)\n", - " plt.show()\n", + " p = hv_plot_prediction(ds_preds.isel(t_source=10))\n", + " display(p)\n", "\n", - " ds_preds.mean('t_source').plot.scatter('t_ahead_hours', 'nll') # Mean over all predictions\n", " n = len(ds_preds.t_source)\n", - " plt.ylabel('NLL (lower is better)')\n", - " plt.xlabel('Hours ahead')\n", - " plt.title(f'NLL vs time ahead (no. samples={n})')\n", - " plt.show()\n", + " d_ahead = ds_preds.mean(['t_source'])['nll'].groupby('t_ahead_hours').mean()\n", + " nll_vs_tahead = (hv.Curve(\n", + " (d_ahead.t_ahead_hours,\n", + " d_ahead)).redim(x='hours ahead',\n", + " y='nll').opts(\n", + " title=f'NLL vs time ahead (no. samples={n})'))\n", + " display(nll_vs_tahead)\n", "\n", " # Make a plot of the NLL over time. Does this solution get worse with time?\n", " if full:\n", - " d = ds_preds.mean('t_ahead').groupby('t_source').mean().plot.scatter('t_source', 'nll')\n", - " plt.xticks(rotation=45)\n", - " plt.title('NLL over source time (lower is better)')\n", - " plt.show()\n", + " d_source = ds_preds.mean(['t_ahead'])['nll'].groupby('t_source').mean()\n", + " nll_vs_time = (hv.Curve(d_source).opts(\n", + " title='Error vs time of prediction'))\n", + " display(nll_vs_time)\n", "\n", " # A scatter plot is easy with xarray\n", " if full:\n", - " plt.figure(figsize=(5, 5))\n", - " ds_preds.plot.scatter('y_true', 'y_pred', s=.01)\n", - " plt.show()" + " tlim = (ds_preds.y_true.min().item(), ds_preds.y_true.max().item())\n", + " true_vs_pred = datashade(hv.Scatter(\n", + " (ds_preds.y_true,\n", + " ds_preds.y_pred))).redim(x='true', y='pred').opts(width=400,\n", + " height=400,\n", + " xlim=tlim,\n", + " ylim=tlim,\n", + " title='Scatter plot')\n", + " true_vs_pred = dynspread(true_vs_pred)\n", + " true_vs_pred\n", + " display(true_vs_pred)" ] }, { "cell_type": "code", - "execution_count": 10, + "execution_count": 279, "metadata": { "ExecuteTime": { - "end_time": "2020-10-30T06:05:01.757797Z", - "start_time": "2020-10-30T06:05:01.707827Z" + "end_time": "2020-10-30T14:20:32.513001Z", + "start_time": "2020-10-30T14:20:32.427238Z" } }, "outputs": [], @@ -1844,15 +1855,193 @@ " try:\n", " df_hist = pd.read_csv(trainer.logger.experiment.metrics_file_path)\n", " df_hist['epoch'] = df_hist['epoch'].ffill()\n", + " \n", " df_histe = df_hist.set_index('epoch').groupby('epoch').mean()\n", " if len(df_histe)>1:\n", - " df_histe[['loss/train', 'loss/val']].plot(title='history')\n", - " plt.show()\n", + " p = hv.Curve(df_histe, kdims=['epoch'], vdims=['loss/train']).relabel('train')\n", + " p *= hv.Curve(df_histe, kdims=['epoch'], vdims=['loss/val']).relabel('val')\n", + " display(p.opts(ylabel='loss'))\n", " return df_histe\n", - " except Exception:\n", + " except Exception as e:\n", + " print(e)\n", " pass" ] }, + { + "cell_type": "code", + "execution_count": 280, + "metadata": { + "ExecuteTime": { + "end_time": "2020-10-30T14:20:32.934663Z", + "start_time": "2020-10-30T14:20:32.670013Z" + } + }, + "outputs": [ + { + "data": {}, + "metadata": {}, + "output_type": "display_data" + }, + { + "data": { + "application/vnd.holoviews_exec.v0+json": "", + "text/html": [ + "
| \n", + " | loss/train | \n", + "model_loss/train | \n", + "step | \n", + "loss/val | \n", + "
|---|---|---|---|---|
| epoch | \n", + "\n", + " | \n", + " | \n", + " | \n", + " |
| 0.0 | \n", + "1.306088 | \n", + "1.348648 | \n", + "191.857143 | \n", + "1.703918 | \n", + "
| 1.0 | \n", + "1.166388 | \n", + "1.201709 | \n", + "491.857143 | \n", + "2.081715 | \n", + "
| 2.0 | \n", + "1.147306 | \n", + "1.177450 | \n", + "791.857143 | \n", + "2.530396 | \n", + "
| 3.0 | \n", + "1.054349 | \n", + "1.083807 | \n", + "1091.857143 | \n", + "2.719204 | \n", + "
| 4.0 | \n", + "1.054608 | \n", + "1.082419 | \n", + "1391.857143 | \n", + "2.729672 | \n", + "
| 5.0 | \n", + "1.013083 | \n", + "1.040525 | \n", + "1691.857143 | \n", + "3.024244 | \n", + "
| 6.0 | \n", + "1.003189 | \n", + "1.033043 | \n", + "1991.857143 | \n", + "3.000752 | \n", + "