Files
DeepTime/scratch-run_exp.ipynb
T
2022-11-23 12:02:22 +08:00

775 lines
24 KiB
Plaintext

{
"cells": [
{
"cell_type": "markdown",
"id": "b1e031e3",
"metadata": {},
"source": [
"- [x] try just one predictor\n",
" - [ ] multi input, single output\n",
"- [x] comparem ulti\n",
"- losses:\n",
" - try logp? nah\n",
" - mae?\n",
"- [x] make my own csv with 5m data (maybe 10k rows)\n",
"- [ ] backtest?"
]
},
{
"cell_type": "code",
"execution_count": 1,
"id": "7f9e3d73",
"metadata": {
"ExecuteTime": {
"end_time": "2022-11-23T03:32:18.927915Z",
"start_time": "2022-11-23T03:32:18.918871Z"
}
},
"outputs": [],
"source": [
"import warnings\n",
"warnings.simplefilter(\"ignore\")\n",
"\n",
"# autoreload import your package\n",
"%load_ext autoreload\n",
"%autoreload 2"
]
},
{
"cell_type": "code",
"execution_count": 2,
"id": "4e09086b",
"metadata": {
"ExecuteTime": {
"end_time": "2022-11-23T03:32:20.144762Z",
"start_time": "2022-11-23T03:32:18.928928Z"
},
"lines_to_next_cell": 0
},
"outputs": [],
"source": [
"import os\n",
"from os.path import join\n",
"import math\n",
"import logging\n",
"from typing import Callable, Optional, Union, Dict, Tuple\n",
"\n",
"from matplotlib import pyplot as plt\n",
"from pathlib import Path\n",
"import matplotlib.colors as mcolors\n",
"\n",
"import gin\n",
"from fire import Fire\n",
"import numpy as np\n",
"import torch\n",
"from torch.utils.data import DataLoader\n",
"from torch import optim\n",
"from torch import nn\n",
"\n",
"from experiments.base import Experiment\n",
"from data.datasets import ForecastDataset\n",
"from models import get_model\n",
"from utils.checkpoint import Checkpoint\n",
"from utils.ops import default_device, to_tensor\n",
"from utils.losses import get_loss_fn\n",
"from utils.metrics import calc_metrics\n",
"\n",
"from experiments.forecast import get_data\n",
"gin.enter_interactive_mode()"
]
},
{
"cell_type": "code",
"execution_count": 3,
"id": "66d7f095",
"metadata": {
"ExecuteTime": {
"end_time": "2022-11-23T03:32:20.171177Z",
"start_time": "2022-11-23T03:32:20.146544Z"
}
},
"outputs": [
{
"data": {
"text/plain": [
"1"
]
},
"execution_count": 3,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"import logging\n",
"logging.root.setLevel(logging.INFO)\n",
"\n",
"from loguru import logger\n",
"logger.remove()\n",
"logger.add(os.sys.stdout, level=\"INFO\", colorize=True, format=\"<level>{time} | {message}</level>\")"
]
},
{
"cell_type": "markdown",
"id": "d4df5270",
"metadata": {},
"source": [
"# auto"
]
},
{
"cell_type": "code",
"execution_count": 4,
"id": "04499bef",
"metadata": {
"ExecuteTime": {
"end_time": "2022-11-23T03:32:20.191130Z",
"start_time": "2022-11-23T03:32:20.172479Z"
}
},
"outputs": [],
"source": [
"\n",
"\n",
"def plot(model_name=\"deeptime\", save_path=Path(\"storage/experiments/Exchange/96M/repeat=0\"), i=200, title=None, plot=True):\n",
"\n",
" gin.clear_config()\n",
" gin.parse_config(open(save_path/\"config.gin\"))\n",
"\n",
" train_set, train_loader = get_data(flag='train', batch_size=2)\n",
"\n",
" model = get_model(model_name,\n",
" dim_size=train_set.data_x.shape[1],\n",
" datetime_feats=train_set.timestamps.shape[-1]).to(default_device())\n",
" model.load_state_dict(torch.load(save_path/'model.pth'))\n",
" model = model.eval()\n",
"\n",
"\n",
" b = train_set[i]\n",
" b = [bb[None, :] for bb in b]\n",
" b2 = list(map(to_tensor, b))\n",
" context_past_x, context_y, query_past_x, query_y, context_time, query_time = b2\n",
" with torch.no_grad():\n",
" forecast = model(*b2)\n",
"\n",
" if title is None:\n",
" title = str(save_path).split('/')[-3:]\n",
" title = \"-\".join(title)\n",
" \n",
" colors = list(mcolors.BASE_COLORS.keys())\n",
" l = x.shape[1]\n",
" forecast2 = forecast[0].detach().cpu().numpy()\n",
" x2 = x[0].cpu()\n",
" y2 = y[0].cpu()\n",
" l2 = y.shape[1]\n",
" i_past = list(range(l))\n",
" i_future = list(range(l, l+l2))\n",
" \n",
" if plot:\n",
" plt.title(title)\n",
" for i in range(x.shape[-1]):\n",
" plt.plot(i_past, x2[:, i], c=colors[i])\n",
" for i in range(x.shape[-1]):\n",
" plt.plot(i_future, y2[:, i], c=colors[i])\n",
" for i in range(x.shape[-1]):\n",
" plt.plot(i_future, forecast2[:, i], c=colors[i], linestyle='--')\n",
" return x2, y2, forecast2, i_past, i_future\n"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "09dd5ebd",
"metadata": {
"ExecuteTime": {
"end_time": "2022-11-22T13:28:35.849491Z",
"start_time": "2022-11-22T13:28:35.766453Z"
},
"scrolled": true
},
"outputs": [],
"source": []
},
{
"cell_type": "code",
"execution_count": 5,
"id": "dc530891",
"metadata": {
"ExecuteTime": {
"end_time": "2022-11-23T03:32:20.213967Z",
"start_time": "2022-11-23T03:32:20.192307Z"
}
},
"outputs": [
{
"data": {
"text/plain": [
"['b', 'g', 'r', 'c', 'm', 'y', 'k', 'w']"
]
},
"execution_count": 5,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"list(mcolors.BASE_COLORS.keys())"
]
},
{
"cell_type": "markdown",
"id": "da8f502f",
"metadata": {
"ExecuteTime": {
"end_time": "2022-11-22T13:38:15.786656Z",
"start_time": "2022-11-22T13:38:15.303577Z"
}
},
"source": [
"# view model"
]
},
{
"cell_type": "markdown",
"id": "ccc9fdb0",
"metadata": {},
"source": [
"# run exps"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "e09f2823",
"metadata": {
"ExecuteTime": {
"end_time": "2022-11-23T03:31:08.583641Z",
"start_time": "2022-11-23T03:31:08.558690Z"
}
},
"outputs": [],
"source": []
},
{
"cell_type": "code",
"execution_count": 6,
"id": "9d13d779",
"metadata": {
"ExecuteTime": {
"end_time": "2022-11-23T03:32:20.232318Z",
"start_time": "2022-11-23T03:32:20.215075Z"
}
},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"[Path('storage/experiments/Stocks/96M2S/base_learner=Transformer,inr=INR,encoder=mlp,repeat=0/config.gin'), Path('storage/experiments/Stocks/96M2S/base_learner=Ridge,inr=INRPlus2,encoder=inception,repeat=0/config.gin'), Path('storage/experiments/Stocks/96M2S/base_learner=Ridge,inr=INRPlus2,encoder=none,repeat=0/config.gin'), Path('storage/experiments/Stocks/96M2S/base_learner=None,inr=INRPlus2,encoder=transformer2,repeat=0/config.gin'), Path('storage/experiments/Stocks/96M2S/base_learner=None,inr=INR,encoder=transformer,repeat=0/config.gin'), Path('storage/experiments/Stocks/96M2S/base_learner=Transformer,inr=INR,encoder=lstm,repeat=0/config.gin'), Path('storage/experiments/Stocks/96M2S/base_learner=None,inr=INR,encoder=mlp,repeat=0/config.gin'), Path('storage/experiments/Stocks/96M2S/base_learner=Transformer,inr=INR,encoder=transformer,repeat=0/config.gin'), Path('storage/experiments/Stocks/96M2S/base_learner=None,inr=INR,encoder=transformer2,repeat=0/config.gin'), Path('storage/experiments/Stocks/96M2S/base_learner=Ridge,inr=INR,encoder=lstm,repeat=0/config.gin'), Path('storage/experiments/Stocks/96M2S/base_learner=Ridge,inr=INR,encoder=lstm2,repeat=0/config.gin'), Path('storage/experiments/Stocks/96M2S/base_learner=Ridge,inr=INRPlus2,encoder=mlp,repeat=0/config.gin'), Path('storage/experiments/Stocks/96M2S/base_learner=None,inr=INRPlus2,encoder=transformer,repeat=0/config.gin'), Path('storage/experiments/Stocks/96M2S/base_learner=None,inr=INRPlus2,encoder=inception,repeat=0/config.gin'), Path('storage/experiments/Stocks/96M2S/base_learner=Ridge,inr=INR,encoder=mlp,repeat=0/config.gin'), Path('storage/experiments/Stocks/96M2S/base_learner=Transformer,inr=INRPlus2,encoder=transformer2,repeat=0/config.gin'), Path('storage/experiments/Stocks/96M2S/base_learner=None,inr=INR,encoder=lstm,repeat=0/config.gin'), Path('storage/experiments/Stocks/96M2S/base_learner=Transformer,inr=INR,encoder=transformer2,repeat=0/config.gin'), Path('storage/experiments/Stocks/96M2S/base_learner=None,inr=INRPlus2,encoder=mlp,repeat=0/config.gin'), Path('storage/experiments/Stocks/96M2S/base_learner=Transformer,inr=INR,encoder=lstm2,repeat=0/config.gin'), Path('storage/experiments/Stocks/96M2S/base_learner=Ridge,inr=INR,encoder=transformer,repeat=0/config.gin'), Path('storage/experiments/Stocks/96M2S/base_learner=Ridge,inr=INRPlus2,encoder=lstm2,repeat=0/config.gin'), Path('storage/experiments/Stocks/96M2S/base_learner=Transformer,inr=INRPlus2,encoder=lstm,repeat=0/config.gin'), Path('storage/experiments/Stocks/96M2S/base_learner=Ridge,inr=INRPlus2,encoder=transformer2,repeat=0/config.gin'), Path('storage/experiments/Stocks/96M2S/base_learner=None,inr=INR,encoder=lstm2,repeat=0/config.gin'), Path('storage/experiments/Stocks/96M2S/base_learner=Ridge,inr=INR,encoder=transformer2,repeat=0/config.gin'), Path('storage/experiments/Stocks/96M2S/base_learner=Transformer,inr=INRPlus2,encoder=inception,repeat=0/config.gin'), Path('storage/experiments/Stocks/96M2S/base_learner=None,inr=INRPlus2,encoder=none,repeat=0/config.gin'), Path('storage/experiments/Stocks/96M2S/base_learner=Ridge,inr=INR,encoder=inception,repeat=0/config.gin'), Path('storage/experiments/Stocks/96M2S/base_learner=Transformer,inr=INRPlus2,encoder=mlp,repeat=0/config.gin'), Path('storage/experiments/Stocks/96M2S/base_learner=Transformer,inr=INRPlus2,encoder=transformer,repeat=0/config.gin'), Path('storage/experiments/Stocks/96M2S/base_learner=None,inr=INRPlus2,encoder=lstm,repeat=0/config.gin'), Path('storage/experiments/Stocks/96M2S/base_learner=None,inr=INR,encoder=inception,repeat=0/config.gin'), Path('storage/experiments/Stocks/96M2S/base_learner=None,inr=INR,encoder=none,repeat=0/config.gin'), Path('storage/experiments/Stocks/96M2S/base_learner=Transformer,inr=INR,encoder=none,repeat=0/config.gin'), Path('storage/experiments/Stocks/96M2S/base_learner=Transformer,inr=INR,encoder=inception,repeat=0/config.gin'), Path('storage/experiments/Stocks/96M2S/base_learner=None,inr=INRPlus2,encoder=lstm2,repeat=0/config.gin'), Path('storage/experiments/Stocks/96M2S/base_learner=Transformer,inr=INRPlus2,encoder=lstm2,repeat=0/config.gin'), Path('storage/experiments/Stocks/96M2S/base_learner=Ridge,inr=INR,encoder=none,repeat=0/config.gin'), Path('storage/experiments/Stocks/96M2S/base_learner=Ridge,inr=INRPlus2,encoder=transformer,repeat=0/config.gin'), Path('storage/experiments/Stocks/96M2S/base_learner=Transformer,inr=INRPlus2,encoder=none,repeat=0/config.gin'), Path('storage/experiments/Stocks/96M2S/base_learner=Ridge,inr=INRPlus2,encoder=lstm,repeat=0/config.gin')]\n"
]
}
],
"source": [
"# list the models we have run...\n",
"configs=sorted(Path(\"storage/experiments/Stocks\").glob(\"**/config.gin\"))\n",
"import random\n",
"random.shuffle(configs)\n",
"print(configs)"
]
},
{
"cell_type": "code",
"execution_count": 7,
"id": "4178d85e",
"metadata": {
"ExecuteTime": {
"end_time": "2022-11-23T03:32:20.248919Z",
"start_time": "2022-11-23T03:32:20.233964Z"
}
},
"outputs": [],
"source": [
"from experiments.forecast import ForecastExperiment"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "8eadaa48",
"metadata": {
"ExecuteTime": {
"start_time": "2022-11-23T04:00:02.202Z"
}
},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"storage/experiments/Stocks/96M2S/base_learner=Transformer,inr=INR,encoder=mlp,repeat=0/config.gin\n"
]
},
{
"name": "stderr",
"output_type": "stream",
"text": [
"INFO:root:epochs: 1, iters: 100 | training loss: 2.41\n",
"INFO:root:Validation loss decreased (inf --> 1.004). Saving model ...\n",
"INFO:root:epochs: 2, iters: 100 | training loss: 0.20\n",
"INFO:root:Validation loss decreased (1.004 --> 0.115). Saving model ...\n"
]
}
],
"source": [
"for config in configs:\n",
" save_path = config.parent\n",
"\n",
" exp = ForecastExperiment(config_path=config)\n",
" print(config)\n",
" try:\n",
" exp.run()\n",
" except KeyboardInterrupt:\n",
" raise\n",
" except Exception as e:\n",
"# raise\n",
" print(e)\n",
" pass"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "e7a3c151",
"metadata": {
"ExecuteTime": {
"end_time": "2022-11-23T03:36:04.312188Z",
"start_time": "2022-11-23T03:36:04.312180Z"
}
},
"outputs": [],
"source": [
"exp.instance()"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "8b73c64e",
"metadata": {
"ExecuteTime": {
"end_time": "2022-11-23T03:36:04.313334Z",
"start_time": "2022-11-23T03:36:04.313322Z"
}
},
"outputs": [],
"source": [
"# save_path = Path('storage/experiments/Stocks/96M2S/repeat=0')"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "6af1bb1e",
"metadata": {
"ExecuteTime": {
"end_time": "2022-11-23T03:36:04.313962Z",
"start_time": "2022-11-23T03:36:04.313954Z"
}
},
"outputs": [],
"source": [
"# gin.clear_config()\n",
"# config_path = save_path/\"config.gin\"\n",
"# gin.parse_config(open(config_path))\n",
"# model_name = gin.query_parameter(\"instance.model_type\")\n",
"# model_name"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "50a92a7e",
"metadata": {
"ExecuteTime": {
"end_time": "2022-11-23T03:36:04.314594Z",
"start_time": "2022-11-23T03:36:04.314588Z"
}
},
"outputs": [],
"source": [
"\n",
"# exp = ForecastExperiment(config_path=config_path)\n",
"# # exp.run()"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "9ef76b76",
"metadata": {
"ExecuteTime": {
"end_time": "2022-11-23T03:36:04.315217Z",
"start_time": "2022-11-23T03:36:04.315210Z"
}
},
"outputs": [],
"source": [
"def save_path2name(save_path: Path) -> str:\n",
" \"\"\"\n",
" Path('storage/experiments/Stocks/96M2S/base_learner=None,inr=INR,encoder=mlp,repeat=0')\n",
" to \n",
" '96M2S-None_INR_mlp_0'\n",
" \"\"\"\n",
" mtitle = str(save_path).split('/')[-2:]\n",
" tags = mtitle[-1]\n",
" tags = [x.split('=')[-1] for x in tags.split(',')]\n",
" mtitle[-1] = '_'.join(tags)\n",
" mtitle = \"-\".join(mtitle)\n",
" return mtitle\n",
"\n",
"# save_path2name(save_path)\n",
"# save_path"
]
},
{
"cell_type": "markdown",
"id": "3637e87d",
"metadata": {},
"source": [
"# view all"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "768530be",
"metadata": {
"ExecuteTime": {
"end_time": "2022-11-23T03:36:04.315891Z",
"start_time": "2022-11-23T03:36:04.315884Z"
}
},
"outputs": [],
"source": [
"from torchsummaryX import summary\n",
"\n",
"def plot_multi(save_paths=[Path(\"storage/experiments/Exchange/96M/repeat=0\")], i=200, title=None, plot=True, verbose=1,):\n",
" assert len(save_paths)>0\n",
" for j in range(len(save_paths)):\n",
" save_path = save_paths[j]\n",
"\n",
" gin.clear_config()\n",
" gin.parse_config(open(save_path/\"config.gin\"))\n",
" model_name = gin.query_parameter(\"instance.model_type\")\n",
"\n",
" train_set, train_loader = get_data(flag='test', batch_size=3)\n",
" seq_len = train_set[0][1].shape\n",
" model = get_model(model_name,\n",
" dim_size=train_set.data_x.shape[1],\n",
" seq_len=seq_len,\n",
" datetime_feats=train_set.timestamps.shape[-1]).to(default_device())\n",
" model.load_state_dict(torch.load(save_path/'model.pth'))\n",
" model = model.eval()\n",
" \n",
" \n",
"\n",
"\n",
" b = train_set[i]\n",
" b = [bb[None, :] for bb in b]\n",
" b2 = list(map(to_tensor, b))\n",
" \n",
"# b = next(iter(train_loader))\n",
"# print([s.shape for s in b]\n",
" \n",
" if verbose>1:\n",
" \n",
"# print(model)\n",
" summary(model, *b2)\n",
" print(save_path)\n",
" \n",
" context_past_x, context_y, query_past_x, query_y, context_time, query_time = b2\n",
" with torch.no_grad():\n",
" forecast = model(*b2)\n",
" \n",
" colors = list(mcolors.BASE_COLORS.keys())\n",
" l = context_time.shape[1]\n",
" forecast2 = forecast[0].detach().cpu().numpy()\n",
" x2 = context_y[0].cpu()\n",
" y2 = query_y[0].cpu()\n",
" l2 = query_time.shape[1]\n",
" i_past = list(range(l))\n",
" i_future = list(range(l, l+l2))\n",
" \n",
" \n",
"\n",
" if plot:\n",
" \n",
" if j==0:\n",
" plt.plot(i_past, x2[:, 0], c=colors[0], label=f\"past\")\n",
" plt.plot(i_future, y2[:, 0], c=colors[0], label=\"future true\", alpha=0.3)\n",
" mtitle = save_path2name(save_path)\n",
" plt.plot(i_future, forecast2[:, 0], linestyle='--', label=f\"{mtitle}\") # c=colors[j], \n",
" \n",
"\n",
" plt.legend()\n",
" plt.title(title)\n",
" return x2, y2, forecast2, i_past, i_future\n"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "739ee5e3",
"metadata": {
"ExecuteTime": {
"end_time": "2022-11-23T03:36:04.316469Z",
"start_time": "2022-11-23T03:36:04.316463Z"
}
},
"outputs": [],
"source": [
"# list the models we have run...\n",
"m=sorted(Path(\"storage/experiments/Stocks/96M2S\").glob(\"**/_SUCCESS\"))\n",
"print(m)"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "48e9175b",
"metadata": {
"ExecuteTime": {
"end_time": "2022-11-23T03:36:04.316979Z",
"start_time": "2022-11-23T03:36:04.316973Z"
}
},
"outputs": [],
"source": [
"for mm in m:\n",
" mtitle = save_path2name(mm.parent)\n",
" print(mtitle)\n",
" m3 = np.load(mm.parent/'metrics.npy', allow_pickle=1)\n",
" m3 = eval(str(m3))\n",
" print(m3['val']['mape'])"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "4b966ef7",
"metadata": {
"ExecuteTime": {
"end_time": "2022-11-23T03:36:04.317585Z",
"start_time": "2022-11-23T03:36:04.317578Z"
}
},
"outputs": [],
"source": [
"train_set, train_loader = get_data(flag='train')\n",
"train_set[0][1].shape"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "0b696e85",
"metadata": {
"ExecuteTime": {
"end_time": "2022-11-23T03:36:04.318277Z",
"start_time": "2022-11-23T03:36:04.318270Z"
}
},
"outputs": [],
"source": [
"save_paths = [mm.parent for mm in m]\n",
"for mm in m:\n",
" try:\n",
" plot_multi(\n",
" save_paths=[mm.parent],\n",
" i=600,\n",
" verbose=2,\n",
" )\n",
" except:\n",
" print('failed', mm)\n",
"# mm.unlink()\n",
" pass\n",
"1"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "ac9e5759",
"metadata": {
"ExecuteTime": {
"end_time": "2022-11-23T03:36:04.319520Z",
"start_time": "2022-11-23T03:36:04.319512Z"
}
},
"outputs": [],
"source": [
"save_paths = [mm.parent for mm in m]\n",
"plot_multi(\n",
" save_paths=save_paths,\n",
" i=200,\n",
" verbose=0,\n",
")\n",
"1"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "b2e27d42",
"metadata": {
"ExecuteTime": {
"end_time": "2022-11-23T03:36:04.320031Z",
"start_time": "2022-11-23T03:36:04.320024Z"
}
},
"outputs": [],
"source": [
"256/24"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "938f06db",
"metadata": {
"ExecuteTime": {
"end_time": "2022-11-22T13:17:31.585029Z",
"start_time": "2022-11-22T13:17:31.528806Z"
}
},
"outputs": [],
"source": []
},
{
"cell_type": "markdown",
"id": "9b4e817c",
"metadata": {},
"source": [
"# check positions in dl"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "df6c9b28",
"metadata": {
"ExecuteTime": {
"end_time": "2022-11-23T03:36:04.320752Z",
"start_time": "2022-11-23T03:36:04.320745Z"
},
"scrolled": true
},
"outputs": [],
"source": [
"train_set, train_loader = get_data(flag='test', batch_size=3)\n",
"b = context_past_x, context_y, query_past_x, query_y, context_time, query_time = train_set[100]\n",
"print([bb.shape for bb in b])\n",
"# context_y, query_y"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "fa54042e",
"metadata": {
"ExecuteTime": {
"end_time": "2022-11-23T03:36:04.321671Z",
"start_time": "2022-11-23T03:36:04.321662Z"
}
},
"outputs": [],
"source": [
"cx_start, cx_end, c_start, c_end, qx_start, qx_end, q_start, q_end = train_set.get_inds(100)\n",
"cx_start, cx_end, c_start, c_end, qx_start, qx_end, q_start, q_end"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "2482f0f7",
"metadata": {
"ExecuteTime": {
"end_time": "2022-11-23T03:36:04.322543Z",
"start_time": "2022-11-23T03:36:04.322535Z"
}
},
"outputs": [],
"source": [
"plt.hlines(1, cx_start, cx_end, color='green', alpha=0.5, label='context_past_x')\n",
"plt.hlines(2, c_start, c_end, color='green', label='context_labels')\n",
"plt.hlines(3, qx_start, qx_end, alpha=0.5, label='query_past_x')\n",
"plt.hlines(4, q_start, q_end, label='query_labels/target')\n",
"plt.legend(loc='upper left')"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "e8fe355e",
"metadata": {},
"outputs": [],
"source": []
},
{
"cell_type": "code",
"execution_count": null,
"id": "3a978978",
"metadata": {},
"outputs": [],
"source": []
},
{
"cell_type": "code",
"execution_count": null,
"id": "045d3fd5",
"metadata": {},
"outputs": [],
"source": []
}
],
"metadata": {
"jupytext": {
"cell_metadata_filter": "-all",
"main_language": "python",
"notebook_metadata_filter": "-all"
},
"kernelspec": {
"display_name": "deeptime",
"language": "python",
"name": "deeptime"
},
"language_info": {
"codemirror_mode": {
"name": "ipython",
"version": 3
},
"file_extension": ".py",
"mimetype": "text/x-python",
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.8.13"
},
"toc": {
"base_numbering": 1,
"nav_menu": {},
"number_sections": true,
"sideBar": true,
"skip_h1_title": false,
"title_cell": "Table of Contents",
"title_sidebar": "Contents",
"toc_cell": false,
"toc_position": {},
"toc_section_display": true,
"toc_window_display": false
}
},
"nbformat": 4,
"nbformat_minor": 5
}