diff --git a/notebooks/RNN_Timeseries_Seq2Seq.ipynb b/notebooks/RNN_Timeseries_Seq2Seq.ipynb
deleted file mode 100644
index f59bf4e..0000000
--- a/notebooks/RNN_Timeseries_Seq2Seq.ipynb
+++ /dev/null
@@ -1,3264 +0,0 @@
-{
- "cells": [
- {
- "cell_type": "markdown",
- "metadata": {
- "ExecuteTime": {
- "end_time": "2020-10-10T01:25:12.788851Z",
- "start_time": "2020-10-10T01:25:12.783398Z"
- }
- },
- "source": [
- "# Sequence to Sequence Models for Timeseries Regression\n",
- "\n",
- "\n",
- "In this notebook we are going to tackle a harder problem: \n",
- "- predicting the future on a timeseries\n",
- "- using an LSTM\n",
- "- with rough uncertainty (uncalibrated)\n",
- "- outputing sequence of predictions\n",
- "\n",
- "\n",
- "\n"
- ]
- },
- {
- "cell_type": "markdown",
- "metadata": {},
- "source": [
- "\n",
- "- [ ] TODO mike autocorrelation baseline\n",
- "- [x] TODO mike acorn data\n",
- "- [ ] TODO mike handle multiple houses. Multiindex"
- ]
- },
- {
- "cell_type": "code",
- "execution_count": 1,
- "metadata": {
- "ExecuteTime": {
- "end_time": "2020-10-18T07:27:09.448780Z",
- "start_time": "2020-10-18T07:27:09.018747Z"
- }
- },
- "outputs": [],
- "source": [
- "# OPTIONAL: Load the \"autoreload\" extension so that code can change. But blacklist large modules\n",
- "%load_ext autoreload\n",
- "%autoreload 2\n",
- "%aimport -pandas\n",
- "%aimport -torch\n",
- "%aimport -numpy\n",
- "%aimport -matplotlib\n",
- "%aimport -dask\n",
- "%aimport -tqdm\n",
- "%matplotlib inline"
- ]
- },
- {
- "cell_type": "code",
- "execution_count": 51,
- "metadata": {
- "ExecuteTime": {
- "end_time": "2020-10-18T07:31:56.528476Z",
- "start_time": "2020-10-18T07:31:56.478567Z"
- }
- },
- "outputs": [],
- "source": [
- "# Imports\n",
- "import torch\n",
- "from torch import nn, optim\n",
- "from torch.nn import functional as F\n",
- "from torch.autograd import Variable\n",
- "import torch\n",
- "import torch.utils.data\n",
- "\n",
- "import pandas as pd\n",
- "import numpy as np\n",
- "import matplotlib.pyplot as plt\n",
- "plt.rcParams['figure.figsize'] = (12.0, 4.0)\n",
- "plt.style.use('ggplot')\n",
- "\n",
- "from pathlib import Path\n",
- "from tqdm.auto import tqdm\n",
- "\n",
- "import pytorch_lightning as pl"
- ]
- },
- {
- "cell_type": "code",
- "execution_count": 3,
- "metadata": {
- "ExecuteTime": {
- "end_time": "2020-10-18T07:27:10.589899Z",
- "start_time": "2020-10-18T07:27:10.503402Z"
- }
- },
- "outputs": [],
- "source": [
- "from seq2seq_time.data.dataset import Seq2SeqDataSet, Seq2SeqDataSets\n",
- "from seq2seq_time.predict import predict"
- ]
- },
- {
- "cell_type": "code",
- "execution_count": 4,
- "metadata": {
- "ExecuteTime": {
- "end_time": "2020-10-18T07:27:10.619540Z",
- "start_time": "2020-10-18T07:27:10.592972Z"
- }
- },
- "outputs": [],
- "source": [
- "import logging, sys\n",
- "# logging.basicConfig(stream=sys.stdout, level=logging.INFO)"
- ]
- },
- {
- "cell_type": "markdown",
- "metadata": {
- "ExecuteTime": {
- "end_time": "2020-10-10T01:28:32.492160Z",
- "start_time": "2020-10-10T01:28:32.488140Z"
- }
- },
- "source": [
- "## Parameters"
- ]
- },
- {
- "cell_type": "code",
- "execution_count": 5,
- "metadata": {
- "ExecuteTime": {
- "end_time": "2020-10-18T07:27:10.661223Z",
- "start_time": "2020-10-18T07:27:10.621920Z"
- }
- },
- "outputs": [
- {
- "name": "stdout",
- "output_type": "stream",
- "text": [
- "using cuda\n"
- ]
- }
- ],
- "source": [
- "device = \"cuda\" if torch.cuda.is_available() else \"cpu\"\n",
- "print(f'using {device}')\n",
- "\n",
- "columns_target=['energy(kWh/hh)']\n",
- "window_past = 48*4\n",
- "window_future = 48*4\n",
- "batch_size = 64\n",
- "num_workers = 0\n",
- "freq = '30T'\n",
- "max_rows = 1e5"
- ]
- },
- {
- "cell_type": "markdown",
- "metadata": {},
- "source": [
- "## Load data"
- ]
- },
- {
- "cell_type": "code",
- "execution_count": 6,
- "metadata": {
- "ExecuteTime": {
- "end_time": "2020-10-18T07:27:10.702018Z",
- "start_time": "2020-10-18T07:27:10.663465Z"
- },
- "lines_to_next_cell": 0
- },
- "outputs": [],
- "source": [
- "\n",
- "def get_smartmeter_df(indir=Path('../data/raw/smart-meters-in-london'), max_files=1):\n",
- " \"\"\"\n",
- " Data loading and cleanding is always messy, so understand this code is optional.\n",
- " \"\"\"\n",
- " \n",
- " # Load csv files\n",
- " csv_files = sorted((indir/'halfhourly_dataset').glob('*.csv'))[:max_files]\n",
- " \n",
- " # concatendate them\n",
- " df = pd.concat([pd.read_csv(f, parse_dates=[1], na_values=['Null']) for f in csv_files])\n",
- " \n",
- " # Add ACORN categories\n",
- " df_households = pd.read_csv(indir/'informations_households.csv')\n",
- " df_households = df_households[['LCLid', 'stdorToU', 'Acorn_grouped']]\n",
- " df = pd.merge(df, df_households, on='LCLid')\n",
- " \n",
- " df = df.set_index('tstp')\n",
- " \n",
- " # Drop nan and 0's\n",
- " df = df[df['energy(kWh/hh)']!=0]\n",
- " df = df.dropna()\n",
- " \n",
- " # Add time features \n",
- " time = df.index.to_series()\n",
- " df[\"month\"] = time.dt.month\n",
- " df['day'] = time.dt.day\n",
- " df['week'] = time.dt.week\n",
- " df['hour'] = time.dt.hour\n",
- " df['minute'] = time.dt.minute\n",
- " df['dayofweek'] = time.dt.dayofweek\n",
- " \n",
- " # Load weather data\n",
- " df_weather = pd.read_csv(indir/'weather_hourly_darksky.csv', parse_dates=[3])\n",
- " use_cols = ['visibility', 'windBearing', 'temperature', 'time', 'dewPoint',\n",
- " 'pressure', 'apparentTemperature', 'windSpeed', \n",
- " 'humidity']\n",
- " df_weather = df_weather[use_cols].set_index('time')\n",
- " df_weather = df_weather.resample(freq).first().ffill() # Resample to match energy data \n",
- " \n",
- " # Join weather and energy data\n",
- " df = pd.merge(df, df_weather, how='inner', left_index=True, right_index=True, sort=True)\n",
- " \n",
- " # Holidays\n",
- " df_hols = pd.read_csv(indir/'uk_bank_holidays.csv', parse_dates=[0])\n",
- " holidays = set(df_hols['Bank holidays'].dt.round('D')) \n",
- " def is_holiday(dt):\n",
- " return dt in holidays\n",
- " days = df.index.floor('D')\n",
- " holiday_mapping = days.unique().to_series().apply(is_holiday).astype(int).to_dict()\n",
- " df['holiday'] = days.to_series().map(holiday_mapping).values\n",
- "\n",
- " # Loop over houses\n",
- " for name, df_h in df.groupby('LCLid'):\n",
- "\n",
- " yield df_h"
- ]
- },
- {
- "cell_type": "markdown",
- "metadata": {},
- "source": [
- "Our dataset is the london smartmeter data. But at half hour intervals"
- ]
- },
- {
- "cell_type": "code",
- "execution_count": 7,
- "metadata": {
- "ExecuteTime": {
- "end_time": "2020-10-18T07:27:14.677877Z",
- "start_time": "2020-10-18T07:27:10.704252Z"
- },
- "scrolled": true
- },
- "outputs": [
- {
- "name": "stderr",
- "output_type": "stream",
- "text": [
- "/home/wassname/anaconda/envs/seq2seq-time/lib/python3.7/site-packages/ipykernel_launcher.py:27: FutureWarning: Series.dt.weekofyear and Series.dt.week have been deprecated. Please use Series.dt.isocalendar().week instead.\n"
- ]
- }
- ],
- "source": [
- "dfs = get_smartmeter_df()\n",
- "\n",
- "# Just get the first one for now\n",
- "dfs = list(dfs)\n",
- "\n",
- "# df = df.resample(freq).first().dropna() # Where empty we will backfill, this will respect causality, and mostly maintain the mean\n",
- "\n",
- "# df = df.tail(int(max_rows)).copy() # Just use last X rows\n",
- "df = pd.concat(dfs[:6], 0)\n",
- "# df = dfs[0]"
- ]
- },
- {
- "cell_type": "code",
- "execution_count": 8,
- "metadata": {
- "ExecuteTime": {
- "end_time": "2020-10-18T07:27:14.722131Z",
- "start_time": "2020-10-18T07:27:14.681417Z"
- }
- },
- "outputs": [
- {
- "data": {
- "text/plain": [
- "array(['MAC000002', 'MAC000246', 'MAC000450', 'MAC001074', 'MAC003223',\n",
- " 'MAC003239'], dtype=object)"
- ]
- },
- "execution_count": 8,
- "metadata": {},
- "output_type": "execute_result"
- }
- ],
- "source": [
- "df.LCLid.unique()"
- ]
- },
- {
- "cell_type": "code",
- "execution_count": null,
- "metadata": {
- "ExecuteTime": {
- "end_time": "2020-10-18T07:15:20.702022Z",
- "start_time": "2020-10-18T07:15:20.646916Z"
- }
- },
- "outputs": [],
- "source": []
- },
- {
- "cell_type": "code",
- "execution_count": 9,
- "metadata": {
- "ExecuteTime": {
- "end_time": "2020-10-18T07:27:14.905468Z",
- "start_time": "2020-10-18T07:27:14.724956Z"
- }
- },
- "outputs": [
- {
- "data": {
- "text/html": [
- "
| \n", - " | energy(kWh/hh) | \n", - "month | \n", - "day | \n", - "week | \n", - "hour | \n", - "minute | \n", - "dayofweek | \n", - "visibility | \n", - "windBearing | \n", - "temperature | \n", - "dewPoint | \n", - "pressure | \n", - "apparentTemperature | \n", - "windSpeed | \n", - "humidity | \n", - "holiday | \n", - "
|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|
| count | \n", - "144861.000000 | \n", - "144861.000000 | \n", - "144861.000000 | \n", - "144861.000000 | \n", - "144861.000000 | \n", - "144861.000000 | \n", - "144861.000000 | \n", - "144861.000000 | \n", - "144861.000000 | \n", - "144861.000000 | \n", - "144861.000000 | \n", - "144861.000000 | \n", - "144861.000000 | \n", - "144861.00000 | \n", - "144861.000000 | \n", - "144861.000000 | \n", - "
| mean | \n", - "0.480475 | \n", - "6.702660 | \n", - "15.803764 | \n", - "27.250626 | \n", - "11.505257 | \n", - "14.999482 | \n", - "2.997335 | \n", - "11.223563 | \n", - "194.106060 | \n", - "10.408324 | \n", - "6.499879 | \n", - "1012.888314 | \n", - "9.123004 | \n", - "3.93198 | \n", - "0.783605 | \n", - "0.021869 | \n", - "
| std | \n", - "0.608599 | \n", - "3.673558 | \n", - "8.776667 | \n", - "15.993510 | \n", - "6.922274 | \n", - "15.000052 | \n", - "2.000109 | \n", - "3.041926 | \n", - "91.403667 | \n", - "5.933314 | \n", - "5.175187 | \n", - "11.071921 | \n", - "7.143647 | \n", - "2.03807 | \n", - "0.140842 | \n", - "0.146257 | \n", - "
| min | \n", - "0.014000 | \n", - "1.000000 | \n", - "1.000000 | \n", - "1.000000 | \n", - "0.000000 | \n", - "0.000000 | \n", - "0.000000 | \n", - "0.270000 | \n", - "0.000000 | \n", - "-5.640000 | \n", - "-9.980000 | \n", - "975.740000 | \n", - "-8.880000 | \n", - "0.04000 | \n", - "0.230000 | \n", - "0.000000 | \n", - "
| 25% | \n", - "0.096000 | \n", - "3.000000 | \n", - "8.000000 | \n", - "13.000000 | \n", - "6.000000 | \n", - "0.000000 | \n", - "1.000000 | \n", - "10.220000 | \n", - "117.000000 | \n", - "6.260000 | \n", - "2.670000 | \n", - "1006.460000 | \n", - "3.590000 | \n", - "2.44000 | \n", - "0.700000 | \n", - "0.000000 | \n", - "
| 50% | \n", - "0.218000 | \n", - "7.000000 | \n", - "16.000000 | \n", - "28.000000 | \n", - "12.000000 | \n", - "0.000000 | \n", - "3.000000 | \n", - "12.260000 | \n", - "215.000000 | \n", - "9.870000 | \n", - "6.610000 | \n", - "1013.660000 | \n", - "8.990000 | \n", - "3.70000 | \n", - "0.810000 | \n", - "0.000000 | \n", - "
| 75% | \n", - "0.640000 | \n", - "10.000000 | \n", - "23.000000 | \n", - "42.000000 | \n", - "18.000000 | \n", - "30.000000 | \n", - "5.000000 | \n", - "13.080000 | \n", - "255.000000 | \n", - "14.440000 | \n", - "10.460000 | \n", - "1020.340000 | \n", - "14.440000 | \n", - "5.13000 | \n", - "0.890000 | \n", - "0.000000 | \n", - "
| max | \n", - "5.250000 | \n", - "12.000000 | \n", - "31.000000 | \n", - "52.000000 | \n", - "23.000000 | \n", - "30.000000 | \n", - "6.000000 | \n", - "16.090000 | \n", - "359.000000 | \n", - "32.400000 | \n", - "19.880000 | \n", - "1043.320000 | \n", - "32.420000 | \n", - "14.80000 | \n", - "1.000000 | \n", - "1.000000 | \n", - "
| \n", - " | energy(kWh/hh) | \n", - "month | \n", - "day | \n", - "week | \n", - "hour | \n", - "minute | \n", - "dayofweek | \n", - "visibility | \n", - "windBearing | \n", - "temperature | \n", - "dewPoint | \n", - "pressure | \n", - "apparentTemperature | \n", - "windSpeed | \n", - "humidity | \n", - "holiday | \n", - "stdorToU | \n", - "LCLid | \n", - "Acorn_grouped | \n", - "
|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|
| 2012-10-12 11:30:00 | \n", - "-0.554513 | \n", - "0.897591 | \n", - "-0.433397 | \n", - "0.859688 | \n", - "-0.072990 | \n", - "1.000035 | \n", - "0.501307 | \n", - "0.754931 | \n", - "0.797497 | \n", - "0.367701 | \n", - "-0.390302 | \n", - "-0.509246 | \n", - "0.485327 | \n", - "1.485734 | \n", - "-1.445630 | \n", - "-0.149527 | \n", - "0.0 | \n", - "0.0 | \n", - "1.0 | \n", - "
| 2012-10-12 12:00:00 | \n", - "0.299911 | \n", - "0.897591 | \n", - "-0.433397 | \n", - "0.859688 | \n", - "0.071471 | \n", - "-0.999965 | \n", - "0.501307 | \n", - "0.754931 | \n", - "0.720914 | \n", - "0.581747 | \n", - "-0.421219 | \n", - "-0.498408 | \n", - "0.663108 | \n", - "1.466108 | \n", - "-1.800639 | \n", - "-0.149527 | \n", - "0.0 | \n", - "0.0 | \n", - "1.0 | \n", - "
| 2012-10-12 12:30:00 | \n", - "-0.368840 | \n", - "0.897591 | \n", - "-0.433397 | \n", - "0.859688 | \n", - "0.071471 | \n", - "1.000035 | \n", - "0.501307 | \n", - "0.754931 | \n", - "0.720914 | \n", - "0.581747 | \n", - "-0.421219 | \n", - "-0.498408 | \n", - "0.663108 | \n", - "1.466108 | \n", - "-1.800639 | \n", - "-0.149527 | \n", - "0.0 | \n", - "0.0 | \n", - "1.0 | \n", - "
| 2012-10-12 13:00:00 | \n", - "-0.534795 | \n", - "0.897591 | \n", - "-0.433397 | \n", - "0.859688 | \n", - "0.215933 | \n", - "-0.999965 | \n", - "0.501307 | \n", - "0.754931 | \n", - "0.644330 | \n", - "0.603657 | \n", - "-0.390302 | \n", - "-0.483054 | \n", - "0.681306 | \n", - "1.505361 | \n", - "-1.800639 | \n", - "-0.149527 | \n", - "0.0 | \n", - "0.0 | \n", - "1.0 | \n", - "
| 2012-10-12 13:30:00 | \n", - "-0.462498 | \n", - "0.897591 | \n", - "-0.433397 | \n", - "0.859688 | \n", - "0.215933 | \n", - "1.000035 | \n", - "0.501307 | \n", - "0.754931 | \n", - "0.644330 | \n", - "0.603657 | \n", - "-0.390302 | \n", - "-0.483054 | \n", - "0.681306 | \n", - "1.505361 | \n", - "-1.800639 | \n", - "-0.149527 | \n", - "0.0 | \n", - "0.0 | \n", - "1.0 | \n", - "
| ... | \n", - "... | \n", - "... | \n", - "... | \n", - "... | \n", - "... | \n", - "... | \n", - "... | \n", - "... | \n", - "... | \n", - "... | \n", - "... | \n", - "... | \n", - "... | \n", - "... | \n", - "... | \n", - "... | \n", - "... | \n", - "... | \n", - "... | \n", - "
| 2014-02-27 22:00:00 | \n", - "-0.521650 | \n", - "-1.280142 | \n", - "1.275686 | \n", - "-1.141131 | \n", - "1.516088 | \n", - "-0.999965 | \n", - "0.001332 | \n", - "0.912726 | \n", - "0.239531 | \n", - "-1.063208 | \n", - "-0.939076 | \n", - "-0.651950 | \n", - "-1.079705 | \n", - "-0.447474 | \n", - "0.400416 | \n", - "-0.149527 | \n", - "0.0 | \n", - "5.0 | \n", - "1.0 | \n", - "
| 2014-02-27 22:30:00 | \n", - "-0.120728 | \n", - "-1.280142 | \n", - "1.275686 | \n", - "-1.141131 | \n", - "1.516088 | \n", - "1.000035 | \n", - "0.001332 | \n", - "0.912726 | \n", - "0.239531 | \n", - "-1.063208 | \n", - "-0.939076 | \n", - "-0.651950 | \n", - "-1.079705 | \n", - "-0.447474 | \n", - "0.400416 | \n", - "-0.149527 | \n", - "0.0 | \n", - "5.0 | \n", - "1.0 | \n", - "
| 2014-02-27 23:00:00 | \n", - "-0.653100 | \n", - "-1.280142 | \n", - "1.275686 | \n", - "-1.141131 | \n", - "1.660550 | \n", - "-0.999965 | \n", - "0.001332 | \n", - "0.922589 | \n", - "0.064483 | \n", - "-1.091860 | \n", - "-0.944873 | \n", - "-0.746785 | \n", - "-1.078305 | \n", - "-0.579953 | \n", - "0.471418 | \n", - "-0.149527 | \n", - "0.0 | \n", - "5.0 | \n", - "1.0 | \n", - "
| 2014-02-27 23:30:00 | \n", - "-0.692535 | \n", - "-1.280142 | \n", - "1.275686 | \n", - "-1.141131 | \n", - "1.660550 | \n", - "1.000035 | \n", - "0.001332 | \n", - "0.922589 | \n", - "0.064483 | \n", - "-1.091860 | \n", - "-0.944873 | \n", - "-0.746785 | \n", - "-1.078305 | \n", - "-0.579953 | \n", - "0.471418 | \n", - "-0.149527 | \n", - "0.0 | \n", - "5.0 | \n", - "1.0 | \n", - "
| 2014-02-28 00:00:00 | \n", - "-0.643242 | \n", - "-1.280142 | \n", - "1.389625 | \n", - "-1.141131 | \n", - "-1.662069 | \n", - "-0.999965 | \n", - "0.501307 | \n", - "0.462352 | \n", - "-0.044922 | \n", - "-1.112085 | \n", - "-0.960332 | \n", - "-0.841620 | \n", - "-1.071306 | \n", - "-0.687898 | \n", - "0.471418 | \n", - "-0.149527 | \n", - "0.0 | \n", - "5.0 | \n", - "1.0 | \n", - "
144861 rows × 19 columns
\n", - "| \n", - " | month | \n", - "day | \n", - "week | \n", - "hour | \n", - "minute | \n", - "dayofweek | \n", - "visibility | \n", - "windBearing | \n", - "temperature | \n", - "dewPoint | \n", - "pressure | \n", - "apparentTemperature | \n", - "windSpeed | \n", - "humidity | \n", - "holiday | \n", - "stdorToU | \n", - "LCLid | \n", - "Acorn_grouped | \n", - "tsp_days | \n", - "is_past | \n", - "
|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|
| 2014-02-23 17:00:00 | \n", - "-1.280142 | \n", - "0.819931 | \n", - "-1.203657 | \n", - "0.793780 | \n", - "-0.999965 | \n", - "1.501256 | \n", - "0.297982 | \n", - "0.108245 | \n", - "0.168823 | \n", - "-0.056013 | \n", - "-0.272611 | \n", - "0.320145 | \n", - "2.045089 | \n", - "-0.593609 | \n", - "-0.149527 | \n", - "0.0 | \n", - "0.0 | \n", - "1.0 | \n", - "-0.104167 | \n", - "1.0 | \n", - "
| 2014-02-23 17:30:00 | \n", - "-1.280142 | \n", - "0.819931 | \n", - "-1.203657 | \n", - "0.793780 | \n", - "1.000035 | \n", - "1.501256 | \n", - "0.297982 | \n", - "0.108245 | \n", - "0.168823 | \n", - "-0.056013 | \n", - "-0.272611 | \n", - "0.320145 | \n", - "2.045089 | \n", - "-0.593609 | \n", - "-0.149527 | \n", - "0.0 | \n", - "0.0 | \n", - "1.0 | \n", - "-0.083333 | \n", - "1.0 | \n", - "
| 2014-02-23 18:00:00 | \n", - "-1.280142 | \n", - "0.819931 | \n", - "-1.203657 | \n", - "0.938242 | \n", - "-0.999965 | \n", - "1.501256 | \n", - "0.297982 | \n", - "0.020721 | \n", - "0.098036 | \n", - "-0.096592 | \n", - "-0.240999 | \n", - "0.261351 | \n", - "1.902797 | \n", - "-0.522607 | \n", - "-0.149527 | \n", - "0.0 | \n", - "0.0 | \n", - "1.0 | \n", - "-0.062500 | \n", - "1.0 | \n", - "
| 2014-02-23 18:30:00 | \n", - "-1.280142 | \n", - "0.819931 | \n", - "-1.203657 | \n", - "0.938242 | \n", - "1.000035 | \n", - "1.501256 | \n", - "0.297982 | \n", - "0.020721 | \n", - "0.098036 | \n", - "-0.096592 | \n", - "-0.240999 | \n", - "0.261351 | \n", - "1.902797 | \n", - "-0.522607 | \n", - "-0.149527 | \n", - "0.0 | \n", - "0.0 | \n", - "1.0 | \n", - "-0.041667 | \n", - "1.0 | \n", - "
| 2014-02-23 19:00:00 | \n", - "-1.280142 | \n", - "0.819931 | \n", - "-1.203657 | \n", - "1.082703 | \n", - "-0.999965 | \n", - "1.501256 | \n", - "0.590561 | \n", - "-0.077744 | \n", - "0.059272 | \n", - "-0.092727 | \n", - "-0.216613 | \n", - "0.229155 | \n", - "1.201150 | \n", - "-0.451605 | \n", - "-0.149527 | \n", - "0.0 | \n", - "0.0 | \n", - "1.0 | \n", - "-0.020833 | \n", - "1.0 | \n", - "
| \n", - " | month | \n", - "day | \n", - "week | \n", - "hour | \n", - "minute | \n", - "dayofweek | \n", - "visibility | \n", - "windBearing | \n", - "temperature | \n", - "dewPoint | \n", - "pressure | \n", - "apparentTemperature | \n", - "windSpeed | \n", - "humidity | \n", - "holiday | \n", - "stdorToU | \n", - "LCLid | \n", - "Acorn_grouped | \n", - "tsp_days | \n", - "is_past | \n", - "
|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|
| 2014-02-27 17:00:00 | \n", - "-1.280142 | \n", - "1.275686 | \n", - "-1.141131 | \n", - "0.793780 | \n", - "-0.999965 | \n", - "0.001332 | \n", - "0.0 | \n", - "0.0 | \n", - "0.0 | \n", - "0.0 | \n", - "0.0 | \n", - "0.0 | \n", - "0.0 | \n", - "0.0 | \n", - "-0.149527 | \n", - "0.0 | \n", - "0.0 | \n", - "1.0 | \n", - "3.895833 | \n", - "0.0 | \n", - "
| 2014-02-27 17:30:00 | \n", - "-1.280142 | \n", - "1.275686 | \n", - "-1.141131 | \n", - "0.793780 | \n", - "1.000035 | \n", - "0.001332 | \n", - "0.0 | \n", - "0.0 | \n", - "0.0 | \n", - "0.0 | \n", - "0.0 | \n", - "0.0 | \n", - "0.0 | \n", - "0.0 | \n", - "-0.149527 | \n", - "0.0 | \n", - "0.0 | \n", - "1.0 | \n", - "3.916667 | \n", - "0.0 | \n", - "
| 2014-02-27 18:00:00 | \n", - "-1.280142 | \n", - "1.275686 | \n", - "-1.141131 | \n", - "0.938242 | \n", - "-0.999965 | \n", - "0.001332 | \n", - "0.0 | \n", - "0.0 | \n", - "0.0 | \n", - "0.0 | \n", - "0.0 | \n", - "0.0 | \n", - "0.0 | \n", - "0.0 | \n", - "-0.149527 | \n", - "0.0 | \n", - "0.0 | \n", - "1.0 | \n", - "3.937500 | \n", - "0.0 | \n", - "
| 2014-02-27 18:30:00 | \n", - "-1.280142 | \n", - "1.275686 | \n", - "-1.141131 | \n", - "0.938242 | \n", - "1.000035 | \n", - "0.001332 | \n", - "0.0 | \n", - "0.0 | \n", - "0.0 | \n", - "0.0 | \n", - "0.0 | \n", - "0.0 | \n", - "0.0 | \n", - "0.0 | \n", - "-0.149527 | \n", - "0.0 | \n", - "0.0 | \n", - "1.0 | \n", - "3.958333 | \n", - "0.0 | \n", - "
| 2014-02-27 19:00:00 | \n", - "-1.280142 | \n", - "1.275686 | \n", - "-1.141131 | \n", - "1.082703 | \n", - "-0.999965 | \n", - "0.001332 | \n", - "0.0 | \n", - "0.0 | \n", - "0.0 | \n", - "0.0 | \n", - "0.0 | \n", - "0.0 | \n", - "0.0 | \n", - "0.0 | \n", - "-0.149527 | \n", - "0.0 | \n", - "0.0 | \n", - "1.0 | \n", - "3.979167 | \n", - "0.0 | \n", - "
<xarray.Dataset>\n", - "Dimensions: (t_ahead: 192, t_behind: 192, t_source: 24943)\n", - "Coordinates:\n", - " * t_source (t_source) datetime64[ns] 2012-09-18T09:00:00 ... 2012-09-...\n", - " * t_ahead (t_ahead) timedelta64[ns] 00:00:00 ... 3 days 23:30:00\n", - " * t_behind (t_behind) timedelta64[ns] -4 days +00:00:00 ... -1 days +...\n", - " t_target (t_source, t_ahead) datetime64[ns] 2012-09-18T09:00:00 ......\n", - " t_past (t_source, t_behind) datetime64[ns] 2012-09-14T09:00:00 .....\n", - " t_ahead_hours (t_ahead) float64 0.0 0.0 1.0 1.0 2.0 ... 94.0 94.0 95.0 95.0\n", - "Data variables:\n", - " y_past (t_source, t_behind) float32 0.061000023 ... 0.12699997\n", - " nll (t_source, t_ahead) float32 0.44390392 ... 0.79414505\n", - " y_pred (t_source, t_ahead) float32 0.33915406 ... 0.37555674\n", - " y_pred_std (t_source, t_ahead) float64 0.3371 0.332 ... 0.3789 0.379\n", - " y_true (t_source, t_ahead) float32 0.17700002 ... 0.05900002
array(['2012-09-18T09:00:00.000000000', '2012-09-18T09:30:00.000000000',\n", - " '2012-09-18T10:00:00.000000000', ..., '2012-09-27T09:30:00.000000000',\n", - " '2012-09-27T10:00:00.000000000', '2012-09-27T10:30:00.000000000'],\n", - " dtype='datetime64[ns]')
array([ 0, 1800000000000, 3600000000000, 5400000000000,\n", - " 7200000000000, 9000000000000, 10800000000000, 12600000000000,\n", - " 14400000000000, 16200000000000, 18000000000000, 19800000000000,\n", - " 21600000000000, 23400000000000, 25200000000000, 27000000000000,\n", - " 28800000000000, 30600000000000, 32400000000000, 34200000000000,\n", - " 36000000000000, 37800000000000, 39600000000000, 41400000000000,\n", - " 43200000000000, 45000000000000, 46800000000000, 48600000000000,\n", - " 50400000000000, 52200000000000, 54000000000000, 55800000000000,\n", - " 57600000000000, 59400000000000, 61200000000000, 63000000000000,\n", - " 64800000000000, 66600000000000, 68400000000000, 70200000000000,\n", - " 72000000000000, 73800000000000, 75600000000000, 77400000000000,\n", - " 79200000000000, 81000000000000, 82800000000000, 84600000000000,\n", - " 86400000000000, 88200000000000, 90000000000000, 91800000000000,\n", - " 93600000000000, 95400000000000, 97200000000000, 99000000000000,\n", - " 100800000000000, 102600000000000, 104400000000000, 106200000000000,\n", - " 108000000000000, 109800000000000, 111600000000000, 113400000000000,\n", - " 115200000000000, 117000000000000, 118800000000000, 120600000000000,\n", - " 122400000000000, 124200000000000, 126000000000000, 127800000000000,\n", - " 129600000000000, 131400000000000, 133200000000000, 135000000000000,\n", - " 136800000000000, 138600000000000, 140400000000000, 142200000000000,\n", - " 144000000000000, 145800000000000, 147600000000000, 149400000000000,\n", - " 151200000000000, 153000000000000, 154800000000000, 156600000000000,\n", - " 158400000000000, 160200000000000, 162000000000000, 163800000000000,\n", - " 165600000000000, 167400000000000, 169200000000000, 171000000000000,\n", - " 172800000000000, 174600000000000, 176400000000000, 178200000000000,\n", - " 180000000000000, 181800000000000, 183600000000000, 185400000000000,\n", - " 187200000000000, 189000000000000, 190800000000000, 192600000000000,\n", - " 194400000000000, 196200000000000, 198000000000000, 199800000000000,\n", - " 201600000000000, 203400000000000, 205200000000000, 207000000000000,\n", - " 208800000000000, 210600000000000, 212400000000000, 214200000000000,\n", - " 216000000000000, 217800000000000, 219600000000000, 221400000000000,\n", - " 223200000000000, 225000000000000, 226800000000000, 228600000000000,\n", - " 230400000000000, 232200000000000, 234000000000000, 235800000000000,\n", - " 237600000000000, 239400000000000, 241200000000000, 243000000000000,\n", - " 244800000000000, 246600000000000, 248400000000000, 250200000000000,\n", - " 252000000000000, 253800000000000, 255600000000000, 257400000000000,\n", - " 259200000000000, 261000000000000, 262800000000000, 264600000000000,\n", - " 266400000000000, 268200000000000, 270000000000000, 271800000000000,\n", - " 273600000000000, 275400000000000, 277200000000000, 279000000000000,\n", - " 280800000000000, 282600000000000, 284400000000000, 286200000000000,\n", - " 288000000000000, 289800000000000, 291600000000000, 293400000000000,\n", - " 295200000000000, 297000000000000, 298800000000000, 300600000000000,\n", - " 302400000000000, 304200000000000, 306000000000000, 307800000000000,\n", - " 309600000000000, 311400000000000, 313200000000000, 315000000000000,\n", - " 316800000000000, 318600000000000, 320400000000000, 322200000000000,\n", - " 324000000000000, 325800000000000, 327600000000000, 329400000000000,\n", - " 331200000000000, 333000000000000, 334800000000000, 336600000000000,\n", - " 338400000000000, 340200000000000, 342000000000000, 343800000000000],\n", - " dtype='timedelta64[ns]')
array([-345600000000000, -343800000000000, -342000000000000, -340200000000000,\n", - " -338400000000000, -336600000000000, -334800000000000, -333000000000000,\n", - " -331200000000000, -329400000000000, -327600000000000, -325800000000000,\n", - " -324000000000000, -322200000000000, -320400000000000, -318600000000000,\n", - " -316800000000000, -315000000000000, -313200000000000, -311400000000000,\n", - " -309600000000000, -307800000000000, -306000000000000, -304200000000000,\n", - " -302400000000000, -300600000000000, -298800000000000, -297000000000000,\n", - " -295200000000000, -293400000000000, -291600000000000, -289800000000000,\n", - " -288000000000000, -286200000000000, -284400000000000, -282600000000000,\n", - " -280800000000000, -279000000000000, -277200000000000, -275400000000000,\n", - " -273600000000000, -271800000000000, -270000000000000, -268200000000000,\n", - " -266400000000000, -264600000000000, -262800000000000, -261000000000000,\n", - " -259200000000000, -257400000000000, -255600000000000, -253800000000000,\n", - " -252000000000000, -250200000000000, -248400000000000, -246600000000000,\n", - " -244800000000000, -243000000000000, -241200000000000, -239400000000000,\n", - " -237600000000000, -235800000000000, -234000000000000, -232200000000000,\n", - " -230400000000000, -228600000000000, -226800000000000, -225000000000000,\n", - " -223200000000000, -221400000000000, -219600000000000, -217800000000000,\n", - " -216000000000000, -214200000000000, -212400000000000, -210600000000000,\n", - " -208800000000000, -207000000000000, -205200000000000, -203400000000000,\n", - " -201600000000000, -199800000000000, -198000000000000, -196200000000000,\n", - " -194400000000000, -192600000000000, -190800000000000, -189000000000000,\n", - " -187200000000000, -185400000000000, -183600000000000, -181800000000000,\n", - " -180000000000000, -178200000000000, -176400000000000, -174600000000000,\n", - " -172800000000000, -171000000000000, -169200000000000, -167400000000000,\n", - " -165600000000000, -163800000000000, -162000000000000, -160200000000000,\n", - " -158400000000000, -156600000000000, -154800000000000, -153000000000000,\n", - " -151200000000000, -149400000000000, -147600000000000, -145800000000000,\n", - " -144000000000000, -142200000000000, -140400000000000, -138600000000000,\n", - " -136800000000000, -135000000000000, -133200000000000, -131400000000000,\n", - " -129600000000000, -127800000000000, -126000000000000, -124200000000000,\n", - " -122400000000000, -120600000000000, -118800000000000, -117000000000000,\n", - " -115200000000000, -113400000000000, -111600000000000, -109800000000000,\n", - " -108000000000000, -106200000000000, -104400000000000, -102600000000000,\n", - " -100800000000000, -99000000000000, -97200000000000, -95400000000000,\n", - " -93600000000000, -91800000000000, -90000000000000, -88200000000000,\n", - " -86400000000000, -84600000000000, -82800000000000, -81000000000000,\n", - " -79200000000000, -77400000000000, -75600000000000, -73800000000000,\n", - " -72000000000000, -70200000000000, -68400000000000, -66600000000000,\n", - " -64800000000000, -63000000000000, -61200000000000, -59400000000000,\n", - " -57600000000000, -55800000000000, -54000000000000, -52200000000000,\n", - " -50400000000000, -48600000000000, -46800000000000, -45000000000000,\n", - " -43200000000000, -41400000000000, -39600000000000, -37800000000000,\n", - " -36000000000000, -34200000000000, -32400000000000, -30600000000000,\n", - " -28800000000000, -27000000000000, -25200000000000, -23400000000000,\n", - " -21600000000000, -19800000000000, -18000000000000, -16200000000000,\n", - " -14400000000000, -12600000000000, -10800000000000, -9000000000000,\n", - " -7200000000000, -5400000000000, -3600000000000, -1800000000000],\n", - " dtype='timedelta64[ns]')
array([['2012-09-18T09:00:00.000000000', '2012-09-18T09:30:00.000000000',\n", - " '2012-09-18T10:00:00.000000000', ...,\n", - " '2012-09-22T07:30:00.000000000', '2012-09-22T08:00:00.000000000',\n", - " '2012-09-22T08:30:00.000000000'],\n", - " ['2012-09-18T09:30:00.000000000', '2012-09-18T10:00:00.000000000',\n", - " '2012-09-18T10:30:00.000000000', ...,\n", - " '2012-09-22T08:00:00.000000000', '2012-09-22T08:30:00.000000000',\n", - " '2012-09-22T09:00:00.000000000'],\n", - " ['2012-09-18T10:00:00.000000000', '2012-09-18T10:30:00.000000000',\n", - " '2012-09-18T11:00:00.000000000', ...,\n", - " '2012-09-22T08:30:00.000000000', '2012-09-22T09:00:00.000000000',\n", - " '2012-09-22T09:30:00.000000000'],\n", - " ...,\n", - " ['2012-09-27T09:30:00.000000000', '2012-09-27T10:00:00.000000000',\n", - " '2012-09-27T10:30:00.000000000', ...,\n", - " '2012-10-01T08:00:00.000000000', '2012-10-01T08:30:00.000000000',\n", - " '2012-10-01T09:00:00.000000000'],\n", - " ['2012-09-27T10:00:00.000000000', '2012-09-27T10:30:00.000000000',\n", - " '2012-09-27T11:00:00.000000000', ...,\n", - " '2012-10-01T08:30:00.000000000', '2012-10-01T09:00:00.000000000',\n", - " '2012-10-01T09:30:00.000000000'],\n", - " ['2012-09-27T10:30:00.000000000', '2012-09-27T11:00:00.000000000',\n", - " '2012-09-27T11:30:00.000000000', ...,\n", - " '2012-10-01T09:00:00.000000000', '2012-10-01T09:30:00.000000000',\n", - " '2012-10-01T10:00:00.000000000']], dtype='datetime64[ns]')
array([['2012-09-14T09:00:00.000000000', '2012-09-14T09:30:00.000000000',\n", - " '2012-09-14T10:00:00.000000000', ...,\n", - " '2012-09-18T07:30:00.000000000', '2012-09-18T08:00:00.000000000',\n", - " '2012-09-18T08:30:00.000000000'],\n", - " ['2012-09-14T09:30:00.000000000', '2012-09-14T10:00:00.000000000',\n", - " '2012-09-14T10:30:00.000000000', ...,\n", - " '2012-09-18T08:00:00.000000000', '2012-09-18T08:30:00.000000000',\n", - " '2012-09-18T09:00:00.000000000'],\n", - " ['2012-09-14T10:00:00.000000000', '2012-09-14T10:30:00.000000000',\n", - " '2012-09-14T11:00:00.000000000', ...,\n", - " '2012-09-18T08:30:00.000000000', '2012-09-18T09:00:00.000000000',\n", - " '2012-09-18T09:30:00.000000000'],\n", - " ...,\n", - " ['2012-09-23T09:30:00.000000000', '2012-09-23T10:00:00.000000000',\n", - " '2012-09-23T10:30:00.000000000', ...,\n", - " '2012-09-27T08:00:00.000000000', '2012-09-27T08:30:00.000000000',\n", - " '2012-09-27T09:00:00.000000000'],\n", - " ['2012-09-23T10:00:00.000000000', '2012-09-23T10:30:00.000000000',\n", - " '2012-09-23T11:00:00.000000000', ...,\n", - " '2012-09-27T08:30:00.000000000', '2012-09-27T09:00:00.000000000',\n", - " '2012-09-27T09:30:00.000000000'],\n", - " ['2012-09-23T10:30:00.000000000', '2012-09-23T11:00:00.000000000',\n", - " '2012-09-23T11:30:00.000000000', ...,\n", - " '2012-09-27T09:00:00.000000000', '2012-09-27T09:30:00.000000000',\n", - " '2012-09-27T10:00:00.000000000']], dtype='datetime64[ns]')
array([ 0., 0., 1., 1., 2., 2., 3., 3., 4., 4., 5., 5., 6.,\n", - " 6., 7., 7., 8., 8., 9., 9., 10., 10., 11., 11., 12., 12.,\n", - " 13., 13., 14., 14., 15., 15., 16., 16., 17., 17., 18., 18., 19.,\n", - " 19., 20., 20., 21., 21., 22., 22., 23., 23., 24., 24., 25., 25.,\n", - " 26., 26., 27., 27., 28., 28., 29., 29., 30., 30., 31., 31., 32.,\n", - " 32., 33., 33., 34., 34., 35., 35., 36., 36., 37., 37., 38., 38.,\n", - " 39., 39., 40., 40., 41., 41., 42., 42., 43., 43., 44., 44., 45.,\n", - " 45., 46., 46., 47., 47., 48., 48., 49., 49., 50., 50., 51., 51.,\n", - " 52., 52., 53., 53., 54., 54., 55., 55., 56., 56., 57., 57., 58.,\n", - " 58., 59., 59., 60., 60., 61., 61., 62., 62., 63., 63., 64., 64.,\n", - " 65., 65., 66., 66., 67., 67., 68., 68., 69., 69., 70., 70., 71.,\n", - " 71., 72., 72., 73., 73., 74., 74., 75., 75., 76., 76., 77., 77.,\n", - " 78., 78., 79., 79., 80., 80., 81., 81., 82., 82., 83., 83., 84.,\n", - " 84., 85., 85., 86., 86., 87., 87., 88., 88., 89., 89., 90., 90.,\n", - " 91., 91., 92., 92., 93., 93., 94., 94., 95., 95.])
array([[0.06100002, 0.04999999, 0.18399999, ..., 0.07999998, 0.083 ,\n", - " 0.10499999],\n", - " [0.04999999, 0.18399999, 0.049 , ..., 0.083 , 0.10499999,\n", - " 0.17700002],\n", - " [0.18399999, 0.049 , 0.07100001, ..., 0.10499999, 0.17700002,\n", - " 0.06799999],\n", - " ...,\n", - " [0.123 , 0.08500001, 0.042 , ..., 0.16300002, 0.141 ,\n", - " 0.14899999],\n", - " [0.08500001, 0.042 , 0.042 , ..., 0.141 , 0.14899999,\n", - " 0.387 ],\n", - " [0.042 , 0.042 , 0.06200001, ..., 0.14899999, 0.387 ,\n", - " 0.12699997]], dtype=float32)
array([[0.44390392, 0.56664157, 0.5809205 , ..., 0.41122854, 0.56711197,\n", - " 0.58973217],\n", - " [0.6743895 , 0.6417631 , 0.5153675 , ..., 0.56559354, 0.588212 ,\n", - " 0.6929046 ],\n", - " [0.7600697 , 0.5740746 , 0.5059802 , ..., 0.58670086, 0.6912286 ,\n", - " 0.66683286],\n", - " ...,\n", - " [0.43764848, 0.706943 , 0.6868087 , ..., 0.6452905 , 0.5894213 ,\n", - " 0.44566935],\n", - " [0.72706467, 0.6857369 , 0.4967181 , ..., 0.5890883 , 0.4456083 ,\n", - " 0.7410833 ],\n", - " [0.7192894 , 0.47428977, 0.37764883, ..., 0.44555452, 0.7405691 ,\n", - " 0.79414505]], dtype=float32)
array([[0.33915406, 0.3045206 , 0.28648722, ..., 0.30625796, 0.3149696 ,\n", - " 0.3200055 ],\n", - " [0.34753686, 0.31381938, 0.2943569 , ..., 0.31429982, 0.31933692,\n", - " 0.32741007],\n", - " [0.35634485, 0.323293 , 0.3040018 , ..., 0.3186717 , 0.3267355 ,\n", - " 0.33097428],\n", - " ...,\n", - " [0.4101906 , 0.394241 , 0.3390931 , ..., 0.36369714, 0.3686615 ,\n", - " 0.369526 ],\n", - " [0.411724 , 0.34518054, 0.28593865, ..., 0.36847636, 0.36931273,\n", - " 0.37449035],\n", - " [0.36108527, 0.28604382, 0.23601256, ..., 0.36910504, 0.37427023,\n", - " 0.37555674]], dtype=float32)
array([[0.33712136, 0.33197249, 0.33133024, ..., 0.3487567 , 0.3522541 ,\n", - " 0.35496513],\n", - " [0.33966592, 0.32810577, 0.33189367, ..., 0.35206261, 0.35478147,\n", - " 0.35801704],\n", - " [0.33683602, 0.3314646 , 0.33128671, ..., 0.35459882, 0.35782855,\n", - " 0.3600166 ],\n", - " ...,\n", - " [0.37538716, 0.38870546, 0.3608304 , ..., 0.37574382, 0.37732049,\n", - " 0.37726724],\n", - " [0.37861461, 0.34732474, 0.33039895, ..., 0.37728526, 0.37722251,\n", - " 0.37900413],\n", - " [0.34116173, 0.31832921, 0.30483951, ..., 0.37718101, 0.37894148,\n", - " 0.3790149 ]])
array([[0.17700002, 0.06799999, 0.04300002, ..., 0.197 , 0.095 ,\n", - " 0.09 ],\n", - " [0.06799999, 0.04300002, 0.083 , ..., 0.095 , 0.09 ,\n", - " 0.04799998],\n", - " [0.04300002, 0.083 , 0.097 , ..., 0.09 , 0.04799998,\n", - " 0.065 ],\n", - " ...,\n", - " [0.387 , 0.12699997, 0.06400001, ..., 0.12099999, 0.16300002,\n", - " 0.407 ],\n", - " [0.12699997, 0.06400001, 0.083 , ..., 0.16300002, 0.407 ,\n", - " 0.083 ],\n", - " [0.06400001, 0.083 , 0.06900001, ..., 0.407 , 0.083 ,\n", - " 0.05900002]], dtype=float32)
-#
-#
-
-#
-# - [ ] TODO mike autocorrelation baseline
-# - [x] TODO mike acorn data
-# - [ ] TODO mike handle multiple houses. Multiindex
-
-# OPTIONAL: Load the "autoreload" extension so that code can change. But blacklist large modules
-# %load_ext autoreload
-# %autoreload 2
-# %aimport -pandas
-# %aimport -torch
-# %aimport -numpy
-# %aimport -matplotlib
-# %aimport -dask
-# %aimport -tqdm
-# %matplotlib inline
-
-# +
-# Imports
-import torch
-from torch import nn, optim
-from torch.nn import functional as F
-from torch.autograd import Variable
-import torch
-import torch.utils.data
-
-import pandas as pd
-import numpy as np
-import matplotlib.pyplot as plt
-plt.rcParams['figure.figsize'] = (12.0, 4.0)
-plt.style.use('ggplot')
-
-from pathlib import Path
-from tqdm.auto import tqdm
-
-import pytorch_lightning as pl
-# -
-
-from seq2seq_time.data.dataset import Seq2SeqDataSet, Seq2SeqDataSets
-from seq2seq_time.predict import predict
-
-import logging, sys
-# logging.basicConfig(stream=sys.stdout, level=logging.INFO)
-
-# ## Parameters
-
-# +
-device = "cuda" if torch.cuda.is_available() else "cpu"
-print(f'using {device}')
-
-columns_target=['energy(kWh/hh)']
-window_past = 48*4
-window_future = 48*4
-batch_size = 64
-num_workers = 0
-freq = '30T'
-max_rows = 1e5
-
-
-# -
-
-# ## Load data
-
-# +
-
-def get_smartmeter_df(indir=Path('../data/raw/smart-meters-in-london'), max_files=1):
- """
- Data loading and cleanding is always messy, so understand this code is optional.
- """
-
- # Load csv files
- csv_files = sorted((indir/'halfhourly_dataset').glob('*.csv'))[:max_files]
-
- # concatendate them
- df = pd.concat([pd.read_csv(f, parse_dates=[1], na_values=['Null']) for f in csv_files])
-
- # Add ACORN categories
- df_households = pd.read_csv(indir/'informations_households.csv')
- df_households = df_households[['LCLid', 'stdorToU', 'Acorn_grouped']]
- df = pd.merge(df, df_households, on='LCLid')
-
- df = df.set_index('tstp')
-
- # Drop nan and 0's
- df = df[df['energy(kWh/hh)']!=0]
- df = df.dropna()
-
- # Add time features
- time = df.index.to_series()
- df["month"] = time.dt.month
- df['day'] = time.dt.day
- df['week'] = time.dt.week
- df['hour'] = time.dt.hour
- df['minute'] = time.dt.minute
- df['dayofweek'] = time.dt.dayofweek
-
- # Load weather data
- df_weather = pd.read_csv(indir/'weather_hourly_darksky.csv', parse_dates=[3])
- use_cols = ['visibility', 'windBearing', 'temperature', 'time', 'dewPoint',
- 'pressure', 'apparentTemperature', 'windSpeed',
- 'humidity']
- df_weather = df_weather[use_cols].set_index('time')
- df_weather = df_weather.resample(freq).first().ffill() # Resample to match energy data
-
- # Join weather and energy data
- df = pd.merge(df, df_weather, how='inner', left_index=True, right_index=True, sort=True)
-
- # Holidays
- df_hols = pd.read_csv(indir/'uk_bank_holidays.csv', parse_dates=[0])
- holidays = set(df_hols['Bank holidays'].dt.round('D'))
- def is_holiday(dt):
- return dt in holidays
- days = df.index.floor('D')
- holiday_mapping = days.unique().to_series().apply(is_holiday).astype(int).to_dict()
- df['holiday'] = days.to_series().map(holiday_mapping).values
-
- # Loop over houses
- for name, df_h in df.groupby('LCLid'):
-
- yield df_h
-
-
-# -
-# Our dataset is the london smartmeter data. But at half hour intervals
-
-# +
-dfs = get_smartmeter_df()
-
-# Just get the first one for now
-dfs = list(dfs)
-
-# df = df.resample(freq).first().dropna() # Where empty we will backfill, this will respect causality, and mostly maintain the mean
-
-# df = df.tail(int(max_rows)).copy() # Just use last X rows
-df = pd.concat(dfs[:6], 0)
-# df = dfs[0]
-# -
-
-df.LCLid.unique()
-
-
-
-df.describe()
-
-# +
-import sklearn
-from sklearn.preprocessing import StandardScaler, OrdinalEncoder
-from sklearn_pandas import DataFrameMapper
-
-columns_input_numeric = list(df.drop(columns=columns_target)._get_numeric_data().columns)
-columns_categorical = list(set(df.columns)-set(columns_input_numeric)-set(columns_target))
-
-output_scalers = [([n], StandardScaler()) for n in columns_target]
-transformers=output_scalers + \
-[([n], StandardScaler()) for n in columns_input_numeric] + \
-[([n], OrdinalEncoder()) for n in columns_categorical]
-scaler = DataFrameMapper(transformers, df_out=True)
-df_norm = scaler.fit_transform(df)
-df_norm
-# -
-
-output_scaler = next(filter(lambda r:r[0][0] in columns_target, scaler.features))[-1]
-output_scaler
-
-dfs_norm = [d.resample(freq).first().ffill().dropna() for _, d in df_norm.groupby('LCLid')]
-len(dfs_norm)
-
-# +
-# split data, with the test in the future
-n_split = -int(len(dfs_norm)*0.2)
-df_train = dfs_norm[:n_split]
-df_test = dfs_norm[n_split:]
-
-# Show split
-pd.concat(df_train)['energy(kWh/hh)'].plot(label='train')
-pd.concat(df_test)['energy(kWh/hh)'].plot(label='test')
-plt.ylabel('energy(kWh/hh)')
-plt.legend()
-# -
-
-
-
-# ### Dataset
-
-# These are the columns that we wont know in the future
-# We need to blank them out in x_future
-columns_blank=['visibility',
- 'windBearing', 'temperature', 'dewPoint', 'pressure',
- 'apparentTemperature', 'windSpeed', 'humidity']
-
-ds_train = Seq2SeqDataSets(df_train,
- window_past=window_past,
- window_future=window_future,
- columns_blank=columns_blank)
-ds_test = Seq2SeqDataSets(df_test,
- window_past=window_past,
- window_future=window_future,
- columns_blank=columns_blank)
-print(ds_train)
-print(ds_test)
-
-# we can treat it like an array
-ds_train[0]
-len(ds_train)
-ds_train[-1]
-
-# +
-# We can get rows
-x_past, y_past, x_future, y_future = ds_train.get_rows(10)
-
-# Plot one instance, this is what the model sees
-y_past['energy(kWh/hh)'].plot(label='past')
-y_future['energy(kWh/hh)'].plot(ax=plt.gca(), label='future')
-plt.legend()
-plt.ylabel('energy(kWh/hh)')
-
-# Notice we've added on two new columns tsp (time since present) and is_past
-x_past.tail()
-# -
-
-# Notice we've hidden some future columns to prevent cheating
-x_future.tail()
-
-
-# ## Model
-
-# +
-
-class Seq2SeqLSTMDecoder(nn.Module):
- def __init__(self, input_size, input_size_decoder, output_size, hidden_size=32, lstm_layers=2, lstm_dropout=0, _min_std = 0.05):
- super().__init__()
- self._min_std = _min_std
-
- self.encoder = nn.LSTM(
- input_size=input_size + output_size,
- hidden_size=hidden_size,
- batch_first=True,
- num_layers=lstm_layers,
- dropout=lstm_dropout,
- )
- self.decoder = nn.LSTM(
- input_size=input_size_decoder,
- hidden_size=hidden_size,
- batch_first=True,
- num_layers=lstm_layers,
- dropout=lstm_dropout,
- )
- self.mean = nn.Linear(hidden_size, output_size)
- self.std = nn.Linear(hidden_size, output_size)
-
- def forward(self, context_x, context_y, target_x, target_y=None):
- x = torch.cat([context_x, context_y], -1)
- _, (h_out, cell) = self.encoder(x)
-
- # output = [batch size, seq len, hid dim * n directions]
- outputs, (_, _) = self.decoder(target_x, (h_out, cell))
-
- # outputs: [B, T, num_direction * H]
- mean = self.mean(outputs)
- log_sigma = self.std(outputs)
- log_sigma = torch.clamp(log_sigma, np.log(self._min_std), -np.log(self._min_std))
-
- sigma = torch.exp(log_sigma)
- y_dist = torch.distributions.Normal(mean, sigma)
- return y_dist
-
-
-# -
-# ## Lightning
-
-# +
-import pytorch_lightning as pl
-
-class PL_Seq2Seq(pl.LightningModule):
- def __init__(self, **hparams):
- super().__init__()
- self._model = Seq2SeqLSTMDecoder(**hparams)
-
- def forward(self, x_past, y_past, x_future, y_future=None):
- """Eval/Predict"""
- y_dist = self._model(x_past, y_past, x_future)
- return y_dist
-
- def training_step(self, batch, batch_idx):
- x_past, y_past, x_future, y_future = batch
- y_dist = self.forward(*batch)
- loss = -y_dist.log_prob(y_future).mean()
- self.log_dict({'loss/train':loss})
- return loss
-
- def validation_step(self, batch, batch_idx):
- x_past, y_past, x_future, y_future = batch
- y_dist = self.forward(*batch)
- loss = -y_dist.log_prob(y_future).mean()
- self.log_dict({'loss/val':loss})
- return loss
-
- def configure_optimizers(self):
- return torch.optim.Adam(self.parameters(), lr=1e-4)
-
-
-# -
-
-from torch.utils.data import DataLoader, random_split
-from pytorch_lightning.loggers import CSVLogger
-from pl_bolts.callbacks import PrintTableMetricsCallback
-
-
-
-# +
-input_size = x_past.shape[-1]
-output_size = y_future.shape[-1]
-
-model = PL_Seq2Seq(input_size=input_size,
- input_size_decoder=input_size,
- output_size=output_size,
- hidden_size=32,
- lstm_layers=2,
- lstm_dropout=0.25).to(device)
-
-logger = CSVLogger("logs", name="seq2seq")
-trainer = pl.Trainer(gpus=1,
- logger=logger)
-dl_train = DataLoader(ds_train,
- batch_size=batch_size,
- shuffle=True,
- num_workers=8)
-dl_test = DataLoader(ds_test, batch_size=batch_size, num_workers=4)
-trainer.fit(model, dl_train, dl_test)
-# -
-df_hist = pd.read_csv(trainer.logger.experiment.metrics_file_path)
-df_hist['epoch'] = df_hist['epoch'].ffill()
-df_histe = df_hist.set_index('epoch').groupby('epoch').mean()
-df_histe[['loss/train', 'loss/val']].plot()
-df_histe
-
-# ## Predict
-#
-
-ds_preds = predict(model.to(device), ds_test.datasets[0], batch_size, device=device, scaler=output_scaler)
-ds_preds
-
-
-# +
-# TODO Metrics... smape etc
-
-# +
-def plot_prediction(ds_preds, i):
- """Plot a prediction into the future, at a single point in time."""
- d = ds_preds.isel(t_source=i)
-
- # Get arrays
- xf = d.t_target
- yp = d.y_pred
- s = d.y_pred_std
- yt = d.y_true
- now = d.t_source.squeeze()
-
-
- plt.figure(figsize=(12, 4))
-
- plt.scatter(xf, yt, label='true', c='k', s=6)
- ylim = plt.ylim()
-
- # plot prediction
- plt.fill_between(xf, yp-2*s, yp+2*s, alpha=0.25,
- facecolor="b",
- interpolate=True,
- label="2 std",)
- plt.plot(xf, yp, label='pred', c='b')
-
- # plot true
- plt.scatter(
- d.t_past,
- d.y_past,
- c='k',
- s=6
- )
-
- # plot a red line for now
- plt.vlines(x=now, ymin=0, ymax=1, label='now', color='r')
- plt.ylim(*ylim)
-
- now=pd.Timestamp(now.values)
- plt.title(f'Prediction NLL={d.nll.mean().item():2.2g}')
- plt.xlabel(f'{now.date()}')
- plt.ylabel('energy(kWh/hh)')
- plt.legend()
- plt.xticks(rotation=45)
- plt.show()
-
-# plot_prediction(ds_preds, 0)
-# plot_prediction(ds_preds, 12) # 6 hours later
-plot_prediction(ds_preds, 24) # 12 hours later
-plot_prediction(ds_preds, 48) # 12 hours later
-# -
-
-# ## Error vs time ahead
-
-
-
-# +
-ds_preds.mean('t_source').plot.scatter('t_ahead_hours', 'nll') # Mean over all predictions
-
-# Tidy the graph
-n = len(ds_preds.t_source)
-plt.ylabel('Negative Log Likelihood (lower is better)')
-plt.xlabel('Hours ahead')
-plt.title(f'NLL vs time (no. samples={n})')
-# -
-
-
-
-# Make a plot of the NLL over time. Does this solution get worse with time?
-d = ds_preds.mean('t_ahead').groupby('t_source').mean().plot.scatter('t_source', 'nll')
-plt.xticks(rotation=45)
-plt.title('NLL over time (lower is better)')
-1
-
-# A scatter plot is easy with xarray
-ds_preds.plot.scatter('y_true', 'y_pred', s=.01)
-
-
-
-
-
-
-
-