mirror of
https://github.com/wassname/ETSformer.git
synced 2026-06-27 16:59:02 +08:00
running in nb
This commit is contained in:
+89
@@ -0,0 +1,89 @@
|
||||
/notebooks/checkpoints/
|
||||
/checkpoints/
|
||||
|
||||
# Byte-compiled / optimized / DLL files
|
||||
__pycache__/
|
||||
*.py[cod]
|
||||
|
||||
# C extensions
|
||||
*.so
|
||||
|
||||
# Distribution / packaging
|
||||
.Python
|
||||
env/
|
||||
build/
|
||||
develop-eggs/
|
||||
dist/
|
||||
downloads/
|
||||
eggs/
|
||||
.eggs/
|
||||
lib/
|
||||
lib64/
|
||||
parts/
|
||||
sdist/
|
||||
var/
|
||||
*.egg-info/
|
||||
.installed.cfg
|
||||
*.egg
|
||||
|
||||
# PyInstaller
|
||||
# Usually these files are written by a python script from a template
|
||||
# before PyInstaller builds the exe, so as to inject date/other infos into it.
|
||||
*.manifest
|
||||
*.spec
|
||||
|
||||
# Installer logs
|
||||
pip-log.txt
|
||||
pip-delete-this-directory.txt
|
||||
|
||||
# Unit test / coverage reports
|
||||
htmlcov/
|
||||
.tox/
|
||||
.coverage
|
||||
.coverage.*
|
||||
.cache
|
||||
nosetests.xml
|
||||
coverage.xml
|
||||
*.cover
|
||||
|
||||
# Translations
|
||||
*.mo
|
||||
*.pot
|
||||
|
||||
# Django stuff:
|
||||
*.log
|
||||
|
||||
# Sphinx documentation
|
||||
docs/_build/
|
||||
|
||||
# PyBuilder
|
||||
target/
|
||||
|
||||
# DotEnv configuration
|
||||
.env
|
||||
|
||||
# Database
|
||||
*.db
|
||||
*.rdb
|
||||
|
||||
# Pycharm
|
||||
.idea
|
||||
|
||||
# VS Code
|
||||
.vscode/
|
||||
|
||||
# Spyder
|
||||
.spyproject/
|
||||
|
||||
# Jupyter NB Checkpoints
|
||||
.ipynb_checkpoints/
|
||||
|
||||
# exclude data from source control by default
|
||||
/data/
|
||||
|
||||
# Mac OS-specific storage files
|
||||
.DS_Store
|
||||
|
||||
# vim
|
||||
*.swp
|
||||
*.swo
|
||||
@@ -1,3 +1,6 @@
|
||||
Fork info: trying out and plotting the results on stock data.----------
|
||||
|
||||
|
||||
# ETSformer: Exponential Smoothing Transformers for Time-series Forecasting
|
||||
|
||||
<p align="center">
|
||||
@@ -8,6 +11,7 @@
|
||||
|
||||
Official PyTorch code repository for the [ETSformer paper](https://arxiv.org/abs/2202.01381). Check out our [blog post](https://blog.salesforceairesearch.com/etsformer-time-series-forecasting/)!
|
||||
|
||||
|
||||
* ETSformer is a novel time-series Transformer architecture which exploits the principle of exponential smoothing in improving
|
||||
Transformers for timeseries forecasting.
|
||||
* ETSformer is inspired by the classical exponential smoothing methods in
|
||||
|
||||
@@ -1,4 +1,6 @@
|
||||
|
||||
# Ignore everything in this directory
|
||||
*
|
||||
# Except this file
|
||||
!.gitignore
|
||||
|
||||
|
||||
@@ -1 +1,7 @@
|
||||
`conda activate deeptime`
|
||||
|
||||
|
||||
TODO:
|
||||
- [ ] run on stocks data
|
||||
- [ ] multivariate in, univariate out
|
||||
- [ ] graph
|
||||
|
||||
@@ -0,0 +1,433 @@
|
||||
{
|
||||
"cells": [
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 3,
|
||||
"id": "cbdce1e4",
|
||||
"metadata": {
|
||||
"ExecuteTime": {
|
||||
"end_time": "2022-11-28T04:13:58.128994Z",
|
||||
"start_time": "2022-11-28T04:13:58.127086Z"
|
||||
}
|
||||
},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"import os\n",
|
||||
"os.sys.path.append('..')"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 4,
|
||||
"id": "1640aab2",
|
||||
"metadata": {
|
||||
"ExecuteTime": {
|
||||
"end_time": "2022-11-28T04:13:58.305698Z",
|
||||
"start_time": "2022-11-28T04:13:58.287564Z"
|
||||
}
|
||||
},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"import warnings\n",
|
||||
"warnings.simplefilter(\"ignore\")\n",
|
||||
"\n",
|
||||
"# autoreload import your package\n",
|
||||
"%load_ext autoreload\n",
|
||||
"%autoreload 2"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 9,
|
||||
"id": "649ddfa9",
|
||||
"metadata": {
|
||||
"ExecuteTime": {
|
||||
"end_time": "2022-11-28T04:14:16.865244Z",
|
||||
"start_time": "2022-11-28T04:14:16.858147Z"
|
||||
}
|
||||
},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"from loguru import logger\n",
|
||||
"logger.remove()\n",
|
||||
"logger.add(os.sys.stdout, level=\"ERROR\", colorize=True, format=\"<level>{time} | {message}</level>\")\n",
|
||||
"# import_dir(ta_dir, verbose=False)\n",
|
||||
"warnings.simplefilter(\"ignore\")"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"id": "bc96a9a7",
|
||||
"metadata": {},
|
||||
"source": [
|
||||
"# Args"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 23,
|
||||
"id": "7fcd9d4f",
|
||||
"metadata": {
|
||||
"ExecuteTime": {
|
||||
"end_time": "2022-11-28T04:18:43.064910Z",
|
||||
"start_time": "2022-11-28T04:18:43.050776Z"
|
||||
}
|
||||
},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"from run import set_seed, get_args, Exp_Main"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 51,
|
||||
"id": "0602682d",
|
||||
"metadata": {
|
||||
"ExecuteTime": {
|
||||
"end_time": "2022-11-28T05:00:16.840051Z",
|
||||
"start_time": "2022-11-28T05:00:16.816918Z"
|
||||
}
|
||||
},
|
||||
"outputs": [
|
||||
{
|
||||
"name": "stdout",
|
||||
"output_type": "stream",
|
||||
"text": [
|
||||
"Args in experiment:\n",
|
||||
"Namespace(K=0, activation='sigmoid', batch_size=32, c_out=12, checkpoints='./checkpoints/', d_ff=2048, d_layers=2, d_model=512, damping_learning_rate=0, data='custom', data_path='OXY_2019.csv.gz', dec_in=12, des=\"'Exp'\", devices='0,1,2,3', dropout=0.2, e_layers=2, embed='timeF', enc_in=12, features='M', freq='h', gpu=0, itr=1, label_len=0, learning_rate=0.001, lradj='exponential_with_warmup', min_lr=1e-30, model='ETSformer', model_id='Exchange', n_heads=8, num_workers=10, optim='adam', output_attention=False, patience=5, pred_len=48, root_path='../dataset/stocks/', seq_len=128, smoothing_learning_rate=0, std=0.2, target='RSMKs_18_144_72_2ref_2ref', train_epochs=15, use_gpu=True, use_multi_gpu=False, warmup_epochs=3)\n"
|
||||
]
|
||||
},
|
||||
{
|
||||
"data": {
|
||||
"text/plain": [
|
||||
"Namespace(K=0, activation='sigmoid', batch_size=32, c_out=12, checkpoints='./checkpoints/', d_ff=2048, d_layers=2, d_model=512, damping_learning_rate=0, data='custom', data_path='OXY_2019.csv.gz', dec_in=12, des=\"'Exp'\", devices='0,1,2,3', dropout=0.2, e_layers=2, embed='timeF', enc_in=12, features='M', freq='h', gpu=0, itr=1, label_len=0, learning_rate=0.001, lradj='exponential_with_warmup', min_lr=1e-30, model='ETSformer', model_id='Exchange', n_heads=8, num_workers=10, optim='adam', output_attention=False, patience=5, pred_len=48, root_path='../dataset/stocks/', seq_len=128, smoothing_learning_rate=0, std=0.2, target='RSMKs_18_144_72_2ref_2ref', train_epochs=15, use_gpu=True, use_multi_gpu=False, warmup_epochs=3)"
|
||||
]
|
||||
},
|
||||
"execution_count": 51,
|
||||
"metadata": {},
|
||||
"output_type": "execute_result"
|
||||
}
|
||||
],
|
||||
"source": [
|
||||
"# mimic cli args to avoid code duplication\n",
|
||||
"argv = \"\"\"python -u run.py \\\n",
|
||||
" --root_path ../dataset/stocks/ \\\n",
|
||||
" --data_path OXY_2019.csv.gz \\\n",
|
||||
" --checkpoints ./checkpoints/ \\\n",
|
||||
" --model_id Exchange \\\n",
|
||||
" --model ETSformer \\\n",
|
||||
" --data custom \\\n",
|
||||
" --features M \\\n",
|
||||
" --seq_len 128 \\\n",
|
||||
" --pred_len 48 \\\n",
|
||||
" --e_layers 2 \\\n",
|
||||
" --d_layers 2 \\\n",
|
||||
" --enc_in 12 \\\n",
|
||||
" --dec_in 12 \\\n",
|
||||
" --c_out 12 \\\n",
|
||||
" --des 'Exp' \\\n",
|
||||
" --K 0 \\\n",
|
||||
" --learning_rate 1e-3 \\\n",
|
||||
" --target RSMKs_18_144_72_2ref_2ref \\\n",
|
||||
" --itr 1\n",
|
||||
"\"\"\"\n",
|
||||
"argv = argv.replace(\"\\\\n\", \"\").split()[3:]\n",
|
||||
"args = get_args(argv)\n",
|
||||
"args"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"id": "7367a3cb",
|
||||
"metadata": {
|
||||
"ExecuteTime": {
|
||||
"end_time": "2022-11-28T04:24:49.799842Z",
|
||||
"start_time": "2022-11-28T04:24:49.783939Z"
|
||||
}
|
||||
},
|
||||
"outputs": [],
|
||||
"source": []
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"id": "93ed3ca8",
|
||||
"metadata": {
|
||||
"ExecuteTime": {
|
||||
"start_time": "2022-11-28T05:00:17.097Z"
|
||||
}
|
||||
},
|
||||
"outputs": [
|
||||
{
|
||||
"name": "stdout",
|
||||
"output_type": "stream",
|
||||
"text": [
|
||||
"0\n",
|
||||
"Use GPU: cuda:0\n",
|
||||
">>>>>>>start training : Exchange_ETSformer_custom_ftM_sl128_pl48_dm512_nh8_el2_dl2_df2048_K0_lr0.001_'Exp'_0>>>>>>>>>>>>>>>>>>>>>>>>>>\n",
|
||||
"train 37203\n",
|
||||
"val 5294\n",
|
||||
"test 10632\n",
|
||||
"\titers: 100, epoch: 1 | loss: 0.4932106\n",
|
||||
"\tspeed: 0.0565s/iter; left time: 978.9201s\n",
|
||||
"\titers: 200, epoch: 1 | loss: 0.4972929\n",
|
||||
"\tspeed: 0.0485s/iter; left time: 835.0221s\n",
|
||||
"\titers: 300, epoch: 1 | loss: 0.4357589\n",
|
||||
"\tspeed: 0.0476s/iter; left time: 815.1622s\n",
|
||||
"\titers: 400, epoch: 1 | loss: 0.5983973\n",
|
||||
"\tspeed: 0.0477s/iter; left time: 811.9901s\n",
|
||||
"\titers: 500, epoch: 1 | loss: 0.6287826\n",
|
||||
"\tspeed: 0.0484s/iter; left time: 820.1333s\n",
|
||||
"\titers: 600, epoch: 1 | loss: 0.4684563\n",
|
||||
"\tspeed: 0.0488s/iter; left time: 821.1888s\n",
|
||||
"\titers: 700, epoch: 1 | loss: 1.0543735\n",
|
||||
"\tspeed: 0.0484s/iter; left time: 810.2026s\n",
|
||||
"\titers: 800, epoch: 1 | loss: 0.7420598\n",
|
||||
"\tspeed: 0.0474s/iter; left time: 788.7177s\n",
|
||||
"\titers: 900, epoch: 1 | loss: 0.5108218\n",
|
||||
"\tspeed: 0.0474s/iter; left time: 784.3583s\n",
|
||||
"\titers: 1000, epoch: 1 | loss: 0.5239242\n",
|
||||
"\tspeed: 0.0505s/iter; left time: 830.0246s\n",
|
||||
"\titers: 1100, epoch: 1 | loss: 0.8626059\n",
|
||||
"\tspeed: 0.0468s/iter; left time: 764.9884s\n",
|
||||
"Epoch: 1 cost time: 56.83964204788208\n",
|
||||
"Epoch: 1, Steps: 1162 | Train Loss: 0.6849515 Vali Loss: 0.3407791 Test Loss: 0.4381559\n",
|
||||
"Validation loss decreased (inf --> 0.340779). Saving model ...\n",
|
||||
"Updating learning rate to 0.00025\n",
|
||||
"\titers: 100, epoch: 2 | loss: 0.9436644\n",
|
||||
"\tspeed: 0.1952s/iter; left time: 3155.9649s\n",
|
||||
"\titers: 200, epoch: 2 | loss: 0.5133687\n",
|
||||
"\tspeed: 0.0487s/iter; left time: 783.2851s\n",
|
||||
"\titers: 300, epoch: 2 | loss: 0.6383107\n",
|
||||
"\tspeed: 0.0482s/iter; left time: 770.3673s\n",
|
||||
"\titers: 400, epoch: 2 | loss: 1.1335649\n",
|
||||
"\tspeed: 0.0492s/iter; left time: 781.1069s\n",
|
||||
"\titers: 500, epoch: 2 | loss: 0.5905243\n",
|
||||
"\tspeed: 0.0467s/iter; left time: 736.0656s\n",
|
||||
"\titers: 600, epoch: 2 | loss: 0.7256416\n",
|
||||
"\tspeed: 0.0474s/iter; left time: 742.6376s\n",
|
||||
"\titers: 700, epoch: 2 | loss: 0.5506995\n",
|
||||
"\tspeed: 0.0478s/iter; left time: 744.2372s\n",
|
||||
"\titers: 800, epoch: 2 | loss: 0.6441094\n",
|
||||
"\tspeed: 0.0487s/iter; left time: 753.5581s\n",
|
||||
"\titers: 900, epoch: 2 | loss: 0.3798619\n",
|
||||
"\tspeed: 0.0487s/iter; left time: 748.4489s\n",
|
||||
"\titers: 1000, epoch: 2 | loss: 0.4259658\n",
|
||||
"\tspeed: 0.0479s/iter; left time: 731.2631s\n",
|
||||
"\titers: 1100, epoch: 2 | loss: 0.5248989\n",
|
||||
"\tspeed: 0.0491s/iter; left time: 745.1800s\n",
|
||||
"Epoch: 2 cost time: 56.406978130340576\n",
|
||||
"Epoch: 2, Steps: 1162 | Train Loss: 0.6076249 Vali Loss: 0.3033751 Test Loss: 0.3955314\n",
|
||||
"Validation loss decreased (0.340779 --> 0.303375). Saving model ...\n",
|
||||
"Updating learning rate to 0.0005\n",
|
||||
"\titers: 100, epoch: 3 | loss: 0.4865443\n",
|
||||
"\tspeed: 0.2002s/iter; left time: 3003.6628s\n",
|
||||
"\titers: 200, epoch: 3 | loss: 0.5826065\n",
|
||||
"\tspeed: 0.0493s/iter; left time: 734.4833s\n",
|
||||
"\titers: 300, epoch: 3 | loss: 1.2626467\n",
|
||||
"\tspeed: 0.0480s/iter; left time: 710.5487s\n",
|
||||
"\titers: 400, epoch: 3 | loss: 0.4687002\n",
|
||||
"\tspeed: 0.0472s/iter; left time: 694.2853s\n",
|
||||
"\titers: 500, epoch: 3 | loss: 0.3959515\n",
|
||||
"\tspeed: 0.0474s/iter; left time: 693.0790s\n",
|
||||
"\titers: 600, epoch: 3 | loss: 0.4113244\n",
|
||||
"\tspeed: 0.0490s/iter; left time: 710.3280s\n",
|
||||
"\titers: 700, epoch: 3 | loss: 0.4114936\n",
|
||||
"\tspeed: 0.0476s/iter; left time: 685.9448s\n",
|
||||
"\titers: 800, epoch: 3 | loss: 0.5290235\n",
|
||||
"\tspeed: 0.0475s/iter; left time: 679.7407s\n",
|
||||
"\titers: 900, epoch: 3 | loss: 0.4256401\n",
|
||||
"\tspeed: 0.0478s/iter; left time: 678.8945s\n",
|
||||
"\titers: 1000, epoch: 3 | loss: 0.6996748\n",
|
||||
"\tspeed: 0.0491s/iter; left time: 692.3603s\n",
|
||||
"\titers: 1100, epoch: 3 | loss: 0.7242295\n",
|
||||
"\tspeed: 0.0470s/iter; left time: 658.2708s\n",
|
||||
"Epoch: 3 cost time: 56.2166702747345\n",
|
||||
"Epoch: 3, Steps: 1162 | Train Loss: 0.5938104 Vali Loss: 0.3029587 Test Loss: 0.3915149\n",
|
||||
"Validation loss decreased (0.303375 --> 0.302959). Saving model ...\n",
|
||||
"Updating learning rate to 0.00075\n",
|
||||
"\titers: 100, epoch: 4 | loss: 0.4237639\n",
|
||||
"\tspeed: 0.1944s/iter; left time: 2691.7509s\n",
|
||||
"\titers: 200, epoch: 4 | loss: 0.3838455\n",
|
||||
"\tspeed: 0.0469s/iter; left time: 644.7742s\n",
|
||||
"\titers: 300, epoch: 4 | loss: 0.5273813\n",
|
||||
"\tspeed: 0.0465s/iter; left time: 635.0866s\n",
|
||||
"\titers: 400, epoch: 4 | loss: 1.2923048\n",
|
||||
"\tspeed: 0.0482s/iter; left time: 652.7872s\n",
|
||||
"\titers: 500, epoch: 4 | loss: 0.6386693\n",
|
||||
"\tspeed: 0.0488s/iter; left time: 656.6116s\n",
|
||||
"\titers: 600, epoch: 4 | loss: 0.5111505\n",
|
||||
"\tspeed: 0.0483s/iter; left time: 644.4140s\n",
|
||||
"\titers: 700, epoch: 4 | loss: 0.5771207\n",
|
||||
"\tspeed: 0.0470s/iter; left time: 622.1141s\n",
|
||||
"\titers: 800, epoch: 4 | loss: 0.6522042\n",
|
||||
"\tspeed: 0.0471s/iter; left time: 619.3240s\n",
|
||||
"\titers: 900, epoch: 4 | loss: 0.3418834\n",
|
||||
"\tspeed: 0.0474s/iter; left time: 618.8682s\n",
|
||||
"\titers: 1000, epoch: 4 | loss: 0.4811396\n",
|
||||
"\tspeed: 0.0470s/iter; left time: 608.9635s\n",
|
||||
"\titers: 1100, epoch: 4 | loss: 0.6203826\n",
|
||||
"\tspeed: 0.0476s/iter; left time: 611.6955s\n",
|
||||
"Epoch: 4 cost time: 55.756075620651245\n",
|
||||
"Epoch: 4, Steps: 1162 | Train Loss: 0.5855654 Vali Loss: 0.3044514 Test Loss: 0.3943208\n",
|
||||
"EarlyStopping counter: 1 out of 5\n",
|
||||
"Updating learning rate to 0.001\n",
|
||||
"\titers: 100, epoch: 5 | loss: 0.7008136\n",
|
||||
"\tspeed: 0.1930s/iter; left time: 2447.9230s\n",
|
||||
"\titers: 200, epoch: 5 | loss: 0.7495702\n",
|
||||
"\tspeed: 0.0493s/iter; left time: 620.2588s\n",
|
||||
"\titers: 300, epoch: 5 | loss: 1.0079670\n",
|
||||
"\tspeed: 0.0487s/iter; left time: 608.5033s\n",
|
||||
"\titers: 400, epoch: 5 | loss: 0.3510064\n",
|
||||
"\tspeed: 0.0482s/iter; left time: 596.2644s\n",
|
||||
"\titers: 500, epoch: 5 | loss: 0.4754097\n",
|
||||
"\tspeed: 0.0474s/iter; left time: 581.7234s\n",
|
||||
"\titers: 600, epoch: 5 | loss: 0.7314481\n",
|
||||
"\tspeed: 0.0493s/iter; left time: 600.1122s\n",
|
||||
"\titers: 700, epoch: 5 | loss: 0.8495453\n",
|
||||
"\tspeed: 0.0483s/iter; left time: 583.0922s\n",
|
||||
"\titers: 800, epoch: 5 | loss: 0.3931953\n",
|
||||
"\tspeed: 0.0474s/iter; left time: 567.6057s\n",
|
||||
"\titers: 900, epoch: 5 | loss: 0.3912098\n",
|
||||
"\tspeed: 0.0488s/iter; left time: 580.1831s\n",
|
||||
"\titers: 1000, epoch: 5 | loss: 0.4786185\n",
|
||||
"\tspeed: 0.0483s/iter; left time: 568.7581s\n",
|
||||
"\titers: 1100, epoch: 5 | loss: 0.6882768\n",
|
||||
"\tspeed: 0.0471s/iter; left time: 550.3884s\n",
|
||||
"Epoch: 5 cost time: 56.28184771537781\n",
|
||||
"Epoch: 5, Steps: 1162 | Train Loss: 0.5815187 Vali Loss: 0.3013998 Test Loss: 0.4001175\n",
|
||||
"Validation loss decreased (0.302959 --> 0.301400). Saving model ...\n",
|
||||
"Updating learning rate to 0.0005\n"
|
||||
]
|
||||
}
|
||||
],
|
||||
"source": [
|
||||
"\n",
|
||||
"Exp = Exp_Main\n",
|
||||
"\n",
|
||||
"for ii in range(args.itr):\n",
|
||||
" print(ii)\n",
|
||||
" set_seed(ii)\n",
|
||||
" # setting record of experiments\n",
|
||||
" setting = '{}_{}_{}_ft{}_sl{}_pl{}_dm{}_nh{}_el{}_dl{}_df{}_K{}_lr{}_{}_{}'.format(\n",
|
||||
" args.model_id,\n",
|
||||
" args.model,\n",
|
||||
" args.data,\n",
|
||||
" args.features,\n",
|
||||
" args.seq_len,\n",
|
||||
" args.pred_len,\n",
|
||||
" args.d_model,\n",
|
||||
" args.n_heads,\n",
|
||||
" args.e_layers,\n",
|
||||
" args.d_layers,\n",
|
||||
" args.d_ff,\n",
|
||||
" args.K,\n",
|
||||
" args.learning_rate,\n",
|
||||
" args.des, ii)\n",
|
||||
"\n",
|
||||
"# if os.path.exists(os.path.join(args.checkpoints, setting)):\n",
|
||||
"# print('skipping exists')\n",
|
||||
"# continue\n",
|
||||
"\n",
|
||||
" exp = Exp(args) # set experiments\n",
|
||||
" print('>>>>>>>start training : {}>>>>>>>>>>>>>>>>>>>>>>>>>>'.format(setting))\n",
|
||||
" exp.train(setting)\n",
|
||||
"\n",
|
||||
" print('>>>>>>>testing : {}<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<'.format(setting))\n",
|
||||
" exp.test(setting, data='val')\n",
|
||||
" exp.test(setting, data='test')\n",
|
||||
"\n",
|
||||
" torch.cuda.empty_cache()\n"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"id": "74745d73",
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": []
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 49,
|
||||
"id": "2ed3cbf1",
|
||||
"metadata": {
|
||||
"ExecuteTime": {
|
||||
"end_time": "2022-11-28T04:48:15.407951Z",
|
||||
"start_time": "2022-11-28T04:48:13.634492Z"
|
||||
}
|
||||
},
|
||||
"outputs": [
|
||||
{
|
||||
"name": "stdout",
|
||||
"output_type": "stream",
|
||||
"text": [
|
||||
"> \u001b[0;32m/media/wassname/SGIronWolf/projects5/investing/ETSformer/models/etsformer/encoder.py\u001b[0m(155)\u001b[0;36mforward\u001b[0;34m()\u001b[0m\n",
|
||||
"\u001b[0;32m 153 \u001b[0;31m \u001b[0mgrowth\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mgrowth\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mview\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mb\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mt\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mc_out\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0;36m1\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n",
|
||||
"\u001b[0m\u001b[0;32m 154 \u001b[0;31m \u001b[0mseason\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mseason\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mview\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mb\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mt\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mc_out\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0;36m1\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n",
|
||||
"\u001b[0m\u001b[0;32m--> 155 \u001b[0;31m \u001b[0mlevel\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mlevel\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mview\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mb\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mt\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mc_out\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0;36m1\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n",
|
||||
"\u001b[0m\u001b[0;32m 156 \u001b[0;31m \u001b[0mout\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mes\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mlevel\u001b[0m \u001b[0;34m-\u001b[0m \u001b[0mseason\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0maux_values\u001b[0m\u001b[0;34m=\u001b[0m\u001b[0mgrowth\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n",
|
||||
"\u001b[0m\u001b[0;32m 157 \u001b[0;31m \u001b[0mout\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mrearrange\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mout\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0;34m'b t h d -> b t (h d)'\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n",
|
||||
"\u001b[0m\n",
|
||||
"ipdb> q\n"
|
||||
]
|
||||
}
|
||||
],
|
||||
"source": [
|
||||
"%debug"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"id": "2faf8241",
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": []
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"id": "609a7018",
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": []
|
||||
}
|
||||
],
|
||||
"metadata": {
|
||||
"kernelspec": {
|
||||
"display_name": "deeptime",
|
||||
"language": "python",
|
||||
"name": "deeptime"
|
||||
},
|
||||
"language_info": {
|
||||
"codemirror_mode": {
|
||||
"name": "ipython",
|
||||
"version": 3
|
||||
},
|
||||
"file_extension": ".py",
|
||||
"mimetype": "text/x-python",
|
||||
"name": "python",
|
||||
"nbconvert_exporter": "python",
|
||||
"pygments_lexer": "ipython3",
|
||||
"version": "3.8.13"
|
||||
},
|
||||
"toc": {
|
||||
"base_numbering": 1,
|
||||
"nav_menu": {},
|
||||
"number_sections": true,
|
||||
"sideBar": true,
|
||||
"skip_h1_title": false,
|
||||
"title_cell": "Table of Contents",
|
||||
"title_sidebar": "Contents",
|
||||
"toc_cell": false,
|
||||
"toc_position": {},
|
||||
"toc_section_display": true,
|
||||
"toc_window_display": false
|
||||
}
|
||||
},
|
||||
"nbformat": 4,
|
||||
"nbformat_minor": 5
|
||||
}
|
||||
@@ -13,6 +13,8 @@ def set_seed(seed):
|
||||
seed += 1
|
||||
torch.manual_seed(seed)
|
||||
|
||||
|
||||
def get_args(argv=None):
|
||||
parser = argparse.ArgumentParser(description='ETSformer: Exponential Smoothing Transformers for Time-series Forecasting')
|
||||
|
||||
# basic config
|
||||
@@ -76,7 +78,7 @@ parser.add_argument('--gpu', type=int, default=0, help='gpu')
|
||||
parser.add_argument('--use_multi_gpu', action='store_true', help='use multiple gpus', default=False)
|
||||
parser.add_argument('--devices', type=str, default='0,1,2,3', help='device ids of multile gpus')
|
||||
|
||||
args = parser.parse_args()
|
||||
args = parser.parse_args(argv)
|
||||
|
||||
args.use_gpu = True if torch.cuda.is_available() and args.use_gpu else False
|
||||
|
||||
@@ -88,6 +90,11 @@ if args.use_gpu and args.use_multi_gpu:
|
||||
|
||||
print('Args in experiment:')
|
||||
print(args)
|
||||
return args
|
||||
|
||||
if __name__=="__main__":
|
||||
|
||||
args = get_args()
|
||||
|
||||
Exp = Exp_Main
|
||||
|
||||
|
||||
Regular → Executable
Regular → Executable
Regular → Executable
Regular → Executable
Regular → Executable
Regular → Executable
Regular → Executable
Regular → Executable
Regular → Executable
Reference in New Issue
Block a user