partially converted to M2S (multi inputs)

This commit is contained in:
wassname
2022-11-22 16:42:58 +08:00
parent 00ae1d8b8f
commit 9072613d03
7 changed files with 647 additions and 38 deletions
+31 -17
View File
@@ -43,14 +43,15 @@ class ForecastDataset(Dataset):
"""
assert flag in ('train', 'val', 'test'), \
f"flag should be one of (train, val, test)"
assert features in ('M', 'S'), \
f"features should be one of (M: multivar, S: univar)"
assert features in ('M', 'S', 'M2S'), \
f"features should be one of (M: multivar, S: univar, M2S: multi inputs 2 single output)"
assert (lookback_len is not None) ^ (lookback_mult is not None), \
f"only 'lookback_len' xor 'lookback_mult' should be specified"
self.flag = flag
self.lookback_len = int(horizon_len * lookback_mult) if lookback_mult is not None else lookback_len
self.lookback_aux_len = lookback_aux_len
self.gap = 0
self.horizon_len = horizon_len
self.scale = scale
self.cross_learn = cross_learn
@@ -84,8 +85,11 @@ class ForecastDataset(Dataset):
elif self.features == 'S':
df_data = df_raw[[self.target]]
self.n_dims = 1
elif self.features == 'M2S':
df_data = df_raw[cols + [self.target]]
self.n_dims = 1 # len(cols + [self.target])
else:
raise ValueError
raise NotImplementedError(self.features)
self.scaler = StandardScaler()
if self.scale:
@@ -95,13 +99,15 @@ class ForecastDataset(Dataset):
else:
data = df_data.values
self.data_x = data[border1:border2]
self.data_y = data[border1:border2]
# is will be our past data, including the y col
self.data_x = data[border1:border2]
# y is just the col we predict
self.data_y = data[border1:border2][:, [-1]]
self.timestamps = get_time_features(pd.to_datetime(df_raw.date[border1:border2].values),
normalise=self.normalise_time_features,
features=self.time_features)
self.n_time = len(self.data_x)
self.n_time_samples = self.n_time - self.lookback_len - self.horizon_len + 1
self.n_time_samples = self.n_time - self.lookback_len * 2 - self.horizon_len + 1 + self.gap
def get_borders(self, df_raw: pd.DataFrame) -> Tuple[List[int], List[int], List[int], List[int]]:
set_type = {'train': 0, 'val': 1, 'test': 2}[self.flag]
@@ -133,18 +139,26 @@ class ForecastDataset(Dataset):
idx = idx % self.n_time_samples
else:
dim_slice = slice(None)
cx_start = idx
cx_end = cx_start + self.lookback_len
c_start = cx_end + self.gap
c_end = c_start + self.horizon_len
qx_start = cx_end + self.gap
qx_end = qx_start + self.lookback_len
q_start = qx_end + self.gap
q_end = q_start + self.horizon_len
x_start = idx
x_end = x_start + self.lookback_len
y_start = x_end - self.lookback_aux_len
y_end = y_start + self.lookback_aux_len + self.horizon_len
x = self.data_x[x_start:x_end, dim_slice]
y = self.data_y[y_start:y_end, dim_slice]
x_time = self.timestamps[x_start:x_end]
y_time = self.timestamps[y_start:y_end]
return x, y, x_time, y_time
context_past_x = self.data_x[cx_start:cx_end, dim_slice]
context_y = self.data_y[c_start:c_end, dim_slice]
context_time = self.timestamps[c_start:c_end]
query_past_x = self.data_x[qx_start:qx_end, dim_slice]
query_y = self.data_y[q_start:q_end, dim_slice]
query_time = self.timestamps[q_start:q_end]
return context_past_x, context_y, query_past_x, query_y, context_time, query_time
def inverse_transform(self, data):
return self.scaler.inverse_transform(data)
+37
View File
@@ -0,0 +1,37 @@
build.experiment_name = 'Stocks/96M2S'
build.module = 'experiments.forecast'
build.repeat = 1
build.variables_dict = {
}
instance.model_type = 'deeptime3'
instance.save_vals = False
get_optimizer.lr = 1e-3
get_optimizer.lambda_lr = 1.
get_optimizer.weight_decay = 0.
get_scheduler.warmup_epochs = 5
get_data.batch_size = 256
train.loss_name = 'mse'
train.epochs = 50
train.clip = 10.
Checkpoint.patience = 7
deeptime3.layer_size = 32
deeptime3.inr_layers = 5
deeptime3.n_fourier_feats = 4096
deeptime3.scales = [0.01, 0.1, 1, 5, 10, 20, 50, 100]
ForecastDataset.data_path = 'stocks/OXY_2019.csv.gz'
ForecastDataset.target = 'RSMKs_18_144_72'
ForecastDataset.scale = True
ForecastDataset.cross_learn = False
ForecastDataset.time_features = []
ForecastDataset.normalise_time_features = True
ForecastDataset.features = 'M2S'
ForecastDataset.horizon_len = 96
ForecastDataset.lookback_mult = 1
+13 -12
View File
@@ -139,16 +139,16 @@ def train(model: nn.Module,
model.train()
for it, data in enumerate(train_loader):
optimizer.zero_grad()
x, y, x_time, y_time = map(to_tensor, data)
forecast = model(x, x_time, y_time)
data2 = list(map(to_tensor, data))
context_past_x, context_y, query_past_x, query_y, context_time, query_time = data2
forecast = model(*data2)
if isinstance(forecast, tuple):
# for models which require reconstruction + forecast loss
loss = training_loss_fn(forecast[0], x) + \
training_loss_fn(forecast[1], y)
loss = training_loss_fn(forecast[0], context_y) + \
training_loss_fn(forecast[1], query_y)
else:
loss = training_loss_fn(forecast, y)
loss = training_loss_fn(forecast, query_y)
loss.backward()
nn.utils.clip_grad_norm_(model.parameters(), clip)
optimizer.step()
@@ -188,22 +188,23 @@ def validate(model: nn.Module,
inps = []
total_loss = []
for it, data in enumerate(loader):
x, y, x_time, y_time = map(to_tensor, data)
data2 = list(map(to_tensor, data))
context_past_x, context_y, query_past_x, query_y, context_time, query_time = data2
if x.shape[0] == 1:
if context_past_x.shape[0] == 1:
# skip final batch if batch_size == 1
# due to bug in torch.linalg.solve which raises error when batch_size == 1
continue
forecast = model(x, x_time, y_time)
forecast = model(*data2)
if report_metrics:
preds.append(forecast)
trues.append(y)
trues.append(query_y)
if save_path is not None:
inps.append(x)
inps.append(context_y)
else:
loss = loss_fn(forecast, y, reduction='none')
loss = loss_fn(forecast, query_y, reduction='none')
total_loss.append(loss)
if report_metrics:
+88
View File
@@ -0,0 +1,88 @@
# Copyright (c) 2022, salesforce.com, inc.
# All rights reserved.
# SPDX-License-Identifier: BSD-3-Clause
# For full license text, see the LICENSE file in the repo root or https://opensource.org/licenses/BSD-3-Clause
from typing import Optional
import gin
import torch
import torch.nn as nn
from torch import Tensor
from einops import rearrange, repeat, reduce
from models.modules.causalinception import CausalInceptionTimePlus, CausalConv1d
from models.modules.inrplus2 import INRPlus2
from models.modules.regressors import RidgeRegressor
@gin.configurable()
def deeptime3(dim_size:int, datetime_feats: int, layer_size: int, inr_layers: int, n_fourier_feats: int, scales: float):
return DeepTIMe3(dim_size, datetime_feats, layer_size, inr_layers, n_fourier_feats, scales)
class DeepTIMe3(nn.Module):
def __init__(self, dim_size: int, datetime_feats: int, layer_size: int, inr_layers: int, n_fourier_feats: int, scales: float, dropout: float=0.3):
super().__init__()
# encode the past
encoded_size = layer_size//2
self.encoder = CausalInceptionTimePlus(
c_in=dim_size, c_out=encoded_size,
# nf=32, depth=6,
nf=17, depth=3,
bn=True,
ks=[39, 19, 3],
coord=True, fc_dropout=dropout,
)
# translate coords to a representation, given a summary of the past
coord_size = 1
in_feats=datetime_feats+encoded_size+coord_size
self.inr = INRPlus2(in_feats=in_feats, layers=inr_layers, layer_size=layer_size,
n_fourier_feats=n_fourier_feats, scales=scales, dropout=dropout)
# meta learn y given a representation
self.adaptive_weights = RidgeRegressor()
self.datetime_feats = datetime_feats
self.inr_layers = inr_layers
self.layer_size = layer_size
self.n_fourier_feats = n_fourier_feats
self.scales = scales
def encode_and_decode(self, past_x, time, offset=0):
"""
h_past = encode(past) # get representation of past
representation = decode(h_past, coords)
i = length of past, so we can offset the coords
"""
encoded_x = self.encoder(past_x.transpose(2, 1))
# relative coordinates are the same for each batch, so we make them once and repeat them
past_len = time.shape[1]
encoded_x = repeat(encoded_x, "b f -> b t f", t=past_len)
coords = self.get_coords(past_len).to(time.device) + offset
coords = repeat(coords, "1 t 1 -> b t 1", b=time.shape[0])
context_input = torch.cat([encoded_x, coords, time], dim=-1)
context_repr = self.inr(context_input)
return context_repr
def forward(self, context_past_x, context_y, query_past_x, query_y, context_time, query_time) -> Tensor:
context_reprs = self.encode_and_decode(context_past_x, context_time)
query_reprs = self.encode_and_decode(query_past_x, query_time, offset=context_reprs.shape[1])
w, b = self.adaptive_weights(context_reprs, context_y)
preds = self.forecast(query_reprs, w, b)
return preds
def forecast(self, inp: Tensor, w: Tensor, b: Tensor) -> Tensor:
return torch.einsum('... d o, ... t d -> ... t o', [w, inp]) + b
def get_coords(self, lookback_len: int) -> Tensor:
coords = torch.linspace(0, 1, lookback_len)
return rearrange(coords, 't -> 1 t 1')
+5 -2
View File
@@ -4,13 +4,16 @@ import torch
from .DeepTIMe import deeptime
from .DeepTIMe2 import deeptime2
from .DeepTIMe3 import deeptime3
def get_model(model_type: str, **kwargs: Union[int, float]) -> torch.nn.Module:
if model_type == 'deeptime':
model = deeptime(datetime_feats=kwargs['datetime_feats'])
model = deeptime(datetime_feats=kwargs['datetime_feats'], dim_size=kwargs['dim_size'])
elif model_type=="deeptime2":
model = deeptime2(datetime_feats=kwargs['datetime_feats'])
model = deeptime2(datetime_feats=kwargs['datetime_feats'], dim_size=kwargs['dim_size'])
elif model_type=="deeptime3":
model = deeptime3(datetime_feats=kwargs['datetime_feats'], dim_size=kwargs['dim_size'])
else:
raise ValueError(f"Unknown model type {model_type}")
return model
+9 -7
View File
@@ -24,13 +24,14 @@ class INRPlus2(nn.Module):
def __init__(self, in_feats: int, layers: int, layer_size: int, n_fourier_feats: int, scales: float,
dropout: Optional[float] = 0.5, bn=False, *args, **kwargs):
super().__init__()
self.features = nn.Linear(in_feats, layer_size) if n_fourier_feats == 0 \
self.n_fourier_feats = n_fourier_feats
self.features = nn.Linear(in_feats, in_feats) if n_fourier_feats == 0 \
else GaussianFourierFeatureTransform(in_feats, n_fourier_feats, scales)
in_size = layer_size if n_fourier_feats == 0 \
in_size = in_feats if n_fourier_feats == 0 \
else n_fourier_feats+in_feats
# import pdb; pdb.set_trace()
self.layers = CausalInceptionTimePlus(
in_size-1, layer_size, seq_len=None, nf=layer_size, depth=layers,
in_size, layer_size, seq_len=None, nf=layer_size, depth=layers,
flatten=False, concat_pool=False, fc_dropout=dropout, conv_dropout=0.05, bn=bn, y_range=None, custom_head=custom_head, ks=[139, 19, 3], dilation=2, *args, **kwargs
)
# layers = [INRPlusLayer(in_size, layer_size, dropout=dropout)] + \
@@ -38,6 +39,7 @@ class INRPlus2(nn.Module):
# self.layers = nn.Sequential(*layers)
def forward(self, x: Tensor) -> Tensor:
x = self.features(x)
# import pdb; pdb.set_trace()
return self.layers(x.permute((0, 2, 1))).permute((0, 2, 1))
f = self.features(x)
if self.n_fourier_feats>0:
f = torch.concat([f, x], -1)
return self.layers(f.permute((0, 2, 1))).permute((0, 2, 1))
+464
View File
@@ -0,0 +1,464 @@
{
"cells": [
{
"cell_type": "markdown",
"id": "b1e031e3",
"metadata": {},
"source": [
"- [x] try just one predictor\n",
" - [ ] multi input, single output\n",
"- [x] comparem ulti\n",
"- losses:\n",
" - try logp? nah\n",
" - mae?\n",
"- [x] make my own csv with 5m data (maybe 10k rows)\n",
"- [ ] backtest?"
]
},
{
"cell_type": "code",
"execution_count": 1,
"id": "7f9e3d73",
"metadata": {
"ExecuteTime": {
"end_time": "2022-11-22T08:32:51.601212Z",
"start_time": "2022-11-22T08:32:51.590095Z"
}
},
"outputs": [],
"source": [
"import warnings\n",
"warnings.simplefilter(\"ignore\")\n",
"\n",
"# autoreload import your package\n",
"%load_ext autoreload\n",
"%autoreload 2"
]
},
{
"cell_type": "code",
"execution_count": 2,
"id": "4e09086b",
"metadata": {
"ExecuteTime": {
"end_time": "2022-11-22T08:32:52.773863Z",
"start_time": "2022-11-22T08:32:51.602648Z"
},
"lines_to_next_cell": 0
},
"outputs": [],
"source": [
"import os\n",
"from os.path import join\n",
"import math\n",
"import logging\n",
"from typing import Callable, Optional, Union, Dict, Tuple\n",
"\n",
"from matplotlib import pyplot as plt\n",
"from pathlib import Path\n",
"import matplotlib.colors as mcolors\n",
"\n",
"import gin\n",
"from fire import Fire\n",
"import numpy as np\n",
"import torch\n",
"from torch.utils.data import DataLoader\n",
"from torch import optim\n",
"from torch import nn\n",
"\n",
"from experiments.base import Experiment\n",
"from data.datasets import ForecastDataset\n",
"from models import get_model\n",
"from utils.checkpoint import Checkpoint\n",
"from utils.ops import default_device, to_tensor\n",
"from utils.losses import get_loss_fn\n",
"from utils.metrics import calc_metrics\n",
"\n",
"from experiments.forecast import get_data\n",
"gin.enter_interactive_mode()"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "66d7f095",
"metadata": {
"ExecuteTime": {
"end_time": "2022-11-19T23:55:34.939075Z",
"start_time": "2022-11-19T23:55:34.820277Z"
}
},
"outputs": [],
"source": []
},
{
"cell_type": "markdown",
"id": "d4df5270",
"metadata": {},
"source": [
"# auto"
]
},
{
"cell_type": "code",
"execution_count": 3,
"id": "04499bef",
"metadata": {
"ExecuteTime": {
"end_time": "2022-11-22T08:32:52.793863Z",
"start_time": "2022-11-22T08:32:52.776011Z"
}
},
"outputs": [],
"source": [
"\n",
"\n",
"def plot(model_name=\"deeptime\", save_path=Path(\"storage/experiments/Exchange/96M/repeat=0\"), i=200, title=None, plot=True):\n",
"\n",
" gin.clear_config()\n",
" gin.parse_config(open(save_path/\"config.gin\"))\n",
"\n",
" train_set, train_loader = get_data(flag='train', batch_size=2)\n",
"\n",
" model = get_model(model_name,\n",
" dim_size=train_set.data_x.shape[1],\n",
" datetime_feats=train_set.timestamps.shape[-1]).to(default_device())\n",
" model.load_state_dict(torch.load(save_path/'model.pth'))\n",
" model = model.eval()\n",
"\n",
"\n",
" b = train_set[i]\n",
" b = [bb[None, :] for bb in b]\n",
" x, y, x_time, y_time = map(to_tensor, b)\n",
" with torch.no_grad():\n",
" forecast = model(x, x_time, y_time)\n",
"\n",
" if title is None:\n",
" title = str(save_path).split('/')[-3:]\n",
" title = \"-\".join(title)\n",
" \n",
" colors = list(mcolors.BASE_COLORS.keys())\n",
" l = x.shape[1]\n",
" forecast2 = forecast[0].detach().cpu().numpy()\n",
" x2 = x[0].cpu()\n",
" y2 = y[0].cpu()\n",
" l2 = y.shape[1]\n",
" i_past = list(range(l))\n",
" i_future = list(range(l, l+l2))\n",
" \n",
" if plot:\n",
" plt.title(title)\n",
" for i in range(x.shape[-1]):\n",
" plt.plot(i_past, x2[:, i], c=colors[i])\n",
" for i in range(x.shape[-1]):\n",
" plt.plot(i_future, y2[:, i], c=colors[i])\n",
" for i in range(x.shape[-1]):\n",
" plt.plot(i_future, forecast2[:, i], c=colors[i], linestyle='--')\n",
" return x2, y2, forecast2, i_past, i_future\n"
]
},
{
"cell_type": "code",
"execution_count": 4,
"id": "768530be",
"metadata": {
"ExecuteTime": {
"end_time": "2022-11-22T08:32:52.819217Z",
"start_time": "2022-11-22T08:32:52.794951Z"
}
},
"outputs": [],
"source": [
"\n",
"\n",
"def plot_multi(save_paths=[Path(\"storage/experiments/Exchange/96M/repeat=0\")], i=200, title=None, plot=True):\n",
" for j in range(len(save_paths)):\n",
" save_path = save_paths[j]\n",
"\n",
" gin.clear_config()\n",
" gin.parse_config(open(save_path/\"config.gin\"))\n",
" model_name = gin.query_parameter(\"instance.model_type\")\n",
"\n",
" train_set, train_loader = get_data(flag='test', batch_size=3)\n",
"\n",
" model = get_model(model_name,\n",
" dim_size=train_set.data_x.shape[1],\n",
" datetime_feats=train_set.timestamps.shape[-1]).to(default_device())\n",
" model.load_state_dict(torch.load(save_path/'model.pth'))\n",
" model = model.eval()\n",
"\n",
"\n",
" b = train_set[i]\n",
" b = [bb[None, :] for bb in b]\n",
" \n",
" b = next(iter(train_loader))\n",
" print([s.shape for s in b])\n",
" \n",
" x, y, x_time, y_time = map(to_tensor, b)\n",
"# print(b)\n",
" with torch.no_grad():\n",
" forecast = model(x, x_time, y_time)\n",
" \n",
" colors = list(mcolors.BASE_COLORS.keys())\n",
" l = x.shape[1]\n",
" forecast2 = forecast[0].detach().cpu().numpy()\n",
" x2 = x[0].cpu()\n",
" y2 = y[0].cpu()\n",
" l2 = y.shape[1]\n",
" i_past = list(range(l))\n",
" i_future = list(range(l, l+l2))\n",
"\n",
" if plot:\n",
" plt.plot(i_past, x2[:, 0], c=colors[0], label=f\"past\")\n",
" plt.plot(i_future, y2[:, 0], c=colors[0], label=\"future true\", alpha=0.3)\n",
" \n",
" mtitle = str(save_path).split('/')[-2:-1]\n",
" mtitle = \"-\".join(mtitle)\n",
" plt.plot(i_future, forecast2[:, 0], c=colors[j], linestyle='--', label=f\"{mtitle}\")\n",
" plt.legend()\n",
" plt.title(title)\n",
" return x2, y2, forecast2, i_past, i_future\n"
]
},
{
"cell_type": "code",
"execution_count": 5,
"id": "739ee5e3",
"metadata": {
"ExecuteTime": {
"end_time": "2022-11-22T08:32:52.843086Z",
"start_time": "2022-11-22T08:32:52.820267Z"
}
},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"[Path('storage/experiments/Stocks/96M/repeat=0/_SUCCESS'), Path('storage/experiments/Stocks/96S/repeat=0/_SUCCESS'), Path('storage/experiments/Stocks/96Splus/repeat=0/_SUCCESS'), Path('storage/experiments/Stocks/96Splusshort/repeat=0/_SUCCESS'), Path('storage/experiments/Stocks/96Sshort/repeat=0/_SUCCESS')]\n"
]
}
],
"source": [
"# list the models we have run...\n",
"m=sorted(Path(\"storage/experiments/Stocks\").glob(\"**/_SUCCESS\"))\n",
"print(m)"
]
},
{
"cell_type": "code",
"execution_count": 6,
"id": "0529c377",
"metadata": {
"ExecuteTime": {
"end_time": "2022-11-22T08:32:52.857950Z",
"start_time": "2022-11-22T08:32:52.844134Z"
}
},
"outputs": [],
"source": [
"save_path = Path('storage/experiments/Stocks/96M2S/repeat=0')"
]
},
{
"cell_type": "code",
"execution_count": 7,
"id": "2ce3c75a",
"metadata": {
"ExecuteTime": {
"end_time": "2022-11-22T08:32:52.879200Z",
"start_time": "2022-11-22T08:32:52.859136Z"
}
},
"outputs": [
{
"data": {
"text/plain": [
"'deeptime3'"
]
},
"execution_count": 7,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"gin.clear_config()\n",
"config_path = save_path/\"config.gin\"\n",
"gin.parse_config(open(config_path))\n",
"model_name = gin.query_parameter(\"instance.model_type\")\n",
"model_name"
]
},
{
"cell_type": "code",
"execution_count": 8,
"id": "ef1989b7",
"metadata": {
"ExecuteTime": {
"end_time": "2022-11-22T08:32:52.904019Z",
"start_time": "2022-11-22T08:32:52.880345Z"
}
},
"outputs": [
{
"data": {
"text/plain": [
"<experiments.forecast.ForecastExperiment at 0x7f4044fc3fa0>"
]
},
"execution_count": 8,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"from experiments.forecast import ForecastExperiment\n",
"exp = ForecastExperiment(config_path=config_path)\n",
"exp"
]
},
{
"cell_type": "code",
"execution_count": 9,
"id": "915e5648",
"metadata": {
"ExecuteTime": {
"end_time": "2022-11-22T08:32:56.783272Z",
"start_time": "2022-11-22T08:32:52.905696Z"
}
},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"receptive field [114 72 12]=[38 18 2]*[[1 1 1]\n",
" [1 1 1]\n",
" [1 2 4]]\n",
"129 in_feats\n",
"receptive field [690 378 242]=[138 18 2]*[[ 1 1 1]\n",
" [ 1 2 4]\n",
" [ 1 4 16]\n",
" [ 1 6 36]\n",
" [ 1 8 64]]\n",
"torch.Size([256, 96, 129])\n",
"torch.Size([256, 96, 129])\n"
]
},
{
"ename": "RuntimeError",
"evalue": "CUDA out of memory. Tried to allocate 24.00 MiB (GPU 0; 10.74 GiB total capacity; 8.00 GiB already allocated; 50.12 MiB free; 8.16 GiB reserved in total by PyTorch) If reserved memory is >> allocated memory try setting max_split_size_mb to avoid fragmentation. See documentation for Memory Management and PYTORCH_CUDA_ALLOC_CONF\n In call to configurable 'train' (<function train at 0x7f4045284ee0>)\n In call to configurable 'instance' (<function ForecastExperiment.instance at 0x7f4045284550>)\n In call to configurable 'run' (<function Experiment.run at 0x7f4092a49550>)",
"output_type": "error",
"traceback": [
"\u001b[0;31m---------------------------------------------------------------------------\u001b[0m",
"\u001b[0;31mRuntimeError\u001b[0m Traceback (most recent call last)",
"Cell \u001b[0;32mIn [9], line 1\u001b[0m\n\u001b[0;32m----> 1\u001b[0m \u001b[43mexp\u001b[49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43mrun\u001b[49m\u001b[43m(\u001b[49m\u001b[43m)\u001b[49m\n",
"File \u001b[0;32m~/miniforge3/envs/deeptime/lib/python3.8/site-packages/gin/config.py:1605\u001b[0m, in \u001b[0;36m_make_gin_wrapper.<locals>.gin_wrapper\u001b[0;34m(*args, **kwargs)\u001b[0m\n\u001b[1;32m 1603\u001b[0m scope_info \u001b[38;5;241m=\u001b[39m \u001b[38;5;124m\"\u001b[39m\u001b[38;5;124m in scope \u001b[39m\u001b[38;5;124m'\u001b[39m\u001b[38;5;132;01m{}\u001b[39;00m\u001b[38;5;124m'\u001b[39m\u001b[38;5;124m\"\u001b[39m\u001b[38;5;241m.\u001b[39mformat(scope_str) \u001b[38;5;28;01mif\u001b[39;00m scope_str \u001b[38;5;28;01melse\u001b[39;00m \u001b[38;5;124m'\u001b[39m\u001b[38;5;124m'\u001b[39m\n\u001b[1;32m 1604\u001b[0m err_str \u001b[38;5;241m=\u001b[39m err_str\u001b[38;5;241m.\u001b[39mformat(name, fn_or_cls, scope_info)\n\u001b[0;32m-> 1605\u001b[0m \u001b[43mutils\u001b[49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43maugment_exception_message_and_reraise\u001b[49m\u001b[43m(\u001b[49m\u001b[43me\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43merr_str\u001b[49m\u001b[43m)\u001b[49m\n",
"File \u001b[0;32m~/miniforge3/envs/deeptime/lib/python3.8/site-packages/gin/utils.py:41\u001b[0m, in \u001b[0;36maugment_exception_message_and_reraise\u001b[0;34m(exception, message)\u001b[0m\n\u001b[1;32m 39\u001b[0m proxy \u001b[38;5;241m=\u001b[39m ExceptionProxy()\n\u001b[1;32m 40\u001b[0m ExceptionProxy\u001b[38;5;241m.\u001b[39m\u001b[38;5;18m__qualname__\u001b[39m \u001b[38;5;241m=\u001b[39m \u001b[38;5;28mtype\u001b[39m(exception)\u001b[38;5;241m.\u001b[39m\u001b[38;5;18m__qualname__\u001b[39m\n\u001b[0;32m---> 41\u001b[0m \u001b[38;5;28;01mraise\u001b[39;00m proxy\u001b[38;5;241m.\u001b[39mwith_traceback(exception\u001b[38;5;241m.\u001b[39m__traceback__) \u001b[38;5;28;01mfrom\u001b[39;00m \u001b[38;5;28mNone\u001b[39m\n",
"File \u001b[0;32m~/miniforge3/envs/deeptime/lib/python3.8/site-packages/gin/config.py:1582\u001b[0m, in \u001b[0;36m_make_gin_wrapper.<locals>.gin_wrapper\u001b[0;34m(*args, **kwargs)\u001b[0m\n\u001b[1;32m 1579\u001b[0m new_kwargs\u001b[38;5;241m.\u001b[39mupdate(kwargs)\n\u001b[1;32m 1581\u001b[0m \u001b[38;5;28;01mtry\u001b[39;00m:\n\u001b[0;32m-> 1582\u001b[0m \u001b[38;5;28;01mreturn\u001b[39;00m \u001b[43mfn\u001b[49m\u001b[43m(\u001b[49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[43mnew_args\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[43mnew_kwargs\u001b[49m\u001b[43m)\u001b[49m\n\u001b[1;32m 1583\u001b[0m \u001b[38;5;28;01mexcept\u001b[39;00m \u001b[38;5;167;01mException\u001b[39;00m \u001b[38;5;28;01mas\u001b[39;00m e: \u001b[38;5;66;03m# pylint: disable=broad-except\u001b[39;00m\n\u001b[1;32m 1584\u001b[0m err_str \u001b[38;5;241m=\u001b[39m \u001b[38;5;124m'\u001b[39m\u001b[38;5;124m'\u001b[39m\n",
"File \u001b[0;32m/media/wassname/SGIronWolf/projects5/investing/DeepTime/experiments/base.py:96\u001b[0m, in \u001b[0;36mExperiment.run\u001b[0;34m(self, timer)\u001b[0m\n\u001b[1;32m 94\u001b[0m \u001b[38;5;28;01mexcept\u001b[39;00m \u001b[38;5;167;01mException\u001b[39;00m \u001b[38;5;28;01mas\u001b[39;00m e:\n\u001b[1;32m 95\u001b[0m Path(running_flag)\u001b[38;5;241m.\u001b[39munlink()\n\u001b[0;32m---> 96\u001b[0m \u001b[38;5;28;01mraise\u001b[39;00m e\n\u001b[1;32m 97\u001b[0m \u001b[38;5;28;01mexcept\u001b[39;00m \u001b[38;5;167;01mKeyboardInterrupt\u001b[39;00m:\n\u001b[1;32m 98\u001b[0m Path(running_flag)\u001b[38;5;241m.\u001b[39munlink()\n",
"File \u001b[0;32m/media/wassname/SGIronWolf/projects5/investing/DeepTime/experiments/base.py:93\u001b[0m, in \u001b[0;36mExperiment.run\u001b[0;34m(self, timer)\u001b[0m\n\u001b[1;32m 90\u001b[0m Path(running_flag)\u001b[38;5;241m.\u001b[39mtouch()\n\u001b[1;32m 92\u001b[0m \u001b[38;5;28;01mtry\u001b[39;00m:\n\u001b[0;32m---> 93\u001b[0m \u001b[38;5;28;43mself\u001b[39;49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43minstance\u001b[49m\u001b[43m(\u001b[49m\u001b[43m)\u001b[49m\n\u001b[1;32m 94\u001b[0m \u001b[38;5;28;01mexcept\u001b[39;00m \u001b[38;5;167;01mException\u001b[39;00m \u001b[38;5;28;01mas\u001b[39;00m e:\n\u001b[1;32m 95\u001b[0m Path(running_flag)\u001b[38;5;241m.\u001b[39munlink()\n",
"File \u001b[0;32m~/miniforge3/envs/deeptime/lib/python3.8/site-packages/gin/config.py:1605\u001b[0m, in \u001b[0;36m_make_gin_wrapper.<locals>.gin_wrapper\u001b[0;34m(*args, **kwargs)\u001b[0m\n\u001b[1;32m 1603\u001b[0m scope_info \u001b[38;5;241m=\u001b[39m \u001b[38;5;124m\"\u001b[39m\u001b[38;5;124m in scope \u001b[39m\u001b[38;5;124m'\u001b[39m\u001b[38;5;132;01m{}\u001b[39;00m\u001b[38;5;124m'\u001b[39m\u001b[38;5;124m\"\u001b[39m\u001b[38;5;241m.\u001b[39mformat(scope_str) \u001b[38;5;28;01mif\u001b[39;00m scope_str \u001b[38;5;28;01melse\u001b[39;00m \u001b[38;5;124m'\u001b[39m\u001b[38;5;124m'\u001b[39m\n\u001b[1;32m 1604\u001b[0m err_str \u001b[38;5;241m=\u001b[39m err_str\u001b[38;5;241m.\u001b[39mformat(name, fn_or_cls, scope_info)\n\u001b[0;32m-> 1605\u001b[0m \u001b[43mutils\u001b[49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43maugment_exception_message_and_reraise\u001b[49m\u001b[43m(\u001b[49m\u001b[43me\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43merr_str\u001b[49m\u001b[43m)\u001b[49m\n",
"File \u001b[0;32m~/miniforge3/envs/deeptime/lib/python3.8/site-packages/gin/utils.py:41\u001b[0m, in \u001b[0;36maugment_exception_message_and_reraise\u001b[0;34m(exception, message)\u001b[0m\n\u001b[1;32m 39\u001b[0m proxy \u001b[38;5;241m=\u001b[39m ExceptionProxy()\n\u001b[1;32m 40\u001b[0m ExceptionProxy\u001b[38;5;241m.\u001b[39m\u001b[38;5;18m__qualname__\u001b[39m \u001b[38;5;241m=\u001b[39m \u001b[38;5;28mtype\u001b[39m(exception)\u001b[38;5;241m.\u001b[39m\u001b[38;5;18m__qualname__\u001b[39m\n\u001b[0;32m---> 41\u001b[0m \u001b[38;5;28;01mraise\u001b[39;00m proxy\u001b[38;5;241m.\u001b[39mwith_traceback(exception\u001b[38;5;241m.\u001b[39m__traceback__) \u001b[38;5;28;01mfrom\u001b[39;00m \u001b[38;5;28mNone\u001b[39m\n",
"File \u001b[0;32m~/miniforge3/envs/deeptime/lib/python3.8/site-packages/gin/config.py:1582\u001b[0m, in \u001b[0;36m_make_gin_wrapper.<locals>.gin_wrapper\u001b[0;34m(*args, **kwargs)\u001b[0m\n\u001b[1;32m 1579\u001b[0m new_kwargs\u001b[38;5;241m.\u001b[39mupdate(kwargs)\n\u001b[1;32m 1581\u001b[0m \u001b[38;5;28;01mtry\u001b[39;00m:\n\u001b[0;32m-> 1582\u001b[0m \u001b[38;5;28;01mreturn\u001b[39;00m \u001b[43mfn\u001b[49m\u001b[43m(\u001b[49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[43mnew_args\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[43mnew_kwargs\u001b[49m\u001b[43m)\u001b[49m\n\u001b[1;32m 1583\u001b[0m \u001b[38;5;28;01mexcept\u001b[39;00m \u001b[38;5;167;01mException\u001b[39;00m \u001b[38;5;28;01mas\u001b[39;00m e: \u001b[38;5;66;03m# pylint: disable=broad-except\u001b[39;00m\n\u001b[1;32m 1584\u001b[0m err_str \u001b[38;5;241m=\u001b[39m \u001b[38;5;124m'\u001b[39m\u001b[38;5;124m'\u001b[39m\n",
"File \u001b[0;32m/media/wassname/SGIronWolf/projects5/investing/DeepTime/experiments/forecast.py:40\u001b[0m, in \u001b[0;36mForecastExperiment.instance\u001b[0;34m(self, model_type, save_vals)\u001b[0m\n\u001b[1;32m 37\u001b[0m checkpoint \u001b[38;5;241m=\u001b[39m Checkpoint(\u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39mroot)\n\u001b[1;32m 39\u001b[0m \u001b[38;5;66;03m# train forecasting task\u001b[39;00m\n\u001b[0;32m---> 40\u001b[0m model \u001b[38;5;241m=\u001b[39m \u001b[43mtrain\u001b[49m\u001b[43m(\u001b[49m\u001b[43mmodel\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43mcheckpoint\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43mtrain_loader\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43mval_loader\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43mtest_loader\u001b[49m\u001b[43m)\u001b[49m\n\u001b[1;32m 42\u001b[0m \u001b[38;5;66;03m# testing\u001b[39;00m\n\u001b[1;32m 43\u001b[0m val_metrics \u001b[38;5;241m=\u001b[39m validate(model, loader\u001b[38;5;241m=\u001b[39mval_loader, report_metrics\u001b[38;5;241m=\u001b[39m\u001b[38;5;28;01mTrue\u001b[39;00m)\n",
"File \u001b[0;32m~/miniforge3/envs/deeptime/lib/python3.8/site-packages/gin/config.py:1605\u001b[0m, in \u001b[0;36m_make_gin_wrapper.<locals>.gin_wrapper\u001b[0;34m(*args, **kwargs)\u001b[0m\n\u001b[1;32m 1603\u001b[0m scope_info \u001b[38;5;241m=\u001b[39m \u001b[38;5;124m\"\u001b[39m\u001b[38;5;124m in scope \u001b[39m\u001b[38;5;124m'\u001b[39m\u001b[38;5;132;01m{}\u001b[39;00m\u001b[38;5;124m'\u001b[39m\u001b[38;5;124m\"\u001b[39m\u001b[38;5;241m.\u001b[39mformat(scope_str) \u001b[38;5;28;01mif\u001b[39;00m scope_str \u001b[38;5;28;01melse\u001b[39;00m \u001b[38;5;124m'\u001b[39m\u001b[38;5;124m'\u001b[39m\n\u001b[1;32m 1604\u001b[0m err_str \u001b[38;5;241m=\u001b[39m err_str\u001b[38;5;241m.\u001b[39mformat(name, fn_or_cls, scope_info)\n\u001b[0;32m-> 1605\u001b[0m \u001b[43mutils\u001b[49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43maugment_exception_message_and_reraise\u001b[49m\u001b[43m(\u001b[49m\u001b[43me\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43merr_str\u001b[49m\u001b[43m)\u001b[49m\n",
"File \u001b[0;32m~/miniforge3/envs/deeptime/lib/python3.8/site-packages/gin/utils.py:41\u001b[0m, in \u001b[0;36maugment_exception_message_and_reraise\u001b[0;34m(exception, message)\u001b[0m\n\u001b[1;32m 39\u001b[0m proxy \u001b[38;5;241m=\u001b[39m ExceptionProxy()\n\u001b[1;32m 40\u001b[0m ExceptionProxy\u001b[38;5;241m.\u001b[39m\u001b[38;5;18m__qualname__\u001b[39m \u001b[38;5;241m=\u001b[39m \u001b[38;5;28mtype\u001b[39m(exception)\u001b[38;5;241m.\u001b[39m\u001b[38;5;18m__qualname__\u001b[39m\n\u001b[0;32m---> 41\u001b[0m \u001b[38;5;28;01mraise\u001b[39;00m proxy\u001b[38;5;241m.\u001b[39mwith_traceback(exception\u001b[38;5;241m.\u001b[39m__traceback__) \u001b[38;5;28;01mfrom\u001b[39;00m \u001b[38;5;28mNone\u001b[39m\n",
"File \u001b[0;32m~/miniforge3/envs/deeptime/lib/python3.8/site-packages/gin/config.py:1582\u001b[0m, in \u001b[0;36m_make_gin_wrapper.<locals>.gin_wrapper\u001b[0;34m(*args, **kwargs)\u001b[0m\n\u001b[1;32m 1579\u001b[0m new_kwargs\u001b[38;5;241m.\u001b[39mupdate(kwargs)\n\u001b[1;32m 1581\u001b[0m \u001b[38;5;28;01mtry\u001b[39;00m:\n\u001b[0;32m-> 1582\u001b[0m \u001b[38;5;28;01mreturn\u001b[39;00m \u001b[43mfn\u001b[49m\u001b[43m(\u001b[49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[43mnew_args\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[43mnew_kwargs\u001b[49m\u001b[43m)\u001b[49m\n\u001b[1;32m 1583\u001b[0m \u001b[38;5;28;01mexcept\u001b[39;00m \u001b[38;5;167;01mException\u001b[39;00m \u001b[38;5;28;01mas\u001b[39;00m e: \u001b[38;5;66;03m# pylint: disable=broad-except\u001b[39;00m\n\u001b[1;32m 1584\u001b[0m err_str \u001b[38;5;241m=\u001b[39m \u001b[38;5;124m'\u001b[39m\u001b[38;5;124m'\u001b[39m\n",
"File \u001b[0;32m/media/wassname/SGIronWolf/projects5/investing/DeepTime/experiments/forecast.py:144\u001b[0m, in \u001b[0;36mtrain\u001b[0;34m(model, checkpoint, train_loader, val_loader, test_loader, loss_name, epochs, clip)\u001b[0m\n\u001b[1;32m 142\u001b[0m data2 \u001b[38;5;241m=\u001b[39m \u001b[38;5;28mmap\u001b[39m(to_tensor, data)\n\u001b[1;32m 143\u001b[0m context_past_x, context_y, query_past_x, query_y, context_time, query_time \u001b[38;5;241m=\u001b[39m data2\n\u001b[0;32m--> 144\u001b[0m forecast \u001b[38;5;241m=\u001b[39m \u001b[43mmodel\u001b[49m\u001b[43m(\u001b[49m\u001b[43mcontext_past_x\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43mcontext_y\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43mquery_past_x\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43mcontext_time\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43mquery_time\u001b[49m\u001b[43m)\u001b[49m\n\u001b[1;32m 146\u001b[0m \u001b[38;5;28;01mif\u001b[39;00m \u001b[38;5;28misinstance\u001b[39m(forecast, \u001b[38;5;28mtuple\u001b[39m):\n\u001b[1;32m 147\u001b[0m \u001b[38;5;66;03m# for models which require reconstruction + forecast loss\u001b[39;00m\n\u001b[1;32m 148\u001b[0m loss \u001b[38;5;241m=\u001b[39m training_loss_fn(forecast[\u001b[38;5;241m0\u001b[39m], context_y) \u001b[38;5;241m+\u001b[39m \\\n\u001b[1;32m 149\u001b[0m training_loss_fn(forecast[\u001b[38;5;241m1\u001b[39m], query_y)\n",
"File \u001b[0;32m~/miniforge3/envs/deeptime/lib/python3.8/site-packages/torch/nn/modules/module.py:1102\u001b[0m, in \u001b[0;36mModule._call_impl\u001b[0;34m(self, *input, **kwargs)\u001b[0m\n\u001b[1;32m 1098\u001b[0m \u001b[38;5;66;03m# If we don't have any hooks, we want to skip the rest of the logic in\u001b[39;00m\n\u001b[1;32m 1099\u001b[0m \u001b[38;5;66;03m# this function, and just call forward.\u001b[39;00m\n\u001b[1;32m 1100\u001b[0m \u001b[38;5;28;01mif\u001b[39;00m \u001b[38;5;129;01mnot\u001b[39;00m (\u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39m_backward_hooks \u001b[38;5;129;01mor\u001b[39;00m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39m_forward_hooks \u001b[38;5;129;01mor\u001b[39;00m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39m_forward_pre_hooks \u001b[38;5;129;01mor\u001b[39;00m _global_backward_hooks\n\u001b[1;32m 1101\u001b[0m \u001b[38;5;129;01mor\u001b[39;00m _global_forward_hooks \u001b[38;5;129;01mor\u001b[39;00m _global_forward_pre_hooks):\n\u001b[0;32m-> 1102\u001b[0m \u001b[38;5;28;01mreturn\u001b[39;00m \u001b[43mforward_call\u001b[49m\u001b[43m(\u001b[49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[38;5;28;43minput\u001b[39;49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[43mkwargs\u001b[49m\u001b[43m)\u001b[49m\n\u001b[1;32m 1103\u001b[0m \u001b[38;5;66;03m# Do not call functions when jit is used\u001b[39;00m\n\u001b[1;32m 1104\u001b[0m full_backward_hooks, non_full_backward_hooks \u001b[38;5;241m=\u001b[39m [], []\n",
"File \u001b[0;32m/media/wassname/SGIronWolf/projects5/investing/DeepTime/models/DeepTIMe3.py:77\u001b[0m, in \u001b[0;36mDeepTIMe3.forward\u001b[0;34m(self, context_past_x, context_y, query_past_x, context_time, query_time)\u001b[0m\n\u001b[1;32m 74\u001b[0m \u001b[38;5;28;01mdef\u001b[39;00m \u001b[38;5;21mforward\u001b[39m(\u001b[38;5;28mself\u001b[39m, context_past_x, context_y, query_past_x, context_time, query_time) \u001b[38;5;241m-\u001b[39m\u001b[38;5;241m>\u001b[39m Tensor:\n\u001b[1;32m 76\u001b[0m context_reprs \u001b[38;5;241m=\u001b[39m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39mencode_and_decode(context_past_x, context_time)\n\u001b[0;32m---> 77\u001b[0m query_reprs \u001b[38;5;241m=\u001b[39m \u001b[38;5;28;43mself\u001b[39;49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43mencode_and_decode\u001b[49m\u001b[43m(\u001b[49m\u001b[43mquery_past_x\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43mquery_time\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43moffset\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[43mcontext_reprs\u001b[49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43mshape\u001b[49m\u001b[43m[\u001b[49m\u001b[38;5;241;43m1\u001b[39;49m\u001b[43m]\u001b[49m\u001b[43m)\u001b[49m\n\u001b[1;32m 79\u001b[0m w, b \u001b[38;5;241m=\u001b[39m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39madaptive_weights(context_reprs, context_y)\n\u001b[1;32m 80\u001b[0m preds \u001b[38;5;241m=\u001b[39m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39mforecast(query_reprs, w, b)\n",
"File \u001b[0;32m/media/wassname/SGIronWolf/projects5/investing/DeepTime/models/DeepTIMe3.py:71\u001b[0m, in \u001b[0;36mDeepTIMe3.encode_and_decode\u001b[0;34m(self, past_x, time, offset)\u001b[0m\n\u001b[1;32m 68\u001b[0m context_input \u001b[38;5;241m=\u001b[39m torch\u001b[38;5;241m.\u001b[39mcat([encoded_x, coords, time], dim\u001b[38;5;241m=\u001b[39m\u001b[38;5;241m-\u001b[39m\u001b[38;5;241m1\u001b[39m)\n\u001b[1;32m 70\u001b[0m \u001b[38;5;28mprint\u001b[39m(context_input\u001b[38;5;241m.\u001b[39mshape)\n\u001b[0;32m---> 71\u001b[0m context_repr \u001b[38;5;241m=\u001b[39m \u001b[38;5;28;43mself\u001b[39;49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43minr\u001b[49m\u001b[43m(\u001b[49m\u001b[43mcontext_input\u001b[49m\u001b[43m)\u001b[49m\n\u001b[1;32m 72\u001b[0m \u001b[38;5;28;01mreturn\u001b[39;00m context_repr\n",
"File \u001b[0;32m~/miniforge3/envs/deeptime/lib/python3.8/site-packages/torch/nn/modules/module.py:1102\u001b[0m, in \u001b[0;36mModule._call_impl\u001b[0;34m(self, *input, **kwargs)\u001b[0m\n\u001b[1;32m 1098\u001b[0m \u001b[38;5;66;03m# If we don't have any hooks, we want to skip the rest of the logic in\u001b[39;00m\n\u001b[1;32m 1099\u001b[0m \u001b[38;5;66;03m# this function, and just call forward.\u001b[39;00m\n\u001b[1;32m 1100\u001b[0m \u001b[38;5;28;01mif\u001b[39;00m \u001b[38;5;129;01mnot\u001b[39;00m (\u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39m_backward_hooks \u001b[38;5;129;01mor\u001b[39;00m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39m_forward_hooks \u001b[38;5;129;01mor\u001b[39;00m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39m_forward_pre_hooks \u001b[38;5;129;01mor\u001b[39;00m _global_backward_hooks\n\u001b[1;32m 1101\u001b[0m \u001b[38;5;129;01mor\u001b[39;00m _global_forward_hooks \u001b[38;5;129;01mor\u001b[39;00m _global_forward_pre_hooks):\n\u001b[0;32m-> 1102\u001b[0m \u001b[38;5;28;01mreturn\u001b[39;00m \u001b[43mforward_call\u001b[49m\u001b[43m(\u001b[49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[38;5;28;43minput\u001b[39;49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[43mkwargs\u001b[49m\u001b[43m)\u001b[49m\n\u001b[1;32m 1103\u001b[0m \u001b[38;5;66;03m# Do not call functions when jit is used\u001b[39;00m\n\u001b[1;32m 1104\u001b[0m full_backward_hooks, non_full_backward_hooks \u001b[38;5;241m=\u001b[39m [], []\n",
"File \u001b[0;32m/media/wassname/SGIronWolf/projects5/investing/DeepTime/models/modules/inrplus2.py:45\u001b[0m, in \u001b[0;36mINRPlus2.forward\u001b[0;34m(self, x)\u001b[0m\n\u001b[1;32m 43\u001b[0m \u001b[38;5;28;01mif\u001b[39;00m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39mn_fourier_feats\u001b[38;5;241m>\u001b[39m\u001b[38;5;241m0\u001b[39m:\n\u001b[1;32m 44\u001b[0m f \u001b[38;5;241m=\u001b[39m torch\u001b[38;5;241m.\u001b[39mconcat([f, x], \u001b[38;5;241m-\u001b[39m\u001b[38;5;241m1\u001b[39m)\n\u001b[0;32m---> 45\u001b[0m \u001b[38;5;28;01mreturn\u001b[39;00m \u001b[38;5;28;43mself\u001b[39;49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43mlayers\u001b[49m\u001b[43m(\u001b[49m\u001b[43mf\u001b[49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43mpermute\u001b[49m\u001b[43m(\u001b[49m\u001b[43m(\u001b[49m\u001b[38;5;241;43m0\u001b[39;49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[38;5;241;43m2\u001b[39;49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[38;5;241;43m1\u001b[39;49m\u001b[43m)\u001b[49m\u001b[43m)\u001b[49m\u001b[43m)\u001b[49m\u001b[38;5;241m.\u001b[39mpermute((\u001b[38;5;241m0\u001b[39m, \u001b[38;5;241m2\u001b[39m, \u001b[38;5;241m1\u001b[39m))\n",
"File \u001b[0;32m~/miniforge3/envs/deeptime/lib/python3.8/site-packages/torch/nn/modules/module.py:1102\u001b[0m, in \u001b[0;36mModule._call_impl\u001b[0;34m(self, *input, **kwargs)\u001b[0m\n\u001b[1;32m 1098\u001b[0m \u001b[38;5;66;03m# If we don't have any hooks, we want to skip the rest of the logic in\u001b[39;00m\n\u001b[1;32m 1099\u001b[0m \u001b[38;5;66;03m# this function, and just call forward.\u001b[39;00m\n\u001b[1;32m 1100\u001b[0m \u001b[38;5;28;01mif\u001b[39;00m \u001b[38;5;129;01mnot\u001b[39;00m (\u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39m_backward_hooks \u001b[38;5;129;01mor\u001b[39;00m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39m_forward_hooks \u001b[38;5;129;01mor\u001b[39;00m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39m_forward_pre_hooks \u001b[38;5;129;01mor\u001b[39;00m _global_backward_hooks\n\u001b[1;32m 1101\u001b[0m \u001b[38;5;129;01mor\u001b[39;00m _global_forward_hooks \u001b[38;5;129;01mor\u001b[39;00m _global_forward_pre_hooks):\n\u001b[0;32m-> 1102\u001b[0m \u001b[38;5;28;01mreturn\u001b[39;00m \u001b[43mforward_call\u001b[49m\u001b[43m(\u001b[49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[38;5;28;43minput\u001b[39;49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[43mkwargs\u001b[49m\u001b[43m)\u001b[49m\n\u001b[1;32m 1103\u001b[0m \u001b[38;5;66;03m# Do not call functions when jit is used\u001b[39;00m\n\u001b[1;32m 1104\u001b[0m full_backward_hooks, non_full_backward_hooks \u001b[38;5;241m=\u001b[39m [], []\n",
"File \u001b[0;32m~/miniforge3/envs/deeptime/lib/python3.8/site-packages/torch/nn/modules/container.py:141\u001b[0m, in \u001b[0;36mSequential.forward\u001b[0;34m(self, input)\u001b[0m\n\u001b[1;32m 139\u001b[0m \u001b[38;5;28;01mdef\u001b[39;00m \u001b[38;5;21mforward\u001b[39m(\u001b[38;5;28mself\u001b[39m, \u001b[38;5;28minput\u001b[39m):\n\u001b[1;32m 140\u001b[0m \u001b[38;5;28;01mfor\u001b[39;00m module \u001b[38;5;129;01min\u001b[39;00m \u001b[38;5;28mself\u001b[39m:\n\u001b[0;32m--> 141\u001b[0m \u001b[38;5;28minput\u001b[39m \u001b[38;5;241m=\u001b[39m \u001b[43mmodule\u001b[49m\u001b[43m(\u001b[49m\u001b[38;5;28;43minput\u001b[39;49m\u001b[43m)\u001b[49m\n\u001b[1;32m 142\u001b[0m \u001b[38;5;28;01mreturn\u001b[39;00m \u001b[38;5;28minput\u001b[39m\n",
"File \u001b[0;32m~/miniforge3/envs/deeptime/lib/python3.8/site-packages/torch/nn/modules/module.py:1102\u001b[0m, in \u001b[0;36mModule._call_impl\u001b[0;34m(self, *input, **kwargs)\u001b[0m\n\u001b[1;32m 1098\u001b[0m \u001b[38;5;66;03m# If we don't have any hooks, we want to skip the rest of the logic in\u001b[39;00m\n\u001b[1;32m 1099\u001b[0m \u001b[38;5;66;03m# this function, and just call forward.\u001b[39;00m\n\u001b[1;32m 1100\u001b[0m \u001b[38;5;28;01mif\u001b[39;00m \u001b[38;5;129;01mnot\u001b[39;00m (\u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39m_backward_hooks \u001b[38;5;129;01mor\u001b[39;00m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39m_forward_hooks \u001b[38;5;129;01mor\u001b[39;00m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39m_forward_pre_hooks \u001b[38;5;129;01mor\u001b[39;00m _global_backward_hooks\n\u001b[1;32m 1101\u001b[0m \u001b[38;5;129;01mor\u001b[39;00m _global_forward_hooks \u001b[38;5;129;01mor\u001b[39;00m _global_forward_pre_hooks):\n\u001b[0;32m-> 1102\u001b[0m \u001b[38;5;28;01mreturn\u001b[39;00m \u001b[43mforward_call\u001b[49m\u001b[43m(\u001b[49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[38;5;28;43minput\u001b[39;49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[43mkwargs\u001b[49m\u001b[43m)\u001b[49m\n\u001b[1;32m 1103\u001b[0m \u001b[38;5;66;03m# Do not call functions when jit is used\u001b[39;00m\n\u001b[1;32m 1104\u001b[0m full_backward_hooks, non_full_backward_hooks \u001b[38;5;241m=\u001b[39m [], []\n",
"File \u001b[0;32m~/miniforge3/envs/deeptime/lib/python3.8/site-packages/torch/nn/modules/container.py:141\u001b[0m, in \u001b[0;36mSequential.forward\u001b[0;34m(self, input)\u001b[0m\n\u001b[1;32m 139\u001b[0m \u001b[38;5;28;01mdef\u001b[39;00m \u001b[38;5;21mforward\u001b[39m(\u001b[38;5;28mself\u001b[39m, \u001b[38;5;28minput\u001b[39m):\n\u001b[1;32m 140\u001b[0m \u001b[38;5;28;01mfor\u001b[39;00m module \u001b[38;5;129;01min\u001b[39;00m \u001b[38;5;28mself\u001b[39m:\n\u001b[0;32m--> 141\u001b[0m \u001b[38;5;28minput\u001b[39m \u001b[38;5;241m=\u001b[39m \u001b[43mmodule\u001b[49m\u001b[43m(\u001b[49m\u001b[38;5;28;43minput\u001b[39;49m\u001b[43m)\u001b[49m\n\u001b[1;32m 142\u001b[0m \u001b[38;5;28;01mreturn\u001b[39;00m \u001b[38;5;28minput\u001b[39m\n",
"File \u001b[0;32m~/miniforge3/envs/deeptime/lib/python3.8/site-packages/torch/nn/modules/module.py:1102\u001b[0m, in \u001b[0;36mModule._call_impl\u001b[0;34m(self, *input, **kwargs)\u001b[0m\n\u001b[1;32m 1098\u001b[0m \u001b[38;5;66;03m# If we don't have any hooks, we want to skip the rest of the logic in\u001b[39;00m\n\u001b[1;32m 1099\u001b[0m \u001b[38;5;66;03m# this function, and just call forward.\u001b[39;00m\n\u001b[1;32m 1100\u001b[0m \u001b[38;5;28;01mif\u001b[39;00m \u001b[38;5;129;01mnot\u001b[39;00m (\u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39m_backward_hooks \u001b[38;5;129;01mor\u001b[39;00m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39m_forward_hooks \u001b[38;5;129;01mor\u001b[39;00m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39m_forward_pre_hooks \u001b[38;5;129;01mor\u001b[39;00m _global_backward_hooks\n\u001b[1;32m 1101\u001b[0m \u001b[38;5;129;01mor\u001b[39;00m _global_forward_hooks \u001b[38;5;129;01mor\u001b[39;00m _global_forward_pre_hooks):\n\u001b[0;32m-> 1102\u001b[0m \u001b[38;5;28;01mreturn\u001b[39;00m \u001b[43mforward_call\u001b[49m\u001b[43m(\u001b[49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[38;5;28;43minput\u001b[39;49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[43mkwargs\u001b[49m\u001b[43m)\u001b[49m\n\u001b[1;32m 1103\u001b[0m \u001b[38;5;66;03m# Do not call functions when jit is used\u001b[39;00m\n\u001b[1;32m 1104\u001b[0m full_backward_hooks, non_full_backward_hooks \u001b[38;5;241m=\u001b[39m [], []\n",
"File \u001b[0;32m/media/wassname/SGIronWolf/projects5/investing/DeepTime/models/modules/causalinception.py:92\u001b[0m, in \u001b[0;36mInceptionBlockPlus.forward\u001b[0;34m(self, x)\u001b[0m\n\u001b[1;32m 90\u001b[0m \u001b[38;5;28;01mfor\u001b[39;00m i \u001b[38;5;129;01min\u001b[39;00m \u001b[38;5;28mrange\u001b[39m(\u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39mdepth):\n\u001b[1;32m 91\u001b[0m \u001b[38;5;28;01mif\u001b[39;00m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39mkeep_prob[i] \u001b[38;5;241m>\u001b[39m random\u001b[38;5;241m.\u001b[39mrandom() \u001b[38;5;129;01mor\u001b[39;00m \u001b[38;5;129;01mnot\u001b[39;00m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39mtraining:\n\u001b[0;32m---> 92\u001b[0m x \u001b[38;5;241m=\u001b[39m \u001b[38;5;28;43mself\u001b[39;49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43minception\u001b[49m\u001b[43m[\u001b[49m\u001b[43mi\u001b[49m\u001b[43m]\u001b[49m\u001b[43m(\u001b[49m\u001b[43mx\u001b[49m\u001b[43m)\u001b[49m\n\u001b[1;32m 93\u001b[0m \u001b[38;5;28;01mif\u001b[39;00m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39mresidual \u001b[38;5;129;01mand\u001b[39;00m i \u001b[38;5;241m%\u001b[39m \u001b[38;5;241m3\u001b[39m \u001b[38;5;241m==\u001b[39m \u001b[38;5;241m2\u001b[39m:\n\u001b[1;32m 94\u001b[0m res \u001b[38;5;241m=\u001b[39m x \u001b[38;5;241m=\u001b[39m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39mact[i\u001b[38;5;241m/\u001b[39m\u001b[38;5;241m/\u001b[39m\u001b[38;5;241m3\u001b[39m](\u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39madd(x, \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39mshortcut[i\u001b[38;5;241m/\u001b[39m\u001b[38;5;241m/\u001b[39m\u001b[38;5;241m3\u001b[39m](res)))\n",
"File \u001b[0;32m~/miniforge3/envs/deeptime/lib/python3.8/site-packages/torch/nn/modules/module.py:1102\u001b[0m, in \u001b[0;36mModule._call_impl\u001b[0;34m(self, *input, **kwargs)\u001b[0m\n\u001b[1;32m 1098\u001b[0m \u001b[38;5;66;03m# If we don't have any hooks, we want to skip the rest of the logic in\u001b[39;00m\n\u001b[1;32m 1099\u001b[0m \u001b[38;5;66;03m# this function, and just call forward.\u001b[39;00m\n\u001b[1;32m 1100\u001b[0m \u001b[38;5;28;01mif\u001b[39;00m \u001b[38;5;129;01mnot\u001b[39;00m (\u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39m_backward_hooks \u001b[38;5;129;01mor\u001b[39;00m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39m_forward_hooks \u001b[38;5;129;01mor\u001b[39;00m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39m_forward_pre_hooks \u001b[38;5;129;01mor\u001b[39;00m _global_backward_hooks\n\u001b[1;32m 1101\u001b[0m \u001b[38;5;129;01mor\u001b[39;00m _global_forward_hooks \u001b[38;5;129;01mor\u001b[39;00m _global_forward_pre_hooks):\n\u001b[0;32m-> 1102\u001b[0m \u001b[38;5;28;01mreturn\u001b[39;00m \u001b[43mforward_call\u001b[49m\u001b[43m(\u001b[49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[38;5;28;43minput\u001b[39;49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[43mkwargs\u001b[49m\u001b[43m)\u001b[49m\n\u001b[1;32m 1103\u001b[0m \u001b[38;5;66;03m# Do not call functions when jit is used\u001b[39;00m\n\u001b[1;32m 1104\u001b[0m full_backward_hooks, non_full_backward_hooks \u001b[38;5;241m=\u001b[39m [], []\n",
"File \u001b[0;32m/media/wassname/SGIronWolf/projects5/investing/DeepTime/models/modules/causalinception.py:52\u001b[0m, in \u001b[0;36mInceptionModulePlus.forward\u001b[0;34m(self, x)\u001b[0m\n\u001b[1;32m 50\u001b[0m input_tensor \u001b[38;5;241m=\u001b[39m x\n\u001b[1;32m 51\u001b[0m x \u001b[38;5;241m=\u001b[39m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39mbottleneck(x)\n\u001b[0;32m---> 52\u001b[0m x \u001b[38;5;241m=\u001b[39m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39mconcat([l(x) \u001b[38;5;28;01mfor\u001b[39;00m l \u001b[38;5;129;01min\u001b[39;00m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39mconvs] \u001b[38;5;241m+\u001b[39m [\u001b[38;5;28;43mself\u001b[39;49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43mmp_conv\u001b[49m\u001b[43m(\u001b[49m\u001b[43minput_tensor\u001b[49m\u001b[43m)\u001b[49m])\n\u001b[1;32m 53\u001b[0m x \u001b[38;5;241m=\u001b[39m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39mnorm(x)\n\u001b[1;32m 54\u001b[0m x \u001b[38;5;241m=\u001b[39m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39mconv_dropout(x)\n",
"File \u001b[0;32m~/miniforge3/envs/deeptime/lib/python3.8/site-packages/torch/nn/modules/module.py:1102\u001b[0m, in \u001b[0;36mModule._call_impl\u001b[0;34m(self, *input, **kwargs)\u001b[0m\n\u001b[1;32m 1098\u001b[0m \u001b[38;5;66;03m# If we don't have any hooks, we want to skip the rest of the logic in\u001b[39;00m\n\u001b[1;32m 1099\u001b[0m \u001b[38;5;66;03m# this function, and just call forward.\u001b[39;00m\n\u001b[1;32m 1100\u001b[0m \u001b[38;5;28;01mif\u001b[39;00m \u001b[38;5;129;01mnot\u001b[39;00m (\u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39m_backward_hooks \u001b[38;5;129;01mor\u001b[39;00m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39m_forward_hooks \u001b[38;5;129;01mor\u001b[39;00m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39m_forward_pre_hooks \u001b[38;5;129;01mor\u001b[39;00m _global_backward_hooks\n\u001b[1;32m 1101\u001b[0m \u001b[38;5;129;01mor\u001b[39;00m _global_forward_hooks \u001b[38;5;129;01mor\u001b[39;00m _global_forward_pre_hooks):\n\u001b[0;32m-> 1102\u001b[0m \u001b[38;5;28;01mreturn\u001b[39;00m \u001b[43mforward_call\u001b[49m\u001b[43m(\u001b[49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[38;5;28;43minput\u001b[39;49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[43mkwargs\u001b[49m\u001b[43m)\u001b[49m\n\u001b[1;32m 1103\u001b[0m \u001b[38;5;66;03m# Do not call functions when jit is used\u001b[39;00m\n\u001b[1;32m 1104\u001b[0m full_backward_hooks, non_full_backward_hooks \u001b[38;5;241m=\u001b[39m [], []\n",
"File \u001b[0;32m~/miniforge3/envs/deeptime/lib/python3.8/site-packages/torch/nn/modules/container.py:141\u001b[0m, in \u001b[0;36mSequential.forward\u001b[0;34m(self, input)\u001b[0m\n\u001b[1;32m 139\u001b[0m \u001b[38;5;28;01mdef\u001b[39;00m \u001b[38;5;21mforward\u001b[39m(\u001b[38;5;28mself\u001b[39m, \u001b[38;5;28minput\u001b[39m):\n\u001b[1;32m 140\u001b[0m \u001b[38;5;28;01mfor\u001b[39;00m module \u001b[38;5;129;01min\u001b[39;00m \u001b[38;5;28mself\u001b[39m:\n\u001b[0;32m--> 141\u001b[0m \u001b[38;5;28minput\u001b[39m \u001b[38;5;241m=\u001b[39m \u001b[43mmodule\u001b[49m\u001b[43m(\u001b[49m\u001b[38;5;28;43minput\u001b[39;49m\u001b[43m)\u001b[49m\n\u001b[1;32m 142\u001b[0m \u001b[38;5;28;01mreturn\u001b[39;00m \u001b[38;5;28minput\u001b[39m\n",
"File \u001b[0;32m~/miniforge3/envs/deeptime/lib/python3.8/site-packages/torch/nn/modules/module.py:1102\u001b[0m, in \u001b[0;36mModule._call_impl\u001b[0;34m(self, *input, **kwargs)\u001b[0m\n\u001b[1;32m 1098\u001b[0m \u001b[38;5;66;03m# If we don't have any hooks, we want to skip the rest of the logic in\u001b[39;00m\n\u001b[1;32m 1099\u001b[0m \u001b[38;5;66;03m# this function, and just call forward.\u001b[39;00m\n\u001b[1;32m 1100\u001b[0m \u001b[38;5;28;01mif\u001b[39;00m \u001b[38;5;129;01mnot\u001b[39;00m (\u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39m_backward_hooks \u001b[38;5;129;01mor\u001b[39;00m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39m_forward_hooks \u001b[38;5;129;01mor\u001b[39;00m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39m_forward_pre_hooks \u001b[38;5;129;01mor\u001b[39;00m _global_backward_hooks\n\u001b[1;32m 1101\u001b[0m \u001b[38;5;129;01mor\u001b[39;00m _global_forward_hooks \u001b[38;5;129;01mor\u001b[39;00m _global_forward_pre_hooks):\n\u001b[0;32m-> 1102\u001b[0m \u001b[38;5;28;01mreturn\u001b[39;00m \u001b[43mforward_call\u001b[49m\u001b[43m(\u001b[49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[38;5;28;43minput\u001b[39;49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[43mkwargs\u001b[49m\u001b[43m)\u001b[49m\n\u001b[1;32m 1103\u001b[0m \u001b[38;5;66;03m# Do not call functions when jit is used\u001b[39;00m\n\u001b[1;32m 1104\u001b[0m full_backward_hooks, non_full_backward_hooks \u001b[38;5;241m=\u001b[39m [], []\n",
"File \u001b[0;32m~/miniforge3/envs/deeptime/lib/python3.8/site-packages/torch/nn/modules/container.py:141\u001b[0m, in \u001b[0;36mSequential.forward\u001b[0;34m(self, input)\u001b[0m\n\u001b[1;32m 139\u001b[0m \u001b[38;5;28;01mdef\u001b[39;00m \u001b[38;5;21mforward\u001b[39m(\u001b[38;5;28mself\u001b[39m, \u001b[38;5;28minput\u001b[39m):\n\u001b[1;32m 140\u001b[0m \u001b[38;5;28;01mfor\u001b[39;00m module \u001b[38;5;129;01min\u001b[39;00m \u001b[38;5;28mself\u001b[39m:\n\u001b[0;32m--> 141\u001b[0m \u001b[38;5;28minput\u001b[39m \u001b[38;5;241m=\u001b[39m \u001b[43mmodule\u001b[49m\u001b[43m(\u001b[49m\u001b[38;5;28;43minput\u001b[39;49m\u001b[43m)\u001b[49m\n\u001b[1;32m 142\u001b[0m \u001b[38;5;28;01mreturn\u001b[39;00m \u001b[38;5;28minput\u001b[39m\n",
"File \u001b[0;32m~/miniforge3/envs/deeptime/lib/python3.8/site-packages/torch/nn/modules/module.py:1102\u001b[0m, in \u001b[0;36mModule._call_impl\u001b[0;34m(self, *input, **kwargs)\u001b[0m\n\u001b[1;32m 1098\u001b[0m \u001b[38;5;66;03m# If we don't have any hooks, we want to skip the rest of the logic in\u001b[39;00m\n\u001b[1;32m 1099\u001b[0m \u001b[38;5;66;03m# this function, and just call forward.\u001b[39;00m\n\u001b[1;32m 1100\u001b[0m \u001b[38;5;28;01mif\u001b[39;00m \u001b[38;5;129;01mnot\u001b[39;00m (\u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39m_backward_hooks \u001b[38;5;129;01mor\u001b[39;00m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39m_forward_hooks \u001b[38;5;129;01mor\u001b[39;00m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39m_forward_pre_hooks \u001b[38;5;129;01mor\u001b[39;00m _global_backward_hooks\n\u001b[1;32m 1101\u001b[0m \u001b[38;5;129;01mor\u001b[39;00m _global_forward_hooks \u001b[38;5;129;01mor\u001b[39;00m _global_forward_pre_hooks):\n\u001b[0;32m-> 1102\u001b[0m \u001b[38;5;28;01mreturn\u001b[39;00m \u001b[43mforward_call\u001b[49m\u001b[43m(\u001b[49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[38;5;28;43minput\u001b[39;49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[43mkwargs\u001b[49m\u001b[43m)\u001b[49m\n\u001b[1;32m 1103\u001b[0m \u001b[38;5;66;03m# Do not call functions when jit is used\u001b[39;00m\n\u001b[1;32m 1104\u001b[0m full_backward_hooks, non_full_backward_hooks \u001b[38;5;241m=\u001b[39m [], []\n",
"File \u001b[0;32m~/miniforge3/envs/deeptime/lib/python3.8/site-packages/tsai/models/layers.py:148\u001b[0m, in \u001b[0;36mCausalConv1d.forward\u001b[0;34m(self, input)\u001b[0m\n\u001b[1;32m 147\u001b[0m \u001b[38;5;28;01mdef\u001b[39;00m \u001b[38;5;21mforward\u001b[39m(\u001b[38;5;28mself\u001b[39m, \u001b[38;5;28minput\u001b[39m):\n\u001b[0;32m--> 148\u001b[0m \u001b[38;5;28;01mreturn\u001b[39;00m \u001b[38;5;28;43msuper\u001b[39;49m\u001b[43m(\u001b[49m\u001b[43mCausalConv1d\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[38;5;28;43mself\u001b[39;49m\u001b[43m)\u001b[49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43mforward\u001b[49m\u001b[43m(\u001b[49m\u001b[43mF\u001b[49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43mpad\u001b[49m\u001b[43m(\u001b[49m\u001b[38;5;28;43minput\u001b[39;49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43m(\u001b[49m\u001b[38;5;28;43mself\u001b[39;49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43m__padding\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[38;5;241;43m0\u001b[39;49m\u001b[43m)\u001b[49m\u001b[43m)\u001b[49m\u001b[43m)\u001b[49m\n",
"File \u001b[0;32m~/miniforge3/envs/deeptime/lib/python3.8/site-packages/torch/nn/modules/conv.py:301\u001b[0m, in \u001b[0;36mConv1d.forward\u001b[0;34m(self, input)\u001b[0m\n\u001b[1;32m 300\u001b[0m \u001b[38;5;28;01mdef\u001b[39;00m \u001b[38;5;21mforward\u001b[39m(\u001b[38;5;28mself\u001b[39m, \u001b[38;5;28minput\u001b[39m: Tensor) \u001b[38;5;241m-\u001b[39m\u001b[38;5;241m>\u001b[39m Tensor:\n\u001b[0;32m--> 301\u001b[0m \u001b[38;5;28;01mreturn\u001b[39;00m \u001b[38;5;28;43mself\u001b[39;49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43m_conv_forward\u001b[49m\u001b[43m(\u001b[49m\u001b[38;5;28;43minput\u001b[39;49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[38;5;28;43mself\u001b[39;49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43mweight\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[38;5;28;43mself\u001b[39;49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43mbias\u001b[49m\u001b[43m)\u001b[49m\n",
"File \u001b[0;32m~/miniforge3/envs/deeptime/lib/python3.8/site-packages/torch/nn/modules/conv.py:297\u001b[0m, in \u001b[0;36mConv1d._conv_forward\u001b[0;34m(self, input, weight, bias)\u001b[0m\n\u001b[1;32m 293\u001b[0m \u001b[38;5;28;01mif\u001b[39;00m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39mpadding_mode \u001b[38;5;241m!=\u001b[39m \u001b[38;5;124m'\u001b[39m\u001b[38;5;124mzeros\u001b[39m\u001b[38;5;124m'\u001b[39m:\n\u001b[1;32m 294\u001b[0m \u001b[38;5;28;01mreturn\u001b[39;00m F\u001b[38;5;241m.\u001b[39mconv1d(F\u001b[38;5;241m.\u001b[39mpad(\u001b[38;5;28minput\u001b[39m, \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39m_reversed_padding_repeated_twice, mode\u001b[38;5;241m=\u001b[39m\u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39mpadding_mode),\n\u001b[1;32m 295\u001b[0m weight, bias, \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39mstride,\n\u001b[1;32m 296\u001b[0m _single(\u001b[38;5;241m0\u001b[39m), \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39mdilation, \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39mgroups)\n\u001b[0;32m--> 297\u001b[0m \u001b[38;5;28;01mreturn\u001b[39;00m \u001b[43mF\u001b[49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43mconv1d\u001b[49m\u001b[43m(\u001b[49m\u001b[38;5;28;43minput\u001b[39;49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43mweight\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43mbias\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[38;5;28;43mself\u001b[39;49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43mstride\u001b[49m\u001b[43m,\u001b[49m\n\u001b[1;32m 298\u001b[0m \u001b[43m \u001b[49m\u001b[38;5;28;43mself\u001b[39;49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43mpadding\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[38;5;28;43mself\u001b[39;49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43mdilation\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[38;5;28;43mself\u001b[39;49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43mgroups\u001b[49m\u001b[43m)\u001b[49m\n",
"File \u001b[0;32m~/miniforge3/envs/deeptime/lib/python3.8/site-packages/fastai/torch_core.py:378\u001b[0m, in \u001b[0;36mTensorBase.__torch_function__\u001b[0;34m(cls, func, types, args, kwargs)\u001b[0m\n\u001b[1;32m 376\u001b[0m \u001b[38;5;28;01mif\u001b[39;00m \u001b[38;5;28mcls\u001b[39m\u001b[38;5;241m.\u001b[39mdebug \u001b[38;5;129;01mand\u001b[39;00m func\u001b[38;5;241m.\u001b[39m\u001b[38;5;18m__name__\u001b[39m \u001b[38;5;129;01mnot\u001b[39;00m \u001b[38;5;129;01min\u001b[39;00m (\u001b[38;5;124m'\u001b[39m\u001b[38;5;124m__str__\u001b[39m\u001b[38;5;124m'\u001b[39m,\u001b[38;5;124m'\u001b[39m\u001b[38;5;124m__repr__\u001b[39m\u001b[38;5;124m'\u001b[39m): \u001b[38;5;28mprint\u001b[39m(func, types, args, kwargs)\n\u001b[1;32m 377\u001b[0m \u001b[38;5;28;01mif\u001b[39;00m _torch_handled(args, \u001b[38;5;28mcls\u001b[39m\u001b[38;5;241m.\u001b[39m_opt, func): types \u001b[38;5;241m=\u001b[39m (torch\u001b[38;5;241m.\u001b[39mTensor,)\n\u001b[0;32m--> 378\u001b[0m res \u001b[38;5;241m=\u001b[39m \u001b[38;5;28;43msuper\u001b[39;49m\u001b[43m(\u001b[49m\u001b[43m)\u001b[49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43m__torch_function__\u001b[49m\u001b[43m(\u001b[49m\u001b[43mfunc\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43mtypes\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43margs\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43mifnone\u001b[49m\u001b[43m(\u001b[49m\u001b[43mkwargs\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43m{\u001b[49m\u001b[43m}\u001b[49m\u001b[43m)\u001b[49m\u001b[43m)\u001b[49m\n\u001b[1;32m 379\u001b[0m dict_objs \u001b[38;5;241m=\u001b[39m _find_args(args) \u001b[38;5;28;01mif\u001b[39;00m args \u001b[38;5;28;01melse\u001b[39;00m _find_args(\u001b[38;5;28mlist\u001b[39m(kwargs\u001b[38;5;241m.\u001b[39mvalues()))\n\u001b[1;32m 380\u001b[0m \u001b[38;5;28;01mif\u001b[39;00m \u001b[38;5;28missubclass\u001b[39m(\u001b[38;5;28mtype\u001b[39m(res),TensorBase) \u001b[38;5;129;01mand\u001b[39;00m dict_objs: res\u001b[38;5;241m.\u001b[39mset_meta(dict_objs[\u001b[38;5;241m0\u001b[39m],as_copy\u001b[38;5;241m=\u001b[39m\u001b[38;5;28;01mTrue\u001b[39;00m)\n",
"File \u001b[0;32m~/miniforge3/envs/deeptime/lib/python3.8/site-packages/torch/_tensor.py:1051\u001b[0m, in \u001b[0;36mTensor.__torch_function__\u001b[0;34m(cls, func, types, args, kwargs)\u001b[0m\n\u001b[1;32m 1048\u001b[0m \u001b[38;5;28;01mreturn\u001b[39;00m \u001b[38;5;28mNotImplemented\u001b[39m\n\u001b[1;32m 1050\u001b[0m \u001b[38;5;28;01mwith\u001b[39;00m _C\u001b[38;5;241m.\u001b[39mDisableTorchFunction():\n\u001b[0;32m-> 1051\u001b[0m ret \u001b[38;5;241m=\u001b[39m \u001b[43mfunc\u001b[49m\u001b[43m(\u001b[49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[43margs\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[43mkwargs\u001b[49m\u001b[43m)\u001b[49m\n\u001b[1;32m 1052\u001b[0m \u001b[38;5;28;01mif\u001b[39;00m func \u001b[38;5;129;01min\u001b[39;00m get_default_nowrap_functions():\n\u001b[1;32m 1053\u001b[0m \u001b[38;5;28;01mreturn\u001b[39;00m ret\n",
"\u001b[0;31mRuntimeError\u001b[0m: CUDA out of memory. Tried to allocate 24.00 MiB (GPU 0; 10.74 GiB total capacity; 8.00 GiB already allocated; 50.12 MiB free; 8.16 GiB reserved in total by PyTorch) If reserved memory is >> allocated memory try setting max_split_size_mb to avoid fragmentation. See documentation for Memory Management and PYTORCH_CUDA_ALLOC_CONF\n In call to configurable 'train' (<function train at 0x7f4045284ee0>)\n In call to configurable 'instance' (<function ForecastExperiment.instance at 0x7f4045284550>)\n In call to configurable 'run' (<function Experiment.run at 0x7f4092a49550>)"
]
}
],
"source": [
"exp.run()"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "66b15b6a",
"metadata": {
"ExecuteTime": {
"end_time": "2022-11-22T08:32:56.784632Z",
"start_time": "2022-11-22T08:32:56.784625Z"
}
},
"outputs": [],
"source": [
"%debug"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "374d6ca6",
"metadata": {},
"outputs": [],
"source": []
}
],
"metadata": {
"jupytext": {
"cell_metadata_filter": "-all",
"main_language": "python",
"notebook_metadata_filter": "-all"
},
"kernelspec": {
"display_name": "deeptime",
"language": "python",
"name": "deeptime"
},
"language_info": {
"codemirror_mode": {
"name": "ipython",
"version": 3
},
"file_extension": ".py",
"mimetype": "text/x-python",
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.8.13"
},
"toc": {
"base_numbering": 1,
"nav_menu": {},
"number_sections": true,
"sideBar": true,
"skip_h1_title": false,
"title_cell": "Table of Contents",
"title_sidebar": "Contents",
"toc_cell": false,
"toc_position": {},
"toc_section_display": true,
"toc_window_display": false
}
},
"nbformat": 4,
"nbformat_minor": 5
}