Files
DeepTime/scratch.ipynb
T
wassname c8631eb40d readme
2022-11-20 07:53:55 +08:00

333 lines
9.4 KiB
Plaintext

{
"cells": [
{
"cell_type": "code",
"execution_count": null,
"id": "4e09086b",
"metadata": {
"lines_to_next_cell": 0
},
"outputs": [],
"source": [
"\n",
"import os\n",
"from os.path import join\n",
"import math\n",
"import logging\n",
"from typing import Callable, Optional, Union, Dict, Tuple\n",
"\n",
"from matplotlib import pyplot as plt\n",
"\n",
"import gin\n",
"from fire import Fire\n",
"import numpy as np\n",
"import torch\n",
"from torch.utils.data import DataLoader\n",
"from torch import optim\n",
"from torch import nn\n",
"\n",
"from experiments.base import Experiment\n",
"from data.datasets import ForecastDataset\n",
"from models import get_model\n",
"from utils.checkpoint import Checkpoint\n",
"from utils.ops import default_device, to_tensor\n",
"from utils.losses import get_loss_fn\n",
"from utils.metrics import calc_metrics\n",
"\n",
"from experiments.forecast import get_data\n",
"gin.enter_interactive_mode()"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "516c4869",
"metadata": {},
"outputs": [],
"source": [
"\n",
"\n",
"gin.clear_config()\n",
"# gin.parse_config(open(\"storage/experiments/Exchange/96M/repeat=0/config.gin\"))\n",
"gin.parse_config(open(\"storage/experiments/Exchange/96Mplus/repeat=0/config.gin\"))"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "3e056c6f",
"metadata": {},
"outputs": [],
"source": [
"train_set, train_loader = get_data(flag='train', batch_size=16)\n",
"# x, _, _, _ =train_set[0]\n",
"# x = x * 1.0"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "07f093f2",
"metadata": {
"lines_to_next_cell": 2
},
"outputs": [],
"source": [
"# x -= x[0]\n",
"# x /= x.std()\n",
"# plt.plot(x)\n",
"# # %%"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "7aeee9df",
"metadata": {},
"outputs": [],
"source": [
"model = get_model(\"deeptime2\",\n",
" dim_size=train_set.data_x.shape[1],\n",
" datetime_feats=train_set.timestamps.shape[-1]).to(default_device())\n",
"model.load_state_dict(torch.load('storage/experiments/Exchange/96Mplus/repeat=0/model.pth'))\n",
"model = model.eval()"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "b6e8ca1f",
"metadata": {
"lines_to_next_cell": 0
},
"outputs": [],
"source": [
"b = train_set[1]\n",
"b = [bb[None, :] for bb in b]\n",
"x, y, x_time, y_time = map(to_tensor, b)\n",
"with torch.no_grad():\n",
" forecast = model(x, x_time, y_time)"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "638da5c2",
"metadata": {
"lines_to_next_cell": 2
},
"outputs": [],
"source": []
},
{
"cell_type": "code",
"execution_count": null,
"id": "66d7f095",
"metadata": {},
"outputs": [],
"source": [
"plt.title('inception inr')\n",
"import matplotlib.colors as mcolors\n",
"colors = list(mcolors.BASE_COLORS.keys())\n",
"l = x.shape[1]\n",
"forecast2 = forecast[0].detach().cpu().numpy()\n",
"x2 = x[0].cpu()\n",
"y2 = y[0].cpu()\n",
"i_past = list(range(l))\n",
"i_future = list(range(l, l*2))\n",
"for i in range(x.shape[-1]):\n",
" plt.plot(range(l), x2[:, i], c=colors[i])\n",
"for i in range(x.shape[-1]):\n",
" plt.plot(range(l, l*2), y2[:, i], c=colors[i])\n",
"for i in range(x.shape[-1]):\n",
" plt.plot(range(l, l*2), forecast2[:, i], c=colors[i], linestyle='--')"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "47bf8d52",
"metadata": {
"lines_to_next_cell": 0
},
"outputs": [],
"source": [
"gin.clear_config()\n",
"gin.parse_config(open(\"storage/experiments/Exchange/96M/repeat=0/config.gin\"))\n",
"\n",
"train_set, train_loader = get_data(flag='train', batch_size=16)\n",
"\n",
"model = get_model(\"deeptime\",\n",
" dim_size=train_set.data_x.shape[1],\n",
" datetime_feats=train_set.timestamps.shape[-1]).to(default_device())\n",
"model.load_state_dict(torch.load('storage/experiments/Exchange/96M/repeat=0/model.pth'))\n",
"model = model.eval()\n",
"\n",
"\n",
"b = train_set[1]\n",
"b = [bb[None, :] for bb in b]\n",
"x, y, x_time, y_time = map(to_tensor, b)\n",
"with torch.no_grad():\n",
" forecast = model(x, x_time, y_time)\n",
"\n",
"\n",
"plt.title('mlp inr')\n",
"import matplotlib.colors as mcolors\n",
"colors = list(mcolors.BASE_COLORS.keys())\n",
"l = x.shape[1]\n",
"forecast2 = forecast[0].detach().cpu().numpy()\n",
"x2 = x[0].cpu()\n",
"y2 = y[0].cpu()\n",
"i_past = list(range(l))\n",
"i_future = list(range(l, l*2))\n",
"for i in range(x.shape[-1]):\n",
" plt.plot(range(l), x2[:, i], c=colors[i])\n",
"for i in range(x.shape[-1]):\n",
" plt.plot(range(l, l*2), y2[:, i], c=colors[i])\n",
"for i in range(x.shape[-1]):\n",
" plt.plot(range(l, l*2), forecast2[:, i], c=colors[i], linestyle='--')"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "28eca411",
"metadata": {},
"outputs": [],
"source": [
"\n",
"\n",
"gin.clear_config()\n",
"gin.parse_config(open(\"storage/experiments/Exchange/96Mplus2/repeat=0/config.gin\"))\n",
"\n",
"train_set, train_loader = get_data(flag='train', batch_size=16)\n",
"\n",
"model = get_model(\"deeptime2\",\n",
" dim_size=train_set.data_x.shape[1],\n",
" datetime_feats=train_set.timestamps.shape[-1]).to(default_device())\n",
"model.load_state_dict(torch.load('storage/experiments/Exchange/96Mplus2/repeat=0/model.pth'))\n",
"model = model.eval()\n",
"\n",
"\n",
"b = train_set[1]\n",
"b = [bb[None, :] for bb in b]\n",
"x, y, x_time, y_time = map(to_tensor, b)\n",
"with torch.no_grad():\n",
" forecast = model(x, x_time, y_time)\n",
"\n",
"\n",
"plt.title('inception inr')\n",
"import matplotlib.colors as mcolors\n",
"colors = list(mcolors.BASE_COLORS.keys())\n",
"l = x.shape[1]\n",
"forecast2 = forecast[0].detach().cpu().numpy()\n",
"x2 = x[0].cpu()\n",
"y2 = y[0].cpu()\n",
"l2 = y.shape[1]\n",
"i_past = list(range(l))\n",
"i_future = list(range(l, l+l2))\n",
"for i in range(x.shape[-1]):\n",
" plt.plot(i_past, x2[:, i], c=colors[i])\n",
"for i in range(x.shape[-1]):\n",
" plt.plot(i_future, y2[:, i], c=colors[i])\n",
"for i in range(x.shape[-1]):\n",
" plt.plot(i_future, forecast2[:, i], c=colors[i], linestyle='--')"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "6ca906be",
"metadata": {},
"outputs": [],
"source": [
"\n",
"\n",
"gin.clear_config()\n",
"gin.parse_config(open(\"storage/experiments/Exchange/96M2/repeat=0/config.gin\"))\n",
"\n",
"train_set, train_loader = get_data(flag='train', batch_size=16)\n",
"\n",
"model = get_model(\"deeptime\",\n",
" dim_size=train_set.data_x.shape[1],\n",
" datetime_feats=train_set.timestamps.shape[-1]).to(default_device())\n",
"model.load_state_dict(torch.load('storage/experiments/Exchange/96M2/repeat=0/model.pth'))\n",
"model = model.eval()\n",
"\n",
"\n",
"b = train_set[1]\n",
"b = [bb[None, :] for bb in b]\n",
"x, y, x_time, y_time = map(to_tensor, b)\n",
"with torch.no_grad():\n",
" forecast = model(x, x_time, y_time)\n",
"\n",
"\n",
"plt.title('mlp inr2')\n",
"import matplotlib.colors as mcolors\n",
"colors = list(mcolors.BASE_COLORS.keys())\n",
"l = x.shape[1]\n",
"forecast2 = forecast[0].detach().cpu().numpy()\n",
"x2 = x[0].cpu()\n",
"y2 = y[0].cpu()\n",
"l2 = y.shape[1]\n",
"i_past = list(range(l))\n",
"i_future = list(range(l, l+l2))\n",
"for i in range(x.shape[-1]):\n",
" plt.plot(i_past, x2[:, i], c=colors[i])\n",
"for i in range(x.shape[-1]):\n",
" plt.plot(i_future, y2[:, i], c=colors[i])\n",
"for i in range(x.shape[-1]):\n",
" plt.plot(i_future, forecast2[:, i], c=colors[i], linestyle='--')"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "8f8c075b",
"metadata": {},
"outputs": [],
"source": []
}
],
"metadata": {
"jupytext": {
"cell_metadata_filter": "-all",
"formats": "auto:percent,ipynb",
"main_language": "python",
"notebook_metadata_filter": "-all"
},
"kernelspec": {
"display_name": "Python 3 (ipykernel)",
"language": "python",
"name": "python3"
},
"language_info": {
"codemirror_mode": {
"name": "ipython",
"version": 3
},
"file_extension": ".py",
"mimetype": "text/x-python",
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.9.9"
},
"toc": {
"base_numbering": 1,
"nav_menu": {},
"number_sections": true,
"sideBar": true,
"skip_h1_title": false,
"title_cell": "Table of Contents",
"title_sidebar": "Contents",
"toc_cell": false,
"toc_position": {},
"toc_section_display": true,
"toc_window_display": false
}
},
"nbformat": 4,
"nbformat_minor": 5
}