diff --git a/README.md b/README.md index 0bc0c3a..af5d7f4 100644 --- a/README.md +++ b/README.md @@ -1,3 +1,10 @@ +Fork note: + +Interested by the results on the Exchange data I'm digging deeper, plotting, extending, and replicating the exchange reults. See `scratch.ipynb` + + +---- + # DeepTime: Deep Time-Index Meta-Learning for Non-Stationary Time-Series Forecasting

@@ -109,4 +116,4 @@ Please consider citing if you find this code useful to your research. author={Gerald Woo and Chenghao Liu and Doyen Sahoo and Akshat Kumar and Steven C. H. Hoi}, year={2022}, url={https://arxiv.org/abs/2207.06046}, -} \ No newline at end of file +} diff --git a/scratch.ipynb b/scratch.ipynb new file mode 100644 index 0000000..aec61dd --- /dev/null +++ b/scratch.ipynb @@ -0,0 +1,332 @@ +{ + "cells": [ + { + "cell_type": "code", + "execution_count": null, + "id": "4e09086b", + "metadata": { + "lines_to_next_cell": 0 + }, + "outputs": [], + "source": [ + "\n", + "import os\n", + "from os.path import join\n", + "import math\n", + "import logging\n", + "from typing import Callable, Optional, Union, Dict, Tuple\n", + "\n", + "from matplotlib import pyplot as plt\n", + "\n", + "import gin\n", + "from fire import Fire\n", + "import numpy as np\n", + "import torch\n", + "from torch.utils.data import DataLoader\n", + "from torch import optim\n", + "from torch import nn\n", + "\n", + "from experiments.base import Experiment\n", + "from data.datasets import ForecastDataset\n", + "from models import get_model\n", + "from utils.checkpoint import Checkpoint\n", + "from utils.ops import default_device, to_tensor\n", + "from utils.losses import get_loss_fn\n", + "from utils.metrics import calc_metrics\n", + "\n", + "from experiments.forecast import get_data\n", + "gin.enter_interactive_mode()" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "516c4869", + "metadata": {}, + "outputs": [], + "source": [ + "\n", + "\n", + "gin.clear_config()\n", + "# gin.parse_config(open(\"storage/experiments/Exchange/96M/repeat=0/config.gin\"))\n", + "gin.parse_config(open(\"storage/experiments/Exchange/96Mplus/repeat=0/config.gin\"))" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "3e056c6f", + "metadata": {}, + "outputs": [], + "source": [ + "train_set, train_loader = get_data(flag='train', batch_size=16)\n", + "# x, _, _, _ =train_set[0]\n", + "# x = x * 1.0" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "07f093f2", + "metadata": { + "lines_to_next_cell": 2 + }, + "outputs": [], + "source": [ + "# x -= x[0]\n", + "# x /= x.std()\n", + "# plt.plot(x)\n", + "# # %%" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "7aeee9df", + "metadata": {}, + "outputs": [], + "source": [ + "model = get_model(\"deeptime2\",\n", + " dim_size=train_set.data_x.shape[1],\n", + " datetime_feats=train_set.timestamps.shape[-1]).to(default_device())\n", + "model.load_state_dict(torch.load('storage/experiments/Exchange/96Mplus/repeat=0/model.pth'))\n", + "model = model.eval()" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "b6e8ca1f", + "metadata": { + "lines_to_next_cell": 0 + }, + "outputs": [], + "source": [ + "b = train_set[1]\n", + "b = [bb[None, :] for bb in b]\n", + "x, y, x_time, y_time = map(to_tensor, b)\n", + "with torch.no_grad():\n", + " forecast = model(x, x_time, y_time)" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "638da5c2", + "metadata": { + "lines_to_next_cell": 2 + }, + "outputs": [], + "source": [] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "66d7f095", + "metadata": {}, + "outputs": [], + "source": [ + "plt.title('inception inr')\n", + "import matplotlib.colors as mcolors\n", + "colors = list(mcolors.BASE_COLORS.keys())\n", + "l = x.shape[1]\n", + "forecast2 = forecast[0].detach().cpu().numpy()\n", + "x2 = x[0].cpu()\n", + "y2 = y[0].cpu()\n", + "i_past = list(range(l))\n", + "i_future = list(range(l, l*2))\n", + "for i in range(x.shape[-1]):\n", + " plt.plot(range(l), x2[:, i], c=colors[i])\n", + "for i in range(x.shape[-1]):\n", + " plt.plot(range(l, l*2), y2[:, i], c=colors[i])\n", + "for i in range(x.shape[-1]):\n", + " plt.plot(range(l, l*2), forecast2[:, i], c=colors[i], linestyle='--')" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "47bf8d52", + "metadata": { + "lines_to_next_cell": 0 + }, + "outputs": [], + "source": [ + "gin.clear_config()\n", + "gin.parse_config(open(\"storage/experiments/Exchange/96M/repeat=0/config.gin\"))\n", + "\n", + "train_set, train_loader = get_data(flag='train', batch_size=16)\n", + "\n", + "model = get_model(\"deeptime\",\n", + " dim_size=train_set.data_x.shape[1],\n", + " datetime_feats=train_set.timestamps.shape[-1]).to(default_device())\n", + "model.load_state_dict(torch.load('storage/experiments/Exchange/96M/repeat=0/model.pth'))\n", + "model = model.eval()\n", + "\n", + "\n", + "b = train_set[1]\n", + "b = [bb[None, :] for bb in b]\n", + "x, y, x_time, y_time = map(to_tensor, b)\n", + "with torch.no_grad():\n", + " forecast = model(x, x_time, y_time)\n", + "\n", + "\n", + "plt.title('mlp inr')\n", + "import matplotlib.colors as mcolors\n", + "colors = list(mcolors.BASE_COLORS.keys())\n", + "l = x.shape[1]\n", + "forecast2 = forecast[0].detach().cpu().numpy()\n", + "x2 = x[0].cpu()\n", + "y2 = y[0].cpu()\n", + "i_past = list(range(l))\n", + "i_future = list(range(l, l*2))\n", + "for i in range(x.shape[-1]):\n", + " plt.plot(range(l), x2[:, i], c=colors[i])\n", + "for i in range(x.shape[-1]):\n", + " plt.plot(range(l, l*2), y2[:, i], c=colors[i])\n", + "for i in range(x.shape[-1]):\n", + " plt.plot(range(l, l*2), forecast2[:, i], c=colors[i], linestyle='--')" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "28eca411", + "metadata": {}, + "outputs": [], + "source": [ + "\n", + "\n", + "gin.clear_config()\n", + "gin.parse_config(open(\"storage/experiments/Exchange/96Mplus2/repeat=0/config.gin\"))\n", + "\n", + "train_set, train_loader = get_data(flag='train', batch_size=16)\n", + "\n", + "model = get_model(\"deeptime2\",\n", + " dim_size=train_set.data_x.shape[1],\n", + " datetime_feats=train_set.timestamps.shape[-1]).to(default_device())\n", + "model.load_state_dict(torch.load('storage/experiments/Exchange/96Mplus2/repeat=0/model.pth'))\n", + "model = model.eval()\n", + "\n", + "\n", + "b = train_set[1]\n", + "b = [bb[None, :] for bb in b]\n", + "x, y, x_time, y_time = map(to_tensor, b)\n", + "with torch.no_grad():\n", + " forecast = model(x, x_time, y_time)\n", + "\n", + "\n", + "plt.title('inception inr')\n", + "import matplotlib.colors as mcolors\n", + "colors = list(mcolors.BASE_COLORS.keys())\n", + "l = x.shape[1]\n", + "forecast2 = forecast[0].detach().cpu().numpy()\n", + "x2 = x[0].cpu()\n", + "y2 = y[0].cpu()\n", + "l2 = y.shape[1]\n", + "i_past = list(range(l))\n", + "i_future = list(range(l, l+l2))\n", + "for i in range(x.shape[-1]):\n", + " plt.plot(i_past, x2[:, i], c=colors[i])\n", + "for i in range(x.shape[-1]):\n", + " plt.plot(i_future, y2[:, i], c=colors[i])\n", + "for i in range(x.shape[-1]):\n", + " plt.plot(i_future, forecast2[:, i], c=colors[i], linestyle='--')" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "6ca906be", + "metadata": {}, + "outputs": [], + "source": [ + "\n", + "\n", + "gin.clear_config()\n", + "gin.parse_config(open(\"storage/experiments/Exchange/96M2/repeat=0/config.gin\"))\n", + "\n", + "train_set, train_loader = get_data(flag='train', batch_size=16)\n", + "\n", + "model = get_model(\"deeptime\",\n", + " dim_size=train_set.data_x.shape[1],\n", + " datetime_feats=train_set.timestamps.shape[-1]).to(default_device())\n", + "model.load_state_dict(torch.load('storage/experiments/Exchange/96M2/repeat=0/model.pth'))\n", + "model = model.eval()\n", + "\n", + "\n", + "b = train_set[1]\n", + "b = [bb[None, :] for bb in b]\n", + "x, y, x_time, y_time = map(to_tensor, b)\n", + "with torch.no_grad():\n", + " forecast = model(x, x_time, y_time)\n", + "\n", + "\n", + "plt.title('mlp inr2')\n", + "import matplotlib.colors as mcolors\n", + "colors = list(mcolors.BASE_COLORS.keys())\n", + "l = x.shape[1]\n", + "forecast2 = forecast[0].detach().cpu().numpy()\n", + "x2 = x[0].cpu()\n", + "y2 = y[0].cpu()\n", + "l2 = y.shape[1]\n", + "i_past = list(range(l))\n", + "i_future = list(range(l, l+l2))\n", + "for i in range(x.shape[-1]):\n", + " plt.plot(i_past, x2[:, i], c=colors[i])\n", + "for i in range(x.shape[-1]):\n", + " plt.plot(i_future, y2[:, i], c=colors[i])\n", + "for i in range(x.shape[-1]):\n", + " plt.plot(i_future, forecast2[:, i], c=colors[i], linestyle='--')" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "8f8c075b", + "metadata": {}, + "outputs": [], + "source": [] + } + ], + "metadata": { + "jupytext": { + "cell_metadata_filter": "-all", + "formats": "auto:percent,ipynb", + "main_language": "python", + "notebook_metadata_filter": "-all" + }, + "kernelspec": { + "display_name": "Python 3 (ipykernel)", + "language": "python", + "name": "python3" + }, + "language_info": { + "codemirror_mode": { + "name": "ipython", + "version": 3 + }, + "file_extension": ".py", + "mimetype": "text/x-python", + "name": "python", + "nbconvert_exporter": "python", + "pygments_lexer": "ipython3", + "version": "3.9.9" + }, + "toc": { + "base_numbering": 1, + "nav_menu": {}, + "number_sections": true, + "sideBar": true, + "skip_h1_title": false, + "title_cell": "Table of Contents", + "title_sidebar": "Contents", + "toc_cell": false, + "toc_position": {}, + "toc_section_display": true, + "toc_window_display": false + } + }, + "nbformat": 4, + "nbformat_minor": 5 +}