This commit is contained in:
wassname
2022-11-20 07:53:55 +08:00
parent 90ee9c3298
commit c8631eb40d
2 changed files with 340 additions and 1 deletions
+332
View File
@@ -0,0 +1,332 @@
{
"cells": [
{
"cell_type": "code",
"execution_count": null,
"id": "4e09086b",
"metadata": {
"lines_to_next_cell": 0
},
"outputs": [],
"source": [
"\n",
"import os\n",
"from os.path import join\n",
"import math\n",
"import logging\n",
"from typing import Callable, Optional, Union, Dict, Tuple\n",
"\n",
"from matplotlib import pyplot as plt\n",
"\n",
"import gin\n",
"from fire import Fire\n",
"import numpy as np\n",
"import torch\n",
"from torch.utils.data import DataLoader\n",
"from torch import optim\n",
"from torch import nn\n",
"\n",
"from experiments.base import Experiment\n",
"from data.datasets import ForecastDataset\n",
"from models import get_model\n",
"from utils.checkpoint import Checkpoint\n",
"from utils.ops import default_device, to_tensor\n",
"from utils.losses import get_loss_fn\n",
"from utils.metrics import calc_metrics\n",
"\n",
"from experiments.forecast import get_data\n",
"gin.enter_interactive_mode()"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "516c4869",
"metadata": {},
"outputs": [],
"source": [
"\n",
"\n",
"gin.clear_config()\n",
"# gin.parse_config(open(\"storage/experiments/Exchange/96M/repeat=0/config.gin\"))\n",
"gin.parse_config(open(\"storage/experiments/Exchange/96Mplus/repeat=0/config.gin\"))"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "3e056c6f",
"metadata": {},
"outputs": [],
"source": [
"train_set, train_loader = get_data(flag='train', batch_size=16)\n",
"# x, _, _, _ =train_set[0]\n",
"# x = x * 1.0"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "07f093f2",
"metadata": {
"lines_to_next_cell": 2
},
"outputs": [],
"source": [
"# x -= x[0]\n",
"# x /= x.std()\n",
"# plt.plot(x)\n",
"# # %%"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "7aeee9df",
"metadata": {},
"outputs": [],
"source": [
"model = get_model(\"deeptime2\",\n",
" dim_size=train_set.data_x.shape[1],\n",
" datetime_feats=train_set.timestamps.shape[-1]).to(default_device())\n",
"model.load_state_dict(torch.load('storage/experiments/Exchange/96Mplus/repeat=0/model.pth'))\n",
"model = model.eval()"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "b6e8ca1f",
"metadata": {
"lines_to_next_cell": 0
},
"outputs": [],
"source": [
"b = train_set[1]\n",
"b = [bb[None, :] for bb in b]\n",
"x, y, x_time, y_time = map(to_tensor, b)\n",
"with torch.no_grad():\n",
" forecast = model(x, x_time, y_time)"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "638da5c2",
"metadata": {
"lines_to_next_cell": 2
},
"outputs": [],
"source": []
},
{
"cell_type": "code",
"execution_count": null,
"id": "66d7f095",
"metadata": {},
"outputs": [],
"source": [
"plt.title('inception inr')\n",
"import matplotlib.colors as mcolors\n",
"colors = list(mcolors.BASE_COLORS.keys())\n",
"l = x.shape[1]\n",
"forecast2 = forecast[0].detach().cpu().numpy()\n",
"x2 = x[0].cpu()\n",
"y2 = y[0].cpu()\n",
"i_past = list(range(l))\n",
"i_future = list(range(l, l*2))\n",
"for i in range(x.shape[-1]):\n",
" plt.plot(range(l), x2[:, i], c=colors[i])\n",
"for i in range(x.shape[-1]):\n",
" plt.plot(range(l, l*2), y2[:, i], c=colors[i])\n",
"for i in range(x.shape[-1]):\n",
" plt.plot(range(l, l*2), forecast2[:, i], c=colors[i], linestyle='--')"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "47bf8d52",
"metadata": {
"lines_to_next_cell": 0
},
"outputs": [],
"source": [
"gin.clear_config()\n",
"gin.parse_config(open(\"storage/experiments/Exchange/96M/repeat=0/config.gin\"))\n",
"\n",
"train_set, train_loader = get_data(flag='train', batch_size=16)\n",
"\n",
"model = get_model(\"deeptime\",\n",
" dim_size=train_set.data_x.shape[1],\n",
" datetime_feats=train_set.timestamps.shape[-1]).to(default_device())\n",
"model.load_state_dict(torch.load('storage/experiments/Exchange/96M/repeat=0/model.pth'))\n",
"model = model.eval()\n",
"\n",
"\n",
"b = train_set[1]\n",
"b = [bb[None, :] for bb in b]\n",
"x, y, x_time, y_time = map(to_tensor, b)\n",
"with torch.no_grad():\n",
" forecast = model(x, x_time, y_time)\n",
"\n",
"\n",
"plt.title('mlp inr')\n",
"import matplotlib.colors as mcolors\n",
"colors = list(mcolors.BASE_COLORS.keys())\n",
"l = x.shape[1]\n",
"forecast2 = forecast[0].detach().cpu().numpy()\n",
"x2 = x[0].cpu()\n",
"y2 = y[0].cpu()\n",
"i_past = list(range(l))\n",
"i_future = list(range(l, l*2))\n",
"for i in range(x.shape[-1]):\n",
" plt.plot(range(l), x2[:, i], c=colors[i])\n",
"for i in range(x.shape[-1]):\n",
" plt.plot(range(l, l*2), y2[:, i], c=colors[i])\n",
"for i in range(x.shape[-1]):\n",
" plt.plot(range(l, l*2), forecast2[:, i], c=colors[i], linestyle='--')"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "28eca411",
"metadata": {},
"outputs": [],
"source": [
"\n",
"\n",
"gin.clear_config()\n",
"gin.parse_config(open(\"storage/experiments/Exchange/96Mplus2/repeat=0/config.gin\"))\n",
"\n",
"train_set, train_loader = get_data(flag='train', batch_size=16)\n",
"\n",
"model = get_model(\"deeptime2\",\n",
" dim_size=train_set.data_x.shape[1],\n",
" datetime_feats=train_set.timestamps.shape[-1]).to(default_device())\n",
"model.load_state_dict(torch.load('storage/experiments/Exchange/96Mplus2/repeat=0/model.pth'))\n",
"model = model.eval()\n",
"\n",
"\n",
"b = train_set[1]\n",
"b = [bb[None, :] for bb in b]\n",
"x, y, x_time, y_time = map(to_tensor, b)\n",
"with torch.no_grad():\n",
" forecast = model(x, x_time, y_time)\n",
"\n",
"\n",
"plt.title('inception inr')\n",
"import matplotlib.colors as mcolors\n",
"colors = list(mcolors.BASE_COLORS.keys())\n",
"l = x.shape[1]\n",
"forecast2 = forecast[0].detach().cpu().numpy()\n",
"x2 = x[0].cpu()\n",
"y2 = y[0].cpu()\n",
"l2 = y.shape[1]\n",
"i_past = list(range(l))\n",
"i_future = list(range(l, l+l2))\n",
"for i in range(x.shape[-1]):\n",
" plt.plot(i_past, x2[:, i], c=colors[i])\n",
"for i in range(x.shape[-1]):\n",
" plt.plot(i_future, y2[:, i], c=colors[i])\n",
"for i in range(x.shape[-1]):\n",
" plt.plot(i_future, forecast2[:, i], c=colors[i], linestyle='--')"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "6ca906be",
"metadata": {},
"outputs": [],
"source": [
"\n",
"\n",
"gin.clear_config()\n",
"gin.parse_config(open(\"storage/experiments/Exchange/96M2/repeat=0/config.gin\"))\n",
"\n",
"train_set, train_loader = get_data(flag='train', batch_size=16)\n",
"\n",
"model = get_model(\"deeptime\",\n",
" dim_size=train_set.data_x.shape[1],\n",
" datetime_feats=train_set.timestamps.shape[-1]).to(default_device())\n",
"model.load_state_dict(torch.load('storage/experiments/Exchange/96M2/repeat=0/model.pth'))\n",
"model = model.eval()\n",
"\n",
"\n",
"b = train_set[1]\n",
"b = [bb[None, :] for bb in b]\n",
"x, y, x_time, y_time = map(to_tensor, b)\n",
"with torch.no_grad():\n",
" forecast = model(x, x_time, y_time)\n",
"\n",
"\n",
"plt.title('mlp inr2')\n",
"import matplotlib.colors as mcolors\n",
"colors = list(mcolors.BASE_COLORS.keys())\n",
"l = x.shape[1]\n",
"forecast2 = forecast[0].detach().cpu().numpy()\n",
"x2 = x[0].cpu()\n",
"y2 = y[0].cpu()\n",
"l2 = y.shape[1]\n",
"i_past = list(range(l))\n",
"i_future = list(range(l, l+l2))\n",
"for i in range(x.shape[-1]):\n",
" plt.plot(i_past, x2[:, i], c=colors[i])\n",
"for i in range(x.shape[-1]):\n",
" plt.plot(i_future, y2[:, i], c=colors[i])\n",
"for i in range(x.shape[-1]):\n",
" plt.plot(i_future, forecast2[:, i], c=colors[i], linestyle='--')"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "8f8c075b",
"metadata": {},
"outputs": [],
"source": []
}
],
"metadata": {
"jupytext": {
"cell_metadata_filter": "-all",
"formats": "auto:percent,ipynb",
"main_language": "python",
"notebook_metadata_filter": "-all"
},
"kernelspec": {
"display_name": "Python 3 (ipykernel)",
"language": "python",
"name": "python3"
},
"language_info": {
"codemirror_mode": {
"name": "ipython",
"version": 3
},
"file_extension": ".py",
"mimetype": "text/x-python",
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.9.9"
},
"toc": {
"base_numbering": 1,
"nav_menu": {},
"number_sections": true,
"sideBar": true,
"skip_h1_title": false,
"title_cell": "Table of Contents",
"title_sidebar": "Contents",
"toc_cell": false,
"toc_position": {},
"toc_section_display": true,
"toc_window_display": false
}
},
"nbformat": 4,
"nbformat_minor": 5
}