diff --git a/notebooks/02.0-mike-RNN_Timeseries_Seq2Seq.ipynb b/notebooks/02.0-mike-RNN_Timeseries_Seq2Seq.ipynb
deleted file mode 100644
index b03d877..0000000
--- a/notebooks/02.0-mike-RNN_Timeseries_Seq2Seq.ipynb
+++ /dev/null
@@ -1,5618 +0,0 @@
-{
- "cells": [
- {
- "cell_type": "markdown",
- "metadata": {
- "ExecuteTime": {
- "end_time": "2020-10-10T01:25:12.788851Z",
- "start_time": "2020-10-10T01:25:12.783398Z"
- }
- },
- "source": [
- "# Sequence to Sequence Models for Timeseries Regression\n",
- "\n",
- "\n",
- "In this notebook we are going to tackle a harder problem: \n",
- "- predicting the future on a timeseries\n",
- "- using an LSTM\n",
- "- with rough uncertainty (uncalibrated)\n",
- "- outputing sequence of predictions\n",
- "\n",
- "\n",
- "\n",
- "\n",
- "https://medium.com/@boitemailjeanmid/smart-meters-in-london-part1-description-and-first-insights-jean-michel-d-db97af2de71b\n"
- ]
- },
- {
- "cell_type": "code",
- "execution_count": 1,
- "metadata": {
- "ExecuteTime": {
- "end_time": "2020-10-19T13:24:29.474642Z",
- "start_time": "2020-10-19T13:24:29.025860Z"
- }
- },
- "outputs": [],
- "source": [
- "# OPTIONAL: Load the \"autoreload\" extension so that code can change. But blacklist large modules\n",
- "%load_ext autoreload\n",
- "%autoreload 2\n",
- "%aimport -pandas\n",
- "%aimport -torch\n",
- "%aimport -numpy\n",
- "%aimport -matplotlib\n",
- "%aimport -dask\n",
- "%aimport -tqdm\n",
- "%matplotlib inline"
- ]
- },
- {
- "cell_type": "code",
- "execution_count": 2,
- "metadata": {
- "ExecuteTime": {
- "end_time": "2020-10-19T13:24:30.583721Z",
- "start_time": "2020-10-19T13:24:29.478450Z"
- }
- },
- "outputs": [],
- "source": [
- "# Imports\n",
- "import torch\n",
- "from torch import nn, optim\n",
- "from torch.nn import functional as F\n",
- "from torch.autograd import Variable\n",
- "import torch\n",
- "import torch.utils.data\n",
- "\n",
- "import pandas as pd\n",
- "import numpy as np\n",
- "import matplotlib.pyplot as plt\n",
- "plt.rcParams['figure.figsize'] = (12.0, 3.0)\n",
- "plt.style.use('ggplot')\n",
- "\n",
- "from pathlib import Path\n",
- "from tqdm.auto import tqdm\n",
- "\n",
- "import pytorch_lightning as pl"
- ]
- },
- {
- "cell_type": "code",
- "execution_count": 3,
- "metadata": {
- "ExecuteTime": {
- "end_time": "2020-10-19T13:24:30.616617Z",
- "start_time": "2020-10-19T13:24:30.588643Z"
- }
- },
- "outputs": [],
- "source": [
- "import warnings\n",
- "warnings.simplefilter('once')"
- ]
- },
- {
- "cell_type": "code",
- "execution_count": 4,
- "metadata": {
- "ExecuteTime": {
- "end_time": "2020-10-19T13:24:31.183770Z",
- "start_time": "2020-10-19T13:24:30.622322Z"
- }
- },
- "outputs": [
- {
- "name": "stderr",
- "output_type": "stream",
- "text": [
- "/home/wassname/anaconda/envs/seq2seq-time/lib/python3.7/site-packages/ipykernel/ipkernel.py:287: DeprecationWarning: `should_run_async` will not call `transform_cell` automatically in the future. Please pass the result to `transformed_cell` argument and any exception that happen during thetransform in `preprocessing_exc_tuple` in IPython 7.17 and above.\n",
- " and should_run_async(code)\n"
- ]
- }
- ],
- "source": [
- "from seq2seq_time.data.dataset import Seq2SeqDataSet, Seq2SeqDataSets\n",
- "from seq2seq_time.predict import predict, predict_multi"
- ]
- },
- {
- "cell_type": "code",
- "execution_count": 5,
- "metadata": {
- "ExecuteTime": {
- "end_time": "2020-10-19T13:24:31.218369Z",
- "start_time": "2020-10-19T13:24:31.187185Z"
- }
- },
- "outputs": [
- {
- "name": "stderr",
- "output_type": "stream",
- "text": [
- "/home/wassname/anaconda/envs/seq2seq-time/lib/python3.7/site-packages/ipykernel/ipkernel.py:287: DeprecationWarning: `should_run_async` will not call `transform_cell` automatically in the future. Please pass the result to `transformed_cell` argument and any exception that happen during thetransform in `preprocessing_exc_tuple` in IPython 7.17 and above.\n",
- " and should_run_async(code)\n"
- ]
- }
- ],
- "source": [
- "import logging, sys\n",
- "# logging.basicConfig(stream=sys.stdout, level=logging.INFO)"
- ]
- },
- {
- "cell_type": "markdown",
- "metadata": {
- "ExecuteTime": {
- "end_time": "2020-10-10T01:28:32.492160Z",
- "start_time": "2020-10-10T01:28:32.488140Z"
- }
- },
- "source": [
- "## Parameters"
- ]
- },
- {
- "cell_type": "code",
- "execution_count": 6,
- "metadata": {
- "ExecuteTime": {
- "end_time": "2020-10-19T13:24:31.286536Z",
- "start_time": "2020-10-19T13:24:31.222958Z"
- }
- },
- "outputs": [
- {
- "name": "stdout",
- "output_type": "stream",
- "text": [
- "using cuda\n"
- ]
- }
- ],
- "source": [
- "device = \"cuda\" if torch.cuda.is_available() else \"cpu\"\n",
- "print(f'using {device}')\n",
- "\n",
- "columns_target=['energy(kWh/hh)']\n",
- "window_past = 48*2\n",
- "window_future = 48*2\n",
- "batch_size = 256\n",
- "num_workers = 5\n",
- "freq = '30T'\n",
- "max_rows = 5e5"
- ]
- },
- {
- "cell_type": "markdown",
- "metadata": {},
- "source": [
- "## Load data"
- ]
- },
- {
- "cell_type": "code",
- "execution_count": 7,
- "metadata": {
- "ExecuteTime": {
- "end_time": "2020-10-19T13:24:31.338216Z",
- "start_time": "2020-10-19T13:24:31.290344Z"
- },
- "lines_to_next_cell": 0
- },
- "outputs": [],
- "source": [
- "\n",
- "def get_smartmeter_df(indir=Path('../data/raw/smart-meters-in-london'), max_files=8):\n",
- " \"\"\"\n",
- " Data loading and cleanding is always messy, so understand this code is optional.\n",
- " \"\"\"\n",
- " \n",
- " # Load csv files\n",
- " csv_files = sorted((indir/'halfhourly_dataset').glob('*.csv'))[:max_files]\n",
- " \n",
- " dfs = []\n",
- " for f in csv_files:\n",
- " df = (pd.read_csv(f, parse_dates=[1], na_values=['Null'])\n",
- " .groupby('tstp')\n",
- " .sum()\n",
- " .sort_index()\n",
- " )\n",
- " df['block'] = f.stem\n",
- "\n",
- " # Drop nan and 0's\n",
- " df = df[df['energy(kWh/hh)']!=0]\n",
- " df = df.dropna()\n",
- " \n",
- " # Add time features \n",
- " time = df.index.to_series()\n",
- " df[\"month\"] = time.dt.month\n",
- " df['day'] = time.dt.day\n",
- " df['week'] = time.dt.week\n",
- " df['hour'] = time.dt.hour\n",
- " df['minute'] = time.dt.minute\n",
- " df['dayofweek'] = time.dt.dayofweek\n",
- "\n",
- " # Load weather data\n",
- " df_weather = pd.read_csv(indir/'weather_hourly_darksky.csv', parse_dates=[3])\n",
- " use_cols = ['visibility', 'windBearing', 'temperature', 'time', 'dewPoint',\n",
- " 'pressure', 'apparentTemperature', 'windSpeed', \n",
- " 'humidity']\n",
- " df_weather = df_weather[use_cols].set_index('time')\n",
- " \n",
- " # Resample to match energy data \n",
- " # Use first, since we have bearing, and you can't take mean\n",
- " df_weather = df_weather.resample(freq).first().ffill() \n",
- "\n",
- " # Join weather and energy data\n",
- " df = pd.merge(df, df_weather, how='inner', left_index=True, right_index=True, sort=True)\n",
- "\n",
- " # Holidays\n",
- " df_hols = pd.read_csv(indir/'uk_bank_holidays.csv', parse_dates=[0])\n",
- " holidays = set(df_hols['Bank holidays'].dt.round('D')) \n",
- " def is_holiday(dt):\n",
- " return dt in holidays\n",
- " days = df.index.floor('D')\n",
- " holiday_mapping = days.unique().to_series().apply(is_holiday).astype(int).to_dict()\n",
- " df['holiday'] = days.to_series().map(holiday_mapping).values\n",
- "\n",
- " # sort\n",
- " df.index.name = 'Date'\n",
- " df = df.loc['2012-09':] # Weird value before this\n",
- " \n",
- " dfs.append(df)\n",
- " \n",
- " return pd.concat(dfs)"
- ]
- },
- {
- "cell_type": "code",
- "execution_count": null,
- "metadata": {
- "ExecuteTime": {
- "end_time": "2020-10-19T06:24:38.318999Z",
- "start_time": "2020-10-19T06:24:35.452722Z"
- }
- },
- "outputs": [],
- "source": []
- },
- {
- "cell_type": "markdown",
- "metadata": {},
- "source": [
- "Our dataset is the london smartmeter data. But at half hour intervals"
- ]
- },
- {
- "cell_type": "code",
- "execution_count": 8,
- "metadata": {
- "ExecuteTime": {
- "end_time": "2020-10-19T13:24:49.503564Z",
- "start_time": "2020-10-19T13:24:31.344006Z"
- },
- "lines_to_next_cell": 0,
- "scrolled": true
- },
- "outputs": [
- {
- "name": "stderr",
- "output_type": "stream",
- "text": [
- "/home/wassname/anaconda/envs/seq2seq-time/lib/python3.7/site-packages/ipykernel_launcher.py:26: FutureWarning: Series.dt.weekofyear and Series.dt.week have been deprecated. Please use Series.dt.isocalendar().week instead.\n",
- "/home/wassname/anaconda/envs/seq2seq-time/lib/python3.7/site-packages/ipykernel_launcher.py:26: FutureWarning: Series.dt.weekofyear and Series.dt.week have been deprecated. Please use Series.dt.isocalendar().week instead.\n",
- "/home/wassname/anaconda/envs/seq2seq-time/lib/python3.7/site-packages/ipykernel_launcher.py:26: FutureWarning: Series.dt.weekofyear and Series.dt.week have been deprecated. Please use Series.dt.isocalendar().week instead.\n",
- "/home/wassname/anaconda/envs/seq2seq-time/lib/python3.7/site-packages/ipykernel_launcher.py:26: FutureWarning: Series.dt.weekofyear and Series.dt.week have been deprecated. Please use Series.dt.isocalendar().week instead.\n",
- "/home/wassname/anaconda/envs/seq2seq-time/lib/python3.7/site-packages/ipykernel_launcher.py:26: FutureWarning: Series.dt.weekofyear and Series.dt.week have been deprecated. Please use Series.dt.isocalendar().week instead.\n",
- "/home/wassname/anaconda/envs/seq2seq-time/lib/python3.7/site-packages/ipykernel_launcher.py:26: FutureWarning: Series.dt.weekofyear and Series.dt.week have been deprecated. Please use Series.dt.isocalendar().week instead.\n",
- "/home/wassname/anaconda/envs/seq2seq-time/lib/python3.7/site-packages/ipykernel_launcher.py:26: FutureWarning: Series.dt.weekofyear and Series.dt.week have been deprecated. Please use Series.dt.isocalendar().week instead.\n",
- "/home/wassname/anaconda/envs/seq2seq-time/lib/python3.7/site-packages/ipykernel_launcher.py:26: FutureWarning: Series.dt.weekofyear and Series.dt.week have been deprecated. Please use Series.dt.isocalendar().week instead.\n",
- "/home/wassname/anaconda/envs/seq2seq-time/lib/python3.7/site-packages/ipykernel_launcher.py:26: FutureWarning: Series.dt.weekofyear and Series.dt.week have been deprecated. Please use Series.dt.isocalendar().week instead.\n",
- "/home/wassname/anaconda/envs/seq2seq-time/lib/python3.7/site-packages/ipykernel_launcher.py:26: FutureWarning: Series.dt.weekofyear and Series.dt.week have been deprecated. Please use Series.dt.isocalendar().week instead.\n",
- "/home/wassname/anaconda/envs/seq2seq-time/lib/python3.7/site-packages/ipykernel_launcher.py:26: FutureWarning: Series.dt.weekofyear and Series.dt.week have been deprecated. Please use Series.dt.isocalendar().week instead.\n",
- "/home/wassname/anaconda/envs/seq2seq-time/lib/python3.7/site-packages/ipykernel_launcher.py:26: FutureWarning: Series.dt.weekofyear and Series.dt.week have been deprecated. Please use Series.dt.isocalendar().week instead.\n"
- ]
- },
- {
- "name": "stdout",
- "output_type": "stream",
- "text": [
- "block_107 26161\n",
- "block_100 26161\n",
- "block_10 26161\n",
- "block_1 26161\n",
- "block_105 26161\n",
- "block_0 26161\n",
- "block_102 26161\n",
- "block_103 26161\n",
- "block_108 26161\n",
- "block_106 26161\n",
- "block_101 26161\n",
- "block_104 26161\n",
- "Name: block, dtype: int64\n"
- ]
- },
- {
- "data": {
- "text/html": [
- "
| \n", - " | energy(kWh/hh) | \n", - "block | \n", - "month | \n", - "day | \n", - "week | \n", - "hour | \n", - "minute | \n", - "dayofweek | \n", - "visibility | \n", - "windBearing | \n", - "temperature | \n", - "dewPoint | \n", - "pressure | \n", - "apparentTemperature | \n", - "windSpeed | \n", - "humidity | \n", - "holiday | \n", - "
|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|
| Date | \n", - "\n", - " | \n", - " | \n", - " | \n", - " | \n", - " | \n", - " | \n", - " | \n", - " | \n", - " | \n", - " | \n", - " | \n", - " | \n", - " | \n", - " | \n", - " | \n", - " | \n", - " |
| 2012-09-01 00:00:00 | \n", - "5.013 | \n", - "block_0 | \n", - "9 | \n", - "1 | \n", - "35 | \n", - "0 | \n", - "0 | \n", - "5 | \n", - "13.36 | \n", - "302.0 | \n", - "14.08 | \n", - "9.74 | \n", - "1028.27 | \n", - "14.08 | \n", - "1.89 | \n", - "0.75 | \n", - "0 | \n", - "
| 2012-09-01 00:30:00 | \n", - "5.157 | \n", - "block_0 | \n", - "9 | \n", - "1 | \n", - "35 | \n", - "0 | \n", - "30 | \n", - "5 | \n", - "13.36 | \n", - "302.0 | \n", - "14.08 | \n", - "9.74 | \n", - "1028.27 | \n", - "14.08 | \n", - "1.89 | \n", - "0.75 | \n", - "0 | \n", - "
| 2012-09-01 01:00:00 | \n", - "6.360 | \n", - "block_0 | \n", - "9 | \n", - "1 | \n", - "35 | \n", - "1 | \n", - "0 | \n", - "5 | \n", - "13.50 | \n", - "298.0 | \n", - "13.93 | \n", - "9.81 | \n", - "1027.96 | \n", - "13.93 | \n", - "1.59 | \n", - "0.76 | \n", - "0 | \n", - "
| 2012-09-01 01:30:00 | \n", - "5.511 | \n", - "block_0 | \n", - "9 | \n", - "1 | \n", - "35 | \n", - "1 | \n", - "30 | \n", - "5 | \n", - "13.50 | \n", - "298.0 | \n", - "13.93 | \n", - "9.81 | \n", - "1027.96 | \n", - "13.93 | \n", - "1.59 | \n", - "0.76 | \n", - "0 | \n", - "
| 2012-09-01 02:00:00 | \n", - "4.922 | \n", - "block_0 | \n", - "9 | \n", - "1 | \n", - "35 | \n", - "2 | \n", - "0 | \n", - "5 | \n", - "13.21 | \n", - "274.0 | \n", - "13.52 | \n", - "9.94 | \n", - "1028.04 | \n", - "13.52 | \n", - "0.82 | \n", - "0.79 | \n", - "0 | \n", - "
| ... | \n", - "... | \n", - "... | \n", - "... | \n", - "... | \n", - "... | \n", - "... | \n", - "... | \n", - "... | \n", - "... | \n", - "... | \n", - "... | \n", - "... | \n", - "... | \n", - "... | \n", - "... | \n", - "... | \n", - "... | \n", - "
| 2014-02-27 22:00:00 | \n", - "9.819 | \n", - "block_108 | \n", - "2 | \n", - "27 | \n", - "9 | \n", - "22 | \n", - "0 | \n", - "3 | \n", - "14.00 | \n", - "216.0 | \n", - "4.10 | \n", - "1.64 | \n", - "1005.67 | \n", - "1.41 | \n", - "3.02 | \n", - "0.84 | \n", - "0 | \n", - "
| 2014-02-27 22:30:00 | \n", - "8.792 | \n", - "block_108 | \n", - "2 | \n", - "27 | \n", - "9 | \n", - "22 | \n", - "30 | \n", - "3 | \n", - "14.00 | \n", - "216.0 | \n", - "4.10 | \n", - "1.64 | \n", - "1005.67 | \n", - "1.41 | \n", - "3.02 | \n", - "0.84 | \n", - "0 | \n", - "
| 2014-02-27 23:00:00 | \n", - "8.087 | \n", - "block_108 | \n", - "2 | \n", - "27 | \n", - "9 | \n", - "23 | \n", - "0 | \n", - "3 | \n", - "14.03 | \n", - "200.0 | \n", - "3.93 | \n", - "1.61 | \n", - "1004.62 | \n", - "1.42 | \n", - "2.75 | \n", - "0.85 | \n", - "0 | \n", - "
| 2014-02-27 23:30:00 | \n", - "7.114 | \n", - "block_108 | \n", - "2 | \n", - "27 | \n", - "9 | \n", - "23 | \n", - "30 | \n", - "3 | \n", - "14.03 | \n", - "200.0 | \n", - "3.93 | \n", - "1.61 | \n", - "1004.62 | \n", - "1.42 | \n", - "2.75 | \n", - "0.85 | \n", - "0 | \n", - "
| 2014-02-28 00:00:00 | \n", - "7.287 | \n", - "block_108 | \n", - "2 | \n", - "28 | \n", - "9 | \n", - "0 | \n", - "0 | \n", - "4 | \n", - "12.63 | \n", - "190.0 | \n", - "3.81 | \n", - "1.53 | \n", - "1003.57 | \n", - "1.47 | \n", - "2.53 | \n", - "0.85 | \n", - "0 | \n", - "
313932 rows × 17 columns
\n", - "0&&ye(a,!u&&ve(e,\"script\")),s},cleanData:function(e){for(var t,n,r,i=b.event.special,o=0;void 0!==(n=e[o]);o++)if(X(n)){if(t=n[G.expando]){if(t.events)for(r in t.events)i[r]?b.event.remove(n,r):b.removeEvent(n,r,t.handle);n[G.expando]=void 0}n[Y.expando]&&(n[Y.expando]=void 0)}}}),b.fn.extend({detach:function(e){return Me(this,e,!0)},remove:function(e){return Me(this,e)},text:function(e){return B(this,(function(e){return void 0===e?b.text(this):this.empty().each((function(){1!==this.nodeType&&11!==this.nodeType&&9!==this.nodeType||(this.textContent=e)}))}),null,e,arguments.length)},append:function(){return Re(this,arguments,(function(e){1!==this.nodeType&&11!==this.nodeType&&9!==this.nodeType||qe(this,e).appendChild(e)}))},prepend:function(){return Re(this,arguments,(function(e){if(1===this.nodeType||11===this.nodeType||9===this.nodeType){var t=qe(this,e);t.insertBefore(e,t.firstChild)}}))},before:function(){return Re(this,arguments,(function(e){this.parentNode&&this.parentNode.insertBefore(e,this)}))},after:function(){return Re(this,arguments,(function(e){this.parentNode&&this.parentNode.insertBefore(e,this.nextSibling)}))},empty:function(){for(var e,t=0;null!=(e=this[t]);t++)1===e.nodeType&&(b.cleanData(ve(e,!1)),e.textContent=\"\");return this},clone:function(e,t){return e=null!=e&&e,t=null==t?e:t,this.map((function(){return b.clone(this,e,t)}))},html:function(e){return B(this,(function(e){var t=this[0]||{},n=0,r=this.length;if(void 0===e&&1===t.nodeType)return t.innerHTML;if(\"string\"==typeof e&&!Ne.test(e)&&!ge[(de.exec(e)||[\"\",\"\"])[1].toLowerCase()]){e=b.htmlPrefilter(e);try{for(;n 313932 rows × 17 columns\n",
- " \n",
- "
\n",
- "\n",
- " \n",
- " \n",
- " month \n",
- " day \n",
- " week \n",
- " hour \n",
- " minute \n",
- " dayofweek \n",
- " visibility \n",
- " windBearing \n",
- " temperature \n",
- " dewPoint \n",
- " pressure \n",
- " apparentTemperature \n",
- " windSpeed \n",
- " humidity \n",
- " holiday \n",
- " block \n",
- " tsp_days \n",
- " is_past \n",
- " \n",
- " \n",
- " \n",
- " \n",
- " Date \n",
- " \n",
- " \n",
- " \n",
- " \n",
- " \n",
- " \n",
- " \n",
- " \n",
- " \n",
- " \n",
- " \n",
- " \n",
- " \n",
- " \n",
- " \n",
- " \n",
- " \n",
- " \n",
- " \n",
- " \n",
- " 2013-11-08 17:00:00 \n",
- " 1.085349 \n",
- " -0.87348 \n",
- " 1.030777 \n",
- " 0.794583 \n",
- " -0.999962 \n",
- " 0.500561 \n",
- " -0.205197 \n",
- " 0.700950 \n",
- " -0.350809 \n",
- " 0.160852 \n",
- " -0.928594 \n",
- " -0.347889 \n",
- " -0.539344 \n",
- " 1.037935 \n",
- " -0.150044 \n",
- " 0.0 \n",
- " -0.104167 \n",
- " 1.0 \n",
- " \n",
- " \n",
- " 2013-11-08 17:30:00 \n",
- " 1.085349 \n",
- " -0.87348 \n",
- " 1.030777 \n",
- " 0.794583 \n",
- " 1.000038 \n",
- " 0.500561 \n",
- " -0.205197 \n",
- " 0.700950 \n",
- " -0.350809 \n",
- " 0.160852 \n",
- " -0.928594 \n",
- " -0.347889 \n",
- " -0.539344 \n",
- " 1.037935 \n",
- " -0.150044 \n",
- " 0.0 \n",
- " -0.083333 \n",
- " 1.0 \n",
- " \n",
- " \n",
- " 2013-11-08 18:00:00 \n",
- " 1.085349 \n",
- " -0.87348 \n",
- " 1.030777 \n",
- " 0.939042 \n",
- " -0.999962 \n",
- " 0.500561 \n",
- " 0.176935 \n",
- " 0.919905 \n",
- " -0.349112 \n",
- " 0.139362 \n",
- " -0.896675 \n",
- " -0.442038 \n",
- " 0.034563 \n",
- " 1.037935 \n",
- " -0.150044 \n",
- " 0.0 \n",
- " -0.062500 \n",
- " 1.0 \n",
- " \n",
- " \n",
- " 2013-11-08 18:30:00 \n",
- " 1.085349 \n",
- " -0.87348 \n",
- " 1.030777 \n",
- " 0.939042 \n",
- " 1.000038 \n",
- " 0.500561 \n",
- " 0.176935 \n",
- " 0.919905 \n",
- " -0.349112 \n",
- " 0.139362 \n",
- " -0.896675 \n",
- " -0.442038 \n",
- " 0.034563 \n",
- " 1.037935 \n",
- " -0.150044 \n",
- " 0.0 \n",
- " -0.041667 \n",
- " 1.0 \n",
- " \n",
- " \n",
- " \n",
- "2013-11-08 19:00:00 \n",
- " 1.085349 \n",
- " -0.87348 \n",
- " 1.030777 \n",
- " 1.083500 \n",
- " -0.999962 \n",
- " 0.500561 \n",
- " 0.118644 \n",
- " 0.952749 \n",
- " -0.391543 \n",
- " -0.048187 \n",
- " -0.821312 \n",
- " -0.517919 \n",
- " 0.270877 \n",
- " 0.680856 \n",
- " -0.150044 \n",
- " 0.0 \n",
- " -0.020833 \n",
- " 1.0 \n",
- " \n",
- " \n",
- "
\n",
- "\n",
- " \n",
- " \n",
- " month \n",
- " day \n",
- " week \n",
- " hour \n",
- " minute \n",
- " dayofweek \n",
- " visibility \n",
- " windBearing \n",
- " temperature \n",
- " dewPoint \n",
- " pressure \n",
- " apparentTemperature \n",
- " windSpeed \n",
- " humidity \n",
- " holiday \n",
- " block \n",
- " tsp_days \n",
- " is_past \n",
- " \n",
- " \n",
- " \n",
- " \n",
- " Date \n",
- " \n",
- " \n",
- " \n",
- " \n",
- " \n",
- " \n",
- " \n",
- " \n",
- " \n",
- " \n",
- " \n",
- " \n",
- " \n",
- " \n",
- " \n",
- " \n",
- " \n",
- " \n",
- " \n",
- " \n",
- " 2013-11-10 17:00:00 \n",
- " 1.085349 \n",
- " -0.645393 \n",
- " 1.030777 \n",
- " 0.794583 \n",
- " -0.999962 \n",
- " 1.499889 \n",
- " -0.328257 \n",
- " 0.36157 \n",
- " 0.77956 \n",
- " 1.225586 \n",
- " -1.236252 \n",
- " 0.835296 \n",
- " 2.180685 \n",
- " 0.609441 \n",
- " -0.150044 \n",
- " 0.0 \n",
- " 1.895833 \n",
- " 0.0 \n",
- " \n",
- " \n",
- " 2013-11-10 17:30:00 \n",
- " 1.085349 \n",
- " -0.645393 \n",
- " 1.030777 \n",
- " 0.794583 \n",
- " 1.000038 \n",
- " 1.499889 \n",
- " -0.328257 \n",
- " 0.36157 \n",
- " 0.77956 \n",
- " 1.225586 \n",
- " -1.236252 \n",
- " 0.835296 \n",
- " 2.180685 \n",
- " 0.609441 \n",
- " -0.150044 \n",
- " 0.0 \n",
- " 1.916667 \n",
- " 0.0 \n",
- " \n",
- " \n",
- " 2013-11-10 18:00:00 \n",
- " 1.085349 \n",
- " -0.645393 \n",
- " 1.030777 \n",
- " 0.939042 \n",
- " -0.999962 \n",
- " 1.499889 \n",
- " -0.328257 \n",
- " 0.36157 \n",
- " 0.77956 \n",
- " 1.225586 \n",
- " -1.236252 \n",
- " 0.835296 \n",
- " 2.180685 \n",
- " 0.609441 \n",
- " -0.150044 \n",
- " 0.0 \n",
- " 1.937500 \n",
- " 0.0 \n",
- " \n",
- " \n",
- " 2013-11-10 18:30:00 \n",
- " 1.085349 \n",
- " -0.645393 \n",
- " 1.030777 \n",
- " 0.939042 \n",
- " 1.000038 \n",
- " 1.499889 \n",
- " -0.328257 \n",
- " 0.36157 \n",
- " 0.77956 \n",
- " 1.225586 \n",
- " -1.236252 \n",
- " 0.835296 \n",
- " 2.180685 \n",
- " 0.609441 \n",
- " -0.150044 \n",
- " 0.0 \n",
- " 1.958333 \n",
- " 0.0 \n",
- " \n",
- " \n",
- " \n",
- "2013-11-10 19:00:00 \n",
- " 1.085349 \n",
- " -0.645393 \n",
- " 1.030777 \n",
- " 1.083500 \n",
- " -0.999962 \n",
- " 1.499889 \n",
- " -0.328257 \n",
- " 0.36157 \n",
- " 0.77956 \n",
- " 1.225586 \n",
- " -1.236252 \n",
- " 0.835296 \n",
- " 2.180685 \n",
- " 0.609441 \n",
- " -0.150044 \n",
- " 0.0 \n",
- " 1.979167 \n",
- " 0.0 \n",
- "
-#
-#
-# https://medium.com/@boitemailjeanmid/smart-meters-in-london-part1-description-and-first-insights-jean-michel-d-db97af2de71b
-#
-
-# OPTIONAL: Load the "autoreload" extension so that code can change. But blacklist large modules
-# %load_ext autoreload
-# %autoreload 2
-# %aimport -pandas
-# %aimport -torch
-# %aimport -numpy
-# %aimport -matplotlib
-# %aimport -dask
-# %aimport -tqdm
-# %matplotlib inline
-
-# +
-# Imports
-import torch
-from torch import nn, optim
-from torch.nn import functional as F
-from torch.autograd import Variable
-import torch
-import torch.utils.data
-
-import pandas as pd
-import numpy as np
-import matplotlib.pyplot as plt
-plt.rcParams['figure.figsize'] = (12.0, 3.0)
-plt.style.use('ggplot')
-
-from pathlib import Path
-from tqdm.auto import tqdm
-
-import pytorch_lightning as pl
-# -
-
-import warnings
-warnings.simplefilter('once')
-
-from seq2seq_time.data.dataset import Seq2SeqDataSet, Seq2SeqDataSets
-from seq2seq_time.predict import predict, predict_multi
-
-import logging, sys
-# logging.basicConfig(stream=sys.stdout, level=logging.INFO)
-
-# ## Parameters
-
-# +
-device = "cuda" if torch.cuda.is_available() else "cpu"
-print(f'using {device}')
-
-columns_target=['energy(kWh/hh)']
-window_past = 48*2
-window_future = 48*2
-batch_size = 256
-num_workers = 5
-freq = '30T'
-max_rows = 5e5
-
-
-# -
-
-# ## Load data
-
-# +
-
-def get_smartmeter_df(indir=Path('../data/raw/smart-meters-in-london'), max_files=8):
- """
- Data loading and cleanding is always messy, so understand this code is optional.
- """
-
- # Load csv files
- csv_files = sorted((indir/'halfhourly_dataset').glob('*.csv'))[:max_files]
-
- dfs = []
- for f in csv_files:
- df = (pd.read_csv(f, parse_dates=[1], na_values=['Null'])
- .groupby('tstp')
- .sum()
- .sort_index()
- )
- df['block'] = f.stem
-
- # Drop nan and 0's
- df = df[df['energy(kWh/hh)']!=0]
- df = df.dropna()
-
- # Add time features
- time = df.index.to_series()
- df["month"] = time.dt.month
- df['day'] = time.dt.day
- df['week'] = time.dt.week
- df['hour'] = time.dt.hour
- df['minute'] = time.dt.minute
- df['dayofweek'] = time.dt.dayofweek
-
- # Load weather data
- df_weather = pd.read_csv(indir/'weather_hourly_darksky.csv', parse_dates=[3])
- use_cols = ['visibility', 'windBearing', 'temperature', 'time', 'dewPoint',
- 'pressure', 'apparentTemperature', 'windSpeed',
- 'humidity']
- df_weather = df_weather[use_cols].set_index('time')
-
- # Resample to match energy data
- # Use first, since we have bearing, and you can't take mean
- df_weather = df_weather.resample(freq).first().ffill()
-
- # Join weather and energy data
- df = pd.merge(df, df_weather, how='inner', left_index=True, right_index=True, sort=True)
-
- # Holidays
- df_hols = pd.read_csv(indir/'uk_bank_holidays.csv', parse_dates=[0])
- holidays = set(df_hols['Bank holidays'].dt.round('D'))
- def is_holiday(dt):
- return dt in holidays
- days = df.index.floor('D')
- holiday_mapping = days.unique().to_series().apply(is_holiday).astype(int).to_dict()
- df['holiday'] = days.to_series().map(holiday_mapping).values
-
- # sort
- df.index.name = 'Date'
- df = df.loc['2012-09':] # Weird value before this
-
- dfs.append(df)
-
- return pd.concat(dfs)
-# -
-
-
-# Our dataset is the london smartmeter data. But at half hour intervals
-
-# +
-df = get_smartmeter_df(max_files=12)
-
-# # Just get the first one for now
-# dfs = list(dfs)
-
-# # df = df.resample(freq).first().dropna() # Where empty we will backfill, this will respect causality, and mostly maintain the mean
-
-df = df.tail(int(max_rows)).copy() # Just use last X rows
-# df = pd.concat(dfs[:6], 0)
-# # df = dfs[0]
-print(df.block.value_counts())
-df
-# -
-
-
-
-# ### Plot/explore
-
-
-
-
-
-# +
-import holoviews as hv
-from holoviews import opts
-
-from holoviews.plotting.links import RangeToolLink
-
-import datashader as ds
-
-from holoviews.operation.datashader import datashade, shade, dynspread, rasterize
-from holoviews.operation import decimate
-
-hv.extension('bokeh')
-
-
-# def house_curve(Name=None):
-# if isinstance(Name, int):
-# name = df.block.unique()[Name]
-# d = df[df.block == Name]
-# d_curve = hv.Curve(d, 'Date', 'energy(kWh/hh)', label=Name).opts(framewise=True)
-# return d_curve
-
-
-# dmap = hv.DynamicMap(house_curve, kdims=['Name'])
-# dmap = dmap.redim.values(Name=list(df.block.unique()))
-# dynspread(datashade(dmap).opts(width=800,
-# height=300,
-# tools=['xwheel_zoom', 'pan'],
-# active_tools=['xwheel_zoom', 'pan'],
-# default_tools=['reset', 'save', 'hover']
-# ))
-# -
-
-
-
-
-# ### Profiling
-
-# +
-# from pandas_profiling import ProfileReport
-# profile = ProfileReport(df, title="Pandas Profiling Report", minimal=True)
-# profile
-# -
-
-# ### Norm
-
-df.describe()
-
-# +
-import sklearn
-from sklearn.preprocessing import StandardScaler, OrdinalEncoder
-from sklearn_pandas import DataFrameMapper
-
-columns_input_numeric = list(df.drop(columns=columns_target)._get_numeric_data().columns)
-columns_categorical = list(set(df.columns)-set(columns_input_numeric)-set(columns_target))
-
-output_scalers = [([n], StandardScaler()) for n in columns_target]
-transformers=output_scalers + \
-[([n], StandardScaler()) for n in columns_input_numeric] + \
-[([n], OrdinalEncoder()) for n in columns_categorical]
-scaler = DataFrameMapper(transformers, df_out=True)
-df_norm = scaler.fit_transform(df)
-df_norm
-# -
-
-output_scaler = next(filter(lambda r:r[0][0] in columns_target, scaler.features))[-1]
-output_scaler
-
-# ### Split
-
-# +
-# split data, with the test in the future
-
-d0 =df_norm.index.min()
-d1 = df_norm.index.max()
-split_time = d0+(d1-d0)*0.8
-split_time = split_time.round('1D')
-print(split_time)
-df_train = df_norm.groupby('block').apply(lambda d:d.loc[:split_time]).reset_index(level=0, drop=True)
-df_test = df_norm.groupby('block').apply(lambda d:d.loc[split_time:]).reset_index(level=0, drop=True)
-# df_test
-
-# +
-# # Show split
-# df_train['energy(kWh/hh)'].plot(label='train')
-# df_test['energy(kWh/hh)'].plot(label='test')
-# plt.ylabel('energy(kWh/hh)')
-# plt.legend()
-# -
-
-# # Show split
-scatter = dynspread(datashade(hv.Curve(df_train, kdims=['Date'], vdims=['energy(kWh/hh)', 'block']).groupby('block'), cmap='blue'))
-scatter *= dynspread(datashade(hv.Curve(df_test, kdims=['Date'], vdims=['energy(kWh/hh)', 'block']).groupby('block'), cmap='red'))
-scatter = scatter.opts(plot=dict(width=800))
-scatter
-
-# ### Dataset
-
-# +
-
-# ### Dataset
-# These are the columns that we wont know in the future
-# We need to blank them out in x_future
-columns_blank=['visibility',
- 'windBearing', 'temperature', 'dewPoint', 'pressure',
- 'apparentTemperature', 'windSpeed', 'humidity']
-df_trains = [d.resample(freq).first().ffill().dropna() for _,d in df_train.groupby('block')]
-df_tests = [d.resample(freq).first().ffill().dropna() for _,d in df_test.groupby('block')]
-ds_train = Seq2SeqDataSets(df_trains,
- window_past=window_past,
- window_future=window_future,
- columns_blank=columns_blank)
-ds_test = Seq2SeqDataSets(df_tests,
- window_past=window_past,
- window_future=window_future,
- columns_blank=columns_blank)
-print(ds_train)
-print(ds_test)
-# -
-# we can treat it like an array
-ds_train[0]
-len(ds_train)
-ds_train[-1]
-
-# +
-# We can get rows
-x_past, y_past, x_future, y_future = ds_train.get_rows(10)
-
-# Plot one instance, this is what the model sees
-y_past['energy(kWh/hh)'].plot(label='past')
-y_future['energy(kWh/hh)'].plot(ax=plt.gca(), label='future')
-plt.legend()
-plt.ylabel('energy(kWh/hh)')
-
-# Notice we've added on two new columns tsp (time since present) and is_past
-x_past.tail()
-# -
-
-# Notice we've hidden some future columns to prevent cheating
-x_future.tail()
-
-
-# ## Plot helpers
-
-# +
-def plot_prediction(ds_preds, i):
- """Plot a prediction into the future, at a single point in time."""
- d = ds_preds.isel(t_source=i)
-
- # Get arrays
- xf = d.t_target
- yp = d.y_pred
- s = d.y_pred_std
- yt = d.y_true
- now = d.t_source.squeeze()
-
-
- plt.figure(figsize=(12, 4))
-
- plt.scatter(xf, yt, label='true', c='k', s=6)
- ylim = plt.ylim()
-
- # plot prediction
- plt.fill_between(xf, yp-2*s, yp+2*s, alpha=0.25,
- facecolor="b",
- interpolate=True,
- label="2 std",)
- plt.plot(xf, yp, label='pred', c='b')
-
- # plot true
- plt.scatter(
- d.t_past,
- d.y_past,
- c='k',
- s=6
- )
-
- # plot a red line for now
- plt.vlines(x=now, ymin=0, ymax=1, label='now', color='r')
- plt.ylim(*ylim)
-
- now=pd.Timestamp(now.values)
- plt.title(f'Prediction NLL={d.nll.mean().item():2.2g}')
- plt.xlabel(f'{now.date()}')
- plt.ylabel('energy(kWh/hh)')
- plt.legend()
- plt.xticks(rotation=45)
- plt.show()
-
-def plot_performance(ds_preds, full=False):
- """Multiple plots using xr_preds"""
- plot_prediction(ds_preds, 24)
-
- ds_preds.mean('t_source').plot.scatter('t_ahead_hours', 'nll') # Mean over all predictions
- n = len(ds_preds.t_source)
- plt.ylabel('Negative Log Likelihood (lower is better)')
- plt.xlabel('Hours ahead')
- plt.title(f'NLL vs time ahead (no. samples={n})')
- plt.show()
-
- # Make a plot of the NLL over time. Does this solution get worse with time?
- if full:
- d = ds_preds.mean('t_ahead').groupby('t_source').mean().plot.scatter('t_source', 'nll')
- plt.xticks(rotation=45)
- plt.title('NLL over source time (lower is better)')
- plt.show()
-
- # A scatter plot is easy with xarray
- if full:
- plt.figure(figsize=(5, 5))
- ds_preds.plot.scatter('y_true', 'y_pred', s=.01)
- plt.show()
-
-
-# -
-
-
-
-def plot_hist(trainer):
- try:
- df_hist = pd.read_csv(trainer.logger.experiment.metrics_file_path)
- df_hist['epoch'] = df_hist['epoch'].ffill()
- df_histe = df_hist.set_index('epoch').groupby('epoch').mean()
- if len(df_histe)>1:
- df_histe[['loss/train', 'loss/val']].plot(title='history')
- return df_histe
- except Exception:
- pass
-
-
-# ## Lightning
-
-# +
-import pytorch_lightning as pl
-
-class PL_MODEL(pl.LightningModule):
- def __init__(self, model, lr=3e-4, patience=2):
- super().__init__()
- self._model = model
- self.lr = lr
- self.patience = patience
-
- def forward(self, x_past, y_past, x_future, y_future=None):
- """Eval/Predict"""
- y_dist, extra = self._model(x_past, y_past, x_future, y_future)
- return y_dist, extra
-
- def training_step(self, batch, batch_idx, phase='train'):
- x_past, y_past, x_future, y_future = batch
- y_dist, extra = self.forward(*batch)
- loss = -y_dist.log_prob(y_future).mean()
- self.log_dict({f'loss/{phase}':loss})
- if ('loss' in extra) and (phase=='train'):
- # some models have a special loss
- loss = extra['loss']
- self.log_dict({f'model_loss/{phase}':loss})
- return loss
-
- def validation_step(self, batch, batch_idx):
- return self.training_step(batch, batch_idx, phase='val')
-
- def configure_optimizers(self):
- optim = torch.optim.Adam(self.parameters(), lr=self.lr)
- scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau(
- optim,
- patience=self.patience,
- verbose=True,
- min_lr=1e-7,
- )
- return {'optimizer': optim, 'lr_scheduler': scheduler, 'monitor': 'loss/val'}
-
-
-# -
-
-# # Run
-from torch.utils.data import DataLoader
-from pytorch_lightning.loggers import CSVLogger
-from pytorch_lightning.callbacks.early_stopping import EarlyStopping
-
-
-# +
-# Init data
-x_past, y_past, x_future, y_future = ds_train.get_rows(10)
-input_size = x_past.shape[-1]
-output_size = y_future.shape[-1]
-
-dl_train = DataLoader(ds_train,
- batch_size=batch_size,
- shuffle=True,
- pin_memory=num_workers==0,
- num_workers=num_workers)
-dl_test = DataLoader(ds_test, batch_size=batch_size, num_workers=num_workers)
-# -
-
-from seq2seq_time.models.lstm_seq2seq import LSTMSeq2Seq
-from seq2seq_time.models.lstm_seq import LSTMSeq
-from seq2seq_time.models.lstm import LSTM
-from seq2seq_time.models.baseline import BaselineLast
-from seq2seq_time.models.transformer import Transformer
-from seq2seq_time.models.transformer_seq2seq import TransformerSeq2Seq
-from seq2seq_time.models.transformer_seq import TransformerSeq
-from seq2seq_time.models.neural_process import RANP
-# ## Plots
-# +
-models = [
- RANP(input_size,
- output_size),
- LSTM(input_size,
- output_size,
- hidden_size=80,
- lstm_layers=3,
- lstm_dropout=0.3),
-
- LSTMSeq2Seq(input_size,
- output_size,
- hidden_size=64,
- lstm_layers=2,
- lstm_dropout=0.25),
- TransformerSeq2Seq(input_size,
- output_size,
- hidden_size=64,
- nhead=8,
- nlayers=4,
- attention_dropout=0.3),
- Transformer(input_size,
- output_size,
- attention_dropout=0.3,
- nhead=8,
- nlayers=6,
- hidden_size=64),
- TransformerSeq(input_size,
- output_size),
- LSTMSeq(input_size,
- output_size),
-
-]
-# -
-
-# Baseline model
-pt_model = BaselineLast()
-model = PL_MODEL(pt_model).to(device)
-trainer = pl.Trainer(gpus=1,
- max_epochs=1,
- limit_train_batches=0.01,
- logger=CSVLogger("logs",
- name=type(pt_model).__name__),
- )
-trainer.fit(model, dl_train, dl_test)
-print(plot_hist(trainer))
-ds_predss = predict_multi(model.to(device),
- ds_test.datasets,
- batch_size*8,
- device=device,
- scaler=output_scaler)
-print(f'baseline nll: {ds_preds.nll.mean().item():2.2g}')
-
-for pt_model in models:
- name = type(pt_model).__name__
- print(name)
-
- # Wrap in lightning
- patience = 2
- model = PL_MODEL(pt_model, patience=patience, lr=3e-4).to(device)
-
- # Trainer
- trainer = pl.Trainer(gpus=1,
- min_epochs=2,
- max_epochs=10,
- amp_level='O1',
- precision=16,
- gradient_clip_val=1,
- logger=CSVLogger("logs",
- name=type(pt_model).__name__),
- callbacks=[
- EarlyStopping(monitor='loss/val', patience=patience*2),
-# PrintTableMetricsCallback2()
- ],
- )
-
- # Train
- trainer.fit(model, dl_train, dl_test)
-
-
-
- ds_predss = predict_multi(model.to(device),
- ds_test.datasets,
- batch_size*8,
- device=device,
- scaler=output_scaler)
-
- print(name)
- print(f'mean_NLL {ds_predss.nll.mean().item():2.2f}')
-
- # Performance
- ds_preds = ds_predss.isel(block=0)
- print(plot_hist(trainer))
- plot_performance(ds_preds)
-
-# +
-# ds_preds = predict(model.to(device),
-# ds_test.datasets[0],
-# batch_size*8,
-# device=device,
-# scaler=output_scaler)
-# -
-
-ds_predss = predict_multi(model.to(device),
- ds_test.datasets,
- batch_size*8,
- device=device,
- scaler=output_scaler)
-
-ds_pred_block = ds_predss.isel(block=1)
-
-# # holoviews pred
-
-import holoviews as hv
-from holoviews import opts
-
-
-# +
-def plot_prediction_now(t_source):
- """Plot predictions with holoviews"""
-
- # Let us pass in an int
- if isinstance(t_source, int):
- t_source = ds_pred_block.t_source[t_source].to_pandas()
-
- d = ds_pred_block.sel(t_source=t_source)
-
- # Sometimes there are duplicate times, take the first
- if len(d.t_source.shape) and d.t_source.shape[0] > 0:
- d = d.isel(t_source=0)
- if len(d.t_source.shape) and d.t_source.shape[0] == 0:
- return None
-
- now = d.t_source.to_pandas()
-
- # Plot true
- x = np.concatenate([d.t_past, d.t_target])
- yt = np.concatenate([d.y_past, d.y_true])
- p = hv.Scatter({
- 'x': x,
- 'y': yt
- }, label='true').opts(color='black')
-
- # Get arrays
- xf = d.t_target.values
- yp = d.y_pred
- s = d.y_pred_std
- p *= hv.Curve({
- 'x': xf,
- 'y': yp
- }, label='pred').opts(color='blue')
- p *= hv.Area((xf, yp - 2 * s, yp + 2 * s),
- vdims=['y', 'y2'],
- label='2*std').opts(alpha=0.5, line_width=0)
-
- # plot now line
- p *= hv.VLine(now, label='now').opts(color='red', framewise=True)
- return p.opts(title=f'Prediction at {now}. NLL={d.nll.mean().item():2.2f}')
-
-
-dmap_pred = (hv.DynamicMap(plot_prediction_now, kdims=['t_source'])
- .redim.values(t_source=ds_pred_block.t_source.to_pandas())
- .opts(width=800,
- height=300,
- ))
-dmap_pred
-# -
-
-d = ds_preds.mean(['t_source', 'block'])['nll'].groupby('t_ahead_hours').mean()
-nll_vs_tahead = hv.Curve((d.t_ahead_hours, d)).redim(x='hours ahead', y='nll').opts(width=800)
-nll_vs_tahead
-
-# +
-# def plot_predictions_vs_time(it_ahead):
-# """Plot predictions vs time with holoviews"""
-
-# d = ds_pred_block.isel(t_ahead=it_ahead).groupby('t_source').first()
-# # print(d)
-
-# p = hv.Scatter({
-# 'x': d.t_source,
-# 'y': d.y_true
-# }, label='true').opts(color='black')
-
-# # Get arrays
-# xf = d.t_source.values
-# yp = d.y_pred
-# s = d.y_pred_std
-# p *= hv.Curve({
-# 'x': xf,
-# 'y': yp
-# }, label='pred').opts(color='blue')
-# p *= hv.Area((xf, yp - 2 * s, yp + 2 * s),
-# vdims=['y', 'y2'],
-# label='2*std').opts(alpha=0.5, line_width=0)
-
-
-# return p.opts(title=f'Prediction at {it_ahead * pd.Timedelta(freq)} ahead. NLL={d.nll.mean().item():2.2f}')
-
-
-# dmap_preds = (hv.DynamicMap(plot_predictions_vs_time, kdims=['it_ahead'])
-# .redim.values(it_ahead=range(ds_pred_block.t_ahead.shape[0]))
-# .opts(width=800,
-# height=300,
-# ))
-# dmap_preds
-# # TODO fixme
-# -
-
-
-
-# +
-# d = ds_preds.mean(['t_ahead', 'block'])['nll'].groupby('t_source').mean()
-# nll_vs_time = hv.Curve(d).opts(width=800)
-# nll_vs_time
-
-# +
-# true_vs_pred = hv.Scatter((ds_preds.y_true, ds_preds.y_pred))
-# dynspread(datashade(true_vs_pred))
-# -
-
-# # Summarize experiments
-
-# # LR finder
-
-# +
-
-# # Run learning rate finder
-# lr_finder = trainer.tuner.lr_find(model)
-
-# # Results can be found in
-# lr_finder.results
-
-# # Plot with
-# fig = lr_finder.plot(suggest=True)
-# fig.show()
-
-# # Pick point based on plot, or get suggestion
-# new_lr = lr_finder.suggestion()
-# -
-
-
diff --git a/notebooks/03.0-mike-RNN_Timeseries_Seq2Seq.ipynb b/notebooks/03.0-mike-RNN_Timeseries_Seq2Seq.ipynb
deleted file mode 100644
index 4d23b1a..0000000
--- a/notebooks/03.0-mike-RNN_Timeseries_Seq2Seq.ipynb
+++ /dev/null
@@ -1,8715 +0,0 @@
-{
- "cells": [
- {
- "cell_type": "markdown",
- "metadata": {
- "ExecuteTime": {
- "end_time": "2020-10-10T01:25:12.788851Z",
- "start_time": "2020-10-10T01:25:12.783398Z"
- }
- },
- "source": [
- "# Sequence to Sequence Models for Timeseries Regression\n",
- "\n",
- "\n",
- "In this notebook we are going to tackle a harder problem: \n",
- "- predicting the future on a timeseries\n",
- "- using an LSTM\n",
- "- with rough uncertainty (uncalibrated)\n",
- "- outputing sequence of predictions\n",
- "\n",
- "\n",
- "\n",
- "\n",
- "https://medium.com/@boitemailjeanmid/smart-meters-in-london-part1-description-and-first-insights-jean-michel-d-db97af2de71b\n"
- ]
- },
- {
- "cell_type": "code",
- "execution_count": 1,
- "metadata": {
- "ExecuteTime": {
- "end_time": "2020-10-24T01:46:25.617948Z",
- "start_time": "2020-10-24T01:46:25.197663Z"
- }
- },
- "outputs": [],
- "source": [
- "# OPTIONAL: Load the \"autoreload\" extension so that code can change. But blacklist large modules\n",
- "%load_ext autoreload\n",
- "%autoreload 2\n",
- "%aimport -pandas\n",
- "%aimport -torch\n",
- "%aimport -numpy\n",
- "%aimport -matplotlib\n",
- "%aimport -dask\n",
- "%aimport -tqdm\n",
- "%matplotlib inline"
- ]
- },
- {
- "cell_type": "code",
- "execution_count": 2,
- "metadata": {
- "ExecuteTime": {
- "end_time": "2020-10-24T01:46:26.786046Z",
- "start_time": "2020-10-24T01:46:25.622254Z"
- }
- },
- "outputs": [],
- "source": [
- "# Imports\n",
- "import torch\n",
- "from torch import nn, optim\n",
- "from torch.nn import functional as F\n",
- "from torch.autograd import Variable\n",
- "import torch\n",
- "import torch.utils.data\n",
- "\n",
- "import pandas as pd\n",
- "import numpy as np\n",
- "import matplotlib.pyplot as plt\n",
- "plt.rcParams['figure.figsize'] = (12.0, 3.0)\n",
- "plt.style.use('ggplot')\n",
- "\n",
- "from pathlib import Path\n",
- "from tqdm.auto import tqdm\n",
- "\n",
- "import pytorch_lightning as pl"
- ]
- },
- {
- "cell_type": "code",
- "execution_count": 3,
- "metadata": {
- "ExecuteTime": {
- "end_time": "2020-10-24T01:46:26.817188Z",
- "start_time": "2020-10-24T01:46:26.789800Z"
- }
- },
- "outputs": [],
- "source": [
- "import warnings\n",
- "warnings.simplefilter('once')"
- ]
- },
- {
- "cell_type": "code",
- "execution_count": 4,
- "metadata": {
- "ExecuteTime": {
- "end_time": "2020-10-24T01:46:27.392050Z",
- "start_time": "2020-10-24T01:46:26.822497Z"
- }
- },
- "outputs": [
- {
- "name": "stderr",
- "output_type": "stream",
- "text": [
- "/home/wassname/anaconda/envs/seq2seq-time/lib/python3.7/site-packages/ipykernel/ipkernel.py:287: DeprecationWarning: `should_run_async` will not call `transform_cell` automatically in the future. Please pass the result to `transformed_cell` argument and any exception that happen during thetransform in `preprocessing_exc_tuple` in IPython 7.17 and above.\n",
- " and should_run_async(code)\n"
- ]
- }
- ],
- "source": [
- "from seq2seq_time.data.dataset import Seq2SeqDataSet, Seq2SeqDataSets\n",
- "from seq2seq_time.predict import predict, predict_multi"
- ]
- },
- {
- "cell_type": "code",
- "execution_count": 5,
- "metadata": {
- "ExecuteTime": {
- "end_time": "2020-10-24T01:46:27.445533Z",
- "start_time": "2020-10-24T01:46:27.396912Z"
- }
- },
- "outputs": [
- {
- "name": "stderr",
- "output_type": "stream",
- "text": [
- "/home/wassname/anaconda/envs/seq2seq-time/lib/python3.7/site-packages/ipykernel/ipkernel.py:287: DeprecationWarning: `should_run_async` will not call `transform_cell` automatically in the future. Please pass the result to `transformed_cell` argument and any exception that happen during thetransform in `preprocessing_exc_tuple` in IPython 7.17 and above.\n",
- " and should_run_async(code)\n"
- ]
- }
- ],
- "source": [
- "import logging, sys\n",
- "# logging.basicConfig(stream=sys.stdout, level=logging.INFO)"
- ]
- },
- {
- "cell_type": "markdown",
- "metadata": {
- "ExecuteTime": {
- "end_time": "2020-10-10T01:28:32.492160Z",
- "start_time": "2020-10-10T01:28:32.488140Z"
- }
- },
- "source": [
- "## Parameters"
- ]
- },
- {
- "cell_type": "code",
- "execution_count": 6,
- "metadata": {
- "ExecuteTime": {
- "end_time": "2020-10-24T01:46:27.534753Z",
- "start_time": "2020-10-24T01:46:27.451107Z"
- }
- },
- "outputs": [
- {
- "name": "stdout",
- "output_type": "stream",
- "text": [
- "using cuda\n"
- ]
- }
- ],
- "source": [
- "device = \"cuda\" if torch.cuda.is_available() else \"cpu\"\n",
- "print(f'using {device}')\n",
- "\n",
- "columns_target=['energy(kWh/hh)']\n",
- "window_past = 48*2\n",
- "window_future = 48*2\n",
- "batch_size = 128\n",
- "num_workers = 5\n",
- "freq = '30T'\n",
- "max_rows = 5e5"
- ]
- },
- {
- "cell_type": "markdown",
- "metadata": {},
- "source": [
- "## Load data"
- ]
- },
- {
- "cell_type": "code",
- "execution_count": 7,
- "metadata": {
- "ExecuteTime": {
- "end_time": "2020-10-24T01:46:27.595828Z",
- "start_time": "2020-10-24T01:46:27.538900Z"
- },
- "lines_to_next_cell": 0
- },
- "outputs": [],
- "source": [
- "\n",
- "def get_smartmeter_df(indir=Path('../data/raw/smart-meters-in-london'), max_files=8):\n",
- " \"\"\"\n",
- " Data loading and cleanding is always messy, so understand this code is optional.\n",
- " \"\"\"\n",
- " \n",
- " # Load csv files\n",
- " csv_files = sorted((indir/'halfhourly_dataset').glob('*.csv'))[:max_files]\n",
- " \n",
- " dfs = []\n",
- " for f in csv_files:\n",
- " df = (pd.read_csv(f, parse_dates=[1], na_values=['Null'])\n",
- " .groupby('tstp')\n",
- " .sum()\n",
- " .sort_index()\n",
- " )\n",
- " df['block'] = f.stem\n",
- "\n",
- " # Drop nan and 0's\n",
- " df = df[df['energy(kWh/hh)']!=0]\n",
- " df = df.dropna()\n",
- " \n",
- " # Add time features \n",
- " time = df.index.to_series()\n",
- " df[\"month\"] = time.dt.month\n",
- " df['day'] = time.dt.day\n",
- " df['week'] = time.dt.week\n",
- " df['hour'] = time.dt.hour\n",
- " df['minute'] = time.dt.minute\n",
- " df['dayofweek'] = time.dt.dayofweek\n",
- "\n",
- " # Load weather data\n",
- " df_weather = pd.read_csv(indir/'weather_hourly_darksky.csv', parse_dates=[3])\n",
- " use_cols = ['visibility', 'windBearing', 'temperature', 'time', 'dewPoint',\n",
- " 'pressure', 'apparentTemperature', 'windSpeed', \n",
- " 'humidity']\n",
- " df_weather = df_weather[use_cols].set_index('time')\n",
- " \n",
- " # Resample to match energy data \n",
- " # Use first, since we have bearing, and you can't take mean\n",
- " df_weather = df_weather.resample(freq).first().ffill() \n",
- "\n",
- " # Join weather and energy data\n",
- " df = pd.merge(df, df_weather, how='inner', left_index=True, right_index=True, sort=True)\n",
- "\n",
- " # Holidays\n",
- " df_hols = pd.read_csv(indir/'uk_bank_holidays.csv', parse_dates=[0])\n",
- " holidays = set(df_hols['Bank holidays'].dt.round('D')) \n",
- " def is_holiday(dt):\n",
- " return dt in holidays\n",
- " days = df.index.floor('D')\n",
- " holiday_mapping = days.unique().to_series().apply(is_holiday).astype(int).to_dict()\n",
- " df['holiday'] = days.to_series().map(holiday_mapping).values\n",
- "\n",
- " # sort\n",
- " df.index.name = 'Date'\n",
- " df = df.loc['2012-09':] # Weird value before this\n",
- " \n",
- " dfs.append(df)\n",
- " \n",
- " return pd.concat(dfs)"
- ]
- },
- {
- "cell_type": "markdown",
- "metadata": {},
- "source": [
- "Our dataset is the london smartmeter data. But at half hour intervals"
- ]
- },
- {
- "cell_type": "code",
- "execution_count": 8,
- "metadata": {
- "ExecuteTime": {
- "end_time": "2020-10-24T01:46:46.307630Z",
- "start_time": "2020-10-24T01:46:27.600463Z"
- },
- "lines_to_next_cell": 0,
- "scrolled": true
- },
- "outputs": [
- {
- "name": "stderr",
- "output_type": "stream",
- "text": [
- "/home/wassname/anaconda/envs/seq2seq-time/lib/python3.7/site-packages/ipykernel_launcher.py:26: FutureWarning: Series.dt.weekofyear and Series.dt.week have been deprecated. Please use Series.dt.isocalendar().week instead.\n",
- "/home/wassname/anaconda/envs/seq2seq-time/lib/python3.7/site-packages/ipykernel_launcher.py:26: FutureWarning: Series.dt.weekofyear and Series.dt.week have been deprecated. Please use Series.dt.isocalendar().week instead.\n",
- "/home/wassname/anaconda/envs/seq2seq-time/lib/python3.7/site-packages/ipykernel_launcher.py:26: FutureWarning: Series.dt.weekofyear and Series.dt.week have been deprecated. Please use Series.dt.isocalendar().week instead.\n",
- "/home/wassname/anaconda/envs/seq2seq-time/lib/python3.7/site-packages/ipykernel_launcher.py:26: FutureWarning: Series.dt.weekofyear and Series.dt.week have been deprecated. Please use Series.dt.isocalendar().week instead.\n",
- "/home/wassname/anaconda/envs/seq2seq-time/lib/python3.7/site-packages/ipykernel_launcher.py:26: FutureWarning: Series.dt.weekofyear and Series.dt.week have been deprecated. Please use Series.dt.isocalendar().week instead.\n",
- "/home/wassname/anaconda/envs/seq2seq-time/lib/python3.7/site-packages/ipykernel_launcher.py:26: FutureWarning: Series.dt.weekofyear and Series.dt.week have been deprecated. Please use Series.dt.isocalendar().week instead.\n",
- "/home/wassname/anaconda/envs/seq2seq-time/lib/python3.7/site-packages/ipykernel_launcher.py:26: FutureWarning: Series.dt.weekofyear and Series.dt.week have been deprecated. Please use Series.dt.isocalendar().week instead.\n",
- "/home/wassname/anaconda/envs/seq2seq-time/lib/python3.7/site-packages/ipykernel_launcher.py:26: FutureWarning: Series.dt.weekofyear and Series.dt.week have been deprecated. Please use Series.dt.isocalendar().week instead.\n",
- "/home/wassname/anaconda/envs/seq2seq-time/lib/python3.7/site-packages/ipykernel_launcher.py:26: FutureWarning: Series.dt.weekofyear and Series.dt.week have been deprecated. Please use Series.dt.isocalendar().week instead.\n",
- "/home/wassname/anaconda/envs/seq2seq-time/lib/python3.7/site-packages/ipykernel_launcher.py:26: FutureWarning: Series.dt.weekofyear and Series.dt.week have been deprecated. Please use Series.dt.isocalendar().week instead.\n",
- "/home/wassname/anaconda/envs/seq2seq-time/lib/python3.7/site-packages/ipykernel_launcher.py:26: FutureWarning: Series.dt.weekofyear and Series.dt.week have been deprecated. Please use Series.dt.isocalendar().week instead.\n",
- "/home/wassname/anaconda/envs/seq2seq-time/lib/python3.7/site-packages/ipykernel_launcher.py:26: FutureWarning: Series.dt.weekofyear and Series.dt.week have been deprecated. Please use Series.dt.isocalendar().week instead.\n"
- ]
- },
- {
- "name": "stdout",
- "output_type": "stream",
- "text": [
- "block_0 26161\n",
- "block_107 26161\n",
- "block_106 26161\n",
- "block_1 26161\n",
- "block_105 26161\n",
- "block_102 26161\n",
- "block_108 26161\n",
- "block_100 26161\n",
- "block_10 26161\n",
- "block_101 26161\n",
- "block_104 26161\n",
- "block_103 26161\n",
- "Name: block, dtype: int64\n"
- ]
- },
- {
- "data": {
- "text/html": [
- "
\n",
- " \n",
- "
\n",
- "\n",
- " \n",
- " \n",
- " energy(kWh/hh) \n",
- " block \n",
- " month \n",
- " day \n",
- " week \n",
- " hour \n",
- " minute \n",
- " dayofweek \n",
- " visibility \n",
- " windBearing \n",
- " temperature \n",
- " dewPoint \n",
- " pressure \n",
- " apparentTemperature \n",
- " windSpeed \n",
- " humidity \n",
- " holiday \n",
- " \n",
- " \n",
- " \n",
- " \n",
- " Date \n",
- " \n",
- " \n",
- " \n",
- " \n",
- " \n",
- " \n",
- " \n",
- " \n",
- " \n",
- " \n",
- " \n",
- " \n",
- " \n",
- " \n",
- " \n",
- " \n",
- " \n",
- " \n",
- " \n",
- " 2012-09-01 00:00:00 \n",
- " 5.013 \n",
- " block_0 \n",
- " 9 \n",
- " 1 \n",
- " 35 \n",
- " 0 \n",
- " 0 \n",
- " 5 \n",
- " 13.36 \n",
- " 302.0 \n",
- " 14.08 \n",
- " 9.74 \n",
- " 1028.27 \n",
- " 14.08 \n",
- " 1.89 \n",
- " 0.75 \n",
- " 0 \n",
- " \n",
- " \n",
- " 2012-09-01 00:30:00 \n",
- " 5.157 \n",
- " block_0 \n",
- " 9 \n",
- " 1 \n",
- " 35 \n",
- " 0 \n",
- " 30 \n",
- " 5 \n",
- " 13.36 \n",
- " 302.0 \n",
- " 14.08 \n",
- " 9.74 \n",
- " 1028.27 \n",
- " 14.08 \n",
- " 1.89 \n",
- " 0.75 \n",
- " 0 \n",
- " \n",
- " \n",
- " 2012-09-01 01:00:00 \n",
- " 6.360 \n",
- " block_0 \n",
- " 9 \n",
- " 1 \n",
- " 35 \n",
- " 1 \n",
- " 0 \n",
- " 5 \n",
- " 13.50 \n",
- " 298.0 \n",
- " 13.93 \n",
- " 9.81 \n",
- " 1027.96 \n",
- " 13.93 \n",
- " 1.59 \n",
- " 0.76 \n",
- " 0 \n",
- " \n",
- " \n",
- " 2012-09-01 01:30:00 \n",
- " 5.511 \n",
- " block_0 \n",
- " 9 \n",
- " 1 \n",
- " 35 \n",
- " 1 \n",
- " 30 \n",
- " 5 \n",
- " 13.50 \n",
- " 298.0 \n",
- " 13.93 \n",
- " 9.81 \n",
- " 1027.96 \n",
- " 13.93 \n",
- " 1.59 \n",
- " 0.76 \n",
- " 0 \n",
- " \n",
- " \n",
- " 2012-09-01 02:00:00 \n",
- " 4.922 \n",
- " block_0 \n",
- " 9 \n",
- " 1 \n",
- " 35 \n",
- " 2 \n",
- " 0 \n",
- " 5 \n",
- " 13.21 \n",
- " 274.0 \n",
- " 13.52 \n",
- " 9.94 \n",
- " 1028.04 \n",
- " 13.52 \n",
- " 0.82 \n",
- " 0.79 \n",
- " 0 \n",
- " \n",
- " \n",
- " ... \n",
- " ... \n",
- " ... \n",
- " ... \n",
- " ... \n",
- " ... \n",
- " ... \n",
- " ... \n",
- " ... \n",
- " ... \n",
- " ... \n",
- " ... \n",
- " ... \n",
- " ... \n",
- " ... \n",
- " ... \n",
- " ... \n",
- " ... \n",
- " \n",
- " \n",
- " 2014-02-27 22:00:00 \n",
- " 9.819 \n",
- " block_108 \n",
- " 2 \n",
- " 27 \n",
- " 9 \n",
- " 22 \n",
- " 0 \n",
- " 3 \n",
- " 14.00 \n",
- " 216.0 \n",
- " 4.10 \n",
- " 1.64 \n",
- " 1005.67 \n",
- " 1.41 \n",
- " 3.02 \n",
- " 0.84 \n",
- " 0 \n",
- " \n",
- " \n",
- " 2014-02-27 22:30:00 \n",
- " 8.792 \n",
- " block_108 \n",
- " 2 \n",
- " 27 \n",
- " 9 \n",
- " 22 \n",
- " 30 \n",
- " 3 \n",
- " 14.00 \n",
- " 216.0 \n",
- " 4.10 \n",
- " 1.64 \n",
- " 1005.67 \n",
- " 1.41 \n",
- " 3.02 \n",
- " 0.84 \n",
- " 0 \n",
- " \n",
- " \n",
- " 2014-02-27 23:00:00 \n",
- " 8.087 \n",
- " block_108 \n",
- " 2 \n",
- " 27 \n",
- " 9 \n",
- " 23 \n",
- " 0 \n",
- " 3 \n",
- " 14.03 \n",
- " 200.0 \n",
- " 3.93 \n",
- " 1.61 \n",
- " 1004.62 \n",
- " 1.42 \n",
- " 2.75 \n",
- " 0.85 \n",
- " 0 \n",
- " \n",
- " \n",
- " 2014-02-27 23:30:00 \n",
- " 7.114 \n",
- " block_108 \n",
- " 2 \n",
- " 27 \n",
- " 9 \n",
- " 23 \n",
- " 30 \n",
- " 3 \n",
- " 14.03 \n",
- " 200.0 \n",
- " 3.93 \n",
- " 1.61 \n",
- " 1004.62 \n",
- " 1.42 \n",
- " 2.75 \n",
- " 0.85 \n",
- " 0 \n",
- " \n",
- " \n",
- " \n",
- "2014-02-28 00:00:00 \n",
- " 7.287 \n",
- " block_108 \n",
- " 2 \n",
- " 28 \n",
- " 9 \n",
- " 0 \n",
- " 0 \n",
- " 4 \n",
- " 12.63 \n",
- " 190.0 \n",
- " 3.81 \n",
- " 1.53 \n",
- " 1003.57 \n",
- " 1.47 \n",
- " 2.53 \n",
- " 0.85 \n",
- " 0 \n",
- " >>t&1&&(yield e);else e+=32}}*zeros(){const{_array:t,_nwords:s,size:r}=this;for(let e=0,i=0;i>>t&1||(yield e);else e+=32}}_check_size(t){e.assert(this.size==t.size,\"Size mismatch\")}add(t){this._check_size(t);for(let s=0;s>2];t>1;s[h]>t?e=h:i=h+1}return s[i]}function a(t,s,i,e,h){const n=t[e];t[e]=t[h],t[h]=n;const o=4*e,r=4*h,a=s[o],_=s[o+1],d=s[o+2],x=s[o+3];s[o]=s[r],s[o+1]=s[r+1],s[o+2]=s[r+2],s[o+3]=s[r+3],s[r]=a,s[r+1]=_,s[r+2]=d,s[r+3]=x;const l=i[e];i[e]=i[h],i[h]=l}function _(t,s){let i=t^s,e=65535^i,h=65535^(t|s),n=t&(65535^s),o=i|e>>1,r=i>>1^i,a=h>>1^e&n>>1^h,_=i&h>>1^n>>1^n;i=o,e=r,h=a,n=_,o=i&i>>2^e&e>>2,r=i&e>>2^e&(i^e)>>2,a^=i&h>>2^e&n>>2,_^=e&h>>2^(i^e)&n>>2,i=o,e=r,h=a,n=_,o=i&i>>4^e&e>>4,r=i&e>>4^e&(i^e)>>4,a^=i&h>>4^e&n>>4,_^=e&h>>4^(i^e)&n>>4,i=o,e=r,h=a,n=_,a^=i&h>>8^e&n>>8,_^=e&h>>8^(i^e)&n>>8,i=a^a>>1,e=_^_>>1;let d=t^s,x=e|65535^(d|i);return d=16711935&(d|d<<8),d=252645135&(d|d<<4),d=858993459&(d|d<<2),d=1431655765&(d|d<<1),x=16711935&(x|x<<8),x=252645135&(x|x<<4),x=858993459&(x|x<<2),x=1431655765&(x|x<<1),(x<<1|d)>>>0}i.default=n},\n",
- " function _(s,t,i){Object.defineProperty(i,\"__esModule\",{value:!0});i.default=class{constructor(){this.ids=[],this.values=[],this.length=0}clear(){this.length=0}push(s,t){let i=this.length++;for(this.ids[i]=s,this.values[i]=t;i>0;){const s=i-1>>1,h=this.values[s];if(t>=h)break;this.ids[i]=this.ids[s],this.values[i]=h,i=s}this.ids[i]=s,this.values[i]=t}pop(){if(0===this.length)return;const s=this.ids[0];if(this.length--,this.length>0){const s=this.ids[0]=this.ids[this.length],t=this.values[0]=this.values[this.length],i=this.length>>1;let h=0;for(;h=t)break;this.ids[h]=e,this.values[h]=l,h=s}this.ids[h]=s,this.values[h]=t}return s}peek(){if(0!==this.length)return this.ids[0]}peekValue(){if(0!==this.length)return this.values[0]}}},\n",
- " function _(t,e,n){Object.defineProperty(n,\"__esModule\",{value:!0});const s=t(1),i=t(99),r=s.__importStar(t(18)),a=t(24),o=t(9),p=t(8),g=t(11);function c(t,e,n=0){const s=new Map;for(let i=0;i