diff --git a/notebooks/02.0-mike-RNN_Timeseries_Seq2Seq.ipynb b/notebooks/02.0-mike-RNN_Timeseries_Seq2Seq.ipynb
index 929f798..b03d877 100644
--- a/notebooks/02.0-mike-RNN_Timeseries_Seq2Seq.ipynb
+++ b/notebooks/02.0-mike-RNN_Timeseries_Seq2Seq.ipynb
@@ -29,8 +29,8 @@
"execution_count": 1,
"metadata": {
"ExecuteTime": {
- "end_time": "2020-10-19T09:44:37.619727Z",
- "start_time": "2020-10-19T09:44:37.200497Z"
+ "end_time": "2020-10-19T13:24:29.474642Z",
+ "start_time": "2020-10-19T13:24:29.025860Z"
}
},
"outputs": [],
@@ -52,8 +52,8 @@
"execution_count": 2,
"metadata": {
"ExecuteTime": {
- "end_time": "2020-10-19T09:44:38.709728Z",
- "start_time": "2020-10-19T09:44:37.625401Z"
+ "end_time": "2020-10-19T13:24:30.583721Z",
+ "start_time": "2020-10-19T13:24:29.478450Z"
}
},
"outputs": [],
@@ -83,8 +83,8 @@
"execution_count": 3,
"metadata": {
"ExecuteTime": {
- "end_time": "2020-10-19T09:44:38.741179Z",
- "start_time": "2020-10-19T09:44:38.714268Z"
+ "end_time": "2020-10-19T13:24:30.616617Z",
+ "start_time": "2020-10-19T13:24:30.588643Z"
}
},
"outputs": [],
@@ -95,11 +95,11 @@
},
{
"cell_type": "code",
- "execution_count": 78,
+ "execution_count": 4,
"metadata": {
"ExecuteTime": {
- "end_time": "2020-10-19T12:34:06.047830Z",
- "start_time": "2020-10-19T12:34:05.980011Z"
+ "end_time": "2020-10-19T13:24:31.183770Z",
+ "start_time": "2020-10-19T13:24:30.622322Z"
}
},
"outputs": [
@@ -122,8 +122,8 @@
"execution_count": 5,
"metadata": {
"ExecuteTime": {
- "end_time": "2020-10-19T09:44:39.357695Z",
- "start_time": "2020-10-19T09:44:39.306658Z"
+ "end_time": "2020-10-19T13:24:31.218369Z",
+ "start_time": "2020-10-19T13:24:31.187185Z"
}
},
"outputs": [
@@ -158,8 +158,8 @@
"execution_count": 6,
"metadata": {
"ExecuteTime": {
- "end_time": "2020-10-19T09:44:39.442584Z",
- "start_time": "2020-10-19T09:44:39.362549Z"
+ "end_time": "2020-10-19T13:24:31.286536Z",
+ "start_time": "2020-10-19T13:24:31.222958Z"
}
},
"outputs": [
@@ -196,8 +196,8 @@
"execution_count": 7,
"metadata": {
"ExecuteTime": {
- "end_time": "2020-10-19T09:44:39.507875Z",
- "start_time": "2020-10-19T09:44:39.446523Z"
+ "end_time": "2020-10-19T13:24:31.338216Z",
+ "start_time": "2020-10-19T13:24:31.290344Z"
},
"lines_to_next_cell": 0
},
@@ -290,8 +290,8 @@
"execution_count": 8,
"metadata": {
"ExecuteTime": {
- "end_time": "2020-10-19T09:44:57.542614Z",
- "start_time": "2020-10-19T09:44:39.513851Z"
+ "end_time": "2020-10-19T13:24:49.503564Z",
+ "start_time": "2020-10-19T13:24:31.344006Z"
},
"lines_to_next_cell": 0,
"scrolled": true
@@ -319,18 +319,18 @@
"name": "stdout",
"output_type": "stream",
"text": [
- "block_104 26161\n",
- "block_1 26161\n",
- "block_102 26161\n",
"block_107 26161\n",
- "block_105 26161\n",
- "block_0 26161\n",
- "block_106 26161\n",
"block_100 26161\n",
"block_10 26161\n",
- "block_108 26161\n",
+ "block_1 26161\n",
+ "block_105 26161\n",
+ "block_0 26161\n",
+ "block_102 26161\n",
"block_103 26161\n",
+ "block_108 26161\n",
+ "block_106 26161\n",
"block_101 26161\n",
+ "block_104 26161\n",
"Name: block, dtype: int64\n"
]
},
@@ -749,8 +749,8 @@
"execution_count": 9,
"metadata": {
"ExecuteTime": {
- "end_time": "2020-10-19T09:45:00.737432Z",
- "start_time": "2020-10-19T09:44:57.552213Z"
+ "end_time": "2020-10-19T13:24:52.741132Z",
+ "start_time": "2020-10-19T13:24:49.518745Z"
},
"lines_to_next_cell": 2
},
@@ -2290,8 +2290,8 @@
"execution_count": 10,
"metadata": {
"ExecuteTime": {
- "end_time": "2020-10-19T09:45:00.790696Z",
- "start_time": "2020-10-19T09:45:00.743705Z"
+ "end_time": "2020-10-19T13:24:52.807091Z",
+ "start_time": "2020-10-19T13:24:52.747179Z"
}
},
"outputs": [
@@ -2322,8 +2322,8 @@
"execution_count": 11,
"metadata": {
"ExecuteTime": {
- "end_time": "2020-10-19T09:45:01.141061Z",
- "start_time": "2020-10-19T09:45:00.796566Z"
+ "end_time": "2020-10-19T13:24:53.146412Z",
+ "start_time": "2020-10-19T13:24:52.811284Z"
}
},
"outputs": [
@@ -2579,8 +2579,8 @@
"execution_count": 12,
"metadata": {
"ExecuteTime": {
- "end_time": "2020-10-19T09:45:01.989650Z",
- "start_time": "2020-10-19T09:45:01.146532Z"
+ "end_time": "2020-10-19T13:24:53.972592Z",
+ "start_time": "2020-10-19T13:24:53.151057Z"
}
},
"outputs": [
@@ -2965,8 +2965,8 @@
"execution_count": 13,
"metadata": {
"ExecuteTime": {
- "end_time": "2020-10-19T09:45:02.066974Z",
- "start_time": "2020-10-19T09:45:01.994653Z"
+ "end_time": "2020-10-19T13:24:54.040879Z",
+ "start_time": "2020-10-19T13:24:53.976041Z"
}
},
"outputs": [
@@ -3006,8 +3006,8 @@
"execution_count": 14,
"metadata": {
"ExecuteTime": {
- "end_time": "2020-10-19T09:45:02.328455Z",
- "start_time": "2020-10-19T09:45:02.071659Z"
+ "end_time": "2020-10-19T13:24:54.291172Z",
+ "start_time": "2020-10-19T13:24:54.045260Z"
}
},
"outputs": [
@@ -3037,8 +3037,8 @@
"execution_count": 15,
"metadata": {
"ExecuteTime": {
- "end_time": "2020-10-19T09:45:02.388481Z",
- "start_time": "2020-10-19T09:45:02.332008Z"
+ "end_time": "2020-10-19T13:24:54.344965Z",
+ "start_time": "2020-10-19T13:24:54.295630Z"
}
},
"outputs": [],
@@ -3055,8 +3055,8 @@
"execution_count": 16,
"metadata": {
"ExecuteTime": {
- "end_time": "2020-10-19T09:45:05.419601Z",
- "start_time": "2020-10-19T09:45:02.392108Z"
+ "end_time": "2020-10-19T13:24:57.313123Z",
+ "start_time": "2020-10-19T13:24:54.348463Z"
}
},
"outputs": [
@@ -3083,12 +3083,12 @@
"\n",
"\n",
"\n",
- "
\n",
+ " \n",
"\n",
""
- ],
- "text/plain": [
- ":DynamicMap [t_source]\n",
- " :Overlay\n",
- " .Scatter.True :Scatter [x] (y)\n",
- " .Curve.Pred :Curve [x] (y)\n",
- " .Area.A_2_times_std :Area [x] (y,y2)\n",
- " .VLine.Now :VLine [x,y]"
- ]
- },
- "execution_count": 32,
- "metadata": {
- "application/vnd.holoviews_exec.v0+json": {
- "id": "1443"
- }
- },
- "output_type": "execute_result"
- }
- ],
+ "outputs": [],
"source": [
"def plot_prediction_now(t_source):\n",
" \"\"\"Plot predictions with holoviews\"\"\"\n",
"\n",
" # Let us pass in an int\n",
" if isinstance(t_source, int):\n",
- " t_source = ds_preds.t_source[t_source].to_pandas()\n",
+ " t_source = ds_pred_block.t_source[t_source].to_pandas()\n",
"\n",
- " d = ds_preds.sel(t_source=t_source)\n",
+ " d = ds_pred_block.sel(t_source=t_source)\n",
"\n",
" # Sometimes there are duplicate times, take the first\n",
" if len(d.t_source.shape) and d.t_source.shape[0] > 0:\n",
@@ -5500,7 +5410,7 @@
"\n",
"\n",
"dmap_pred = (hv.DynamicMap(plot_prediction_now, kdims=['t_source'])\n",
- " .redim.values(t_source=ds_preds.t_source.to_pandas())\n",
+ " .redim.values(t_source=ds_pred_block.t_source.to_pandas())\n",
" .opts(width=800,\n",
" height=300, \n",
" ))\n",
@@ -5509,350 +5419,102 @@
},
{
"cell_type": "code",
- "execution_count": 45,
+ "execution_count": null,
"metadata": {
"ExecuteTime": {
- "end_time": "2020-10-19T10:57:43.721383Z",
- "start_time": "2020-10-19T10:57:42.548727Z"
+ "end_time": "2020-10-19T13:53:20.561774Z",
+ "start_time": "2020-10-19T13:24:29.300Z"
}
},
- "outputs": [
- {
- "name": "stderr",
- "output_type": "stream",
- "text": [
- "/home/wassname/anaconda/envs/seq2seq-time/lib/python3.7/site-packages/ipykernel/ipkernel.py:287: DeprecationWarning: `should_run_async` will not call `transform_cell` automatically in the future. Please pass the result to `transformed_cell` argument and any exception that happen during thetransform in `preprocessing_exc_tuple` in IPython 7.17 and above.\n",
- " and should_run_async(code)\n"
- ]
- },
- {
- "name": "stdout",
- "output_type": "stream",
- "text": [
- "\n",
- "Dimensions: (t_behind: 96, t_source: 274)\n",
- "Coordinates:\n",
- " t_ahead timedelta64[ns] 00:00:00\n",
- " * t_behind (t_behind) timedelta64[ns] -2 days +00:00:00 ... -1 days +...\n",
- " t_ahead_hours float64 0.0\n",
- " * t_source (t_source) datetime64[ns] 2013-11-11 ... 2013-11-16T16:30:00\n",
- "Data variables:\n",
- " y_past (t_source, t_behind) float32 16.394 13.901 ... 15.281 16.59\n",
- " nll (t_source) float32 -0.42425382 1.825292 ... 0.45063046\n",
- " y_pred (t_source) float32 16.014568 24.125784 ... 21.401958\n",
- " y_pred_std (t_source) float64 1.658 2.507 3.558 ... 1.876 2.285 2.487\n",
- " y_true (t_source) float32 16.267 19.307001 23.038 ... 16.59 18.963\n"
- ]
- },
- {
- "name": "stderr",
- "output_type": "stream",
- "text": [
- "/home/wassname/anaconda/envs/seq2seq-time/lib/python3.7/site-packages/bokeh/core/property/bases.py:241: DeprecationWarning: elementwise comparison failed; this will raise an error in the future.\n",
- " return new == old\n"
- ]
- },
- {
- "data": {
- "text/html": [
- "\n",
- "Dimensions: (t_behind: 96, t_source: 274)\n",
- "Coordinates:\n",
- " t_ahead timedelta64[ns] 04:30:00\n",
- " * t_behind (t_behind) timedelta64[ns] -2 days +00:00:00 ... -1 days +...\n",
- " t_ahead_hours float64 4.0\n",
- " * t_source (t_source) datetime64[ns] 2013-11-11 ... 2013-11-16T16:30:00\n",
- "Data variables:\n",
- " y_past (t_source, t_behind) float32 16.394 13.901 ... 15.281 16.59\n",
- " nll (t_source) float32 10.724 0.0037053227 ... 0.3161078\n",
- " y_pred (t_source) float32 10.76828 21.676603 ... 22.7903 22.420483\n",
- " y_pred_std (t_source) float64 0.6646 2.448 3.175 ... 2.839 2.836 2.874\n",
- " y_true (t_source) float32 14.034 22.449 ... 19.858002 20.595001\n",
- "\n",
- "Dimensions: (t_behind: 96, t_source: 274)\n",
- "Coordinates:\n",
- " t_ahead timedelta64[ns] 06:30:00\n",
- " * t_behind (t_behind) timedelta64[ns] -2 days +00:00:00 ... -1 days +...\n",
- " t_ahead_hours float64 6.0\n",
- " * t_source (t_source) datetime64[ns] 2013-11-11 ... 2013-11-16T16:30:00\n",
- "Data variables:\n",
- " y_past (t_source, t_behind) float32 16.394 13.901 ... 15.281 16.59\n",
- " nll (t_source) float32 0.7417145 0.36029536 ... 0.22455204\n",
- " y_pred (t_source) float32 15.744131 21.442015 ... 21.529743\n",
- " y_pred_std (t_source) float64 1.607 2.483 2.928 ... 2.9 2.847 2.859\n",
- " y_true (t_source) float32 18.243 19.243 27.181002 ... 21.012 22.904\n",
- "\n",
- "Dimensions: (t_behind: 96, t_source: 274)\n",
- "Coordinates:\n",
- " t_ahead timedelta64[ns] 09:00:00\n",
- " * t_behind (t_behind) timedelta64[ns] -2 days +00:00:00 ... -1 days +...\n",
- " t_ahead_hours float64 9.0\n",
- " * t_source (t_source) datetime64[ns] 2013-11-11 ... 2013-11-16T16:30:00\n",
- "Data variables:\n",
- " y_past (t_source, t_behind) float32 16.394 13.901 ... 15.281 16.59\n",
- " nll (t_source) float32 -0.07550305 0.59446204 ... 0.7492668\n",
- " y_pred (t_source) float32 23.871077 28.067057 ... 22.720411\n",
- " y_pred_std (t_source) float64 2.377 3.221 1.669 ... 2.932 3.001 3.046\n",
- " y_true (t_source) float32 23.867 25.310999 18.483 ... 16.969 19.449\n",
- "\n",
- "Dimensions: (t_behind: 96, t_source: 274)\n",
- "Coordinates:\n",
- " t_ahead timedelta64[ns] 10:30:00\n",
- " * t_behind (t_behind) timedelta64[ns] -2 days +00:00:00 ... -1 days +...\n",
- " t_ahead_hours float64 10.0\n",
- " * t_source (t_source) datetime64[ns] 2013-11-11 ... 2013-11-16T16:30:00\n",
- "Data variables:\n",
- " y_past (t_source, t_behind) float32 16.394 13.901 ... 15.281 16.59\n",
- " nll (t_source) float32 -0.08419824 2.6844807 ... 0.8413604\n",
- " y_pred (t_source) float32 22.059172 38.408897 ... 26.288809 28.46659\n",
- " y_pred_std (t_source) float64 2.337 3.194 1.007 ... 3.247 3.347 3.389\n",
- " y_true (t_source) float32 21.762001 31.318 ... 29.953001 32.059998\n",
- "\n",
- "Dimensions: (t_behind: 96, t_source: 274)\n",
- "Coordinates:\n",
- " t_ahead timedelta64[ns] 16:30:00\n",
- " * t_behind (t_behind) timedelta64[ns] -2 days +00:00:00 ... -1 days +...\n",
- " t_ahead_hours float64 16.0\n",
- " * t_source (t_source) datetime64[ns] 2013-11-11 ... 2013-11-16T16:30:00\n",
- "Data variables:\n",
- " y_past (t_source, t_behind) float32 16.394 13.901 ... 15.281 16.59\n",
- " nll (t_source) float32 0.1073612 0.22418153 ... 0.93859696\n",
- " y_pred (t_source) float32 24.415854 13.307469 ... 16.607029\n",
- " y_pred_std (t_source) float64 2.836 1.394 2.472 ... 2.501 2.243 1.79\n",
- " y_true (t_source) float32 24.737999 15.107 ... 23.396 19.491001\n",
- "\n",
- "Dimensions: (t_behind: 96, t_source: 274)\n",
- "Coordinates:\n",
- " t_ahead timedelta64[ns] 17:00:00\n",
- " * t_behind (t_behind) timedelta64[ns] -2 days +00:00:00 ... -1 days +...\n",
- " t_ahead_hours float64 17.0\n",
- " * t_source (t_source) datetime64[ns] 2013-11-11 ... 2013-11-16T16:30:00\n",
- "Data variables:\n",
- " y_past (t_source, t_behind) float32 16.394 13.901 ... 15.281 16.59\n",
- " nll (t_source) float32 0.22898376 -0.8886603 ... -0.15025973\n",
- " y_pred (t_source) float32 27.6007 11.929974 ... 16.62101 15.153604\n",
- " y_pred_std (t_source) float64 3.214 1.054 2.7 3.47 ... 2.263 1.795 1.606\n",
- " y_true (t_source) float32 27.36 11.9279995 ... 19.491001 16.433\n",
- "\n",
- "Dimensions: (t_behind: 96, t_source: 274)\n",
- "Coordinates:\n",
- " t_ahead timedelta64[ns] 07:30:00\n",
- " * t_behind (t_behind) timedelta64[ns] -2 days +00:00:00 ... -1 days +...\n",
- " t_ahead_hours float64 7.0\n",
- " * t_source (t_source) datetime64[ns] 2013-11-11 ... 2013-11-16T16:30:00\n",
- "Data variables:\n",
- " y_past (t_source, t_behind) float32 16.394 13.901 ... 15.281 16.59\n",
- " nll (t_source) float32 0.047929287 1.9480975 ... 0.35706586\n",
- " y_pred (t_source) float32 23.83615 22.037819 ... 21.589767 21.60583\n",
- " y_pred_std (t_source) float64 2.305 2.585 2.597 ... 2.898 2.837 2.903\n",
- " y_true (t_source) float32 22.556 16.946 26.053 ... 23.376001 19.626"
- ]
- },
- "metadata": {},
- "output_type": "display_data"
- },
- {
- "data": {
- "application/vnd.holoviews_exec.v0+json": "",
- "text/html": [
- "\n",
- "\n",
- "\n",
- "\n",
- "\n",
- "\n",
- "
\n",
- "
\n",
- ""
- ],
- "text/plain": [
- ":DynamicMap [it_ahead]\n",
- " :Overlay\n",
- " .Scatter.True :Scatter [x] (y)\n",
- " .Curve.Pred :Curve [x] (y)\n",
- " .Area.A_2_times_std :Area [x] (y,y2)"
- ]
- },
- "execution_count": 45,
- "metadata": {
- "application/vnd.holoviews_exec.v0+json": {
- "id": "7323"
- }
- },
- "output_type": "execute_result"
- }
- ],
+ "outputs": [],
"source": [
- "def plot_predictions_vs_time(it_ahead):\n",
- " \"\"\"Plot predictions vs time with holoviews\"\"\"\n",
- "\n",
- " d = ds_preds.isel(t_ahead=it_ahead).groupby('t_source').first()\n",
- " print(d)\n",
- "\n",
- " p = hv.Scatter({\n",
- " 'x': d.t_source,\n",
- " 'y': d.y_true\n",
- " }, label='true').opts(color='black')\n",
- "\n",
- " # Get arrays\n",
- " xf = d.t_source.values\n",
- " yp = d.y_pred\n",
- " s = d.y_pred_std\n",
- " p *= hv.Curve({\n",
- " 'x': xf,\n",
- " 'y': yp\n",
- " }, label='pred').opts(color='blue')\n",
- " p *= hv.Area((xf, yp - 2 * s, yp + 2 * s),\n",
- " vdims=['y', 'y2'],\n",
- " label='2*std').opts(alpha=0.5, line_width=0)\n",
- "\n",
- "\n",
- " return p.opts(title=f'Prediction at {it_ahead * pd.Timedelta(freq)} ahead. NLL={d.nll.mean().item():2.2f}')\n",
- "\n",
- "\n",
- "dmap_preds = (hv.DynamicMap(plot_predictions_vs_time, kdims=['it_ahead'])\n",
- " .redim.values(it_ahead=range(ds_preds.t_ahead.shape[0]))\n",
- " .opts(width=800,\n",
- " height=300, \n",
- " ))\n",
- "dmap_preds\n",
- "# TODO fixme"
- ]
- },
- {
- "cell_type": "code",
- "execution_count": 44,
- "metadata": {
- "ExecuteTime": {
- "end_time": "2020-10-19T10:57:06.875280Z",
- "start_time": "2020-10-19T10:57:06.661607Z"
- }
- },
- "outputs": [
- {
- "name": "stderr",
- "output_type": "stream",
- "text": [
- "/home/wassname/anaconda/envs/seq2seq-time/lib/python3.7/site-packages/ipykernel/ipkernel.py:287: DeprecationWarning: `should_run_async` will not call `transform_cell` automatically in the future. Please pass the result to `transformed_cell` argument and any exception that happen during thetransform in `preprocessing_exc_tuple` in IPython 7.17 and above.\n",
- " and should_run_async(code)\n"
- ]
- },
- {
- "data": {},
- "metadata": {},
- "output_type": "display_data"
- },
- {
- "data": {
- "application/vnd.holoviews_exec.v0+json": "",
- "text/html": [
- "\n",
- "\n",
- "\n",
- "\n",
- "\n",
- "\n",
- "
\n",
- "
\n",
- ""
- ],
- "text/plain": [
- ":Curve [hours ahead] (nll)"
- ]
- },
- "execution_count": 44,
- "metadata": {
- "application/vnd.holoviews_exec.v0+json": {
- "id": "7143"
- }
- },
- "output_type": "execute_result"
- }
- ],
- "source": [
- "d = ds_preds.mean('t_source')['nll'].groupby('t_ahead_hours').mean()\n",
+ "d = ds_preds.mean(['t_source', 'block'])['nll'].groupby('t_ahead_hours').mean()\n",
"nll_vs_tahead = hv.Curve((d.t_ahead_hours, d)).redim(x='hours ahead', y='nll').opts(width=800)\n",
"nll_vs_tahead"
]
},
{
"cell_type": "code",
- "execution_count": 38,
+ "execution_count": null,
"metadata": {
"ExecuteTime": {
- "end_time": "2020-10-19T10:53:35.125269Z",
- "start_time": "2020-10-19T10:53:35.053232Z"
+ "end_time": "2020-10-19T13:53:20.564734Z",
+ "start_time": "2020-10-19T13:24:29.300Z"
}
},
- "outputs": [
- {
- "name": "stderr",
- "output_type": "stream",
- "text": [
- "/home/wassname/anaconda/envs/seq2seq-time/lib/python3.7/site-packages/ipykernel/ipkernel.py:287: DeprecationWarning: `should_run_async` will not call `transform_cell` automatically in the future. Please pass the result to `transformed_cell` argument and any exception that happen during thetransform in `preprocessing_exc_tuple` in IPython 7.17 and above.\n",
- " and should_run_async(code)\n"
- ]
- }
- ],
+ "outputs": [],
"source": [
- "# d = ds_preds.mean('t_ahead')['nll'].groupby('t_source').mean()\n",
+ "# def plot_predictions_vs_time(it_ahead):\n",
+ "# \"\"\"Plot predictions vs time with holoviews\"\"\"\n",
+ "\n",
+ "# d = ds_pred_block.isel(t_ahead=it_ahead).groupby('t_source').first()\n",
+ "# # print(d)\n",
+ "\n",
+ "# p = hv.Scatter({\n",
+ "# 'x': d.t_source,\n",
+ "# 'y': d.y_true\n",
+ "# }, label='true').opts(color='black')\n",
+ "\n",
+ "# # Get arrays\n",
+ "# xf = d.t_source.values\n",
+ "# yp = d.y_pred\n",
+ "# s = d.y_pred_std\n",
+ "# p *= hv.Curve({\n",
+ "# 'x': xf,\n",
+ "# 'y': yp\n",
+ "# }, label='pred').opts(color='blue')\n",
+ "# p *= hv.Area((xf, yp - 2 * s, yp + 2 * s),\n",
+ "# vdims=['y', 'y2'],\n",
+ "# label='2*std').opts(alpha=0.5, line_width=0)\n",
+ "\n",
+ "\n",
+ "# return p.opts(title=f'Prediction at {it_ahead * pd.Timedelta(freq)} ahead. NLL={d.nll.mean().item():2.2f}')\n",
+ "\n",
+ "\n",
+ "# dmap_preds = (hv.DynamicMap(plot_predictions_vs_time, kdims=['it_ahead'])\n",
+ "# .redim.values(it_ahead=range(ds_pred_block.t_ahead.shape[0]))\n",
+ "# .opts(width=800,\n",
+ "# height=300, \n",
+ "# ))\n",
+ "# dmap_preds\n",
+ "# # TODO fixme"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": null,
+ "metadata": {
+ "ExecuteTime": {
+ "end_time": "2020-10-19T13:15:31.423090Z",
+ "start_time": "2020-10-19T13:15:31.260955Z"
+ }
+ },
+ "outputs": [],
+ "source": []
+ },
+ {
+ "cell_type": "code",
+ "execution_count": null,
+ "metadata": {
+ "ExecuteTime": {
+ "end_time": "2020-10-19T13:53:20.568851Z",
+ "start_time": "2020-10-19T13:24:29.300Z"
+ }
+ },
+ "outputs": [],
+ "source": [
+ "# d = ds_preds.mean(['t_ahead', 'block'])['nll'].groupby('t_source').mean()\n",
"# nll_vs_time = hv.Curve(d).opts(width=800)\n",
"# nll_vs_time"
]
},
{
"cell_type": "code",
- "execution_count": 39,
+ "execution_count": null,
"metadata": {
"ExecuteTime": {
- "end_time": "2020-10-19T10:53:37.325570Z",
- "start_time": "2020-10-19T10:53:37.258789Z"
+ "end_time": "2020-10-19T13:53:20.572004Z",
+ "start_time": "2020-10-19T13:24:29.300Z"
},
"scrolled": true
},
@@ -5878,23 +5540,14 @@
},
{
"cell_type": "code",
- "execution_count": 37,
+ "execution_count": null,
"metadata": {
"ExecuteTime": {
- "end_time": "2020-10-19T10:47:51.248969Z",
- "start_time": "2020-10-19T10:47:51.198312Z"
+ "end_time": "2020-10-19T13:53:20.574629Z",
+ "start_time": "2020-10-19T13:24:29.300Z"
}
},
- "outputs": [
- {
- "name": "stderr",
- "output_type": "stream",
- "text": [
- "/home/wassname/anaconda/envs/seq2seq-time/lib/python3.7/site-packages/ipykernel/ipkernel.py:287: DeprecationWarning: `should_run_async` will not call `transform_cell` automatically in the future. Please pass the result to `transformed_cell` argument and any exception that happen during thetransform in `preprocessing_exc_tuple` in IPython 7.17 and above.\n",
- " and should_run_async(code)\n"
- ]
- }
- ],
+ "outputs": [],
"source": [
"\n",
"# # Run learning rate finder\n",
diff --git a/notebooks/02.0-mike-RNN_Timeseries_Seq2Seq.py b/notebooks/02.0-mike-RNN_Timeseries_Seq2Seq.py
index 0612317..5194d09 100644
--- a/notebooks/02.0-mike-RNN_Timeseries_Seq2Seq.py
+++ b/notebooks/02.0-mike-RNN_Timeseries_Seq2Seq.py
@@ -369,9 +369,7 @@ def plot_prediction(ds_preds, i):
def plot_performance(ds_preds, full=False):
"""Multiple plots using xr_preds"""
- print(f'mean_NLL {ds_preds.nll.mean().item():2.2f}')
plot_prediction(ds_preds, 24)
-# plot_prediction(ds_preds, 480)
ds_preds.mean('t_source').plot.scatter('t_ahead_hours', 'nll') # Mean over all predictions
n = len(ds_preds.t_source)
@@ -481,7 +479,7 @@ from seq2seq_time.models.baseline import BaselineLast
from seq2seq_time.models.transformer import Transformer
from seq2seq_time.models.transformer_seq2seq import TransformerSeq2Seq
from seq2seq_time.models.transformer_seq import TransformerSeq
-from seq2seq_time.models.anp import RANP
+from seq2seq_time.models.neural_process import RANP
# ## Plots
# +
models = [
@@ -529,7 +527,11 @@ trainer = pl.Trainer(gpus=1,
)
trainer.fit(model, dl_train, dl_test)
print(plot_hist(trainer))
-ds_preds = predict(model.to(device), ds_test.datasets[0], batch_size, device=device, scaler=output_scaler)
+ds_predss = predict_multi(model.to(device),
+ ds_test.datasets,
+ batch_size*8,
+ device=device,
+ scaler=output_scaler)
print(f'baseline nll: {ds_preds.nll.mean().item():2.2g}')
for pt_model in models:
@@ -560,46 +562,35 @@ for pt_model in models:
- ds_preds = predict(model.to(device),
- ds_test.datasets[0],
- batch_size,
+ ds_predss = predict_multi(model.to(device),
+ ds_test.datasets,
+ batch_size*8,
device=device,
scaler=output_scaler)
print(name)
- print(f'mean_NLL {ds_preds.nll.mean().item():2.2f}')
+ print(f'mean_NLL {ds_predss.nll.mean().item():2.2f}')
# Performance
+ ds_preds = ds_predss.isel(block=0)
print(plot_hist(trainer))
plot_performance(ds_preds)
-# %debug
-
-ds_preds = predict(model.to(device),q
-
- ds_test.datasets[0],
- batch_size,
- device=device,
- scaler=output_scaler)
-
# +
-# ds_predss = predict_multi(model.to(device),
-# ds_test.datasets,
-# batch_size,
+# ds_preds = predict(model.to(device),
+# ds_test.datasets[0],
+# batch_size*8,
# device=device,
# scaler=output_scaler)
# -
-ds_test.datasets[0].df.index.value_counts()
+ds_predss = predict_multi(model.to(device),
+ ds_test.datasets,
+ batch_size*8,
+ device=device,
+ scaler=output_scaler)
-# TODO why dup?
-ds_preds.sel(t_source='2013-11-11 00:30:00')
-
-# TODO why duplicates?
-d = ds_preds.isel(t_ahead=0)
-d.t_source.to_series().sort_index()#.value_counts()
-# np.unique
-# d
+ds_pred_block = ds_predss.isel(block=1)
# # holoviews pred
@@ -613,9 +604,9 @@ def plot_prediction_now(t_source):
# Let us pass in an int
if isinstance(t_source, int):
- t_source = ds_preds.t_source[t_source].to_pandas()
+ t_source = ds_pred_block.t_source[t_source].to_pandas()
- d = ds_preds.sel(t_source=t_source)
+ d = ds_pred_block.sel(t_source=t_source)
# Sometimes there are duplicate times, take the first
if len(d.t_source.shape) and d.t_source.shape[0] > 0:
@@ -651,56 +642,58 @@ def plot_prediction_now(t_source):
dmap_pred = (hv.DynamicMap(plot_prediction_now, kdims=['t_source'])
- .redim.values(t_source=ds_preds.t_source.to_pandas())
+ .redim.values(t_source=ds_pred_block.t_source.to_pandas())
.opts(width=800,
height=300,
))
dmap_pred
-
-
-# +
-def plot_predictions_vs_time(it_ahead):
- """Plot predictions vs time with holoviews"""
-
- d = ds_preds.isel(t_ahead=it_ahead).groupby('t_source').first()
- print(d)
-
- p = hv.Scatter({
- 'x': d.t_source,
- 'y': d.y_true
- }, label='true').opts(color='black')
-
- # Get arrays
- xf = d.t_source.values
- yp = d.y_pred
- s = d.y_pred_std
- p *= hv.Curve({
- 'x': xf,
- 'y': yp
- }, label='pred').opts(color='blue')
- p *= hv.Area((xf, yp - 2 * s, yp + 2 * s),
- vdims=['y', 'y2'],
- label='2*std').opts(alpha=0.5, line_width=0)
-
-
- return p.opts(title=f'Prediction at {it_ahead * pd.Timedelta(freq)} ahead. NLL={d.nll.mean().item():2.2f}')
-
-
-dmap_preds = (hv.DynamicMap(plot_predictions_vs_time, kdims=['it_ahead'])
- .redim.values(it_ahead=range(ds_preds.t_ahead.shape[0]))
- .opts(width=800,
- height=300,
- ))
-dmap_preds
-# TODO fixme
# -
-d = ds_preds.mean('t_source')['nll'].groupby('t_ahead_hours').mean()
+d = ds_preds.mean(['t_source', 'block'])['nll'].groupby('t_ahead_hours').mean()
nll_vs_tahead = hv.Curve((d.t_ahead_hours, d)).redim(x='hours ahead', y='nll').opts(width=800)
nll_vs_tahead
# +
-# d = ds_preds.mean('t_ahead')['nll'].groupby('t_source').mean()
+# def plot_predictions_vs_time(it_ahead):
+# """Plot predictions vs time with holoviews"""
+
+# d = ds_pred_block.isel(t_ahead=it_ahead).groupby('t_source').first()
+# # print(d)
+
+# p = hv.Scatter({
+# 'x': d.t_source,
+# 'y': d.y_true
+# }, label='true').opts(color='black')
+
+# # Get arrays
+# xf = d.t_source.values
+# yp = d.y_pred
+# s = d.y_pred_std
+# p *= hv.Curve({
+# 'x': xf,
+# 'y': yp
+# }, label='pred').opts(color='blue')
+# p *= hv.Area((xf, yp - 2 * s, yp + 2 * s),
+# vdims=['y', 'y2'],
+# label='2*std').opts(alpha=0.5, line_width=0)
+
+
+# return p.opts(title=f'Prediction at {it_ahead * pd.Timedelta(freq)} ahead. NLL={d.nll.mean().item():2.2f}')
+
+
+# dmap_preds = (hv.DynamicMap(plot_predictions_vs_time, kdims=['it_ahead'])
+# .redim.values(it_ahead=range(ds_pred_block.t_ahead.shape[0]))
+# .opts(width=800,
+# height=300,
+# ))
+# dmap_preds
+# # TODO fixme
+# -
+
+
+
+# +
+# d = ds_preds.mean(['t_ahead', 'block'])['nll'].groupby('t_source').mean()
# nll_vs_time = hv.Curve(d).opts(width=800)
# nll_vs_time
diff --git a/seq2seq_time/models/lstm_seq2seq.py b/seq2seq_time/models/lstm_seq2seq.py
index 893ff08..8a837e0 100644
--- a/seq2seq_time/models/lstm_seq2seq.py
+++ b/seq2seq_time/models/lstm_seq2seq.py
@@ -36,4 +36,4 @@ class LSTMSeq2Seq(nn.Module):
log_sigma = self.std(outputs)
sigma = self._min_std + (1 - self._min_std) * F.softplus(log_sigma)
y_dist = torch.distributions.Normal(mean, sigma)
- return y_dist
+ return y_dist, {}
diff --git a/seq2seq_time/models/anp.py b/seq2seq_time/models/neural_process.py
similarity index 97%
rename from seq2seq_time/models/anp.py
rename to seq2seq_time/models/neural_process.py
index 526d7e4..d81efdb 100644
--- a/seq2seq_time/models/anp.py
+++ b/seq2seq_time/models/neural_process.py
@@ -7,6 +7,7 @@ import math
class LSTMBlock(nn.Module):
+ """Wrapper to return only lstm output."""
def __init__(
self,
in_channels,
@@ -437,14 +438,21 @@ class RANP(nn.Module):
if self._use_rnn:
# see https://arxiv.org/abs/1910.09323 where x is substituted with h = RNN(x)
# x need to be provided as [B, T, H]
- future_x, _ = self._lstm(future_x)
- past_x, _ = self._lstm(past_x)
+ S = past_x.shape[1]
+ x = torch.cat([past_x, future_x], 1)
+ x, _ = self._lstm(x)
+ past_x = x[:, :S]
+ future_x = x[:, S:]
+ # future_x, _ = self._lstm(future_x)
+ # past_x, _ = self._lstm(past_x)
dist_prior, log_var_prior = self._latent_encoder(past_x, past_y)
- if future_y is not None:
+ if (future_y is not None):
dist_post, log_var_post = self._latent_encoder(future_x, future_y)
- z = dist_post.loc
+
+ if self.training:
+ z = dist_prior.rsample()
else:
z = dist_prior.loc
@@ -471,3 +479,4 @@ class RANP(nn.Module):
].mean()
loss = (kl_loss - log_p).mean()
return dist, {'loss':loss}
+
diff --git a/seq2seq_time/predict.py b/seq2seq_time/predict.py
index dbe979e..41b796c 100644
--- a/seq2seq_time/predict.py
+++ b/seq2seq_time/predict.py
@@ -35,9 +35,10 @@ def predict(model, ds_test, batch_size, device='cpu', scaler=None):
# Make an xarray.Dataset for the data
bs = y_future.shape[0]
- t_source = ds_test.df.index[i:i+bs].values
- t_ahead = pd.timedelta_range(0, periods=ds_test.window_future, freq=freq).values
- t_behind = pd.timedelta_range(end=-pd.Timedelta(freq), periods=ds_test.window_past, freq=freq)
+ wp = ds_test.window_past
+ t_source = ds_test.df.index[wp + i*bs -1:wp+ i*bs+bs -1].values
+ t_ahead = pd.timedelta_range(1, periods=ds_test.window_future, freq=freq).values
+ t_behind = pd.timedelta_range(end=0, periods=ds_test.window_past, freq=freq)
xr_out = xr.Dataset(
{
# Format> name: ([dimensions,...], array),
@@ -77,5 +78,5 @@ def predict_multi(model, datasets, batch_size, device='cpu', scaler=None):
d,
batch_size,
device=device,
- scaler=output_scaler) for d in tqdm(datasets)]
+ scaler=scaler) for d in tqdm(datasets, desc='predict_multi')]
return xr.concat(ds_preds, dim='block')