mirror of
https://github.com/wassname/DeepTime.git
synced 2026-06-27 21:50:22 +08:00
freeze reqs before I install vb2
This commit is contained in:
File diff suppressed because it is too large
Load Diff
File diff suppressed because one or more lines are too long
+24
-2
@@ -100,14 +100,19 @@ class ForecastDataset(Dataset):
|
||||
data = df_data.values
|
||||
|
||||
# is will be our past data, including the y col
|
||||
self.data_x = data[border1:border2]
|
||||
self.data_x = data[border1:border2]
|
||||
|
||||
self.dates = df_raw['date'][border1:border2]
|
||||
# y is just the col we predict
|
||||
self.data_y = data[border1:border2][:, [-1]]
|
||||
self.timestamps = get_time_features(pd.to_datetime(df_raw.date[border1:border2].values),
|
||||
self.timestamps = get_time_features(pd.to_datetime(self.dates.values),
|
||||
normalise=self.normalise_time_features,
|
||||
features=self.time_features)
|
||||
self.n_time = len(self.data_x)
|
||||
self.n_time_samples = self.n_time - self.lookback_len * 2 - self.horizon_len + 1 + self.gap
|
||||
|
||||
o = self.horizon_len + self.lookback_len
|
||||
self.index = self.dates.iloc[o:].iloc[:self.n_time_samples]
|
||||
|
||||
def get_borders(self, df_raw: pd.DataFrame) -> Tuple[List[int], List[int], List[int], List[int]]:
|
||||
set_type = {'train': 0, 'val': 1, 'test': 2}[self.flag]
|
||||
@@ -133,6 +138,9 @@ class ForecastDataset(Dataset):
|
||||
return self.n_time_samples
|
||||
|
||||
def get_inds(self, idx):
|
||||
|
||||
|
||||
|
||||
cx_start = idx
|
||||
cx_end = cx_start + self.lookback_len
|
||||
c_start = cx_end + self.gap
|
||||
@@ -144,6 +152,20 @@ class ForecastDataset(Dataset):
|
||||
qx_end = q_start
|
||||
qx_start = qx_end - self.lookback_len
|
||||
|
||||
####
|
||||
|
||||
# q_start = idx
|
||||
# q_end = q_start + self.horizon_len
|
||||
|
||||
# qx_end = q_start
|
||||
# qx_start = qx_end - self.lookback_len
|
||||
|
||||
# c_end = q_start - self.gap
|
||||
# c_start = c_end - self.horizon_len
|
||||
|
||||
# cx_end = c_start - self.gap
|
||||
# cx_start = cx_end - self.lookback_len
|
||||
|
||||
return cx_start, cx_end, c_start, c_end, qx_start, qx_end, q_start, q_end
|
||||
|
||||
def __getitem__(self, idx: int) -> Tuple[np.ndarray, np.ndarray, np.ndarray, np.ndarray]:
|
||||
|
||||
@@ -24,7 +24,7 @@ Checkpoint.patience = 7
|
||||
deeptime3.layer_size = 32
|
||||
deeptime3.inr_layers = 5
|
||||
deeptime3.dropout = 0.3
|
||||
deeptime3.base_learner = 'Ridge'
|
||||
deeptime3.lrn = 'Ridge'
|
||||
deeptime3.n_fourier_feats = 2048
|
||||
deeptime3.scales = [0.01, 0.1, 1, 5, 10, 20, 50, 100]
|
||||
|
||||
|
||||
@@ -5,9 +5,9 @@ build.variables_dict = {
|
||||
# 'ForecastDataset.lookback_mult': [1, 3, 5, 7, 9],
|
||||
# 'ForecastDataset.horizon_len': [6, 12, 24, 48, 96, 192, 336, 720],
|
||||
# 'ForecastDataset.features': ['m', 'h', 'd'],
|
||||
'deeptime3.base_learner': ['Ridge', 'None', 'Transformer'],
|
||||
'deeptime3.lrn': ['Ridge', 'None', 'Transformer'],
|
||||
'deeptime3.inr': ['INR', 'INRPlus2'],
|
||||
'deeptime3.encoder': ['inception', 'lstm', 'mlp', 'lstm2', 'transformer', 'transformer2', 'none'],
|
||||
'deeptime3.enc': ['inception', 'lstm', 'mlp', 'lstm2', 'transformer', 'transformer2', 'none'],
|
||||
# 'deeptime3.dropout': [0.0, 0.1, 0.3, 0.5,],
|
||||
}
|
||||
|
||||
@@ -31,7 +31,7 @@ Checkpoint.patience = 7
|
||||
deeptime3.layer_size = 256
|
||||
deeptime3.inr_layers = 5
|
||||
deeptime3.dropout = 0.1
|
||||
deeptime3.base_learner = 'Ridge'
|
||||
deeptime3.lrn = 'Ridge'
|
||||
deeptime3.n_fourier_feats = 4096
|
||||
deeptime3.scales = [0.01, 0.1, 1, 5, 10, 20, 50, 100]
|
||||
|
||||
|
||||
@@ -34,7 +34,7 @@ class ForecastExperiment(Experiment):
|
||||
test_set, test_loader = get_data(flag='test')
|
||||
|
||||
dim_size=train_set.data_x.shape[1]
|
||||
seq_len = train_set[0][1].shape[0]
|
||||
seq_len = train_set[0][0].shape[0]
|
||||
model = get_model(model_type,
|
||||
dim_size=dim_size,
|
||||
seq_len=seq_len,
|
||||
@@ -51,7 +51,7 @@ class ForecastExperiment(Experiment):
|
||||
metrics = {'val': val_metrics, 'test': test_metrics}
|
||||
# np.save(join(self.root, 'metrics.npy'), {'val': val_metrics, 'test': test_metrics})
|
||||
metrics = serialize(metrics)
|
||||
json.dump(metrics, open(join(self.root, 'metrics.npy'), 'w'))
|
||||
json.dump(metrics, open(join(self.root, 'metrics.json'), 'w'))
|
||||
|
||||
val_metrics = {f'ValMetric/{k}': v for k, v in val_metrics.items()}
|
||||
test_metrics = {f'TestMetric/{k}': v for k, v in test_metrics.items()}
|
||||
|
||||
+2
-2
@@ -71,8 +71,8 @@ TODO:
|
||||
|
||||
- [x] M2S mode
|
||||
- [ ] add other INR's
|
||||
- [ ] add None as learner
|
||||
- [ ] no encoder?
|
||||
- [ ] add None as lrn
|
||||
- [ ] no enc?
|
||||
|
||||
```
|
||||
python -m experiments.forecast --config_path=experiments/configs/hp_search/Stocks.gin build_experiment
|
||||
|
||||
+22
-22
@@ -20,38 +20,38 @@ from models.modules.encoders import LSTMEncoder, TransformerEncoder2, Transforme
|
||||
# from models.modules.regressors import RidgeRegressor
|
||||
|
||||
@gin.configurable()
|
||||
def deeptime3(dim_size:int, datetime_feats: int, layer_size: int, inr_layers: int, n_fourier_feats: int, scales: float, dropout: float, base_learner: str, encoder:str, inr: str, seq_len: int):
|
||||
return DeepTIMe3(dim_size, datetime_feats, layer_size, inr_layers, n_fourier_feats, scales, dropout, base_learner, encoder, inr, seq_len)
|
||||
def deeptime3(dim_size:int, datetime_feats: int, layer_size: int, inr_layers: int, n_fourier_feats: int, scales: float, dropout: float, lrn: str, enc:str, inr: str, seq_len: int):
|
||||
return DeepTIMe3(dim_size, datetime_feats, layer_size, inr_layers, n_fourier_feats, scales, dropout, lrn, enc, inr, seq_len)
|
||||
|
||||
|
||||
class DeepTIMe3(nn.Module):
|
||||
def __init__(self, dim_size: int, datetime_feats: int, layer_size: int, inr_layers: int, n_fourier_feats: int, scales: float, dropout: float=0.3, base_learner:str='Ridge', encoder:str='inception', inr:str='INR', seq_len: int=46):
|
||||
def __init__(self, dim_size: int, datetime_feats: int, layer_size: int, inr_layers: int, n_fourier_feats: int, scales: float, dropout: float=0.3, lrn:str='Ridge', enc:str='inception', inr:str='INR', seq_len: int=46):
|
||||
super().__init__()
|
||||
|
||||
# encode the past
|
||||
encoded_size = layer_size
|
||||
encoder_features = 24
|
||||
encoder_layers = 3
|
||||
if encoder == 'inception':
|
||||
self.encoder = InceptionEncoder(
|
||||
if enc == 'inception':
|
||||
self.enc = InceptionEncoder(
|
||||
c_in=dim_size, c_out=encoded_size, dilation=6,
|
||||
layer_size=17, layers=encoder_layers, dropout=dropout,
|
||||
)
|
||||
elif encoder == 'lstm':
|
||||
self.encoder = LSTMEncoder(c_in=dim_size, c_out=encoded_size, dropout=dropout, layers=encoder_layers, layer_size=24)
|
||||
elif encoder == 'lstm2':
|
||||
self.encoder = LSTMEncoder2(c_in=dim_size, c_out=encoded_size, dropout=dropout, layers=encoder_layers, layer_size=32, seq_len=seq_len)
|
||||
elif encoder == 'mlp':
|
||||
self.encoder = MLPEncoder(c_in=dim_size, c_out=encoded_size, dropout=dropout, layers=encoder_layers, layer_size=256)
|
||||
elif encoder == 'transformer':
|
||||
self.encoder = TransformerEncoder(c_in=dim_size, c_out=encoded_size, dropout=dropout, layers=encoder_layers, layer_size=256, seq_len=seq_len)
|
||||
elif encoder == 'transformer2':
|
||||
self.encoder = TransformerEncoder2(c_in=dim_size, c_out=encoded_size, dropout=dropout, layers=encoder_layers, layer_size=256, seq_len=seq_len)
|
||||
elif encoder == 'none':
|
||||
self.encoder = None
|
||||
elif enc == 'lstm':
|
||||
self.enc = LSTMEncoder(c_in=dim_size, c_out=encoded_size, dropout=dropout, layers=encoder_layers, layer_size=24)
|
||||
elif enc == 'lstm2':
|
||||
self.enc = LSTMEncoder2(c_in=dim_size, c_out=encoded_size, dropout=dropout, layers=encoder_layers, layer_size=32, seq_len=seq_len)
|
||||
elif enc == 'mlp':
|
||||
self.enc = MLPEncoder(c_in=dim_size, c_out=encoded_size, dropout=dropout, layers=encoder_layers, layer_size=256)
|
||||
elif enc == 'transformer':
|
||||
self.enc = TransformerEncoder(c_in=dim_size, c_out=encoded_size, dropout=dropout, layers=encoder_layers, layer_size=256, seq_len=seq_len)
|
||||
elif enc == 'transformer2':
|
||||
self.enc = TransformerEncoder2(c_in=dim_size, c_out=encoded_size, dropout=dropout, layers=encoder_layers, layer_size=128, seq_len=seq_len)
|
||||
elif enc == 'none':
|
||||
self.enc = None
|
||||
encoded_size = 0
|
||||
else:
|
||||
raise NotADirectoryError(encoder)
|
||||
raise NotADirectoryError(enc)
|
||||
|
||||
# translate coords to a representation, given a summary of the past
|
||||
coord_size = 1
|
||||
@@ -66,7 +66,7 @@ class DeepTIMe3(nn.Module):
|
||||
raise NotImplementedError(inr)
|
||||
|
||||
# meta learn y given a representation
|
||||
self.regressionhead = RegressionHead(base_learner=base_learner, d=layer_size, dropout=dropout)
|
||||
self.regressionhead = RegressionHead(lrn=lrn, d=layer_size, dropout=dropout)
|
||||
|
||||
self.datetime_feats = datetime_feats
|
||||
self.inr_layers = inr_layers
|
||||
@@ -83,8 +83,8 @@ class DeepTIMe3(nn.Module):
|
||||
|
||||
# we summarize the past into a single hidden layer. Then repeat it for each coordinate
|
||||
past_len = time.shape[1]
|
||||
if self.encoder is not None:
|
||||
encoded_x = self.encoder(past_x)
|
||||
if self.enc is not None:
|
||||
encoded_x = self.enc(past_x)
|
||||
encoded_x = repeat(encoded_x, "b f -> b t f", t=past_len)
|
||||
|
||||
|
||||
@@ -93,7 +93,7 @@ class DeepTIMe3(nn.Module):
|
||||
coords = repeat(coords, "1 t 1 -> b t 1", b=time.shape[0])
|
||||
|
||||
# combine and run INR to decode the representation
|
||||
if self.encoder is not None:
|
||||
if self.enc is not None:
|
||||
context_input = torch.cat([encoded_x, coords, time], dim=-1)
|
||||
else:
|
||||
context_input = torch.cat([coords, time], dim=-1)
|
||||
|
||||
@@ -126,7 +126,8 @@ class CausalInceptionTimePlus(nn.Sequential):
|
||||
dilations = np.array([max(1, d*dilation) for d in range(depth)])
|
||||
d=np.array([dilations**i for i in range(3)]).T
|
||||
rf = ((ks-1)*d).sum(0)
|
||||
print(f"receptive field {rf}={ks-1}*{d}")
|
||||
# print(f"receptive field {rf}={ks-1}*{d}")
|
||||
print(f"receptive field {rf}")
|
||||
|
||||
def create_head(self, nf, c_out, seq_len, flatten=False, concat_pool=False, fc_dropout=0., bn=False, y_range=None):
|
||||
if flatten:
|
||||
|
||||
@@ -133,14 +133,14 @@ class TransformerEncoder2(nn.Module):
|
||||
super().__init__()
|
||||
# d_model (82) must be divisible by n_heads (4)
|
||||
layer_size = layer_size // n_heads * n_heads
|
||||
d_model = layer_size // 2
|
||||
d_model = layer_size // 4
|
||||
self.net = TSPerceiver(
|
||||
c_in=c_in,
|
||||
c_out=c_out,
|
||||
seq_len=seq_len,
|
||||
|
||||
# cat_szs=0, n_cont=0,
|
||||
n_latents=layer_size, d_latent=layer_size//4,
|
||||
n_latents=layer_size, d_latent=d_model,
|
||||
# d_context=None,
|
||||
self_per_cross_attn=1,
|
||||
# share_weights=True, cross_n_heads=1, d_head=None,
|
||||
@@ -229,7 +229,7 @@ class LSTMEncoder2(nn.Module):
|
||||
depth=layers,
|
||||
lstm_dropout=conv_dropout,
|
||||
fc_dropout=dropout,
|
||||
pre_norm=False, use_token=True, use_pe=True,
|
||||
pre_norm=False, use_token=False, use_pe=False,
|
||||
use_bn=False,
|
||||
)
|
||||
|
||||
|
||||
@@ -63,17 +63,17 @@ class TransformerHead(nn.Module):
|
||||
|
||||
|
||||
class RegressionHead(nn.Module):
|
||||
def __init__(self, base_learner='Ridge', d=512, enable_scale=True, dropout=0.1, num_heads=16):
|
||||
def __init__(self, lrn='Ridge', d=512, enable_scale=True, dropout=0.1, num_heads=16):
|
||||
super().__init__()
|
||||
if ('Ridge' in base_learner):
|
||||
if ('Ridge' in lrn):
|
||||
# the regular DeepTime one
|
||||
self.head = RidgeRegressor()
|
||||
elif ("None" in base_learner):
|
||||
elif ("None" in lrn):
|
||||
self.head = SumHead(d=d, dropout=dropout)
|
||||
elif ("Transformer" in base_learner):
|
||||
elif ("Transformer" in lrn):
|
||||
self.head = TransformerHead(d=d, dropout=dropout, num_heads=num_heads)
|
||||
else:
|
||||
raise NotImplementedError(base_learner)
|
||||
raise NotImplementedError(lrn)
|
||||
|
||||
# Add a learnable scale
|
||||
self.enable_scale = enable_scale
|
||||
|
||||
Regular → Executable
@@ -149,9 +149,11 @@ dependencies:
|
||||
- grpcio==1.50.0
|
||||
- imbalanced-learn==0.9.1
|
||||
- joblib==1.2.0
|
||||
- json-tricks==3.16.1
|
||||
- kiwisolver==1.4.4
|
||||
- langcodes==3.3.0
|
||||
- llvmlite==0.39.1
|
||||
- loguru==0.6.0
|
||||
- markdown==3.4.1
|
||||
- matplotlib==3.6.2
|
||||
- murmurhash==1.0.9
|
||||
@@ -185,6 +187,8 @@ dependencies:
|
||||
- threadpoolctl==3.1.0
|
||||
- torch==1.10.0+cu113
|
||||
- torchaudio==0.10.0+cu113
|
||||
- torchinfo==1.7.1
|
||||
- torchsummaryx==1.3.0
|
||||
- torchvision==0.11.1+cu113
|
||||
- tqdm==4.64.1
|
||||
- tsai==0.3.4
|
||||
|
||||
@@ -55,6 +55,7 @@ jedi @ file:///home/conda/feedstock_root/build_artifacts/jedi_1659959867326/work
|
||||
jeepney @ file:///home/conda/feedstock_root/build_artifacts/jeepney_1649085214306/work
|
||||
Jinja2 @ file:///home/conda/feedstock_root/build_artifacts/jinja2_1654302431367/work
|
||||
joblib==1.2.0
|
||||
json-tricks==3.16.1
|
||||
jsonschema @ file:///home/conda/feedstock_root/build_artifacts/jsonschema-meta_1667361745641/work
|
||||
jupyter_client @ file:///home/conda/feedstock_root/build_artifacts/jupyter_client_1668623095912/work
|
||||
jupyter_core @ file:///home/conda/feedstock_root/build_artifacts/jupyter_core_1668030817979/work
|
||||
@@ -63,6 +64,7 @@ keyring @ file:///home/conda/feedstock_root/build_artifacts/keyring_166769673355
|
||||
kiwisolver==1.4.4
|
||||
langcodes==3.3.0
|
||||
llvmlite==0.39.1
|
||||
loguru==0.6.0
|
||||
Markdown==3.4.1
|
||||
MarkupSafe @ file:///home/conda/feedstock_root/build_artifacts/markupsafe_1666770195345/work
|
||||
matplotlib==3.6.2
|
||||
@@ -134,6 +136,8 @@ tomlkit @ file:///home/conda/feedstock_root/build_artifacts/tomlkit_166686418860
|
||||
toolz @ file:///home/conda/feedstock_root/build_artifacts/toolz_1657485559105/work
|
||||
torch==1.10.0+cu113
|
||||
torchaudio==0.10.0+cu113
|
||||
torchinfo==1.7.1
|
||||
torchsummaryX==1.3.0
|
||||
torchvision==0.11.1+cu113
|
||||
tornado @ file:///home/conda/feedstock_root/build_artifacts/tornado_1666788592778/work
|
||||
tqdm==4.64.1
|
||||
|
||||
@@ -1,778 +0,0 @@
|
||||
{
|
||||
"cells": [
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"id": "b1e031e3",
|
||||
"metadata": {},
|
||||
"source": [
|
||||
"- [x] try just one predictor\n",
|
||||
" - [ ] multi input, single output\n",
|
||||
"- [x] comparem ulti\n",
|
||||
"- losses:\n",
|
||||
" - try logp? nah\n",
|
||||
" - mae?\n",
|
||||
"- [x] make my own csv with 5m data (maybe 10k rows)\n",
|
||||
"- [ ] backtest?"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 1,
|
||||
"id": "7f9e3d73",
|
||||
"metadata": {
|
||||
"ExecuteTime": {
|
||||
"end_time": "2022-11-23T05:42:46.869262Z",
|
||||
"start_time": "2022-11-23T05:42:46.859832Z"
|
||||
}
|
||||
},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"import warnings\n",
|
||||
"warnings.simplefilter(\"ignore\")\n",
|
||||
"\n",
|
||||
"# autoreload import your package\n",
|
||||
"%load_ext autoreload\n",
|
||||
"%autoreload 2"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 2,
|
||||
"id": "4e09086b",
|
||||
"metadata": {
|
||||
"ExecuteTime": {
|
||||
"end_time": "2022-11-23T05:42:48.073268Z",
|
||||
"start_time": "2022-11-23T05:42:46.870484Z"
|
||||
},
|
||||
"lines_to_next_cell": 0
|
||||
},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"import os\n",
|
||||
"from os.path import join\n",
|
||||
"import math\n",
|
||||
"import logging\n",
|
||||
"from typing import Callable, Optional, Union, Dict, Tuple\n",
|
||||
"\n",
|
||||
"from matplotlib import pyplot as plt\n",
|
||||
"from pathlib import Path\n",
|
||||
"import matplotlib.colors as mcolors\n",
|
||||
"\n",
|
||||
"import gin\n",
|
||||
"from fire import Fire\n",
|
||||
"import numpy as np\n",
|
||||
"import torch\n",
|
||||
"from torch.utils.data import DataLoader\n",
|
||||
"from torch import optim\n",
|
||||
"from torch import nn\n",
|
||||
"\n",
|
||||
"from experiments.base import Experiment\n",
|
||||
"from data.datasets import ForecastDataset\n",
|
||||
"from models import get_model\n",
|
||||
"from utils.checkpoint import Checkpoint\n",
|
||||
"from utils.ops import default_device, to_tensor\n",
|
||||
"from utils.losses import get_loss_fn\n",
|
||||
"from utils.metrics import calc_metrics\n",
|
||||
"\n",
|
||||
"from experiments.forecast import get_data\n",
|
||||
"gin.enter_interactive_mode()"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 3,
|
||||
"id": "66d7f095",
|
||||
"metadata": {
|
||||
"ExecuteTime": {
|
||||
"end_time": "2022-11-23T05:42:48.097114Z",
|
||||
"start_time": "2022-11-23T05:42:48.074670Z"
|
||||
}
|
||||
},
|
||||
"outputs": [
|
||||
{
|
||||
"data": {
|
||||
"text/plain": [
|
||||
"1"
|
||||
]
|
||||
},
|
||||
"execution_count": 3,
|
||||
"metadata": {},
|
||||
"output_type": "execute_result"
|
||||
}
|
||||
],
|
||||
"source": [
|
||||
"import logging\n",
|
||||
"logging.root.setLevel(logging.INFO)\n",
|
||||
"\n",
|
||||
"from loguru import logger\n",
|
||||
"logger.remove()\n",
|
||||
"logger.add(os.sys.stdout, level=\"INFO\", colorize=True, format=\"<level>{time} | {message}</level>\")"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"id": "d4df5270",
|
||||
"metadata": {},
|
||||
"source": [
|
||||
"# auto"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 4,
|
||||
"id": "04499bef",
|
||||
"metadata": {
|
||||
"ExecuteTime": {
|
||||
"end_time": "2022-11-23T05:42:48.115157Z",
|
||||
"start_time": "2022-11-23T05:42:48.098029Z"
|
||||
}
|
||||
},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"\n",
|
||||
"\n",
|
||||
"def plot(model_name=\"deeptime\", save_path=Path(\"storage/experiments/Exchange/96M/repeat=0\"), i=200, title=None, plot=True):\n",
|
||||
"\n",
|
||||
" gin.clear_config()\n",
|
||||
" gin.parse_config(open(save_path/\"config.gin\"))\n",
|
||||
"\n",
|
||||
" train_set, train_loader = get_data(flag='train', batch_size=2)\n",
|
||||
"\n",
|
||||
" model = get_model(model_name,\n",
|
||||
" dim_size=train_set.data_x.shape[1],\n",
|
||||
" datetime_feats=train_set.timestamps.shape[-1]).to(default_device())\n",
|
||||
" model.load_state_dict(torch.load(save_path/'model.pth'))\n",
|
||||
" model = model.eval()\n",
|
||||
"\n",
|
||||
"\n",
|
||||
" b = train_set[i]\n",
|
||||
" b = [bb[None, :] for bb in b]\n",
|
||||
" b2 = list(map(to_tensor, b))\n",
|
||||
" context_past_x, context_y, query_past_x, query_y, context_time, query_time = b2\n",
|
||||
" with torch.no_grad():\n",
|
||||
" forecast = model(*b2)\n",
|
||||
"\n",
|
||||
" if title is None:\n",
|
||||
" title = str(save_path).split('/')[-3:]\n",
|
||||
" title = \"-\".join(title)\n",
|
||||
" \n",
|
||||
" colors = list(mcolors.BASE_COLORS.keys())\n",
|
||||
" l = x.shape[1]\n",
|
||||
" forecast2 = forecast[0].detach().cpu().numpy()\n",
|
||||
" x2 = x[0].cpu()\n",
|
||||
" y2 = y[0].cpu()\n",
|
||||
" l2 = y.shape[1]\n",
|
||||
" i_past = list(range(l))\n",
|
||||
" i_future = list(range(l, l+l2))\n",
|
||||
" \n",
|
||||
" if plot:\n",
|
||||
" plt.title(title)\n",
|
||||
" for i in range(x.shape[-1]):\n",
|
||||
" plt.plot(i_past, x2[:, i], c=colors[i])\n",
|
||||
" for i in range(x.shape[-1]):\n",
|
||||
" plt.plot(i_future, y2[:, i], c=colors[i])\n",
|
||||
" for i in range(x.shape[-1]):\n",
|
||||
" plt.plot(i_future, forecast2[:, i], c=colors[i], linestyle='--')\n",
|
||||
" return x2, y2, forecast2, i_past, i_future\n"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"id": "6615bb09",
|
||||
"metadata": {
|
||||
"ExecuteTime": {
|
||||
"end_time": "2022-11-22T13:28:35.849491Z",
|
||||
"start_time": "2022-11-22T13:28:35.766453Z"
|
||||
},
|
||||
"scrolled": true
|
||||
},
|
||||
"outputs": [],
|
||||
"source": []
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 5,
|
||||
"id": "ec47ae46",
|
||||
"metadata": {
|
||||
"ExecuteTime": {
|
||||
"end_time": "2022-11-23T05:42:48.138242Z",
|
||||
"start_time": "2022-11-23T05:42:48.115973Z"
|
||||
}
|
||||
},
|
||||
"outputs": [
|
||||
{
|
||||
"data": {
|
||||
"text/plain": [
|
||||
"['b', 'g', 'r', 'c', 'm', 'y', 'k', 'w']"
|
||||
]
|
||||
},
|
||||
"execution_count": 5,
|
||||
"metadata": {},
|
||||
"output_type": "execute_result"
|
||||
}
|
||||
],
|
||||
"source": [
|
||||
"list(mcolors.BASE_COLORS.keys())"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"id": "0a298ce2",
|
||||
"metadata": {
|
||||
"ExecuteTime": {
|
||||
"end_time": "2022-11-22T13:38:15.786656Z",
|
||||
"start_time": "2022-11-22T13:38:15.303577Z"
|
||||
}
|
||||
},
|
||||
"source": [
|
||||
"# view model"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"id": "40002390",
|
||||
"metadata": {},
|
||||
"source": [
|
||||
"# run exps"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"id": "2e1e58e7",
|
||||
"metadata": {
|
||||
"ExecuteTime": {
|
||||
"end_time": "2022-11-23T03:31:08.583641Z",
|
||||
"start_time": "2022-11-23T03:31:08.558690Z"
|
||||
}
|
||||
},
|
||||
"outputs": [],
|
||||
"source": []
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 6,
|
||||
"id": "91b92e80",
|
||||
"metadata": {
|
||||
"ExecuteTime": {
|
||||
"end_time": "2022-11-23T05:42:48.154921Z",
|
||||
"start_time": "2022-11-23T05:42:48.139190Z"
|
||||
}
|
||||
},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"# list the models we have run...\n",
|
||||
"configs=sorted(Path(\"storage/experiments/Stocks\").glob(\"**/config.gin\"))\n",
|
||||
"import random\n",
|
||||
"random.shuffle(configs)\n",
|
||||
"# print(configs)"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 7,
|
||||
"id": "8bb001f7",
|
||||
"metadata": {
|
||||
"ExecuteTime": {
|
||||
"end_time": "2022-11-23T05:42:48.171079Z",
|
||||
"start_time": "2022-11-23T05:42:48.155764Z"
|
||||
}
|
||||
},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"from experiments.forecast import ForecastExperiment\n",
|
||||
"from tqdm.auto import tqdm"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"id": "029641dd",
|
||||
"metadata": {
|
||||
"ExecuteTime": {
|
||||
"start_time": "2022-11-23T05:42:46.859Z"
|
||||
},
|
||||
"scrolled": true
|
||||
},
|
||||
"outputs": [
|
||||
{
|
||||
"data": {
|
||||
"application/vnd.jupyter.widget-view+json": {
|
||||
"model_id": "49eccaa5eb6942e29ed4943d5b92cf89",
|
||||
"version_major": 2,
|
||||
"version_minor": 0
|
||||
},
|
||||
"text/plain": [
|
||||
" 0%| | 0/42 [00:00<?, ?it/s]"
|
||||
]
|
||||
},
|
||||
"metadata": {},
|
||||
"output_type": "display_data"
|
||||
},
|
||||
{
|
||||
"name": "stdout",
|
||||
"output_type": "stream",
|
||||
"text": [
|
||||
"storage/experiments/Stocks/96M2S/base_learner=Transformer,inr=INR,encoder=none,repeat=0/config.gin\n"
|
||||
]
|
||||
},
|
||||
{
|
||||
"name": "stderr",
|
||||
"output_type": "stream",
|
||||
"text": [
|
||||
"INFO:root:epochs: 1, iters: 100 | training loss: 1.84\n"
|
||||
]
|
||||
}
|
||||
],
|
||||
"source": [
|
||||
"for config in tqdm(configs):\n",
|
||||
" save_path = config.parent\n",
|
||||
"\n",
|
||||
" exp = ForecastExperiment(config_path=config)\n",
|
||||
" print(config)\n",
|
||||
" try:\n",
|
||||
" exp.run()\n",
|
||||
" except KeyboardInterrupt:\n",
|
||||
" raise\n",
|
||||
" except Exception as e:\n",
|
||||
" raise\n",
|
||||
" print(e)\n",
|
||||
" pass"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"id": "0472a84b",
|
||||
"metadata": {
|
||||
"ExecuteTime": {
|
||||
"start_time": "2022-11-23T05:42:46.875Z"
|
||||
}
|
||||
},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"%debug"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"id": "03a28a90",
|
||||
"metadata": {
|
||||
"ExecuteTime": {
|
||||
"start_time": "2022-11-23T05:42:46.875Z"
|
||||
}
|
||||
},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"exp.instance()"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"id": "a31ae85a",
|
||||
"metadata": {
|
||||
"ExecuteTime": {
|
||||
"start_time": "2022-11-23T05:42:46.875Z"
|
||||
}
|
||||
},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"# save_path = Path('storage/experiments/Stocks/96M2S/repeat=0')"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"id": "1a9d823d",
|
||||
"metadata": {
|
||||
"ExecuteTime": {
|
||||
"start_time": "2022-11-23T05:42:46.875Z"
|
||||
}
|
||||
},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"# gin.clear_config()\n",
|
||||
"# config_path = save_path/\"config.gin\"\n",
|
||||
"# gin.parse_config(open(config_path))\n",
|
||||
"# model_name = gin.query_parameter(\"instance.model_type\")\n",
|
||||
"# model_name"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"id": "5e4983e7",
|
||||
"metadata": {
|
||||
"ExecuteTime": {
|
||||
"start_time": "2022-11-23T05:42:46.875Z"
|
||||
}
|
||||
},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"\n",
|
||||
"# exp = ForecastExperiment(config_path=config_path)\n",
|
||||
"# # exp.run()"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"id": "cdd0fbbf",
|
||||
"metadata": {
|
||||
"ExecuteTime": {
|
||||
"start_time": "2022-11-23T05:42:46.875Z"
|
||||
}
|
||||
},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"def save_path2name(save_path: Path) -> str:\n",
|
||||
" \"\"\"\n",
|
||||
" Path('storage/experiments/Stocks/96M2S/base_learner=None,inr=INR,encoder=mlp,repeat=0')\n",
|
||||
" to \n",
|
||||
" '96M2S-None_INR_mlp_0'\n",
|
||||
" \"\"\"\n",
|
||||
" mtitle = str(save_path).split('/')[-2:]\n",
|
||||
" tags = mtitle[-1]\n",
|
||||
" tags = [x.split('=')[-1] for x in tags.split(',')]\n",
|
||||
" mtitle[-1] = '_'.join(tags)\n",
|
||||
" mtitle = \"-\".join(mtitle)\n",
|
||||
" return mtitle\n",
|
||||
"\n",
|
||||
"# save_path2name(save_path)\n",
|
||||
"# save_path"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"id": "321106ef",
|
||||
"metadata": {},
|
||||
"source": [
|
||||
"# view all"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"id": "768530be",
|
||||
"metadata": {
|
||||
"ExecuteTime": {
|
||||
"start_time": "2022-11-23T05:42:46.875Z"
|
||||
}
|
||||
},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"from torchsummaryX import summary\n",
|
||||
"\n",
|
||||
"def plot_multi(save_paths=[Path(\"storage/experiments/Exchange/96M/repeat=0\")], i=200, title=None, plot=True, verbose=1,):\n",
|
||||
" assert len(save_paths)>0\n",
|
||||
" for j in range(len(save_paths)):\n",
|
||||
" save_path = save_paths[j]\n",
|
||||
"\n",
|
||||
" gin.clear_config()\n",
|
||||
" gin.parse_config(open(save_path/\"config.gin\"))\n",
|
||||
" model_name = gin.query_parameter(\"instance.model_type\")\n",
|
||||
"\n",
|
||||
" train_set, train_loader = get_data(flag='test', batch_size=3)\n",
|
||||
" seq_len = train_set[0][1].shape[0]\n",
|
||||
" model = get_model(model_name,\n",
|
||||
" dim_size=train_set.data_x.shape[1],\n",
|
||||
" seq_len=seq_len,\n",
|
||||
" datetime_feats=train_set.timestamps.shape[-1]).to(default_device())\n",
|
||||
" model.load_state_dict(torch.load(save_path/'model.pth'))\n",
|
||||
" model = model.eval()\n",
|
||||
" \n",
|
||||
" \n",
|
||||
"\n",
|
||||
"\n",
|
||||
" b = train_set[i]\n",
|
||||
" b = [bb[None, :] for bb in b]\n",
|
||||
" b2 = list(map(to_tensor, b))\n",
|
||||
" \n",
|
||||
"# b = next(iter(train_loader))\n",
|
||||
"# print([s.shape for s in b]\n",
|
||||
" \n",
|
||||
" if verbose>1:\n",
|
||||
" \n",
|
||||
"# print(model)\n",
|
||||
" summary(model, *b2)\n",
|
||||
" print(save_path)\n",
|
||||
" \n",
|
||||
" context_past_x, context_y, query_past_x, query_y, context_time, query_time = b2\n",
|
||||
" with torch.no_grad():\n",
|
||||
" forecast = model(*b2)\n",
|
||||
" \n",
|
||||
" colors = list(mcolors.BASE_COLORS.keys())\n",
|
||||
" l = context_time.shape[1]\n",
|
||||
" forecast2 = forecast[0].detach().cpu().numpy()\n",
|
||||
" x2 = context_y[0].cpu()\n",
|
||||
" y2 = query_y[0].cpu()\n",
|
||||
" l2 = query_time.shape[1]\n",
|
||||
" i_past = list(range(l))\n",
|
||||
" i_future = list(range(l, l+l2))\n",
|
||||
" \n",
|
||||
" \n",
|
||||
"\n",
|
||||
" if plot:\n",
|
||||
" \n",
|
||||
" if j==0:\n",
|
||||
" plt.plot(i_past, x2[:, 0], c=colors[0], label=f\"past\")\n",
|
||||
" plt.plot(i_future, y2[:, 0], c=colors[0], label=\"future true\", alpha=0.3)\n",
|
||||
" mtitle = save_path2name(save_path)\n",
|
||||
" plt.plot(i_future, forecast2[:, 0], linestyle='--', label=f\"{mtitle}\") # c=colors[j], \n",
|
||||
" \n",
|
||||
"\n",
|
||||
" plt.legend()\n",
|
||||
" plt.title(title)\n",
|
||||
" return x2, y2, forecast2, i_past, i_future\n"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"id": "739ee5e3",
|
||||
"metadata": {
|
||||
"ExecuteTime": {
|
||||
"start_time": "2022-11-23T05:42:46.875Z"
|
||||
}
|
||||
},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"# list the models we have run...\n",
|
||||
"m=sorted(Path(\"storage/experiments/Stocks/96M2S\").glob(\"**/_SUCCESS\"))\n",
|
||||
"print(m)"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"id": "f1ae652d",
|
||||
"metadata": {
|
||||
"ExecuteTime": {
|
||||
"start_time": "2022-11-23T05:42:46.875Z"
|
||||
}
|
||||
},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"for mm in m:\n",
|
||||
" mtitle = save_path2name(mm.parent)\n",
|
||||
" print(mtitle)\n",
|
||||
" m3 = np.load(mm.parent/'metrics.npy', allow_pickle=1)\n",
|
||||
" m3 = eval(str(m3))\n",
|
||||
" print(m3['val']['mape'])"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"id": "79bda4bf",
|
||||
"metadata": {
|
||||
"ExecuteTime": {
|
||||
"start_time": "2022-11-23T05:42:46.875Z"
|
||||
}
|
||||
},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"train_set, train_loader = get_data(flag='train')\n",
|
||||
"train_set[0][1].shape"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"id": "b06e53a9",
|
||||
"metadata": {
|
||||
"ExecuteTime": {
|
||||
"start_time": "2022-11-23T05:42:46.892Z"
|
||||
}
|
||||
},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"save_paths = [mm.parent for mm in m]\n",
|
||||
"for mm in m:\n",
|
||||
" try:\n",
|
||||
" plot_multi(\n",
|
||||
" save_paths=[mm.parent],\n",
|
||||
" i=600,\n",
|
||||
" verbose=2,\n",
|
||||
" )\n",
|
||||
" except:\n",
|
||||
" print('failed', mm)\n",
|
||||
"# mm.unlink()\n",
|
||||
" pass\n",
|
||||
"1"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"id": "25b9b9c5",
|
||||
"metadata": {
|
||||
"ExecuteTime": {
|
||||
"start_time": "2022-11-23T05:42:46.892Z"
|
||||
}
|
||||
},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"save_paths = [mm.parent for mm in m]\n",
|
||||
"plot_multi(\n",
|
||||
" save_paths=save_paths,\n",
|
||||
" i=200,\n",
|
||||
" verbose=0,\n",
|
||||
")\n",
|
||||
"1"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"id": "e173ea6f",
|
||||
"metadata": {
|
||||
"ExecuteTime": {
|
||||
"start_time": "2022-11-23T05:42:46.892Z"
|
||||
}
|
||||
},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"256/24"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"id": "9ca6f142",
|
||||
"metadata": {
|
||||
"ExecuteTime": {
|
||||
"end_time": "2022-11-22T13:17:31.585029Z",
|
||||
"start_time": "2022-11-22T13:17:31.528806Z"
|
||||
}
|
||||
},
|
||||
"outputs": [],
|
||||
"source": []
|
||||
},
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"id": "fce6db58",
|
||||
"metadata": {},
|
||||
"source": [
|
||||
"# check positions in dl"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"id": "5d574727",
|
||||
"metadata": {
|
||||
"ExecuteTime": {
|
||||
"start_time": "2022-11-23T05:42:46.892Z"
|
||||
},
|
||||
"scrolled": true
|
||||
},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"train_set, train_loader = get_data(flag='test', batch_size=3)\n",
|
||||
"b = context_past_x, context_y, query_past_x, query_y, context_time, query_time = train_set[100]\n",
|
||||
"print([bb.shape for bb in b])\n",
|
||||
"# context_y, query_y"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"id": "74ebc120",
|
||||
"metadata": {
|
||||
"ExecuteTime": {
|
||||
"start_time": "2022-11-23T05:42:46.892Z"
|
||||
}
|
||||
},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"cx_start, cx_end, c_start, c_end, qx_start, qx_end, q_start, q_end = train_set.get_inds(100)\n",
|
||||
"cx_start, cx_end, c_start, c_end, qx_start, qx_end, q_start, q_end"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"id": "d289f6bd",
|
||||
"metadata": {
|
||||
"ExecuteTime": {
|
||||
"start_time": "2022-11-23T05:42:46.892Z"
|
||||
}
|
||||
},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"plt.hlines(1, cx_start, cx_end, color='green', alpha=0.5, label='context_past_x')\n",
|
||||
"plt.hlines(2, c_start, c_end, color='green', label='context_labels')\n",
|
||||
"plt.hlines(3, qx_start, qx_end, alpha=0.5, label='query_past_x')\n",
|
||||
"plt.hlines(4, q_start, q_end, label='query_labels/target')\n",
|
||||
"plt.legend(loc='upper left')"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"id": "879d9d45",
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": []
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"id": "37d7641b",
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": []
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"id": "061b3b77",
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": []
|
||||
}
|
||||
],
|
||||
"metadata": {
|
||||
"jupytext": {
|
||||
"cell_metadata_filter": "-all",
|
||||
"main_language": "python",
|
||||
"notebook_metadata_filter": "-all"
|
||||
},
|
||||
"kernelspec": {
|
||||
"display_name": "deeptime",
|
||||
"language": "python",
|
||||
"name": "deeptime"
|
||||
},
|
||||
"language_info": {
|
||||
"codemirror_mode": {
|
||||
"name": "ipython",
|
||||
"version": 3
|
||||
},
|
||||
"file_extension": ".py",
|
||||
"mimetype": "text/x-python",
|
||||
"name": "python",
|
||||
"nbconvert_exporter": "python",
|
||||
"pygments_lexer": "ipython3",
|
||||
"version": "3.8.13"
|
||||
},
|
||||
"toc": {
|
||||
"base_numbering": 1,
|
||||
"nav_menu": {},
|
||||
"number_sections": true,
|
||||
"sideBar": true,
|
||||
"skip_h1_title": false,
|
||||
"title_cell": "Table of Contents",
|
||||
"title_sidebar": "Contents",
|
||||
"toc_cell": false,
|
||||
"toc_position": {},
|
||||
"toc_section_display": true,
|
||||
"toc_window_display": false
|
||||
}
|
||||
},
|
||||
"nbformat": 4,
|
||||
"nbformat_minor": 5
|
||||
}
|
||||
File diff suppressed because one or more lines are too long
File diff suppressed because one or more lines are too long
-869
File diff suppressed because one or more lines are too long
Reference in New Issue
Block a user