mirror of
https://github.com/wassname/DeepTime.git
synced 2026-06-27 15:13:37 +08:00
333 lines
9.4 KiB
Plaintext
333 lines
9.4 KiB
Plaintext
{
|
|
"cells": [
|
|
{
|
|
"cell_type": "code",
|
|
"execution_count": null,
|
|
"id": "4e09086b",
|
|
"metadata": {
|
|
"lines_to_next_cell": 0
|
|
},
|
|
"outputs": [],
|
|
"source": [
|
|
"\n",
|
|
"import os\n",
|
|
"from os.path import join\n",
|
|
"import math\n",
|
|
"import logging\n",
|
|
"from typing import Callable, Optional, Union, Dict, Tuple\n",
|
|
"\n",
|
|
"from matplotlib import pyplot as plt\n",
|
|
"\n",
|
|
"import gin\n",
|
|
"from fire import Fire\n",
|
|
"import numpy as np\n",
|
|
"import torch\n",
|
|
"from torch.utils.data import DataLoader\n",
|
|
"from torch import optim\n",
|
|
"from torch import nn\n",
|
|
"\n",
|
|
"from experiments.base import Experiment\n",
|
|
"from data.datasets import ForecastDataset\n",
|
|
"from models import get_model\n",
|
|
"from utils.checkpoint import Checkpoint\n",
|
|
"from utils.ops import default_device, to_tensor\n",
|
|
"from utils.losses import get_loss_fn\n",
|
|
"from utils.metrics import calc_metrics\n",
|
|
"\n",
|
|
"from experiments.forecast import get_data\n",
|
|
"gin.enter_interactive_mode()"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "code",
|
|
"execution_count": null,
|
|
"id": "516c4869",
|
|
"metadata": {},
|
|
"outputs": [],
|
|
"source": [
|
|
"\n",
|
|
"\n",
|
|
"gin.clear_config()\n",
|
|
"# gin.parse_config(open(\"storage/experiments/Exchange/96M/repeat=0/config.gin\"))\n",
|
|
"gin.parse_config(open(\"storage/experiments/Exchange/96Mplus/repeat=0/config.gin\"))"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "code",
|
|
"execution_count": null,
|
|
"id": "3e056c6f",
|
|
"metadata": {},
|
|
"outputs": [],
|
|
"source": [
|
|
"train_set, train_loader = get_data(flag='train', batch_size=16)\n",
|
|
"# x, _, _, _ =train_set[0]\n",
|
|
"# x = x * 1.0"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "code",
|
|
"execution_count": null,
|
|
"id": "07f093f2",
|
|
"metadata": {
|
|
"lines_to_next_cell": 2
|
|
},
|
|
"outputs": [],
|
|
"source": [
|
|
"# x -= x[0]\n",
|
|
"# x /= x.std()\n",
|
|
"# plt.plot(x)\n",
|
|
"# # %%"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "code",
|
|
"execution_count": null,
|
|
"id": "7aeee9df",
|
|
"metadata": {},
|
|
"outputs": [],
|
|
"source": [
|
|
"model = get_model(\"deeptime2\",\n",
|
|
" dim_size=train_set.data_x.shape[1],\n",
|
|
" datetime_feats=train_set.timestamps.shape[-1]).to(default_device())\n",
|
|
"model.load_state_dict(torch.load('storage/experiments/Exchange/96Mplus/repeat=0/model.pth'))\n",
|
|
"model = model.eval()"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "code",
|
|
"execution_count": null,
|
|
"id": "b6e8ca1f",
|
|
"metadata": {
|
|
"lines_to_next_cell": 0
|
|
},
|
|
"outputs": [],
|
|
"source": [
|
|
"b = train_set[1]\n",
|
|
"b = [bb[None, :] for bb in b]\n",
|
|
"x, y, x_time, y_time = map(to_tensor, b)\n",
|
|
"with torch.no_grad():\n",
|
|
" forecast = model(x, x_time, y_time)"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "code",
|
|
"execution_count": null,
|
|
"id": "638da5c2",
|
|
"metadata": {
|
|
"lines_to_next_cell": 2
|
|
},
|
|
"outputs": [],
|
|
"source": []
|
|
},
|
|
{
|
|
"cell_type": "code",
|
|
"execution_count": null,
|
|
"id": "66d7f095",
|
|
"metadata": {},
|
|
"outputs": [],
|
|
"source": [
|
|
"plt.title('inception inr')\n",
|
|
"import matplotlib.colors as mcolors\n",
|
|
"colors = list(mcolors.BASE_COLORS.keys())\n",
|
|
"l = x.shape[1]\n",
|
|
"forecast2 = forecast[0].detach().cpu().numpy()\n",
|
|
"x2 = x[0].cpu()\n",
|
|
"y2 = y[0].cpu()\n",
|
|
"i_past = list(range(l))\n",
|
|
"i_future = list(range(l, l*2))\n",
|
|
"for i in range(x.shape[-1]):\n",
|
|
" plt.plot(range(l), x2[:, i], c=colors[i])\n",
|
|
"for i in range(x.shape[-1]):\n",
|
|
" plt.plot(range(l, l*2), y2[:, i], c=colors[i])\n",
|
|
"for i in range(x.shape[-1]):\n",
|
|
" plt.plot(range(l, l*2), forecast2[:, i], c=colors[i], linestyle='--')"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "code",
|
|
"execution_count": null,
|
|
"id": "47bf8d52",
|
|
"metadata": {
|
|
"lines_to_next_cell": 0
|
|
},
|
|
"outputs": [],
|
|
"source": [
|
|
"gin.clear_config()\n",
|
|
"gin.parse_config(open(\"storage/experiments/Exchange/96M/repeat=0/config.gin\"))\n",
|
|
"\n",
|
|
"train_set, train_loader = get_data(flag='train', batch_size=16)\n",
|
|
"\n",
|
|
"model = get_model(\"deeptime\",\n",
|
|
" dim_size=train_set.data_x.shape[1],\n",
|
|
" datetime_feats=train_set.timestamps.shape[-1]).to(default_device())\n",
|
|
"model.load_state_dict(torch.load('storage/experiments/Exchange/96M/repeat=0/model.pth'))\n",
|
|
"model = model.eval()\n",
|
|
"\n",
|
|
"\n",
|
|
"b = train_set[1]\n",
|
|
"b = [bb[None, :] for bb in b]\n",
|
|
"x, y, x_time, y_time = map(to_tensor, b)\n",
|
|
"with torch.no_grad():\n",
|
|
" forecast = model(x, x_time, y_time)\n",
|
|
"\n",
|
|
"\n",
|
|
"plt.title('mlp inr')\n",
|
|
"import matplotlib.colors as mcolors\n",
|
|
"colors = list(mcolors.BASE_COLORS.keys())\n",
|
|
"l = x.shape[1]\n",
|
|
"forecast2 = forecast[0].detach().cpu().numpy()\n",
|
|
"x2 = x[0].cpu()\n",
|
|
"y2 = y[0].cpu()\n",
|
|
"i_past = list(range(l))\n",
|
|
"i_future = list(range(l, l*2))\n",
|
|
"for i in range(x.shape[-1]):\n",
|
|
" plt.plot(range(l), x2[:, i], c=colors[i])\n",
|
|
"for i in range(x.shape[-1]):\n",
|
|
" plt.plot(range(l, l*2), y2[:, i], c=colors[i])\n",
|
|
"for i in range(x.shape[-1]):\n",
|
|
" plt.plot(range(l, l*2), forecast2[:, i], c=colors[i], linestyle='--')"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "code",
|
|
"execution_count": null,
|
|
"id": "28eca411",
|
|
"metadata": {},
|
|
"outputs": [],
|
|
"source": [
|
|
"\n",
|
|
"\n",
|
|
"gin.clear_config()\n",
|
|
"gin.parse_config(open(\"storage/experiments/Exchange/96Mplus2/repeat=0/config.gin\"))\n",
|
|
"\n",
|
|
"train_set, train_loader = get_data(flag='train', batch_size=16)\n",
|
|
"\n",
|
|
"model = get_model(\"deeptime2\",\n",
|
|
" dim_size=train_set.data_x.shape[1],\n",
|
|
" datetime_feats=train_set.timestamps.shape[-1]).to(default_device())\n",
|
|
"model.load_state_dict(torch.load('storage/experiments/Exchange/96Mplus2/repeat=0/model.pth'))\n",
|
|
"model = model.eval()\n",
|
|
"\n",
|
|
"\n",
|
|
"b = train_set[1]\n",
|
|
"b = [bb[None, :] for bb in b]\n",
|
|
"x, y, x_time, y_time = map(to_tensor, b)\n",
|
|
"with torch.no_grad():\n",
|
|
" forecast = model(x, x_time, y_time)\n",
|
|
"\n",
|
|
"\n",
|
|
"plt.title('inception inr')\n",
|
|
"import matplotlib.colors as mcolors\n",
|
|
"colors = list(mcolors.BASE_COLORS.keys())\n",
|
|
"l = x.shape[1]\n",
|
|
"forecast2 = forecast[0].detach().cpu().numpy()\n",
|
|
"x2 = x[0].cpu()\n",
|
|
"y2 = y[0].cpu()\n",
|
|
"l2 = y.shape[1]\n",
|
|
"i_past = list(range(l))\n",
|
|
"i_future = list(range(l, l+l2))\n",
|
|
"for i in range(x.shape[-1]):\n",
|
|
" plt.plot(i_past, x2[:, i], c=colors[i])\n",
|
|
"for i in range(x.shape[-1]):\n",
|
|
" plt.plot(i_future, y2[:, i], c=colors[i])\n",
|
|
"for i in range(x.shape[-1]):\n",
|
|
" plt.plot(i_future, forecast2[:, i], c=colors[i], linestyle='--')"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "code",
|
|
"execution_count": null,
|
|
"id": "6ca906be",
|
|
"metadata": {},
|
|
"outputs": [],
|
|
"source": [
|
|
"\n",
|
|
"\n",
|
|
"gin.clear_config()\n",
|
|
"gin.parse_config(open(\"storage/experiments/Exchange/96M2/repeat=0/config.gin\"))\n",
|
|
"\n",
|
|
"train_set, train_loader = get_data(flag='train', batch_size=16)\n",
|
|
"\n",
|
|
"model = get_model(\"deeptime\",\n",
|
|
" dim_size=train_set.data_x.shape[1],\n",
|
|
" datetime_feats=train_set.timestamps.shape[-1]).to(default_device())\n",
|
|
"model.load_state_dict(torch.load('storage/experiments/Exchange/96M2/repeat=0/model.pth'))\n",
|
|
"model = model.eval()\n",
|
|
"\n",
|
|
"\n",
|
|
"b = train_set[1]\n",
|
|
"b = [bb[None, :] for bb in b]\n",
|
|
"x, y, x_time, y_time = map(to_tensor, b)\n",
|
|
"with torch.no_grad():\n",
|
|
" forecast = model(x, x_time, y_time)\n",
|
|
"\n",
|
|
"\n",
|
|
"plt.title('mlp inr2')\n",
|
|
"import matplotlib.colors as mcolors\n",
|
|
"colors = list(mcolors.BASE_COLORS.keys())\n",
|
|
"l = x.shape[1]\n",
|
|
"forecast2 = forecast[0].detach().cpu().numpy()\n",
|
|
"x2 = x[0].cpu()\n",
|
|
"y2 = y[0].cpu()\n",
|
|
"l2 = y.shape[1]\n",
|
|
"i_past = list(range(l))\n",
|
|
"i_future = list(range(l, l+l2))\n",
|
|
"for i in range(x.shape[-1]):\n",
|
|
" plt.plot(i_past, x2[:, i], c=colors[i])\n",
|
|
"for i in range(x.shape[-1]):\n",
|
|
" plt.plot(i_future, y2[:, i], c=colors[i])\n",
|
|
"for i in range(x.shape[-1]):\n",
|
|
" plt.plot(i_future, forecast2[:, i], c=colors[i], linestyle='--')"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "code",
|
|
"execution_count": null,
|
|
"id": "8f8c075b",
|
|
"metadata": {},
|
|
"outputs": [],
|
|
"source": []
|
|
}
|
|
],
|
|
"metadata": {
|
|
"jupytext": {
|
|
"cell_metadata_filter": "-all",
|
|
"formats": "auto:percent,ipynb",
|
|
"main_language": "python",
|
|
"notebook_metadata_filter": "-all"
|
|
},
|
|
"kernelspec": {
|
|
"display_name": "Python 3 (ipykernel)",
|
|
"language": "python",
|
|
"name": "python3"
|
|
},
|
|
"language_info": {
|
|
"codemirror_mode": {
|
|
"name": "ipython",
|
|
"version": 3
|
|
},
|
|
"file_extension": ".py",
|
|
"mimetype": "text/x-python",
|
|
"name": "python",
|
|
"nbconvert_exporter": "python",
|
|
"pygments_lexer": "ipython3",
|
|
"version": "3.9.9"
|
|
},
|
|
"toc": {
|
|
"base_numbering": 1,
|
|
"nav_menu": {},
|
|
"number_sections": true,
|
|
"sideBar": true,
|
|
"skip_h1_title": false,
|
|
"title_cell": "Table of Contents",
|
|
"title_sidebar": "Contents",
|
|
"toc_cell": false,
|
|
"toc_position": {},
|
|
"toc_section_display": true,
|
|
"toc_window_display": false
|
|
}
|
|
},
|
|
"nbformat": 4,
|
|
"nbformat_minor": 5
|
|
}
|