diff --git a/Makefile b/Makefile
index 2170eac..3aade1f 100644
--- a/Makefile
+++ b/Makefile
@@ -51,7 +51,7 @@ lint:
## Set up python interpreter environment
create_environment:
@echo ">>> Detected conda, creating conda environment."
- conda create --name $(PROJECT_NAME) python=3
+ conda create --name $(PROJECT_NAME) python=3.7
@echo ">>> New conda env created. Activate with:\nsource activate $(PROJECT_NAME)"
## Test python environment is setup correctly
diff --git a/notebooks/RNN_Timeseries_Seq2Seq.ipynb b/notebooks/RNN_Timeseries_Seq2Seq.ipynb
index bd15109..32754b1 100644
--- a/notebooks/RNN_Timeseries_Seq2Seq.ipynb
+++ b/notebooks/RNN_Timeseries_Seq2Seq.ipynb
@@ -18,15 +18,17 @@
"- with rough uncertainty (uncalibrated)\n",
"- outputing sequence of predictions\n",
"\n",
- "This gets you close to real world applications\n",
+ "\n",
+ "\n"
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "metadata": {},
+ "source": [
"\n",
- "Key learning objectives:\n",
- "- Understand the LSTM's maintain a memory along sequences\n",
- "- A rough overview of sequence to sequence approaches, and the data prep involved\n",
- "- Understand likelihood and outputting a distribution\n",
- "- Evalute a timeseries solution\n",
- "\n",
- "
"
+ "- [ ] TODO mike autocorrelation baseline\n",
+ "- [ ] TODO mike acorn data"
]
},
{
@@ -34,8 +36,31 @@
"execution_count": 1,
"metadata": {
"ExecuteTime": {
- "end_time": "2020-10-16T04:36:15.235547Z",
- "start_time": "2020-10-16T04:36:14.522810Z"
+ "end_time": "2020-10-18T03:12:41.037540Z",
+ "start_time": "2020-10-18T03:12:40.520045Z"
+ }
+ },
+ "outputs": [],
+ "source": [
+ "# OPTIONAL: Load the \"autoreload\" extension so that code can change. But blacklist large modules\n",
+ "%load_ext autoreload\n",
+ "%autoreload 2\n",
+ "%aimport -pandas\n",
+ "%aimport -torch\n",
+ "%aimport -numpy\n",
+ "%aimport -matplotlib\n",
+ "%aimport -dask\n",
+ "%aimport -tqdm\n",
+ "%matplotlib inline"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 2,
+ "metadata": {
+ "ExecuteTime": {
+ "end_time": "2020-10-18T03:12:42.247775Z",
+ "start_time": "2020-10-18T03:12:41.040058Z"
}
},
"outputs": [],
@@ -53,7 +78,39 @@
"import matplotlib.pyplot as plt\n",
"\n",
"from pathlib import Path\n",
- "from tqdm.auto import tqdm"
+ "from tqdm.auto import tqdm\n",
+ "\n",
+ "import pytorch_lightning as pl"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 3,
+ "metadata": {
+ "ExecuteTime": {
+ "end_time": "2020-10-18T03:12:42.334734Z",
+ "start_time": "2020-10-18T03:12:42.250766Z"
+ }
+ },
+ "outputs": [],
+ "source": [
+ "from seq2seq_time.data.dataset import Seq2SeqDataSet\n",
+ "from seq2seq_time.predict import predict"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 4,
+ "metadata": {
+ "ExecuteTime": {
+ "end_time": "2020-10-18T03:12:42.363653Z",
+ "start_time": "2020-10-18T03:12:42.336896Z"
+ }
+ },
+ "outputs": [],
+ "source": [
+ "import logging, sys\n",
+ "logging.basicConfig(stream=sys.stdout, level=logging.INFO)"
]
},
{
@@ -70,11 +127,11 @@
},
{
"cell_type": "code",
- "execution_count": 2,
+ "execution_count": 5,
"metadata": {
"ExecuteTime": {
- "end_time": "2020-10-16T04:36:15.249766Z",
- "start_time": "2020-10-16T04:36:15.237079Z"
+ "end_time": "2020-10-18T03:12:42.426253Z",
+ "start_time": "2020-10-18T03:12:42.366725Z"
}
},
"outputs": [
@@ -90,13 +147,13 @@
"device = \"cuda\" if torch.cuda.is_available() else \"cpu\"\n",
"print(f'using {device}')\n",
"\n",
+ "columns_target=['energy(kWh/hh)']\n",
"window_past = 48*4\n",
"window_future = 48*4\n",
"batch_size = 64\n",
"num_workers = 0\n",
"freq = '30T'\n",
- "max_rows = 1e5\n",
- "use_future = False # Cheat!"
+ "max_rows = 1e5"
]
},
{
@@ -108,18 +165,18 @@
},
{
"cell_type": "code",
- "execution_count": 3,
+ "execution_count": 159,
"metadata": {
"ExecuteTime": {
- "end_time": "2020-10-16T04:36:15.295873Z",
- "start_time": "2020-10-16T04:36:15.251348Z"
+ "end_time": "2020-10-18T05:10:01.335567Z",
+ "start_time": "2020-10-18T05:10:01.272903Z"
},
"lines_to_next_cell": 0
},
"outputs": [],
"source": [
"\n",
- "def get_smartmeter_df(indir=Path('../../data/processed/smartmeter')):\n",
+ "def get_smartmeter_df(indir=Path('../data/raw/smart-meters-in-london')):\n",
" \"\"\"\n",
" Data loading and cleanding is always messy, so understand this code is optional.\n",
" \"\"\"\n",
@@ -131,9 +188,16 @@
" \n",
" # concatendate them\n",
" df = pd.concat([pd.read_csv(f, parse_dates=[1], na_values=['Null']) for f in csv_files])\n",
+ " \n",
+ " # Add ACORN categories\n",
+ " df_households = pd.read_csv(indir/'informations_households.csv')\n",
+ " df_households = df_households[['LCLid', 'stdorToU', 'Acorn_grouped']]\n",
+ " df = pd.merge(df, df_households, on='LCLid')\n",
"\n",
" # Take the mean over all houses\n",
- " df = df.groupby('tstp').mean()\n",
+ " name, df = next(iter(df.groupby('LCLid')))\n",
+ " df = df.set_index('tstp')\n",
+ " print(df)\n",
"\n",
" # Load weather data\n",
" df_weather = pd.read_csv(indir/'weather_hourly_darksky.csv', parse_dates=[3])\n",
@@ -154,6 +218,8 @@
" def is_holiday(dt):\n",
" return dt.floor('D') in holidays\n",
" df['holiday'] = time.apply(is_holiday).astype(int)\n",
+ " \n",
+ " # TODO pd.read_csv('../data/raw/smart-meters-in-london/acorn_details.csv', engine='python')\n",
"\n",
"\n",
" # Add time features \n",
@@ -174,19 +240,6 @@
" return df"
]
},
- {
- "cell_type": "code",
- "execution_count": null,
- "metadata": {
- "ExecuteTime": {
- "end_time": "2020-10-11T01:19:12.463006Z",
- "start_time": "2020-10-11T01:19:12.458426Z"
- },
- "lines_to_next_cell": 2
- },
- "outputs": [],
- "source": []
- },
{
"cell_type": "markdown",
"metadata": {},
@@ -196,14 +249,43 @@
},
{
"cell_type": "code",
- "execution_count": 4,
+ "execution_count": 160,
"metadata": {
"ExecuteTime": {
- "end_time": "2020-10-16T04:36:20.888572Z",
- "start_time": "2020-10-16T04:36:15.297303Z"
- }
+ "end_time": "2020-10-18T05:10:07.567408Z",
+ "start_time": "2020-10-18T05:10:01.929712Z"
+ },
+ "scrolled": true
},
"outputs": [
+ {
+ "name": "stdout",
+ "output_type": "stream",
+ "text": [
+ " LCLid energy(kWh/hh) stdorToU Acorn_grouped\n",
+ "tstp \n",
+ "2012-10-12 00:30:00 MAC000002 0.000 Std Affluent\n",
+ "2012-10-12 01:00:00 MAC000002 0.000 Std Affluent\n",
+ "2012-10-12 01:30:00 MAC000002 0.000 Std Affluent\n",
+ "2012-10-12 02:00:00 MAC000002 0.000 Std Affluent\n",
+ "2012-10-12 02:30:00 MAC000002 0.000 Std Affluent\n",
+ "... ... ... ... ...\n",
+ "2014-02-27 22:00:00 MAC000002 0.416 Std Affluent\n",
+ "2014-02-27 22:30:00 MAC000002 1.350 Std Affluent\n",
+ "2014-02-27 23:00:00 MAC000002 1.247 Std Affluent\n",
+ "2014-02-27 23:30:00 MAC000002 1.218 Std Affluent\n",
+ "2014-02-28 00:00:00 MAC000002 1.387 Std Affluent\n",
+ "\n",
+ "[24141 rows x 4 columns]\n"
+ ]
+ },
+ {
+ "name": "stderr",
+ "output_type": "stream",
+ "text": [
+ "/home/wassname/anaconda/envs/seq2seq-time/lib/python3.7/site-packages/ipykernel_launcher.py:50: FutureWarning: Series.dt.weekofyear and Series.dt.week have been deprecated. Please use Series.dt.isocalendar().week instead.\n"
+ ]
+ },
{
"data": {
"text/html": [
@@ -225,7 +307,10 @@
" \n",
"
\n",
" \n",
" \n",
+ " LCLid \n",
" energy(kWh/hh) \n",
+ " stdorToU \n",
+ " Acorn_grouped \n",
" visibility \n",
" windBearing \n",
" temperature \n",
@@ -245,99 +330,114 @@
"
39247 rows × 16 columns
\n", + "24119 rows × 19 columns
\n", "" ], "text/plain": [ - " energy(kWh/hh) visibility windBearing temperature \\\n", - "2011-12-03 09:00:00 0.149000 13.07 262.0 11.00 \n", - "2011-12-03 09:30:00 0.154000 13.07 262.0 11.00 \n", - "2011-12-03 10:00:00 0.768000 12.76 268.0 11.42 \n", - "2011-12-03 10:30:00 1.179000 12.76 268.0 11.42 \n", - "2011-12-03 11:00:00 0.588000 13.07 274.0 11.41 \n", - "... ... ... ... ... \n", - "2014-02-27 22:00:00 0.678429 14.00 216.0 4.10 \n", - "2014-02-27 22:30:00 0.652071 14.00 216.0 4.10 \n", - "2014-02-27 23:00:00 0.545190 14.03 200.0 3.93 \n", - "2014-02-27 23:30:00 0.442571 14.03 200.0 3.93 \n", - "2014-02-28 00:00:00 0.328810 12.63 190.0 3.81 \n", + " LCLid energy(kWh/hh) stdorToU Acorn_grouped \\\n", + "2012-10-12 11:30:00 MAC000002 0.143 Std Affluent \n", + "2012-10-12 12:00:00 MAC000002 0.663 Std Affluent \n", + "2012-10-12 12:30:00 MAC000002 0.256 Std Affluent \n", + "2012-10-12 13:00:00 MAC000002 0.155 Std Affluent \n", + "2012-10-12 13:30:00 MAC000002 0.199 Std Affluent \n", + "... ... ... ... ... \n", + "2014-02-27 22:00:00 MAC000002 0.416 Std Affluent \n", + "2014-02-27 22:30:00 MAC000002 1.350 Std Affluent \n", + "2014-02-27 23:00:00 MAC000002 1.247 Std Affluent \n", + "2014-02-27 23:30:00 MAC000002 1.218 Std Affluent \n", + "2014-02-28 00:00:00 MAC000002 1.387 Std Affluent \n", "\n", - " dewPoint pressure apparentTemperature windSpeed \\\n", - "2011-12-03 09:00:00 8.84 1002.07 11.00 5.99 \n", - "2011-12-03 09:30:00 8.84 1002.07 11.00 5.99 \n", - "2011-12-03 10:00:00 7.52 1002.76 11.42 6.10 \n", - "2011-12-03 10:30:00 7.52 1002.76 11.42 6.10 \n", - "2011-12-03 11:00:00 6.39 1003.24 11.41 6.20 \n", - "... ... ... ... ... \n", - "2014-02-27 22:00:00 1.64 1005.67 1.41 3.02 \n", - "2014-02-27 22:30:00 1.64 1005.67 1.41 3.02 \n", - "2014-02-27 23:00:00 1.61 1004.62 1.42 2.75 \n", - "2014-02-27 23:30:00 1.61 1004.62 1.42 2.75 \n", - "2014-02-28 00:00:00 1.53 1003.57 1.47 2.53 \n", + " visibility windBearing temperature dewPoint pressure \\\n", + "2012-10-12 11:30:00 13.52 267.0 12.59 4.48 1007.25 \n", + "2012-10-12 12:00:00 13.52 260.0 13.86 4.32 1007.37 \n", + "2012-10-12 12:30:00 13.52 260.0 13.86 4.32 1007.37 \n", + "2012-10-12 13:00:00 13.52 253.0 13.99 4.48 1007.54 \n", + "2012-10-12 13:30:00 13.52 253.0 13.99 4.48 1007.54 \n", + "... ... ... ... ... ... \n", + "2014-02-27 22:00:00 14.00 216.0 4.10 1.64 1005.67 \n", + "2014-02-27 22:30:00 14.00 216.0 4.10 1.64 1005.67 \n", + "2014-02-27 23:00:00 14.03 200.0 3.93 1.61 1004.62 \n", + "2014-02-27 23:30:00 14.03 200.0 3.93 1.61 1004.62 \n", + "2014-02-28 00:00:00 12.63 190.0 3.81 1.53 1003.57 \n", "\n", - " humidity holiday month day week hour minute \\\n", - "2011-12-03 09:00:00 0.87 0 12 3 48 9 0 \n", - "2011-12-03 09:30:00 0.87 0 12 3 48 9 30 \n", - "2011-12-03 10:00:00 0.77 0 12 3 48 10 0 \n", - "2011-12-03 10:30:00 0.77 0 12 3 48 10 30 \n", - "2011-12-03 11:00:00 0.71 0 12 3 48 11 0 \n", - "... ... ... ... ... ... ... ... \n", - "2014-02-27 22:00:00 0.84 0 2 27 9 22 0 \n", - "2014-02-27 22:30:00 0.84 0 2 27 9 22 30 \n", - "2014-02-27 23:00:00 0.85 0 2 27 9 23 0 \n", - "2014-02-27 23:30:00 0.85 0 2 27 9 23 30 \n", - "2014-02-28 00:00:00 0.85 0 2 28 9 0 0 \n", + " apparentTemperature windSpeed humidity holiday month \\\n", + "2012-10-12 11:30:00 12.59 6.96 0.58 0 10 \n", + "2012-10-12 12:00:00 13.86 6.92 0.53 0 10 \n", + "2012-10-12 12:30:00 13.86 6.92 0.53 0 10 \n", + "2012-10-12 13:00:00 13.99 7.00 0.53 0 10 \n", + "2012-10-12 13:30:00 13.99 7.00 0.53 0 10 \n", + "... ... ... ... ... ... \n", + "2014-02-27 22:00:00 1.41 3.02 0.84 0 2 \n", + "2014-02-27 22:30:00 1.41 3.02 0.84 0 2 \n", + "2014-02-27 23:00:00 1.42 2.75 0.85 0 2 \n", + "2014-02-27 23:30:00 1.42 2.75 0.85 0 2 \n", + "2014-02-28 00:00:00 1.47 2.53 0.85 0 2 \n", "\n", - " dayofweek \n", - "2011-12-03 09:00:00 5 \n", - "2011-12-03 09:30:00 5 \n", - "2011-12-03 10:00:00 5 \n", - "2011-12-03 10:30:00 5 \n", - "2011-12-03 11:00:00 5 \n", - "... ... \n", - "2014-02-27 22:00:00 3 \n", - "2014-02-27 22:30:00 3 \n", - "2014-02-27 23:00:00 3 \n", - "2014-02-27 23:30:00 3 \n", - "2014-02-28 00:00:00 4 \n", + " day week hour minute dayofweek \n", + "2012-10-12 11:30:00 12 41 11 30 4 \n", + "2012-10-12 12:00:00 12 41 12 0 4 \n", + "2012-10-12 12:30:00 12 41 12 30 4 \n", + "2012-10-12 13:00:00 12 41 13 0 4 \n", + "2012-10-12 13:30:00 12 41 13 30 4 \n", + "... ... ... ... ... ... \n", + "2014-02-27 22:00:00 27 9 22 0 3 \n", + "2014-02-27 22:30:00 27 9 22 30 3 \n", + "2014-02-27 23:00:00 27 9 23 0 3 \n", + "2014-02-27 23:30:00 27 9 23 30 3 \n", + "2014-02-28 00:00:00 28 9 0 0 4 \n", "\n", - "[39247 rows x 16 columns]" + "[24119 rows x 19 columns]" ] }, - "execution_count": 4, + "execution_count": 160, "metadata": {}, "output_type": "execute_result" } @@ -522,19 +640,19 @@ "source": [ "df = get_smartmeter_df()\n", "\n", - "df = df.resample(freq).first().dropna() # Where empty we will backfill, this will respect causality, and mostly maintain the mean\n", + "# df = df.resample(freq).first().dropna() # Where empty we will backfill, this will respect causality, and mostly maintain the mean\n", "\n", - "df = df.tail(int(max_rows)) # Just use last X rows\n", + "df = df.tail(int(max_rows)).copy() # Just use last X rows\n", "df" ] }, { "cell_type": "code", - "execution_count": 5, + "execution_count": 161, "metadata": { "ExecuteTime": { - "end_time": "2020-10-16T04:36:20.962200Z", - "start_time": "2020-10-16T04:36:20.889920Z" + "end_time": "2020-10-18T05:10:07.699157Z", + "start_time": "2020-10-18T05:10:07.569250Z" } }, "outputs": [ @@ -580,68 +698,68 @@ " \n", "24119 rows × 19 columns
\n", "" ], "text/plain": [ - " energy(kWh/hh) visibility windBearing temperature dewPoint \\\n", - "count 3.924700e+04 3.924700e+04 3.924700e+04 3.924700e+04 3.924700e+04 \n", - "mean 1.216615e-16 -9.110126e-16 3.476042e-17 2.317361e-17 -1.738021e-17 \n", - "std 1.000013e+00 1.000013e+00 1.000013e+00 1.000013e+00 1.000013e+00 \n", - "min -1.946796e+00 -3.702976e+00 -2.173256e+00 -2.730140e+00 -3.206283e+00 \n", - "25% -7.137231e-01 -3.123495e-01 -7.750751e-01 -6.961120e-01 -7.429789e-01 \n", - "50% -5.893824e-02 3.429252e-01 2.377807e-01 -9.349948e-02 1.346874e-02 \n", - "75% 6.147169e-01 6.016747e-01 6.561341e-01 6.914159e-01 7.602183e-01 \n", - "max 6.995269e+00 1.613150e+00 1.779083e+00 3.690974e+00 2.585391e+00 \n", + " energy(kWh/hh) visibility windBearing temperature \\\n", + "2012-10-12 11:30:00 -0.444130 0.759528 0.800636 0.453439 \n", + "2012-10-12 12:00:00 1.660500 0.759528 0.724633 0.667233 \n", + "2012-10-12 12:30:00 0.013222 0.759528 0.724633 0.667233 \n", + "2012-10-12 13:00:00 -0.395562 0.759528 0.648631 0.689117 \n", + "2012-10-12 13:30:00 -0.217478 0.759528 0.648631 0.689117 \n", + "... ... ... ... ... \n", + "2014-02-27 22:00:00 0.660800 0.912217 0.246907 -0.975781 \n", + "2014-02-27 22:30:00 4.441040 0.912217 0.246907 -0.975781 \n", + "2014-02-27 23:00:00 4.024161 0.921760 0.073188 -1.004399 \n", + "2014-02-27 23:30:00 3.906788 0.921760 0.073188 -1.004399 \n", + "2014-02-28 00:00:00 4.590792 0.476418 -0.035387 -1.024600 \n", "\n", - " pressure apparentTemperature windSpeed humidity \\\n", - "count 3.924700e+04 3.924700e+04 3.924700e+04 3.924700e+04 \n", - "mean 5.086608e-15 -6.952083e-17 1.158681e-16 -8.110764e-17 \n", - "std 1.000013e+00 1.000013e+00 1.000013e+00 1.000013e+00 \n", - "min -3.338279e+00 -2.553322e+00 -1.904380e+00 -3.922565e+00 \n", - "25% -5.819161e-01 -7.782148e-01 -7.334959e-01 -5.691148e-01 \n", - "50% 4.476806e-02 6.995454e-02 -1.064117e-01 2.157353e-01 \n", - "75% 6.862901e-01 7.535537e-01 5.794616e-01 7.865354e-01 \n", - "max 2.560233e+00 3.255864e+00 5.326685e+00 1.571385e+00 \n", + " dewPoint pressure apparentTemperature windSpeed \\\n", + "2012-10-12 11:30:00 -0.305918 -0.458899 0.577361 1.406991 \n", + "2012-10-12 12:00:00 -0.336717 -0.448391 0.754344 1.387977 \n", + "2012-10-12 12:30:00 -0.336717 -0.448391 0.754344 1.387977 \n", + "2012-10-12 13:00:00 -0.305918 -0.433505 0.772460 1.426005 \n", + "2012-10-12 13:30:00 -0.305918 -0.433505 0.772460 1.426005 \n", + "... ... ... ... ... \n", + "2014-02-27 22:00:00 -0.852602 -0.597253 -0.980649 -0.465873 \n", + "2014-02-27 22:30:00 -0.852602 -0.597253 -0.980649 -0.465873 \n", + "2014-02-27 23:00:00 -0.858377 -0.689197 -0.979256 -0.594216 \n", + "2014-02-27 23:30:00 -0.858377 -0.689197 -0.979256 -0.594216 \n", + "2014-02-28 00:00:00 -0.873777 -0.781141 -0.972288 -0.698792 \n", "\n", - " holiday month day week hour \\\n", - "count 3.924700e+04 3.924700e+04 3.924700e+04 3.924700e+04 3.924700e+04 \n", - "mean -3.331207e-17 9.848785e-17 -1.303516e-16 5.793403e-17 -6.734831e-17 \n", - "std 1.000013e+00 1.000013e+00 1.000013e+00 1.000013e+00 1.000013e+00 \n", - "min -1.500332e-01 -1.457330e+00 -1.681079e+00 -1.549070e+00 -1.661862e+00 \n", - "25% -1.500332e-01 -9.128440e-01 -8.828511e-01 -9.237337e-01 -7.950405e-01 \n", - "50% -1.500332e-01 -9.611440e-02 2.940958e-02 -4.826216e-02 7.178048e-02 \n", - "75% -1.500332e-01 9.928584e-01 8.276377e-01 8.897430e-01 9.386015e-01 \n", - "max 6.665191e+00 1.537345e+00 1.739898e+00 1.640147e+00 1.660952e+00 \n", + " humidity holiday month day week \\\n", + "2012-10-12 11:30:00 -1.492695 -0.149604 0.857219 -0.445950 0.822749 \n", + "2012-10-12 12:00:00 -1.854084 -0.149604 0.857219 -0.445950 0.822749 \n", + "2012-10-12 12:30:00 -1.854084 -0.149604 0.857219 -0.445950 0.822749 \n", + "2012-10-12 13:00:00 -1.854084 -0.149604 0.857219 -0.445950 0.822749 \n", + "2012-10-12 13:30:00 -1.854084 -0.149604 0.857219 -0.445950 0.822749 \n", + "... ... ... ... ... ... \n", + "2014-02-27 22:00:00 0.386532 -0.149604 -1.203661 1.269408 -1.069273 \n", + "2014-02-27 22:30:00 0.386532 -0.149604 -1.203661 1.269408 -1.069273 \n", + "2014-02-27 23:00:00 0.458810 -0.149604 -1.203661 1.269408 -1.069273 \n", + "2014-02-27 23:30:00 0.458810 -0.149604 -1.203661 1.269408 -1.069273 \n", + "2014-02-28 00:00:00 0.458810 -0.149604 -1.203661 1.383766 -1.069273 \n", "\n", - " minute dayofweek \n", - "count 3.924700e+04 3.924700e+04 \n", - "mean -2.648162e-16 3.584668e-17 \n", - "std 1.000013e+00 1.000013e+00 \n", - "min -9.999745e-01 -1.498271e+00 \n", - "25% -9.999745e-01 -9.984948e-01 \n", - "50% -9.999745e-01 1.056932e-03 \n", - "75% 1.000025e+00 1.000609e+00 \n", - "max 1.000025e+00 1.500385e+00 " + " hour minute dayofweek Acorn_grouped stdorToU \\\n", + "2012-10-12 11:30:00 -0.072904 1.000041 0.499755 0.0 0.0 \n", + "2012-10-12 12:00:00 0.071557 -0.999959 0.499755 0.0 0.0 \n", + "2012-10-12 12:30:00 0.071557 1.000041 0.499755 0.0 0.0 \n", + "2012-10-12 13:00:00 0.216018 -0.999959 0.499755 0.0 0.0 \n", + "2012-10-12 13:30:00 0.216018 1.000041 0.499755 0.0 0.0 \n", + "... ... ... ... ... ... \n", + "2014-02-27 22:00:00 1.516170 -0.999959 0.000435 0.0 0.0 \n", + "2014-02-27 22:30:00 1.516170 1.000041 0.000435 0.0 0.0 \n", + "2014-02-27 23:00:00 1.660631 -0.999959 0.000435 0.0 0.0 \n", + "2014-02-27 23:30:00 1.660631 1.000041 0.000435 0.0 0.0 \n", + "2014-02-28 00:00:00 -1.661979 -0.999959 0.499755 0.0 0.0 \n", + "\n", + " LCLid \n", + "2012-10-12 11:30:00 0.0 \n", + "2012-10-12 12:00:00 0.0 \n", + "2012-10-12 12:30:00 0.0 \n", + "2012-10-12 13:00:00 0.0 \n", + "2012-10-12 13:30:00 0.0 \n", + "... ... \n", + "2014-02-27 22:00:00 0.0 \n", + "2014-02-27 22:30:00 0.0 \n", + "2014-02-27 23:00:00 0.0 \n", + "2014-02-27 23:30:00 0.0 \n", + "2014-02-28 00:00:00 0.0 \n", + "\n", + "[24119 rows x 19 columns]" ] }, - "execution_count": 6, + "execution_count": 162, "metadata": {}, "output_type": "execute_result" } ], "source": [ - "# Normalise\n", - "from sklearn.preprocessing import StandardScaler\n", - "input_columns = df.columns[1:]\n", - "out_columns = df.columns[:1]\n", - "scaler_input = StandardScaler()\n", - "scaler_output = StandardScaler()\n", - "df_norm = df.copy()\n", - "df_norm[input_columns] = scaler_input.fit_transform(df[input_columns])\n", - "df_norm[out_columns] = scaler_output.fit_transform(df[out_columns])\n", - "df_norm.describe()" + "import sklearn\n", + "from sklearn.preprocessing import StandardScaler, OrdinalEncoder\n", + "from sklearn_pandas import DataFrameMapper\n", + "\n", + "columns_input_numeric = list(df.drop(columns=columns_target)._get_numeric_data().columns)\n", + "columns_categorical = list(set(df.columns)-set(columns_input_numeric)-set(columns_target))\n", + "\n", + "output_scalers = [([n], StandardScaler()) for n in columns_target]\n", + "transformers=output_scalers + \\\n", + "[([n], StandardScaler()) for n in columns_input_numeric] + \\\n", + "[([n], OrdinalEncoder()) for n in columns_categorical]\n", + "scaler = DataFrameMapper(transformers, df_out=True)\n", + "df_norm = scaler.fit_transform(df)\n", + "df_norm" ] }, { "cell_type": "code", - "execution_count": 7, + "execution_count": 163, "metadata": { "ExecuteTime": { - "end_time": "2020-10-16T04:36:21.414354Z", - "start_time": "2020-10-16T04:36:21.411122Z" + "end_time": "2020-10-18T05:10:07.883594Z", + "start_time": "2020-10-18T05:10:07.834002Z" + } + }, + "outputs": [ + { + "data": { + "text/plain": [ + "StandardScaler()" + ] + }, + "execution_count": 163, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "output_scaler = next(filter(lambda r:r[0][0] in columns_target, mapper4.features))[-1]\n", + "output_scaler" + ] + }, + { + "cell_type": "code", + "execution_count": 164, + "metadata": { + "ExecuteTime": { + "end_time": "2020-10-18T05:10:07.940549Z", + "start_time": "2020-10-18T05:10:07.885654Z" } }, "outputs": [], "source": [ - "# Resample\n", - "df_norm = df_norm.fillna(0)" + "# # Resample\n", + "df_norm = df_norm.resample(freq).first().fillna(0)" ] }, { "cell_type": "code", - "execution_count": null, + "execution_count": 165, "metadata": { "ExecuteTime": { - "end_time": "2020-10-10T05:04:20.745565Z", - "start_time": "2020-10-10T05:04:20.741168Z" - } - }, - "outputs": [], - "source": [] - }, - { - "cell_type": "code", - "execution_count": 8, - "metadata": { - "ExecuteTime": { - "end_time": "2020-10-16T04:36:23.402331Z", - "start_time": "2020-10-16T04:36:21.416369Z" + "end_time": "2020-10-18T05:10:11.052548Z", + "start_time": "2020-10-18T05:10:07.942532Z" }, "lines_to_next_cell": 0 }, @@ -1112,16 +1349,16 @@ { "data": { "text/plain": [ - "| \n", + " | energy(kWh/hh) | \n", + "visibility | \n", + "windBearing | \n", + "temperature | \n", + "dewPoint | \n", + "pressure | \n", + "apparentTemperature | \n", + "windSpeed | \n", + "humidity | \n", + "holiday | \n", + "month | \n", + "day | \n", + "week | \n", + "hour | \n", + "minute | \n", + "dayofweek | \n", + "Acorn_grouped | \n", + "stdorToU | \n", + "LCLid | \n", + "
|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|
| 2012-10-12 11:30:00 | \n", + "-0.444130 | \n", + "0.759528 | \n", + "0.800636 | \n", + "0.453439 | \n", + "-0.305918 | \n", + "-0.458899 | \n", + "0.577361 | \n", + "1.406991 | \n", + "-1.492695 | \n", + "-0.149604 | \n", + "0.857219 | \n", + "-0.445950 | \n", + "0.822749 | \n", + "-0.072904 | \n", + "1.000041 | \n", + "0.499755 | \n", + "0.0 | \n", + "0.0 | \n", + "0.0 | \n", + "
| 2012-10-12 12:00:00 | \n", + "1.660500 | \n", + "0.759528 | \n", + "0.724633 | \n", + "0.667233 | \n", + "-0.336717 | \n", + "-0.448391 | \n", + "0.754344 | \n", + "1.387977 | \n", + "-1.854084 | \n", + "-0.149604 | \n", + "0.857219 | \n", + "-0.445950 | \n", + "0.822749 | \n", + "0.071557 | \n", + "-0.999959 | \n", + "0.499755 | \n", + "0.0 | \n", + "0.0 | \n", + "0.0 | \n", + "
| 2012-10-12 12:30:00 | \n", + "0.013222 | \n", + "0.759528 | \n", + "0.724633 | \n", + "0.667233 | \n", + "-0.336717 | \n", + "-0.448391 | \n", + "0.754344 | \n", + "1.387977 | \n", + "-1.854084 | \n", + "-0.149604 | \n", + "0.857219 | \n", + "-0.445950 | \n", + "0.822749 | \n", + "0.071557 | \n", + "1.000041 | \n", + "0.499755 | \n", + "0.0 | \n", + "0.0 | \n", + "0.0 | \n", + "
| 2012-10-12 13:00:00 | \n", + "-0.395562 | \n", + "0.759528 | \n", + "0.648631 | \n", + "0.689117 | \n", + "-0.305918 | \n", + "-0.433505 | \n", + "0.772460 | \n", + "1.426005 | \n", + "-1.854084 | \n", + "-0.149604 | \n", + "0.857219 | \n", + "-0.445950 | \n", + "0.822749 | \n", + "0.216018 | \n", + "-0.999959 | \n", + "0.499755 | \n", + "0.0 | \n", + "0.0 | \n", + "0.0 | \n", + "
| 2012-10-12 13:30:00 | \n", + "-0.217478 | \n", + "0.759528 | \n", + "0.648631 | \n", + "0.689117 | \n", + "-0.305918 | \n", + "-0.433505 | \n", + "0.772460 | \n", + "1.426005 | \n", + "-1.854084 | \n", + "-0.149604 | \n", + "0.857219 | \n", + "-0.445950 | \n", + "0.822749 | \n", + "0.216018 | \n", + "1.000041 | \n", + "0.499755 | \n", + "0.0 | \n", + "0.0 | \n", + "0.0 | \n", + "
| ... | \n", + "... | \n", + "... | \n", + "... | \n", + "... | \n", + "... | \n", + "... | \n", + "... | \n", + "... | \n", + "... | \n", + "... | \n", + "... | \n", + "... | \n", + "... | \n", + "... | \n", + "... | \n", + "... | \n", + "... | \n", + "... | \n", + "... | \n", + "
| 2014-02-27 22:00:00 | \n", + "0.660800 | \n", + "0.912217 | \n", + "0.246907 | \n", + "-0.975781 | \n", + "-0.852602 | \n", + "-0.597253 | \n", + "-0.980649 | \n", + "-0.465873 | \n", + "0.386532 | \n", + "-0.149604 | \n", + "-1.203661 | \n", + "1.269408 | \n", + "-1.069273 | \n", + "1.516170 | \n", + "-0.999959 | \n", + "0.000435 | \n", + "0.0 | \n", + "0.0 | \n", + "0.0 | \n", + "
| 2014-02-27 22:30:00 | \n", + "4.441040 | \n", + "0.912217 | \n", + "0.246907 | \n", + "-0.975781 | \n", + "-0.852602 | \n", + "-0.597253 | \n", + "-0.980649 | \n", + "-0.465873 | \n", + "0.386532 | \n", + "-0.149604 | \n", + "-1.203661 | \n", + "1.269408 | \n", + "-1.069273 | \n", + "1.516170 | \n", + "1.000041 | \n", + "0.000435 | \n", + "0.0 | \n", + "0.0 | \n", + "0.0 | \n", + "
| 2014-02-27 23:00:00 | \n", + "4.024161 | \n", + "0.921760 | \n", + "0.073188 | \n", + "-1.004399 | \n", + "-0.858377 | \n", + "-0.689197 | \n", + "-0.979256 | \n", + "-0.594216 | \n", + "0.458810 | \n", + "-0.149604 | \n", + "-1.203661 | \n", + "1.269408 | \n", + "-1.069273 | \n", + "1.660631 | \n", + "-0.999959 | \n", + "0.000435 | \n", + "0.0 | \n", + "0.0 | \n", + "0.0 | \n", + "
| 2014-02-27 23:30:00 | \n", + "3.906788 | \n", + "0.921760 | \n", + "0.073188 | \n", + "-1.004399 | \n", + "-0.858377 | \n", + "-0.689197 | \n", + "-0.979256 | \n", + "-0.594216 | \n", + "0.458810 | \n", + "-0.149604 | \n", + "-1.203661 | \n", + "1.269408 | \n", + "-1.069273 | \n", + "1.660631 | \n", + "1.000041 | \n", + "0.000435 | \n", + "0.0 | \n", + "0.0 | \n", + "0.0 | \n", + "
| 2014-02-28 00:00:00 | \n", + "4.590792 | \n", + "0.476418 | \n", + "-0.035387 | \n", + "-1.024600 | \n", + "-0.873777 | \n", + "-0.781141 | \n", + "-0.972288 | \n", + "-0.698792 | \n", + "0.458810 | \n", + "-0.149604 | \n", + "-1.203661 | \n", + "1.383766 | \n", + "-1.069273 | \n", + "-1.661979 | \n", + "-0.999959 | \n", + "0.499755 | \n", + "0.0 | \n", + "0.0 | \n", + "0.0 | \n", + "
24170 rows × 19 columns
\n", + "<xarray.Dataset>\n", - "Dimensions: (t_ahead: 192, t_behind: 192, t_source: 7465)\n", - "Coordinates:\n", - " * t_source (t_source) datetime64[ns] 2013-09-17T12:00:00 ... 2013-09-...\n", - " * t_ahead (t_ahead) timedelta64[ns] 00:00:00 ... 3 days 23:30:00\n", - " * t_behind (t_behind) timedelta64[ns] -4 days +00:00:00 ... -1 days +...\n", - " t_target (t_source, t_ahead) datetime64[ns] 2013-09-17T12:00:00 ......\n", - " t_past (t_source, t_behind) datetime64[ns] 2013-09-13T12:00:00 .....\n", - " t_ahead_hours (t_ahead) float64 0.0 0.0 1.0 1.0 2.0 ... 94.0 94.0 95.0 95.0\n", - "Data variables:\n", - " y_past (t_source, t_behind) float32 0.34182978 ... 0.42559525\n", - " nll (t_source, t_ahead) float32 -0.39930427 ... -0.5875193\n", - " y_pred (t_source, t_ahead) float32 0.4066002 ... 0.43110877\n", - " y_pred_std (t_source, t_ahead) float64 0.0581 0.05431 ... 0.04713\n", - " y_true (t_source, t_ahead) float32 0.39902174 ... 0.44257143
array(['2013-09-17T12:00:00.000000000', '2013-09-17T12:30:00.000000000',\n", - " '2013-09-17T13:00:00.000000000', ..., '2013-09-21T08:30:00.000000000',\n", - " '2013-09-21T09:00:00.000000000', '2013-09-21T09:30:00.000000000'],\n", - " dtype='datetime64[ns]')
array([ 0, 1800000000000, 3600000000000, 5400000000000,\n", - " 7200000000000, 9000000000000, 10800000000000, 12600000000000,\n", - " 14400000000000, 16200000000000, 18000000000000, 19800000000000,\n", - " 21600000000000, 23400000000000, 25200000000000, 27000000000000,\n", - " 28800000000000, 30600000000000, 32400000000000, 34200000000000,\n", - " 36000000000000, 37800000000000, 39600000000000, 41400000000000,\n", - " 43200000000000, 45000000000000, 46800000000000, 48600000000000,\n", - " 50400000000000, 52200000000000, 54000000000000, 55800000000000,\n", - " 57600000000000, 59400000000000, 61200000000000, 63000000000000,\n", - " 64800000000000, 66600000000000, 68400000000000, 70200000000000,\n", - " 72000000000000, 73800000000000, 75600000000000, 77400000000000,\n", - " 79200000000000, 81000000000000, 82800000000000, 84600000000000,\n", - " 86400000000000, 88200000000000, 90000000000000, 91800000000000,\n", - " 93600000000000, 95400000000000, 97200000000000, 99000000000000,\n", - " 100800000000000, 102600000000000, 104400000000000, 106200000000000,\n", - " 108000000000000, 109800000000000, 111600000000000, 113400000000000,\n", - " 115200000000000, 117000000000000, 118800000000000, 120600000000000,\n", - " 122400000000000, 124200000000000, 126000000000000, 127800000000000,\n", - " 129600000000000, 131400000000000, 133200000000000, 135000000000000,\n", - " 136800000000000, 138600000000000, 140400000000000, 142200000000000,\n", - " 144000000000000, 145800000000000, 147600000000000, 149400000000000,\n", - " 151200000000000, 153000000000000, 154800000000000, 156600000000000,\n", - " 158400000000000, 160200000000000, 162000000000000, 163800000000000,\n", - " 165600000000000, 167400000000000, 169200000000000, 171000000000000,\n", - " 172800000000000, 174600000000000, 176400000000000, 178200000000000,\n", - " 180000000000000, 181800000000000, 183600000000000, 185400000000000,\n", - " 187200000000000, 189000000000000, 190800000000000, 192600000000000,\n", - " 194400000000000, 196200000000000, 198000000000000, 199800000000000,\n", - " 201600000000000, 203400000000000, 205200000000000, 207000000000000,\n", - " 208800000000000, 210600000000000, 212400000000000, 214200000000000,\n", - " 216000000000000, 217800000000000, 219600000000000, 221400000000000,\n", - " 223200000000000, 225000000000000, 226800000000000, 228600000000000,\n", - " 230400000000000, 232200000000000, 234000000000000, 235800000000000,\n", - " 237600000000000, 239400000000000, 241200000000000, 243000000000000,\n", - " 244800000000000, 246600000000000, 248400000000000, 250200000000000,\n", - " 252000000000000, 253800000000000, 255600000000000, 257400000000000,\n", - " 259200000000000, 261000000000000, 262800000000000, 264600000000000,\n", - " 266400000000000, 268200000000000, 270000000000000, 271800000000000,\n", - " 273600000000000, 275400000000000, 277200000000000, 279000000000000,\n", - " 280800000000000, 282600000000000, 284400000000000, 286200000000000,\n", - " 288000000000000, 289800000000000, 291600000000000, 293400000000000,\n", - " 295200000000000, 297000000000000, 298800000000000, 300600000000000,\n", - " 302400000000000, 304200000000000, 306000000000000, 307800000000000,\n", - " 309600000000000, 311400000000000, 313200000000000, 315000000000000,\n", - " 316800000000000, 318600000000000, 320400000000000, 322200000000000,\n", - " 324000000000000, 325800000000000, 327600000000000, 329400000000000,\n", - " 331200000000000, 333000000000000, 334800000000000, 336600000000000,\n", - " 338400000000000, 340200000000000, 342000000000000, 343800000000000],\n", - " dtype='timedelta64[ns]')
array([-345600000000000, -343800000000000, -342000000000000, -340200000000000,\n", - " -338400000000000, -336600000000000, -334800000000000, -333000000000000,\n", - " -331200000000000, -329400000000000, -327600000000000, -325800000000000,\n", - " -324000000000000, -322200000000000, -320400000000000, -318600000000000,\n", - " -316800000000000, -315000000000000, -313200000000000, -311400000000000,\n", - " -309600000000000, -307800000000000, -306000000000000, -304200000000000,\n", - " -302400000000000, -300600000000000, -298800000000000, -297000000000000,\n", - " -295200000000000, -293400000000000, -291600000000000, -289800000000000,\n", - " -288000000000000, -286200000000000, -284400000000000, -282600000000000,\n", - " -280800000000000, -279000000000000, -277200000000000, -275400000000000,\n", - " -273600000000000, -271800000000000, -270000000000000, -268200000000000,\n", - " -266400000000000, -264600000000000, -262800000000000, -261000000000000,\n", - " -259200000000000, -257400000000000, -255600000000000, -253800000000000,\n", - " -252000000000000, -250200000000000, -248400000000000, -246600000000000,\n", - " -244800000000000, -243000000000000, -241200000000000, -239400000000000,\n", - " -237600000000000, -235800000000000, -234000000000000, -232200000000000,\n", - " -230400000000000, -228600000000000, -226800000000000, -225000000000000,\n", - " -223200000000000, -221400000000000, -219600000000000, -217800000000000,\n", - " -216000000000000, -214200000000000, -212400000000000, -210600000000000,\n", - " -208800000000000, -207000000000000, -205200000000000, -203400000000000,\n", - " -201600000000000, -199800000000000, -198000000000000, -196200000000000,\n", - " -194400000000000, -192600000000000, -190800000000000, -189000000000000,\n", - " -187200000000000, -185400000000000, -183600000000000, -181800000000000,\n", - " -180000000000000, -178200000000000, -176400000000000, -174600000000000,\n", - " -172800000000000, -171000000000000, -169200000000000, -167400000000000,\n", - " -165600000000000, -163800000000000, -162000000000000, -160200000000000,\n", - " -158400000000000, -156600000000000, -154800000000000, -153000000000000,\n", - " -151200000000000, -149400000000000, -147600000000000, -145800000000000,\n", - " -144000000000000, -142200000000000, -140400000000000, -138600000000000,\n", - " -136800000000000, -135000000000000, -133200000000000, -131400000000000,\n", - " -129600000000000, -127800000000000, -126000000000000, -124200000000000,\n", - " -122400000000000, -120600000000000, -118800000000000, -117000000000000,\n", - " -115200000000000, -113400000000000, -111600000000000, -109800000000000,\n", - " -108000000000000, -106200000000000, -104400000000000, -102600000000000,\n", - " -100800000000000, -99000000000000, -97200000000000, -95400000000000,\n", - " -93600000000000, -91800000000000, -90000000000000, -88200000000000,\n", - " -86400000000000, -84600000000000, -82800000000000, -81000000000000,\n", - " -79200000000000, -77400000000000, -75600000000000, -73800000000000,\n", - " -72000000000000, -70200000000000, -68400000000000, -66600000000000,\n", - " -64800000000000, -63000000000000, -61200000000000, -59400000000000,\n", - " -57600000000000, -55800000000000, -54000000000000, -52200000000000,\n", - " -50400000000000, -48600000000000, -46800000000000, -45000000000000,\n", - " -43200000000000, -41400000000000, -39600000000000, -37800000000000,\n", - " -36000000000000, -34200000000000, -32400000000000, -30600000000000,\n", - " -28800000000000, -27000000000000, -25200000000000, -23400000000000,\n", - " -21600000000000, -19800000000000, -18000000000000, -16200000000000,\n", - " -14400000000000, -12600000000000, -10800000000000, -9000000000000,\n", - " -7200000000000, -5400000000000, -3600000000000, -1800000000000],\n", - " dtype='timedelta64[ns]')
array([['2013-09-17T12:00:00.000000000', '2013-09-17T12:30:00.000000000',\n", - " '2013-09-17T13:00:00.000000000', ...,\n", - " '2013-09-21T10:30:00.000000000', '2013-09-21T11:00:00.000000000',\n", - " '2013-09-21T11:30:00.000000000'],\n", - " ['2013-09-17T12:30:00.000000000', '2013-09-17T13:00:00.000000000',\n", - " '2013-09-17T13:30:00.000000000', ...,\n", - " '2013-09-21T11:00:00.000000000', '2013-09-21T11:30:00.000000000',\n", - " '2013-09-21T12:00:00.000000000'],\n", - " ['2013-09-17T13:00:00.000000000', '2013-09-17T13:30:00.000000000',\n", - " '2013-09-17T14:00:00.000000000', ...,\n", - " '2013-09-21T11:30:00.000000000', '2013-09-21T12:00:00.000000000',\n", - " '2013-09-21T12:30:00.000000000'],\n", - " ...,\n", - " ['2013-09-21T08:30:00.000000000', '2013-09-21T09:00:00.000000000',\n", - " '2013-09-21T09:30:00.000000000', ...,\n", - " '2013-09-25T07:00:00.000000000', '2013-09-25T07:30:00.000000000',\n", - " '2013-09-25T08:00:00.000000000'],\n", - " ['2013-09-21T09:00:00.000000000', '2013-09-21T09:30:00.000000000',\n", - " '2013-09-21T10:00:00.000000000', ...,\n", - " '2013-09-25T07:30:00.000000000', '2013-09-25T08:00:00.000000000',\n", - " '2013-09-25T08:30:00.000000000'],\n", - " ['2013-09-21T09:30:00.000000000', '2013-09-21T10:00:00.000000000',\n", - " '2013-09-21T10:30:00.000000000', ...,\n", - " '2013-09-25T08:00:00.000000000', '2013-09-25T08:30:00.000000000',\n", - " '2013-09-25T09:00:00.000000000']], dtype='datetime64[ns]')
array([['2013-09-13T12:00:00.000000000', '2013-09-13T12:30:00.000000000',\n", - " '2013-09-13T13:00:00.000000000', ...,\n", - " '2013-09-17T10:30:00.000000000', '2013-09-17T11:00:00.000000000',\n", - " '2013-09-17T11:30:00.000000000'],\n", - " ['2013-09-13T12:30:00.000000000', '2013-09-13T13:00:00.000000000',\n", - " '2013-09-13T13:30:00.000000000', ...,\n", - " '2013-09-17T11:00:00.000000000', '2013-09-17T11:30:00.000000000',\n", - " '2013-09-17T12:00:00.000000000'],\n", - " ['2013-09-13T13:00:00.000000000', '2013-09-13T13:30:00.000000000',\n", - " '2013-09-13T14:00:00.000000000', ...,\n", - " '2013-09-17T11:30:00.000000000', '2013-09-17T12:00:00.000000000',\n", - " '2013-09-17T12:30:00.000000000'],\n", - " ...,\n", - " ['2013-09-17T08:30:00.000000000', '2013-09-17T09:00:00.000000000',\n", - " '2013-09-17T09:30:00.000000000', ...,\n", - " '2013-09-21T07:00:00.000000000', '2013-09-21T07:30:00.000000000',\n", - " '2013-09-21T08:00:00.000000000'],\n", - " ['2013-09-17T09:00:00.000000000', '2013-09-17T09:30:00.000000000',\n", - " '2013-09-17T10:00:00.000000000', ...,\n", - " '2013-09-21T07:30:00.000000000', '2013-09-21T08:00:00.000000000',\n", - " '2013-09-21T08:30:00.000000000'],\n", - " ['2013-09-17T09:30:00.000000000', '2013-09-17T10:00:00.000000000',\n", - " '2013-09-17T10:30:00.000000000', ...,\n", - " '2013-09-21T08:00:00.000000000', '2013-09-21T08:30:00.000000000',\n", - " '2013-09-21T09:00:00.000000000']], dtype='datetime64[ns]')
array([ 0., 0., 1., 1., 2., 2., 3., 3., 4., 4., 5., 5., 6.,\n", - " 6., 7., 7., 8., 8., 9., 9., 10., 10., 11., 11., 12., 12.,\n", - " 13., 13., 14., 14., 15., 15., 16., 16., 17., 17., 18., 18., 19.,\n", - " 19., 20., 20., 21., 21., 22., 22., 23., 23., 24., 24., 25., 25.,\n", - " 26., 26., 27., 27., 28., 28., 29., 29., 30., 30., 31., 31., 32.,\n", - " 32., 33., 33., 34., 34., 35., 35., 36., 36., 37., 37., 38., 38.,\n", - " 39., 39., 40., 40., 41., 41., 42., 42., 43., 43., 44., 44., 45.,\n", - " 45., 46., 46., 47., 47., 48., 48., 49., 49., 50., 50., 51., 51.,\n", - " 52., 52., 53., 53., 54., 54., 55., 55., 56., 56., 57., 57., 58.,\n", - " 58., 59., 59., 60., 60., 61., 61., 62., 62., 63., 63., 64., 64.,\n", - " 65., 65., 66., 66., 67., 67., 68., 68., 69., 69., 70., 70., 71.,\n", - " 71., 72., 72., 73., 73., 74., 74., 75., 75., 76., 76., 77., 77.,\n", - " 78., 78., 79., 79., 80., 80., 81., 81., 82., 82., 83., 83., 84.,\n", - " 84., 85., 85., 86., 86., 87., 87., 88., 88., 89., 89., 90., 90.,\n", - " 91., 91., 92., 92., 93., 93., 94., 94., 95., 95.])
array([[0.34182978, 0.3743617 , 0.35202128, ..., 0.45023912, 0.37195653,\n", - " 0.41715217],\n", - " [0.3743617 , 0.35202128, 0.341 , ..., 0.37195653, 0.41715217,\n", - " 0.39902174],\n", - " [0.35202128, 0.341 , 0.34478724, ..., 0.41715217, 0.39902174,\n", - " 0.3688913 ],\n", - " ...,\n", - " [0.48439535, 0.4035349 , 0.2892093 , ..., 0.63783336, 0.7064286 ,\n", - " 0.5955476 ],\n", - " [0.4035349 , 0.2892093 , 0.2788372 , ..., 0.7064286 , 0.5955476 ,\n", - " 0.44709525],\n", - " [0.2892093 , 0.2788372 , 0.28755814, ..., 0.5955476 , 0.44709525,\n", - " 0.42559525]], dtype=float32)
array([[-0.39930427, -0.4066378 , -0.5434241 , ..., -0.70005894,\n", - " 0.34137225, 0.25325173],\n", - " [-0.30783594, -0.46890533, -0.53692436, ..., 0.3190887 ,\n", - " 0.23320925, -0.26906586],\n", - " [-0.4261266 , -0.49373317, 1.63804 , ..., 0.21017814,\n", - " -0.29087424, 0.7156215 ],\n", - " ...,\n", - " [ 1.1717722 , 1.0530639 , 0.2793969 , ..., -0.45947552,\n", - " -0.11476445, 0.7189827 ],\n", - " [ 0.9254643 , 0.27463275, -0.58350575, ..., -0.09927058,\n", - " 0.7557026 , -0.08390903],\n", - " [ 0.2534623 , -0.55001223, -0.80816877, ..., 0.7782191 ,\n", - " -0.0669713 , -0.5875193 ]], dtype=float32)
array([[0.4066002 , 0.38901886, 0.3836266 , ..., 0.35888395, 0.35464624,\n", - " 0.35680822],\n", - " [0.3938792 , 0.3774008 , 0.37440446, ..., 0.3537852 , 0.35594296,\n", - " 0.35910505],\n", - " [0.37239552, 0.3650774 , 0.36374143, ..., 0.35499245, 0.35813135,\n", - " 0.36522728],\n", - " ...,\n", - " [0.60592777, 0.5684328 , 0.45174578, ..., 0.67859375, 0.63273644,\n", - " 0.56903 ],\n", - " [0.55782264, 0.45016435, 0.35072467, ..., 0.6316794 , 0.5678074 ,\n", - " 0.49656454],\n", - " [0.44697702, 0.34681633, 0.30443627, ..., 0.5670123 , 0.49567437,\n", - " 0.43110877]], dtype=float32)
array([[0.05809981, 0.05430707, 0.05057289, ..., 0.03990596, 0.04132949,\n", - " 0.04317982],\n", - " [0.05863498, 0.05367276, 0.05093383, ..., 0.04102965, 0.04281238,\n", - " 0.04410928],\n", - " [0.05488988, 0.05166907, 0.04870024, ..., 0.04244705, 0.04368677,\n", - " 0.04446527],\n", - " ...,\n", - " [0.12378779, 0.1136273 , 0.07299946, ..., 0.05473997, 0.05548793,\n", - " 0.05332477],\n", - " [0.11952963, 0.07925848, 0.04873188, ..., 0.05541608, 0.053298 ,\n", - " 0.05082954],\n", - " [0.08220471, 0.05017719, 0.03436997, ..., 0.05333869, 0.05086022,\n", - " 0.04712841]])
array([[0.39902174, 0.3688913 , 0.38763043, ..., 0.34258696, 0.29363042,\n", - " 0.29704347],\n", - " [0.3688913 , 0.38763043, 0.37802175, ..., 0.29363042, 0.29704347,\n", - " 0.31895652],\n", - " [0.38763043, 0.37802175, 0.46641305, ..., 0.29704347, 0.31895652,\n", - " 0.29106522],\n", - " ...,\n", - " [0.44709525, 0.42559525, 0.38180953, ..., 0.6854762 , 0.6784286 ,\n", - " 0.6520714 ],\n", - " [0.42559525, 0.38180953, 0.35152382, ..., 0.6784286 , 0.6520714 ,\n", - " 0.54519045],\n", - " [0.38180953, 0.35152382, 0.32159525, ..., 0.6520714 , 0.54519045,\n", - " 0.44257143]], dtype=float32)
+#
+#
+
+#
+# - [ ] TODO mike autocorrelation baseline
+# - [ ] TODO mike acorn data
+
+# OPTIONAL: Load the "autoreload" extension so that code can change. But blacklist large modules
+# %load_ext autoreload
+# %autoreload 2
+# %aimport -pandas
+# %aimport -torch
+# %aimport -numpy
+# %aimport -matplotlib
+# %aimport -dask
+# %aimport -tqdm
+# %matplotlib inline
+
+# +
+# Imports
+import torch
+from torch import nn, optim
+from torch.nn import functional as F
+from torch.autograd import Variable
+import torch
+import torch.utils.data
+
+import pandas as pd
+import numpy as np
+import matplotlib.pyplot as plt
+
+from pathlib import Path
+from tqdm.auto import tqdm
+
+import pytorch_lightning as pl
+# -
+
+from seq2seq_time.data.dataset import Seq2SeqDataSet
+from seq2seq_time.predict import predict
+
+import logging, sys
+logging.basicConfig(stream=sys.stdout, level=logging.INFO)
+
+# ## Parameters
+
+# +
+device = "cuda" if torch.cuda.is_available() else "cpu"
+print(f'using {device}')
+
+columns_target=['energy(kWh/hh)']
+window_past = 48*4
+window_future = 48*4
+batch_size = 64
+num_workers = 0
+freq = '30T'
+max_rows = 1e5
+
+
+# -
+
+# ## Load data
+
+# +
+
+def get_smartmeter_df(indir=Path('../data/raw/smart-meters-in-london')):
+ """
+ Data loading and cleanding is always messy, so understand this code is optional.
+ """
+
+ # Load csv files
+ csv_files = sorted((indir/'halfhourly_dataset').glob('*.csv'))[:1]
+
+# import pdb; pdb.set_trace() # you can use debugging in jupyter to interact with variables inside a function
+
+ # concatendate them
+ df = pd.concat([pd.read_csv(f, parse_dates=[1], na_values=['Null']) for f in csv_files])
+
+ # Add ACORN categories
+ df_households = pd.read_csv(indir/'informations_households.csv')
+ df_households = df_households[['LCLid', 'stdorToU', 'Acorn_grouped']]
+ df = pd.merge(df, df_households, on='LCLid')
+
+ # Take the mean over all houses
+ name, df = next(iter(df.groupby('LCLid')))
+ df = df.set_index('tstp')
+ print(df)
+
+ # Load weather data
+ df_weather = pd.read_csv(indir/'weather_hourly_darksky.csv', parse_dates=[3])
+ use_cols = ['visibility', 'windBearing', 'temperature', 'time', 'dewPoint',
+ 'pressure', 'apparentTemperature', 'windSpeed',
+ 'humidity']
+ df_weather = df_weather[use_cols].set_index('time')
+ df_weather = df_weather.resample(freq).first().ffill() # Resample to match energy data
+
+ # Join weather and energy data
+ df = pd.concat([df, df_weather], 1).dropna()
+
+ # Also find bank holidays
+ df_hols = pd.read_csv(indir/'uk_bank_holidays.csv', parse_dates=[0])
+ holidays = set(df_hols['Bank holidays'].dt.round('D'))
+
+ time = df.index.to_series()
+ def is_holiday(dt):
+ return dt.floor('D') in holidays
+ df['holiday'] = time.apply(is_holiday).astype(int)
+
+ # TODO pd.read_csv('../data/raw/smart-meters-in-london/acorn_details.csv', engine='python')
+
+
+ # Add time features
+ df["month"] = time.dt.month
+ df['day'] = time.dt.day
+ df['week'] = time.dt.week
+ df['hour'] = time.dt.hour
+ df['minute'] = time.dt.minute
+ df['dayofweek'] = time.dt.dayofweek
+
+ # Drop nan and 0's
+ df = df[df['energy(kWh/hh)']!=0]
+ df = df.dropna()
+
+ # sort by time
+ df = df.sort_index()
+
+ return df
+# -
+# Our dataset is the london smartmeter data. But at half hour intervals
+
+# +
+df = get_smartmeter_df()
+
+# df = df.resample(freq).first().dropna() # Where empty we will backfill, this will respect causality, and mostly maintain the mean
+
+df = df.tail(int(max_rows)).copy() # Just use last X rows
+df
+# -
+
+df.describe()
+
+# +
+import sklearn
+from sklearn.preprocessing import StandardScaler, OrdinalEncoder
+from sklearn_pandas import DataFrameMapper
+
+columns_input_numeric = list(df.drop(columns=columns_target)._get_numeric_data().columns)
+columns_categorical = list(set(df.columns)-set(columns_input_numeric)-set(columns_target))
+
+output_scalers = [([n], StandardScaler()) for n in columns_target]
+transformers=output_scalers + \
+[([n], StandardScaler()) for n in columns_input_numeric] + \
+[([n], OrdinalEncoder()) for n in columns_categorical]
+scaler = DataFrameMapper(transformers, df_out=True)
+df_norm = scaler.fit_transform(df)
+df_norm
+# -
+
+output_scaler = next(filter(lambda r:r[0][0] in columns_target, mapper4.features))[-1]
+output_scaler
+
+# # Resample
+df_norm = df_norm.resample(freq).first().fillna(0)
+
+# +
+# split data, with the test in the future
+n_split = -int(len(df)*0.2)
+df_train = df_norm[:n_split]
+df_test = df_norm[n_split:]
+
+# Show split
+df_train['energy(kWh/hh)'].plot(label='train')
+df_test['energy(kWh/hh)'].plot(label='test')
+plt.ylabel('energy(kWh/hh)')
+plt.legend()
+# -
+df_norm
+
+
+columns_blank=['visibility',
+ 'windBearing', 'temperature', 'dewPoint', 'pressure',
+ 'apparentTemperature', 'windSpeed', 'humidity']
+
+ds_train = Seq2SeqDataSet(df_train,
+ window_past=window_past,
+ window_future=window_future,
+ columns_blank=columns_blank)
+ds_test = Seq2SeqDataSet(df_test,
+ window_past=window_past,
+ window_future=window_future,
+ columns_blank=columns_blank)
+print(ds_train)
+print(ds_test)
+
+# %%timeit
+for i in range(100):
+ ds_train[i]
+
+# we can treat it like an array
+ds_train[0]
+len(ds_train)
+ds_train[0][2][-2]
+
+# +
+# We can get rows
+x_past, y_past, x_future, y_future = ds_train.get_rows(10)
+
+# Plot one instance, this is what the model sees
+y_past['energy(kWh/hh)'].plot(label='past')
+y_future['energy(kWh/hh)'].plot(ax=plt.gca(), label='future')
+plt.legend()
+plt.ylabel('energy(kWh/hh)')
+
+# Notice we've added on two new columns tsp (time since present) and is_past
+x_past.tail()
+# -
+
+# Notice we've hidden some future columns to prevent cheating
+x_future.tail()
+
+
+# ## Model
+
+# +
+
+class Seq2SeqNet(nn.Module):
+ def __init__(self, input_size, input_size_decoder, output_size, hidden_size=32, lstm_layers=2, lstm_dropout=0, _min_std = 0.05):
+ super().__init__()
+ self._min_std = _min_std
+
+ self.encoder = nn.LSTM(
+ input_size=input_size + output_size,
+ hidden_size=hidden_size,
+ batch_first=True,
+ num_layers=lstm_layers,
+ dropout=lstm_dropout,
+ )
+ self.decoder = nn.LSTM(
+ input_size=input_size_decoder,
+ hidden_size=hidden_size,
+ batch_first=True,
+ num_layers=lstm_layers,
+ dropout=lstm_dropout,
+ )
+ self.mean = nn.Linear(hidden_size, output_size)
+ self.std = nn.Linear(hidden_size, output_size)
+
+ def forward(self, context_x, context_y, target_x, target_y=None):
+ x = torch.cat([context_x, context_y], -1)
+ _, (h_out, cell) = self.encoder(x)
+
+ ## Shape
+ # hidden = [batch size, n layers * n directions, hid dim]
+ # cell = [batch size, n layers * n directions, hid dim]
+ # output = [batch size, seq len, hid dim * n directions]
+ outputs, (_, _) = self.decoder(target_x, (h_out, cell))
+
+
+ # outputs: [B, T, num_direction * H]
+ mean = self.mean(outputs)
+ log_sigma = self.std(outputs)
+ log_sigma = torch.clamp(log_sigma, np.log(self._min_std), -np.log(self._min_std))
+
+ sigma = torch.exp(log_sigma)
+ y_dist = torch.distributions.Normal(mean, sigma)
+ return y_dist
+
+
+# -
+
+
+
+# +
+input_size = x_past.shape[-1]
+output_size = y_future.shape[-1]
+
+model = Seq2SeqNet(input_size, input_size, output_size,
+ hidden_size=32,
+ lstm_layers=2,
+ lstm_dropout=0).to(device)
+model
+# -
+# Init the optimiser
+optimizer = optim.Adam(model.parameters(), lr=1e-3)
+
+# +
+
+past_x = torch.rand((batch_size, window_past, input_size)).to(device)
+future_x = torch.rand((batch_size, window_future, input_size)).to(device)
+past_y = torch.rand((batch_size, window_past, output_size)).to(device)
+future_y = torch.rand((batch_size, window_future, output_size)).to(device)
+output = model(past_x, past_y, future_x, future_y)
+print(output)
+
+from torchsummaryX import summary
+summary(model, past_x, past_y, future_x, future_y )
+1
+# -
+
+# ## Training
+
+
+
+
+
+# +
+def train_epoch(ds, model, bs=128):
+ model.train()
+
+ training_loss = []
+
+ # Put data into a torch loader
+ load_train = torch.utils.data.dataloader.DataLoader(
+ ds,
+ batch_size=bs,
+ pin_memory=False,
+ num_workers=num_workers,
+ shuffle=True,
+ )
+
+ for batch in tqdm(load_train, leave=False, desc='train'):
+ # Send data to gpu
+ x_past, y_past, x_future, y_future = [d.to(device) for d in batch]
+
+ # Discard previous gradients
+ optimizer.zero_grad()
+
+ # Run model
+ y_dist = model(x_past, y_past, x_future, y_future)
+
+ # Get loss, it's Negative Log Likelihood
+ loss = -y_dist.log_prob(y_future).mean()
+
+ # Backprop
+ loss.backward()
+ optimizer.step()
+
+ # Record stats
+ training_loss.append(loss.item())
+
+ return np.mean(training_loss)
+
+
+def test_epoch(ds, model, bs=512):
+ model.eval()
+
+ test_loss = []
+ load_test = torch.utils.data.dataloader.DataLoader(ds,
+ batch_size=bs,
+ pin_memory=False,
+ num_workers=num_workers)
+ for batch in tqdm(load_test, leave=False, desc='test'):
+ # Send data to gpu
+ x_past, y_past, x_future, y_future = [d.to(device) for d in batch]
+ with torch.no_grad():
+ # Run model
+ y_dist = model(x_past, y_past, x_future, y_future)
+ # Get loss, it's Negative Log Likelihood
+ loss = -y_dist.log_prob(y_future).mean()
+
+ test_loss.append(loss.item())
+
+ return np.mean(test_loss)
+
+
+def training_loop(ds_train, ds_test, model, epochs=1, bs=128):
+ all_losses = []
+ try:
+ test_loss = test_epoch(ds_test, model)
+ print(f"Start: Test Loss = {test_loss:.2f}")
+ for epoch in tqdm(range(epochs), desc='epochs'):
+ loss = train_epoch(ds_train, model, bs=bs)
+ print(f"Epoch {epoch+1}/{epochs}: Training Loss = {loss:.2f}")
+
+ test_loss = test_epoch(ds_test, model)
+ print(f"Epoch {epoch+1}/{epochs}: Test Loss = {test_loss:.2f}")
+ print("-" * 50)
+
+ all_losses.append([loss, test_loss])
+
+ except KeyboardInterrupt:
+ # This lets you stop manually. and still get the results
+ pass
+
+ # Visualising the results
+ all_losses = np.array(all_losses)
+ plt.plot(all_losses[:, 0], label="Training")
+ plt.plot(all_losses[:, 1], label="Test")
+ plt.title("Loss")
+ plt.legend()
+
+ return all_losses
+
+
+# -
+
+# this might take 1 minute per epoch on a gpu
+training_loop(ds_train, ds_test, model, epochs=8, bs=batch_size)
+1
+
+# ## Predict
+#
+
+# TODO get working
+output_scaler = scaler.transformers[-4][1]
+ds_preds = predict(model, ds_test, batch_size*6, device=device, scaler=output_scaler)
+
+
+
+# +
+# TODO Metrics... smape etc
+
+# +
+def plot_prediction(ds_preds, i):
+ """Plot a prediction into the future, at a single point in time."""
+ d = ds_preds.isel(t_source=i)
+
+ # Get arrays
+ xf = d.t_target
+ yp = d.y_pred
+ s = d.y_pred_std
+ yt = d.y_true
+ now = d.t_source.squeeze()
+
+ # plot prediction
+ plt.fill_between(xf, yp-2*s, yp+2*s, alpha=0.25,
+ facecolor="b",
+ interpolate=True,
+ label="2 std",)
+ plt.plot(xf, yp, label='pred', c='b')
+
+ # plot true
+ plt.scatter(
+ d.t_past,
+ d.y_past,
+ c='k',
+ s=6
+ )
+ plt.scatter(xf, yt, label='true', c='k', s=6)
+
+ # plot a red line for now
+ plt.vlines(x=now, ymin=0, ymax=1, label='now', color='r')
+
+ now=pd.Timestamp(now.values)
+ plt.title(f'Prediction NLL={d.nll.mean().item():2.2g}')
+ plt.xlabel(f'{now.date()}')
+ plt.ylabel('energy(kWh/hh)')
+ plt.legend()
+ plt.xticks(rotation=45)
+ plt.show()
+
+# plot_prediction(ds_preds, 0)
+# plot_prediction(ds_preds, 12) # 6 hours later
+plot_prediction(ds_preds, 24) # 12 hours later
+plot_prediction(ds_preds, 48) # 12 hours later
+# -
+
+# ## Error vs time ahead
+
+
+
+# +
+d = ds_preds.mean('t_source') # Mean over all predictions
+
+# Plot with xarray, it has a pandas like interface
+d.plot.scatter('t_ahead_hours', 'nll')
+
+# Tidy the graph
+n = len(ds_preds.t_source)
+plt.ylabel('Negative Log Likelihood (lower is better)')
+plt.xlabel('Hours ahead')
+plt.title(f'NLL vs time (no. samples={n})')
+# -
+
+d = ds_preds.mean('t_source') # Mean over all predictions
+d['likelihood'] = np.exp(-d.nll) # get likelihood, after taking mean in log domain
+d.plot.scatter('t_ahead_hours', 'likelihood')
+
+
+
+# Make a plot of the NLL over time. Does this solution get worse with time?
+# this is hard because we need to take the mean over t_ahead
+# then group by t_source
+d = ds_preds.mean('t_ahead').groupby('t_source').mean()
+# And even then it's clearer with smoothing
+d.plot.scatter('t_source', 'nll')
+plt.xticks(rotation=45)
+plt.title('NLL over time (lower is better)')
+1
+
+# A scatter plot is easy with xarray
+ds_preds.plot.scatter('y_true', 'y_pred', s=.01)
+
+
diff --git a/requirements/environment.max.yaml b/requirements/environment.max.yaml
new file mode 100644
index 0000000..a485d98
--- /dev/null
+++ b/requirements/environment.max.yaml
@@ -0,0 +1,221 @@
+name: seq2seq-time
+channels:
+ - pytorch
+ - conda-forge
+ - defaults
+dependencies:
+ - _libgcc_mutex=0.1=conda_forge
+ - _openmp_mutex=4.5=1_gnu
+ - absl-py=0.10.0=py37hc8dfbb8_1
+ - aiohttp=3.6.3=py37h7b6447c_0
+ - appdirs=1.4.4=py_0
+ - argon2-cffi=20.1.0=py37h8f50634_2
+ - async-timeout=3.0.1=py_1000
+ - async_generator=1.10=py_0
+ - attrs=20.2.0=pyh9f0ad1d_0
+ - awscli=1.18.159=py37hc8dfbb8_0
+ - backcall=0.2.0=pyh9f0ad1d_0
+ - backports=1.0=py_2
+ - backports.functools_lru_cache=1.6.1=py_0
+ - black=20.8b1=py_1
+ - blas=1.0=mkl
+ - bleach=3.2.1=pyh9f0ad1d_0
+ - blinker=1.4=py_1
+ - botocore=1.18.18=pyh9f0ad1d_0
+ - brotlipy=0.7.0=py37hb5d75c8_1001
+ - c-ares=1.16.1=h516909a_3
+ - ca-certificates=2020.10.14=0
+ - cachetools=4.1.1=py_0
+ - certifi=2020.6.20=py37he5f6b98_2
+ - cffi=1.14.3=py37h00ebd2e_1
+ - chardet=3.0.4=py37he5f6b98_1008
+ - click=7.1.2=pyh9f0ad1d_0
+ - colorama=0.4.3=py_0
+ - cryptography=3.1.1=py37hff6837a_1
+ - cudatoolkit=10.2.89=hfd86e86_1
+ - cycler=0.10.0=py_2
+ - dataclasses=0.7=py37_0
+ - dbus=1.13.18=hb2f20db_0
+ - decorator=4.4.2=py_0
+ - defusedxml=0.6.0=py_0
+ - docutils=0.15.2=py37_0
+ - entrypoints=0.3=py37hc8dfbb8_1002
+ - expat=2.2.10=he6710b0_2
+ - fontconfig=2.13.1=h1056068_1002
+ - freetype=2.10.3=he06d7ca_0
+ - fsspec=0.8.4=py_0
+ - future=0.18.2=py37hc8dfbb8_2
+ - gettext=0.19.8.1=hf34092f_1003
+ - glib=2.66.1=he1b5a44_1
+ - google-auth=1.22.1=py_0
+ - google-auth-oauthlib=0.4.1=py_2
+ - grpcio=1.31.0=py37hb0870dc_0
+ - gst-plugins-base=1.14.5=h0935bb2_2
+ - gstreamer=1.14.5=h36ae1b5_2
+ - icu=67.1=he1b5a44_0
+ - idna=2.10=pyh9f0ad1d_0
+ - importlib-metadata=2.0.0=py37hc8dfbb8_0
+ - importlib_metadata=2.0.0=1
+ - iniconfig=1.1.1=py_0
+ - intel-openmp=2020.2=254
+ - ipykernel=5.3.4=py37hc6149b9_1
+ - ipython=7.18.1=py37hc6149b9_1
+ - ipython_genutils=0.2.0=py_1
+ - ipywidgets=7.5.1=pyh9f0ad1d_1
+ - jedi=0.17.2=py37hc8dfbb8_1
+ - jinja2=2.11.2=pyh9f0ad1d_0
+ - jmespath=0.10.0=pyh9f0ad1d_0
+ - joblib=0.17.0=py_0
+ - jpeg=9d=h516909a_0
+ - jsonschema=3.2.0=py37hc8dfbb8_1
+ - jupyter_client=6.1.7=py_0
+ - jupyter_core=4.6.3=py37hc8dfbb8_2
+ - jupyterlab_pygments=0.1.2=pyh9f0ad1d_0
+ - kiwisolver=1.2.0=py37h99015e2_1
+ - krb5=1.17.1=hfafb76e_3
+ - lcms2=2.11=hbd6801e_0
+ - ld_impl_linux-64=2.35=h769bd43_9
+ - libblas=3.8.0=17_openblas
+ - libcblas=3.8.0=17_openblas
+ - libclang=10.0.1=default_hde54327_1
+ - libedit=3.1.20191231=he28a2e2_2
+ - libevent=2.1.10=hcdb4288_3
+ - libffi=3.2.1=he1b5a44_1007
+ - libgcc-ng=9.3.0=h5dbcf3e_17
+ - libgfortran-ng=7.5.0=hae1eefd_17
+ - libgfortran4=7.5.0=hae1eefd_17
+ - libglib=2.66.1=h0dae87d_1
+ - libgomp=9.3.0=h5dbcf3e_17
+ - libiconv=1.16=h516909a_0
+ - liblapack=3.8.0=17_openblas
+ - libllvm10=10.0.1=he513fc3_3
+ - libopenblas=0.3.10=pthreads_hb3c22a3_5
+ - libpng=1.6.37=hed695b0_2
+ - libpq=12.3=h1281834_2
+ - libprotobuf=3.13.0.1=h8b12597_0
+ - libsodium=1.0.18=h516909a_1
+ - libstdcxx-ng=9.3.0=h2ae2ef3_17
+ - libtiff=4.1.0=hc7e4089_6
+ - libuuid=2.32.1=h14c3975_1000
+ - libwebp-base=1.1.0=h516909a_3
+ - libxcb=1.14=h7b6447c_0
+ - libxkbcommon=0.10.0=he1b5a44_0
+ - libxml2=2.9.10=h68273f3_2
+ - lz4-c=1.9.2=he1b5a44_3
+ - markdown=3.3.1=pyh9f0ad1d_0
+ - markupsafe=1.1.1=py37hb5d75c8_2
+ - matplotlib=3.3.2=py37hc8dfbb8_1
+ - matplotlib-base=3.3.2=py37hc9afd2a_1
+ - mccabe=0.6.1=py_1
+ - mistune=0.8.4=py37h8f50634_1002
+ - mkl=2020.2=256
+ - more-itertools=8.5.0=py_0
+ - multidict=4.7.6=py37h7b6447c_1
+ - mypy=0.790=py_0
+ - mypy_extensions=0.4.3=py37hc8dfbb8_1
+ - mysql-common=8.0.21=2
+ - mysql-libs=8.0.21=hf3661c5_2
+ - nbclient=0.5.1=py_0
+ - nbconvert=6.0.7=py37hc8dfbb8_1
+ - nbformat=5.0.8=py_0
+ - ncurses=6.2=he1b5a44_2
+ - nest-asyncio=1.4.1=py_0
+ - ninja=1.10.1=hfc4b9b4_2
+ - notebook=6.1.4=py37hc8dfbb8_1
+ - nspr=4.29=he1b5a44_1
+ - nss=3.58=h27285de_1
+ - numpy=1.19.2=py37h7ea13bd_1
+ - oauthlib=3.1.0=py_0
+ - olefile=0.46=pyh9f0ad1d_1
+ - openssl=1.1.1h=h516909a_0
+ - packaging=20.4=pyh9f0ad1d_0
+ - pandas=1.1.3=py37h9fdb41a_2
+ - pandoc=2.11.0.2=hd18ef5c_0
+ - pandocfilters=1.4.2=py_1
+ - parso=0.7.1=pyh9f0ad1d_0
+ - pathspec=0.8.0=pyh9f0ad1d_0
+ - pcre=8.44=he1b5a44_0
+ - pexpect=4.8.0=py37hc8dfbb8_1
+ - pickleshare=0.7.5=py37hc8dfbb8_1002
+ - pillow=8.0.0=py37h718be6c_0
+ - pip=20.2.4=py_0
+ - pluggy=0.13.1=py37hc8dfbb8_3
+ - prometheus_client=0.8.0=pyh9f0ad1d_0
+ - prompt-toolkit=3.0.8=py_0
+ - protobuf=3.13.0.1=py37h3340039_1
+ - psutil=5.7.2=py37hb5d75c8_1
+ - ptyprocess=0.6.0=py37_1000
+ - py=1.9.0=pyh9f0ad1d_0
+ - pyasn1=0.4.8=py_0
+ - pyasn1-modules=0.2.8=py_0
+ - pycodestyle=2.6.0=pyh9f0ad1d_0
+ - pycparser=2.20=pyh9f0ad1d_2
+ - pydocstyle=5.1.1=py_0
+ - pyflakes=2.2.0=pyh9f0ad1d_0
+ - pygments=2.7.1=py_0
+ - pyjwt=1.7.1=py_0
+ - pylama=7.7.1=py_0
+ - pyopenssl=19.1.0=py37_0
+ - pyparsing=2.4.7=pyh9f0ad1d_0
+ - pyqt=5.12.3=py37h8685d9f_4
+ - pyrsistent=0.17.3=py37h8f50634_1
+ - pysocks=1.7.1=py37he5f6b98_2
+ - pytest=6.1.1=py37hc8dfbb8_1
+ - python=3.7.8=h425cb1d_1_cpython
+ - python-dateutil=2.8.1=py_0
+ - python_abi=3.7=1_cp37m
+ - pytorch=1.6.0=py3.7_cuda10.2.89_cudnn7.6.5_0
+ - pytorch-lightning=1.0.2=py_0
+ - pytz=2020.1=pyh9f0ad1d_0
+ - pyyaml=5.3.1=py37hb5d75c8_1
+ - pyzmq=19.0.2=py37hac76be4_2
+ - qt=5.12.9=h1f2b2cb_0
+ - readline=8.0=he28a2e2_2
+ - regex=2020.10.15=py37h8f50634_0
+ - requests=2.24.0=pyh9f0ad1d_0
+ - requests-oauthlib=1.3.0=pyh9f0ad1d_0
+ - rsa=4.4.1=pyh9f0ad1d_0
+ - s3transfer=0.3.3=py37hc8dfbb8_2
+ - scikit-learn=0.23.2=py37h6785257_0
+ - scipy=1.5.2=py37hb14ef9d_2
+ - send2trash=1.5.0=py_0
+ - setuptools=49.6.0=py37he5f6b98_2
+ - six=1.15.0=pyh9f0ad1d_0
+ - snowballstemmer=2.0.0=py_0
+ - sqlite=3.33.0=h4cf870e_1
+ - tensorboard=2.3.0=py_0
+ - tensorboard-plugin-wit=1.6.0=pyh9f0ad1d_0
+ - terminado=0.9.1=py37hc8dfbb8_1
+ - testpath=0.4.4=py_0
+ - threadpoolctl=2.1.0=pyh5ca1d4c_0
+ - tk=8.6.10=hed695b0_1
+ - toml=0.10.1=pyh9f0ad1d_0
+ - torchvision=0.7.0=py37_cu102
+ - tornado=6.0.4=py37h8f50634_2
+ - tqdm=4.50.2=pyh9f0ad1d_0
+ - traitlets=5.0.5=py_0
+ - typed-ast=1.4.1=py37h516909a_0
+ - typing-extensions=3.7.4.3=0
+ - typing_extensions=3.7.4.3=py_0
+ - urllib3=1.25.10=py_0
+ - wcwidth=0.2.5=pyh9f0ad1d_2
+ - webencodings=0.5.1=py_1
+ - werkzeug=1.0.1=pyh9f0ad1d_0
+ - wheel=0.35.1=pyh9f0ad1d_0
+ - widgetsnbextension=3.5.1=py37hc8dfbb8_2
+ - xarray=0.16.1=py_0
+ - xz=5.2.5=h516909a_1
+ - yaml=0.2.5=h516909a_0
+ - yapf=0.30.0=pyh9f0ad1d_0
+ - yarl=1.6.2=py37h8f50634_0
+ - zeromq=4.3.3=he1b5a44_2
+ - zipp=3.3.1=py_0
+ - zlib=1.2.11=h516909a_1010
+ - zstd=1.4.5=h6597ccf_2
+ - pip:
+ - pyqt5-sip==4.19.18
+ - pyqtchart==5.12
+ - pyqtwebengine==5.12.1
+ - sklearn-pandas==2.0.2
+ - torchsummaryx==1.3.0
+prefix: /home/wassname/anaconda/envs/seq2seq-time
diff --git a/requirements/environment.min.yaml b/requirements/environment.min.yaml
new file mode 100644
index 0000000..fdbc036
--- /dev/null
+++ b/requirements/environment.min.yaml
@@ -0,0 +1,26 @@
+name: seq2seq-time
+channels:
+ - conda-forge
+ - defaults
+dependencies:
+ - python==3.7
+ - pip
+ - awscli
+ - ipykernel
+ - tqdm
+ - xarray
+ - pandas
+ - pytorch
+ - torchvision
+ - cudatoolkit==10.2
+ - black
+ - pylama
+ - mypy
+ - pytest
+ - numpy
+ - matplotlib
+ - scikit-learn
+ - pytorch-lightning
+ - yapf
+ - ipywidgets
+prefix: /home/wassname/anaconda/envs/seq2seq-time
diff --git a/requirements/environment.yaml b/requirements/environment.yaml
index 718a9de..4dc84cd 100644
--- a/requirements/environment.yaml
+++ b/requirements/environment.yaml
@@ -11,5 +11,5 @@ dependencies:
- awscli
- pip:
# local package
- - -e .
+ # - -e .
diff --git a/requirements/readme.md b/requirements/readme.md
index 340249b..d022f01 100644
--- a/requirements/readme.md
+++ b/requirements/readme.md
@@ -4,3 +4,13 @@ This project has multiple ways of documenting requirements
- environment.min.yaml - This is the minimum requirements, use it to install a new test or dev environment
- environment.max.yaml - This pins all conda packages, use for production or finding vunrebilities
- requirements.txt - For people or bots not using conda
+
+```
+# Install requirements
+conda create --name seq2seq-time python=3.7 -f ./requirements/environment.yaml
+conda activate seq2seq-time
+# Install this package in editable mode
+python -m pip install -e .
+# Install kernel
+python -m ipykernel install --user --name seq2seq-time --display-name seq2seq-time
+```
diff --git a/requirements/requirements.txt b/requirements/requirements.txt
new file mode 100644
index 0000000..e69de29
diff --git a/seq2seq_time/data/dataset.py b/seq2seq_time/data/dataset.py
new file mode 100644
index 0000000..88b7a3d
--- /dev/null
+++ b/seq2seq_time/data/dataset.py
@@ -0,0 +1,100 @@
+import pandas as pd
+import torch.utils.data
+import numpy as np
+
+def assert_normalized(df):
+ stats = df.describe().T
+ np.testing.assert_allclose(stats['mean'].values, 0, atol=0.1), 'means should be normalized to ~0'
+ np.testing.assert_allclose(stats['std'].values, 1, atol=0.1), 'standard deviations should be normalized to ~0'
+
+def assert_no_objects(df):
+ for name, dtype in df.dtypes.iteritems():
+ assert dtype.name!='object', f'all objects should be pd.categories. {name} is not'
+
+
+class Seq2SeqDataSet(torch.utils.data.Dataset):
+ """
+ Takes in dataframe and returns sequences through time.
+
+ Returns x_past, y_past, x_future, etc.
+ """
+
+ def __init__(self, df: pd.DataFrame, window_past=40, window_future=10, columns_target=['energy(kWh/hh)'], columns_blank=[],):
+ """
+ Args:
+ - df: DataFrame with time index, already scaled
+ - columns_blank: The columns we will blank, in the future
+ """
+ super().__init__()
+ # TODO auto categorical columns
+ # TODO specify blank future columns
+ assert isinstance(df.index, pd.DatetimeIndex), 'should have a datetime index'
+ assert df.index.freq is not None, 'should have freq'
+ # assert_normalized(df)
+ assert_no_objects(df)
+
+ # Use numpy instead of pandas, for speed
+ self.x = df.drop(columns=columns_target).copy().values
+ self.y = df[columns_target].copy().values
+ self.t = df.index.copy()
+ self.columns = list(df.columns)
+ self.icol_blank = [df.drop(columns=columns_target).columns.tolist().index(n) for n in columns_blank]
+
+ self.window_past = window_past
+ self.window_future = window_future
+ self.columns_target = columns_target
+
+ def get_components(self, i):
+ """Get past and future rows."""
+ x = self.x[i : i + (self.window_past + self.window_future)].copy()
+ y = self.y[i:i + (self.window_past + self.window_future)].copy()
+ t = self.t[i:i + (self.window_past + self.window_future)].copy()
+ t = t.astype(int) * 1e-9 / 60 / 60 / 24 # days
+ t = t.values
+ now = t[self.window_past]
+
+ # Add a features: relative hours since present time, is future
+ tstp = (t - now)[:, None]
+ is_past = tstp < 0
+ x = np.concatenate([x, tstp, is_past], -1)
+
+ # Split into future and past
+ x_past = x[:self.window_past]
+ y_past = y[:self.window_past]
+ x_future = x[self.window_past:]
+ y_future = y[self.window_past:]
+
+ # Stop it cheating by using future weather measurements
+ x_future[:, self.icol_blank] = 0
+ return x_past, y_past, x_future, y_future
+
+
+ def __getitem__(self, i):
+ """This is how python implements square brackets"""
+ if i<0:
+ # Handle negative integers
+ i = len(self)+i
+ data = self.get_components(i)
+ # From dataframe to torch
+ return [d.astype(np.float32) for d in data]
+
+
+ def get_rows(self, i):
+ """
+ Output pandas dataframes for display purposes.
+ """
+ x_cols = list(self.columns)[1:] + ['tsp_days', 'is_past']
+ x_past, y_past, x_future, y_future = self.get_components(i)
+ t_past = self.t[i:i+self.window_past]
+ t_future = self.t[i+self.window_past:i+self.window_past + self.window_future]
+ x_past = pd.DataFrame(x_past, columns=x_cols, index=t_past)
+ x_future = pd.DataFrame(x_future, columns=x_cols, index=t_future)
+ y_past = pd.DataFrame(y_past, columns=self.columns_target, index=t_past)
+ y_future = pd.DataFrame(y_future, columns=self.columns_target, index=t_future)
+ return x_past, y_past, x_future, y_future
+
+ def __len__(self):
+ return len(self.x) - (self.window_past + self.window_future)
+
+ def __repr__(self):
+ return f'<{type(self).__name__}(shape={self.x.shape}, times={self.t[0]} to {self.t[1]} at {self.t.freq.freqstr})>'
diff --git a/seq2seq_time/predict.py b/seq2seq_time/predict.py
new file mode 100644
index 0000000..a3acad5
--- /dev/null
+++ b/seq2seq_time/predict.py
@@ -0,0 +1,72 @@
+import xarray as xr
+import torch
+from tqdm.auto import tqdm
+import pandas as pd
+
+from .util import to_numpy
+
+def predict(model, ds_test, batch_size, device='cpu', scaler=None):
+ """
+ Gather all predictions into xarray.
+
+ When we generate prediction in a sequence to sequence model we start at a time then predict
+ N steps into the future. So we have 2 dimensions: source time, target time.
+
+ But we also care about how far we were predicting into the future, so we have 3 dimensions: source time, target time, time ahead.
+
+ It's hard to use pandas for data with virtual dimensions so we will use xarray. Xarray has an interface similar to pandas but also allows coordinates which are virtual dimensions.
+ """
+ load_test = torch.utils.data.dataloader.DataLoader(ds_test, batch_size=batch_size)
+ freq = ds_test.t.freq
+ xrs = []
+ for i, batch in enumerate(tqdm(load_test, desc='predict')):
+ model.eval()
+ with torch.no_grad():
+ x_past, y_past, x_future, y_future = [d.to(device) for d in batch]
+ y_dist = model(x_past, y_past, x_future, y_future)
+ nll = -y_dist.log_prob(y_future)
+
+ # Convert to numpy
+ mean = to_numpy(y_dist.loc.squeeze(-1))
+ std = to_numpy(y_dist.scale.squeeze(-1))
+ nll = to_numpy(nll.squeeze(-1))
+ y_future = to_numpy(y_future.squeeze(-1))
+ y_past = to_numpy(y_past.squeeze(-1))
+
+ # Make an xarray.Dataset for the data
+ bs = y_future.shape[0]
+ t_source = ds_test.t[i:i+bs].values
+ t_ahead = pd.timedelta_range(0, periods=ds_test.window_future, freq=freq).values
+ t_behind = pd.timedelta_range(end=-pd.Timedelta(freq), periods=ds_test.window_past, freq=freq)
+ xr_out = xr.Dataset(
+ {
+ # Format> name: ([dimensions,...], array),
+ "y_past": (["t_source", "t_behind",], y_past),
+ "nll": (["t_source", "t_ahead",], nll),
+ "y_pred": (["t_source", "t_ahead",], mean),
+ "y_pred_std": (["t_source", "t_ahead",], std),
+ "y_true": (["t_source", "t_ahead",], y_future),
+ },
+ coords={"t_source": t_source, "t_ahead": t_ahead, "t_behind": t_behind},
+ )
+ xrs.append(xr_out)
+
+ # Join all batches
+ ds_preds = xr.concat(xrs, dim="t_source")
+
+ # undo scaling on y
+ if scaler:
+ ds_preds['y_pred_std'].values = ds_preds.y_pred_std * scaler.scale_
+ ds_preds['y_past'].values = scaler.inverse_transform(ds_preds.y_past)
+ ds_preds['y_pred'].values = scaler.inverse_transform(ds_preds.y_pred)
+ ds_preds['y_true'].values = scaler.inverse_transform(ds_preds.y_true)
+
+ # Add some derived coordinates, they will be the ones not in bold
+ # The target time, is a function of the source time, and how far we predict ahead
+ ds_preds = ds_preds.assign_coords(t_target=ds_preds.t_source+ds_preds.t_ahead)
+
+ ds_preds = ds_preds.assign_coords(t_past=ds_preds.t_source+ds_preds.t_behind)
+
+ # Some plots don't like timedeltas, so lets make a coordinate for time ahead in hours
+ ds_preds = ds_preds.assign_coords(t_ahead_hours=(ds_preds.t_ahead*1.0e-9/60/60).astype(float))
+ return ds_preds
diff --git a/seq2seq_time/util.py b/seq2seq_time/util.py
new file mode 100644
index 0000000..ba1324a
--- /dev/null
+++ b/seq2seq_time/util.py
@@ -0,0 +1,10 @@
+from pathlib import Path
+import torch
+
+project_dir = Path(__file__).parent.parent
+
+def to_numpy(x):
+ """Helper function to avoid repeating code"""
+ if isinstance(x, torch.Tensor):
+ x = x.cpu().detach().numpy()
+ return x