mirror of
https://github.com/wassname/DeepTime.git
synced 2026-07-04 22:43:25 +08:00
775 lines
24 KiB
Plaintext
775 lines
24 KiB
Plaintext
{
|
|
"cells": [
|
|
{
|
|
"cell_type": "markdown",
|
|
"id": "b1e031e3",
|
|
"metadata": {},
|
|
"source": [
|
|
"- [x] try just one predictor\n",
|
|
" - [ ] multi input, single output\n",
|
|
"- [x] comparem ulti\n",
|
|
"- losses:\n",
|
|
" - try logp? nah\n",
|
|
" - mae?\n",
|
|
"- [x] make my own csv with 5m data (maybe 10k rows)\n",
|
|
"- [ ] backtest?"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "code",
|
|
"execution_count": 1,
|
|
"id": "7f9e3d73",
|
|
"metadata": {
|
|
"ExecuteTime": {
|
|
"end_time": "2022-11-23T03:32:18.927915Z",
|
|
"start_time": "2022-11-23T03:32:18.918871Z"
|
|
}
|
|
},
|
|
"outputs": [],
|
|
"source": [
|
|
"import warnings\n",
|
|
"warnings.simplefilter(\"ignore\")\n",
|
|
"\n",
|
|
"# autoreload import your package\n",
|
|
"%load_ext autoreload\n",
|
|
"%autoreload 2"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "code",
|
|
"execution_count": 2,
|
|
"id": "4e09086b",
|
|
"metadata": {
|
|
"ExecuteTime": {
|
|
"end_time": "2022-11-23T03:32:20.144762Z",
|
|
"start_time": "2022-11-23T03:32:18.928928Z"
|
|
},
|
|
"lines_to_next_cell": 0
|
|
},
|
|
"outputs": [],
|
|
"source": [
|
|
"import os\n",
|
|
"from os.path import join\n",
|
|
"import math\n",
|
|
"import logging\n",
|
|
"from typing import Callable, Optional, Union, Dict, Tuple\n",
|
|
"\n",
|
|
"from matplotlib import pyplot as plt\n",
|
|
"from pathlib import Path\n",
|
|
"import matplotlib.colors as mcolors\n",
|
|
"\n",
|
|
"import gin\n",
|
|
"from fire import Fire\n",
|
|
"import numpy as np\n",
|
|
"import torch\n",
|
|
"from torch.utils.data import DataLoader\n",
|
|
"from torch import optim\n",
|
|
"from torch import nn\n",
|
|
"\n",
|
|
"from experiments.base import Experiment\n",
|
|
"from data.datasets import ForecastDataset\n",
|
|
"from models import get_model\n",
|
|
"from utils.checkpoint import Checkpoint\n",
|
|
"from utils.ops import default_device, to_tensor\n",
|
|
"from utils.losses import get_loss_fn\n",
|
|
"from utils.metrics import calc_metrics\n",
|
|
"\n",
|
|
"from experiments.forecast import get_data\n",
|
|
"gin.enter_interactive_mode()"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "code",
|
|
"execution_count": 3,
|
|
"id": "66d7f095",
|
|
"metadata": {
|
|
"ExecuteTime": {
|
|
"end_time": "2022-11-23T03:32:20.171177Z",
|
|
"start_time": "2022-11-23T03:32:20.146544Z"
|
|
}
|
|
},
|
|
"outputs": [
|
|
{
|
|
"data": {
|
|
"text/plain": [
|
|
"1"
|
|
]
|
|
},
|
|
"execution_count": 3,
|
|
"metadata": {},
|
|
"output_type": "execute_result"
|
|
}
|
|
],
|
|
"source": [
|
|
"import logging\n",
|
|
"logging.root.setLevel(logging.INFO)\n",
|
|
"\n",
|
|
"from loguru import logger\n",
|
|
"logger.remove()\n",
|
|
"logger.add(os.sys.stdout, level=\"INFO\", colorize=True, format=\"<level>{time} | {message}</level>\")"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "markdown",
|
|
"id": "d4df5270",
|
|
"metadata": {},
|
|
"source": [
|
|
"# auto"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "code",
|
|
"execution_count": 4,
|
|
"id": "04499bef",
|
|
"metadata": {
|
|
"ExecuteTime": {
|
|
"end_time": "2022-11-23T03:32:20.191130Z",
|
|
"start_time": "2022-11-23T03:32:20.172479Z"
|
|
}
|
|
},
|
|
"outputs": [],
|
|
"source": [
|
|
"\n",
|
|
"\n",
|
|
"def plot(model_name=\"deeptime\", save_path=Path(\"storage/experiments/Exchange/96M/repeat=0\"), i=200, title=None, plot=True):\n",
|
|
"\n",
|
|
" gin.clear_config()\n",
|
|
" gin.parse_config(open(save_path/\"config.gin\"))\n",
|
|
"\n",
|
|
" train_set, train_loader = get_data(flag='train', batch_size=2)\n",
|
|
"\n",
|
|
" model = get_model(model_name,\n",
|
|
" dim_size=train_set.data_x.shape[1],\n",
|
|
" datetime_feats=train_set.timestamps.shape[-1]).to(default_device())\n",
|
|
" model.load_state_dict(torch.load(save_path/'model.pth'))\n",
|
|
" model = model.eval()\n",
|
|
"\n",
|
|
"\n",
|
|
" b = train_set[i]\n",
|
|
" b = [bb[None, :] for bb in b]\n",
|
|
" b2 = list(map(to_tensor, b))\n",
|
|
" context_past_x, context_y, query_past_x, query_y, context_time, query_time = b2\n",
|
|
" with torch.no_grad():\n",
|
|
" forecast = model(*b2)\n",
|
|
"\n",
|
|
" if title is None:\n",
|
|
" title = str(save_path).split('/')[-3:]\n",
|
|
" title = \"-\".join(title)\n",
|
|
" \n",
|
|
" colors = list(mcolors.BASE_COLORS.keys())\n",
|
|
" l = x.shape[1]\n",
|
|
" forecast2 = forecast[0].detach().cpu().numpy()\n",
|
|
" x2 = x[0].cpu()\n",
|
|
" y2 = y[0].cpu()\n",
|
|
" l2 = y.shape[1]\n",
|
|
" i_past = list(range(l))\n",
|
|
" i_future = list(range(l, l+l2))\n",
|
|
" \n",
|
|
" if plot:\n",
|
|
" plt.title(title)\n",
|
|
" for i in range(x.shape[-1]):\n",
|
|
" plt.plot(i_past, x2[:, i], c=colors[i])\n",
|
|
" for i in range(x.shape[-1]):\n",
|
|
" plt.plot(i_future, y2[:, i], c=colors[i])\n",
|
|
" for i in range(x.shape[-1]):\n",
|
|
" plt.plot(i_future, forecast2[:, i], c=colors[i], linestyle='--')\n",
|
|
" return x2, y2, forecast2, i_past, i_future\n"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "code",
|
|
"execution_count": null,
|
|
"id": "09dd5ebd",
|
|
"metadata": {
|
|
"ExecuteTime": {
|
|
"end_time": "2022-11-22T13:28:35.849491Z",
|
|
"start_time": "2022-11-22T13:28:35.766453Z"
|
|
},
|
|
"scrolled": true
|
|
},
|
|
"outputs": [],
|
|
"source": []
|
|
},
|
|
{
|
|
"cell_type": "code",
|
|
"execution_count": 5,
|
|
"id": "dc530891",
|
|
"metadata": {
|
|
"ExecuteTime": {
|
|
"end_time": "2022-11-23T03:32:20.213967Z",
|
|
"start_time": "2022-11-23T03:32:20.192307Z"
|
|
}
|
|
},
|
|
"outputs": [
|
|
{
|
|
"data": {
|
|
"text/plain": [
|
|
"['b', 'g', 'r', 'c', 'm', 'y', 'k', 'w']"
|
|
]
|
|
},
|
|
"execution_count": 5,
|
|
"metadata": {},
|
|
"output_type": "execute_result"
|
|
}
|
|
],
|
|
"source": [
|
|
"list(mcolors.BASE_COLORS.keys())"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "markdown",
|
|
"id": "da8f502f",
|
|
"metadata": {
|
|
"ExecuteTime": {
|
|
"end_time": "2022-11-22T13:38:15.786656Z",
|
|
"start_time": "2022-11-22T13:38:15.303577Z"
|
|
}
|
|
},
|
|
"source": [
|
|
"# view model"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "markdown",
|
|
"id": "ccc9fdb0",
|
|
"metadata": {},
|
|
"source": [
|
|
"# run exps"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "code",
|
|
"execution_count": null,
|
|
"id": "e09f2823",
|
|
"metadata": {
|
|
"ExecuteTime": {
|
|
"end_time": "2022-11-23T03:31:08.583641Z",
|
|
"start_time": "2022-11-23T03:31:08.558690Z"
|
|
}
|
|
},
|
|
"outputs": [],
|
|
"source": []
|
|
},
|
|
{
|
|
"cell_type": "code",
|
|
"execution_count": 6,
|
|
"id": "9d13d779",
|
|
"metadata": {
|
|
"ExecuteTime": {
|
|
"end_time": "2022-11-23T03:32:20.232318Z",
|
|
"start_time": "2022-11-23T03:32:20.215075Z"
|
|
}
|
|
},
|
|
"outputs": [
|
|
{
|
|
"name": "stdout",
|
|
"output_type": "stream",
|
|
"text": [
|
|
"[Path('storage/experiments/Stocks/96M2S/base_learner=Transformer,inr=INR,encoder=mlp,repeat=0/config.gin'), Path('storage/experiments/Stocks/96M2S/base_learner=Ridge,inr=INRPlus2,encoder=inception,repeat=0/config.gin'), Path('storage/experiments/Stocks/96M2S/base_learner=Ridge,inr=INRPlus2,encoder=none,repeat=0/config.gin'), Path('storage/experiments/Stocks/96M2S/base_learner=None,inr=INRPlus2,encoder=transformer2,repeat=0/config.gin'), Path('storage/experiments/Stocks/96M2S/base_learner=None,inr=INR,encoder=transformer,repeat=0/config.gin'), Path('storage/experiments/Stocks/96M2S/base_learner=Transformer,inr=INR,encoder=lstm,repeat=0/config.gin'), Path('storage/experiments/Stocks/96M2S/base_learner=None,inr=INR,encoder=mlp,repeat=0/config.gin'), Path('storage/experiments/Stocks/96M2S/base_learner=Transformer,inr=INR,encoder=transformer,repeat=0/config.gin'), Path('storage/experiments/Stocks/96M2S/base_learner=None,inr=INR,encoder=transformer2,repeat=0/config.gin'), Path('storage/experiments/Stocks/96M2S/base_learner=Ridge,inr=INR,encoder=lstm,repeat=0/config.gin'), Path('storage/experiments/Stocks/96M2S/base_learner=Ridge,inr=INR,encoder=lstm2,repeat=0/config.gin'), Path('storage/experiments/Stocks/96M2S/base_learner=Ridge,inr=INRPlus2,encoder=mlp,repeat=0/config.gin'), Path('storage/experiments/Stocks/96M2S/base_learner=None,inr=INRPlus2,encoder=transformer,repeat=0/config.gin'), Path('storage/experiments/Stocks/96M2S/base_learner=None,inr=INRPlus2,encoder=inception,repeat=0/config.gin'), Path('storage/experiments/Stocks/96M2S/base_learner=Ridge,inr=INR,encoder=mlp,repeat=0/config.gin'), Path('storage/experiments/Stocks/96M2S/base_learner=Transformer,inr=INRPlus2,encoder=transformer2,repeat=0/config.gin'), Path('storage/experiments/Stocks/96M2S/base_learner=None,inr=INR,encoder=lstm,repeat=0/config.gin'), Path('storage/experiments/Stocks/96M2S/base_learner=Transformer,inr=INR,encoder=transformer2,repeat=0/config.gin'), Path('storage/experiments/Stocks/96M2S/base_learner=None,inr=INRPlus2,encoder=mlp,repeat=0/config.gin'), Path('storage/experiments/Stocks/96M2S/base_learner=Transformer,inr=INR,encoder=lstm2,repeat=0/config.gin'), Path('storage/experiments/Stocks/96M2S/base_learner=Ridge,inr=INR,encoder=transformer,repeat=0/config.gin'), Path('storage/experiments/Stocks/96M2S/base_learner=Ridge,inr=INRPlus2,encoder=lstm2,repeat=0/config.gin'), Path('storage/experiments/Stocks/96M2S/base_learner=Transformer,inr=INRPlus2,encoder=lstm,repeat=0/config.gin'), Path('storage/experiments/Stocks/96M2S/base_learner=Ridge,inr=INRPlus2,encoder=transformer2,repeat=0/config.gin'), Path('storage/experiments/Stocks/96M2S/base_learner=None,inr=INR,encoder=lstm2,repeat=0/config.gin'), Path('storage/experiments/Stocks/96M2S/base_learner=Ridge,inr=INR,encoder=transformer2,repeat=0/config.gin'), Path('storage/experiments/Stocks/96M2S/base_learner=Transformer,inr=INRPlus2,encoder=inception,repeat=0/config.gin'), Path('storage/experiments/Stocks/96M2S/base_learner=None,inr=INRPlus2,encoder=none,repeat=0/config.gin'), Path('storage/experiments/Stocks/96M2S/base_learner=Ridge,inr=INR,encoder=inception,repeat=0/config.gin'), Path('storage/experiments/Stocks/96M2S/base_learner=Transformer,inr=INRPlus2,encoder=mlp,repeat=0/config.gin'), Path('storage/experiments/Stocks/96M2S/base_learner=Transformer,inr=INRPlus2,encoder=transformer,repeat=0/config.gin'), Path('storage/experiments/Stocks/96M2S/base_learner=None,inr=INRPlus2,encoder=lstm,repeat=0/config.gin'), Path('storage/experiments/Stocks/96M2S/base_learner=None,inr=INR,encoder=inception,repeat=0/config.gin'), Path('storage/experiments/Stocks/96M2S/base_learner=None,inr=INR,encoder=none,repeat=0/config.gin'), Path('storage/experiments/Stocks/96M2S/base_learner=Transformer,inr=INR,encoder=none,repeat=0/config.gin'), Path('storage/experiments/Stocks/96M2S/base_learner=Transformer,inr=INR,encoder=inception,repeat=0/config.gin'), Path('storage/experiments/Stocks/96M2S/base_learner=None,inr=INRPlus2,encoder=lstm2,repeat=0/config.gin'), Path('storage/experiments/Stocks/96M2S/base_learner=Transformer,inr=INRPlus2,encoder=lstm2,repeat=0/config.gin'), Path('storage/experiments/Stocks/96M2S/base_learner=Ridge,inr=INR,encoder=none,repeat=0/config.gin'), Path('storage/experiments/Stocks/96M2S/base_learner=Ridge,inr=INRPlus2,encoder=transformer,repeat=0/config.gin'), Path('storage/experiments/Stocks/96M2S/base_learner=Transformer,inr=INRPlus2,encoder=none,repeat=0/config.gin'), Path('storage/experiments/Stocks/96M2S/base_learner=Ridge,inr=INRPlus2,encoder=lstm,repeat=0/config.gin')]\n"
|
|
]
|
|
}
|
|
],
|
|
"source": [
|
|
"# list the models we have run...\n",
|
|
"configs=sorted(Path(\"storage/experiments/Stocks\").glob(\"**/config.gin\"))\n",
|
|
"import random\n",
|
|
"random.shuffle(configs)\n",
|
|
"print(configs)"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "code",
|
|
"execution_count": 7,
|
|
"id": "4178d85e",
|
|
"metadata": {
|
|
"ExecuteTime": {
|
|
"end_time": "2022-11-23T03:32:20.248919Z",
|
|
"start_time": "2022-11-23T03:32:20.233964Z"
|
|
}
|
|
},
|
|
"outputs": [],
|
|
"source": [
|
|
"from experiments.forecast import ForecastExperiment"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "code",
|
|
"execution_count": null,
|
|
"id": "8eadaa48",
|
|
"metadata": {
|
|
"ExecuteTime": {
|
|
"start_time": "2022-11-23T04:00:02.202Z"
|
|
}
|
|
},
|
|
"outputs": [
|
|
{
|
|
"name": "stdout",
|
|
"output_type": "stream",
|
|
"text": [
|
|
"storage/experiments/Stocks/96M2S/base_learner=Transformer,inr=INR,encoder=mlp,repeat=0/config.gin\n"
|
|
]
|
|
},
|
|
{
|
|
"name": "stderr",
|
|
"output_type": "stream",
|
|
"text": [
|
|
"INFO:root:epochs: 1, iters: 100 | training loss: 2.41\n",
|
|
"INFO:root:Validation loss decreased (inf --> 1.004). Saving model ...\n",
|
|
"INFO:root:epochs: 2, iters: 100 | training loss: 0.20\n",
|
|
"INFO:root:Validation loss decreased (1.004 --> 0.115). Saving model ...\n"
|
|
]
|
|
}
|
|
],
|
|
"source": [
|
|
"for config in configs:\n",
|
|
" save_path = config.parent\n",
|
|
"\n",
|
|
" exp = ForecastExperiment(config_path=config)\n",
|
|
" print(config)\n",
|
|
" try:\n",
|
|
" exp.run()\n",
|
|
" except KeyboardInterrupt:\n",
|
|
" raise\n",
|
|
" except Exception as e:\n",
|
|
"# raise\n",
|
|
" print(e)\n",
|
|
" pass"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "code",
|
|
"execution_count": null,
|
|
"id": "e7a3c151",
|
|
"metadata": {
|
|
"ExecuteTime": {
|
|
"end_time": "2022-11-23T03:36:04.312188Z",
|
|
"start_time": "2022-11-23T03:36:04.312180Z"
|
|
}
|
|
},
|
|
"outputs": [],
|
|
"source": [
|
|
"exp.instance()"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "code",
|
|
"execution_count": null,
|
|
"id": "8b73c64e",
|
|
"metadata": {
|
|
"ExecuteTime": {
|
|
"end_time": "2022-11-23T03:36:04.313334Z",
|
|
"start_time": "2022-11-23T03:36:04.313322Z"
|
|
}
|
|
},
|
|
"outputs": [],
|
|
"source": [
|
|
"# save_path = Path('storage/experiments/Stocks/96M2S/repeat=0')"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "code",
|
|
"execution_count": null,
|
|
"id": "6af1bb1e",
|
|
"metadata": {
|
|
"ExecuteTime": {
|
|
"end_time": "2022-11-23T03:36:04.313962Z",
|
|
"start_time": "2022-11-23T03:36:04.313954Z"
|
|
}
|
|
},
|
|
"outputs": [],
|
|
"source": [
|
|
"# gin.clear_config()\n",
|
|
"# config_path = save_path/\"config.gin\"\n",
|
|
"# gin.parse_config(open(config_path))\n",
|
|
"# model_name = gin.query_parameter(\"instance.model_type\")\n",
|
|
"# model_name"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "code",
|
|
"execution_count": null,
|
|
"id": "50a92a7e",
|
|
"metadata": {
|
|
"ExecuteTime": {
|
|
"end_time": "2022-11-23T03:36:04.314594Z",
|
|
"start_time": "2022-11-23T03:36:04.314588Z"
|
|
}
|
|
},
|
|
"outputs": [],
|
|
"source": [
|
|
"\n",
|
|
"# exp = ForecastExperiment(config_path=config_path)\n",
|
|
"# # exp.run()"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "code",
|
|
"execution_count": null,
|
|
"id": "9ef76b76",
|
|
"metadata": {
|
|
"ExecuteTime": {
|
|
"end_time": "2022-11-23T03:36:04.315217Z",
|
|
"start_time": "2022-11-23T03:36:04.315210Z"
|
|
}
|
|
},
|
|
"outputs": [],
|
|
"source": [
|
|
"def save_path2name(save_path: Path) -> str:\n",
|
|
" \"\"\"\n",
|
|
" Path('storage/experiments/Stocks/96M2S/base_learner=None,inr=INR,encoder=mlp,repeat=0')\n",
|
|
" to \n",
|
|
" '96M2S-None_INR_mlp_0'\n",
|
|
" \"\"\"\n",
|
|
" mtitle = str(save_path).split('/')[-2:]\n",
|
|
" tags = mtitle[-1]\n",
|
|
" tags = [x.split('=')[-1] for x in tags.split(',')]\n",
|
|
" mtitle[-1] = '_'.join(tags)\n",
|
|
" mtitle = \"-\".join(mtitle)\n",
|
|
" return mtitle\n",
|
|
"\n",
|
|
"# save_path2name(save_path)\n",
|
|
"# save_path"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "markdown",
|
|
"id": "3637e87d",
|
|
"metadata": {},
|
|
"source": [
|
|
"# view all"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "code",
|
|
"execution_count": null,
|
|
"id": "768530be",
|
|
"metadata": {
|
|
"ExecuteTime": {
|
|
"end_time": "2022-11-23T03:36:04.315891Z",
|
|
"start_time": "2022-11-23T03:36:04.315884Z"
|
|
}
|
|
},
|
|
"outputs": [],
|
|
"source": [
|
|
"from torchsummaryX import summary\n",
|
|
"\n",
|
|
"def plot_multi(save_paths=[Path(\"storage/experiments/Exchange/96M/repeat=0\")], i=200, title=None, plot=True, verbose=1,):\n",
|
|
" assert len(save_paths)>0\n",
|
|
" for j in range(len(save_paths)):\n",
|
|
" save_path = save_paths[j]\n",
|
|
"\n",
|
|
" gin.clear_config()\n",
|
|
" gin.parse_config(open(save_path/\"config.gin\"))\n",
|
|
" model_name = gin.query_parameter(\"instance.model_type\")\n",
|
|
"\n",
|
|
" train_set, train_loader = get_data(flag='test', batch_size=3)\n",
|
|
" seq_len = train_set[0][1].shape\n",
|
|
" model = get_model(model_name,\n",
|
|
" dim_size=train_set.data_x.shape[1],\n",
|
|
" seq_len=seq_len,\n",
|
|
" datetime_feats=train_set.timestamps.shape[-1]).to(default_device())\n",
|
|
" model.load_state_dict(torch.load(save_path/'model.pth'))\n",
|
|
" model = model.eval()\n",
|
|
" \n",
|
|
" \n",
|
|
"\n",
|
|
"\n",
|
|
" b = train_set[i]\n",
|
|
" b = [bb[None, :] for bb in b]\n",
|
|
" b2 = list(map(to_tensor, b))\n",
|
|
" \n",
|
|
"# b = next(iter(train_loader))\n",
|
|
"# print([s.shape for s in b]\n",
|
|
" \n",
|
|
" if verbose>1:\n",
|
|
" \n",
|
|
"# print(model)\n",
|
|
" summary(model, *b2)\n",
|
|
" print(save_path)\n",
|
|
" \n",
|
|
" context_past_x, context_y, query_past_x, query_y, context_time, query_time = b2\n",
|
|
" with torch.no_grad():\n",
|
|
" forecast = model(*b2)\n",
|
|
" \n",
|
|
" colors = list(mcolors.BASE_COLORS.keys())\n",
|
|
" l = context_time.shape[1]\n",
|
|
" forecast2 = forecast[0].detach().cpu().numpy()\n",
|
|
" x2 = context_y[0].cpu()\n",
|
|
" y2 = query_y[0].cpu()\n",
|
|
" l2 = query_time.shape[1]\n",
|
|
" i_past = list(range(l))\n",
|
|
" i_future = list(range(l, l+l2))\n",
|
|
" \n",
|
|
" \n",
|
|
"\n",
|
|
" if plot:\n",
|
|
" \n",
|
|
" if j==0:\n",
|
|
" plt.plot(i_past, x2[:, 0], c=colors[0], label=f\"past\")\n",
|
|
" plt.plot(i_future, y2[:, 0], c=colors[0], label=\"future true\", alpha=0.3)\n",
|
|
" mtitle = save_path2name(save_path)\n",
|
|
" plt.plot(i_future, forecast2[:, 0], linestyle='--', label=f\"{mtitle}\") # c=colors[j], \n",
|
|
" \n",
|
|
"\n",
|
|
" plt.legend()\n",
|
|
" plt.title(title)\n",
|
|
" return x2, y2, forecast2, i_past, i_future\n"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "code",
|
|
"execution_count": null,
|
|
"id": "739ee5e3",
|
|
"metadata": {
|
|
"ExecuteTime": {
|
|
"end_time": "2022-11-23T03:36:04.316469Z",
|
|
"start_time": "2022-11-23T03:36:04.316463Z"
|
|
}
|
|
},
|
|
"outputs": [],
|
|
"source": [
|
|
"# list the models we have run...\n",
|
|
"m=sorted(Path(\"storage/experiments/Stocks/96M2S\").glob(\"**/_SUCCESS\"))\n",
|
|
"print(m)"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "code",
|
|
"execution_count": null,
|
|
"id": "48e9175b",
|
|
"metadata": {
|
|
"ExecuteTime": {
|
|
"end_time": "2022-11-23T03:36:04.316979Z",
|
|
"start_time": "2022-11-23T03:36:04.316973Z"
|
|
}
|
|
},
|
|
"outputs": [],
|
|
"source": [
|
|
"for mm in m:\n",
|
|
" mtitle = save_path2name(mm.parent)\n",
|
|
" print(mtitle)\n",
|
|
" m3 = np.load(mm.parent/'metrics.npy', allow_pickle=1)\n",
|
|
" m3 = eval(str(m3))\n",
|
|
" print(m3['val']['mape'])"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "code",
|
|
"execution_count": null,
|
|
"id": "4b966ef7",
|
|
"metadata": {
|
|
"ExecuteTime": {
|
|
"end_time": "2022-11-23T03:36:04.317585Z",
|
|
"start_time": "2022-11-23T03:36:04.317578Z"
|
|
}
|
|
},
|
|
"outputs": [],
|
|
"source": [
|
|
"train_set, train_loader = get_data(flag='train')\n",
|
|
"train_set[0][1].shape"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "code",
|
|
"execution_count": null,
|
|
"id": "0b696e85",
|
|
"metadata": {
|
|
"ExecuteTime": {
|
|
"end_time": "2022-11-23T03:36:04.318277Z",
|
|
"start_time": "2022-11-23T03:36:04.318270Z"
|
|
}
|
|
},
|
|
"outputs": [],
|
|
"source": [
|
|
"save_paths = [mm.parent for mm in m]\n",
|
|
"for mm in m:\n",
|
|
" try:\n",
|
|
" plot_multi(\n",
|
|
" save_paths=[mm.parent],\n",
|
|
" i=600,\n",
|
|
" verbose=2,\n",
|
|
" )\n",
|
|
" except:\n",
|
|
" print('failed', mm)\n",
|
|
"# mm.unlink()\n",
|
|
" pass\n",
|
|
"1"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "code",
|
|
"execution_count": null,
|
|
"id": "ac9e5759",
|
|
"metadata": {
|
|
"ExecuteTime": {
|
|
"end_time": "2022-11-23T03:36:04.319520Z",
|
|
"start_time": "2022-11-23T03:36:04.319512Z"
|
|
}
|
|
},
|
|
"outputs": [],
|
|
"source": [
|
|
"save_paths = [mm.parent for mm in m]\n",
|
|
"plot_multi(\n",
|
|
" save_paths=save_paths,\n",
|
|
" i=200,\n",
|
|
" verbose=0,\n",
|
|
")\n",
|
|
"1"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "code",
|
|
"execution_count": null,
|
|
"id": "b2e27d42",
|
|
"metadata": {
|
|
"ExecuteTime": {
|
|
"end_time": "2022-11-23T03:36:04.320031Z",
|
|
"start_time": "2022-11-23T03:36:04.320024Z"
|
|
}
|
|
},
|
|
"outputs": [],
|
|
"source": [
|
|
"256/24"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "code",
|
|
"execution_count": null,
|
|
"id": "938f06db",
|
|
"metadata": {
|
|
"ExecuteTime": {
|
|
"end_time": "2022-11-22T13:17:31.585029Z",
|
|
"start_time": "2022-11-22T13:17:31.528806Z"
|
|
}
|
|
},
|
|
"outputs": [],
|
|
"source": []
|
|
},
|
|
{
|
|
"cell_type": "markdown",
|
|
"id": "9b4e817c",
|
|
"metadata": {},
|
|
"source": [
|
|
"# check positions in dl"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "code",
|
|
"execution_count": null,
|
|
"id": "df6c9b28",
|
|
"metadata": {
|
|
"ExecuteTime": {
|
|
"end_time": "2022-11-23T03:36:04.320752Z",
|
|
"start_time": "2022-11-23T03:36:04.320745Z"
|
|
},
|
|
"scrolled": true
|
|
},
|
|
"outputs": [],
|
|
"source": [
|
|
"train_set, train_loader = get_data(flag='test', batch_size=3)\n",
|
|
"b = context_past_x, context_y, query_past_x, query_y, context_time, query_time = train_set[100]\n",
|
|
"print([bb.shape for bb in b])\n",
|
|
"# context_y, query_y"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "code",
|
|
"execution_count": null,
|
|
"id": "fa54042e",
|
|
"metadata": {
|
|
"ExecuteTime": {
|
|
"end_time": "2022-11-23T03:36:04.321671Z",
|
|
"start_time": "2022-11-23T03:36:04.321662Z"
|
|
}
|
|
},
|
|
"outputs": [],
|
|
"source": [
|
|
"cx_start, cx_end, c_start, c_end, qx_start, qx_end, q_start, q_end = train_set.get_inds(100)\n",
|
|
"cx_start, cx_end, c_start, c_end, qx_start, qx_end, q_start, q_end"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "code",
|
|
"execution_count": null,
|
|
"id": "2482f0f7",
|
|
"metadata": {
|
|
"ExecuteTime": {
|
|
"end_time": "2022-11-23T03:36:04.322543Z",
|
|
"start_time": "2022-11-23T03:36:04.322535Z"
|
|
}
|
|
},
|
|
"outputs": [],
|
|
"source": [
|
|
"plt.hlines(1, cx_start, cx_end, color='green', alpha=0.5, label='context_past_x')\n",
|
|
"plt.hlines(2, c_start, c_end, color='green', label='context_labels')\n",
|
|
"plt.hlines(3, qx_start, qx_end, alpha=0.5, label='query_past_x')\n",
|
|
"plt.hlines(4, q_start, q_end, label='query_labels/target')\n",
|
|
"plt.legend(loc='upper left')"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "code",
|
|
"execution_count": null,
|
|
"id": "e8fe355e",
|
|
"metadata": {},
|
|
"outputs": [],
|
|
"source": []
|
|
},
|
|
{
|
|
"cell_type": "code",
|
|
"execution_count": null,
|
|
"id": "3a978978",
|
|
"metadata": {},
|
|
"outputs": [],
|
|
"source": []
|
|
},
|
|
{
|
|
"cell_type": "code",
|
|
"execution_count": null,
|
|
"id": "045d3fd5",
|
|
"metadata": {},
|
|
"outputs": [],
|
|
"source": []
|
|
}
|
|
],
|
|
"metadata": {
|
|
"jupytext": {
|
|
"cell_metadata_filter": "-all",
|
|
"main_language": "python",
|
|
"notebook_metadata_filter": "-all"
|
|
},
|
|
"kernelspec": {
|
|
"display_name": "deeptime",
|
|
"language": "python",
|
|
"name": "deeptime"
|
|
},
|
|
"language_info": {
|
|
"codemirror_mode": {
|
|
"name": "ipython",
|
|
"version": 3
|
|
},
|
|
"file_extension": ".py",
|
|
"mimetype": "text/x-python",
|
|
"name": "python",
|
|
"nbconvert_exporter": "python",
|
|
"pygments_lexer": "ipython3",
|
|
"version": "3.8.13"
|
|
},
|
|
"toc": {
|
|
"base_numbering": 1,
|
|
"nav_menu": {},
|
|
"number_sections": true,
|
|
"sideBar": true,
|
|
"skip_h1_title": false,
|
|
"title_cell": "Table of Contents",
|
|
"title_sidebar": "Contents",
|
|
"toc_cell": false,
|
|
"toc_position": {},
|
|
"toc_section_display": true,
|
|
"toc_window_display": false
|
|
}
|
|
},
|
|
"nbformat": 4,
|
|
"nbformat_minor": 5
|
|
}
|