running in nb

This commit is contained in:
wassname
2022-11-28 13:06:08 +08:00
parent e86284f233
commit 244766827e
15 changed files with 634 additions and 93 deletions
+89
View File
@@ -0,0 +1,89 @@
/notebooks/checkpoints/
/checkpoints/
# Byte-compiled / optimized / DLL files
__pycache__/
*.py[cod]
# C extensions
*.so
# Distribution / packaging
.Python
env/
build/
develop-eggs/
dist/
downloads/
eggs/
.eggs/
lib/
lib64/
parts/
sdist/
var/
*.egg-info/
.installed.cfg
*.egg
# PyInstaller
# Usually these files are written by a python script from a template
# before PyInstaller builds the exe, so as to inject date/other infos into it.
*.manifest
*.spec
# Installer logs
pip-log.txt
pip-delete-this-directory.txt
# Unit test / coverage reports
htmlcov/
.tox/
.coverage
.coverage.*
.cache
nosetests.xml
coverage.xml
*.cover
# Translations
*.mo
*.pot
# Django stuff:
*.log
# Sphinx documentation
docs/_build/
# PyBuilder
target/
# DotEnv configuration
.env
# Database
*.db
*.rdb
# Pycharm
.idea
# VS Code
.vscode/
# Spyder
.spyproject/
# Jupyter NB Checkpoints
.ipynb_checkpoints/
# exclude data from source control by default
/data/
# Mac OS-specific storage files
.DS_Store
# vim
*.swp
*.swo
+4
View File
@@ -1,3 +1,6 @@
Fork info: trying out and plotting the results on stock data.----------
# ETSformer: Exponential Smoothing Transformers for Time-series Forecasting
<p align="center">
@@ -8,6 +11,7 @@
Official PyTorch code repository for the [ETSformer paper](https://arxiv.org/abs/2202.01381). Check out our [blog post](https://blog.salesforceairesearch.com/etsformer-time-series-forecasting/)!
* ETSformer is a novel time-series Transformer architecture which exploits the principle of exponential smoothing in improving
Transformers for timeseries forecasting.
* ETSformer is inspired by the classical exponential smoothing methods in
+2
View File
@@ -1,4 +1,6 @@
# Ignore everything in this directory
*
# Except this file
!.gitignore
+6
View File
@@ -1 +1,7 @@
`conda activate deeptime`
TODO:
- [ ] run on stocks data
- [ ] multivariate in, univariate out
- [ ] graph
+433
View File
@@ -0,0 +1,433 @@
{
"cells": [
{
"cell_type": "code",
"execution_count": 3,
"id": "cbdce1e4",
"metadata": {
"ExecuteTime": {
"end_time": "2022-11-28T04:13:58.128994Z",
"start_time": "2022-11-28T04:13:58.127086Z"
}
},
"outputs": [],
"source": [
"import os\n",
"os.sys.path.append('..')"
]
},
{
"cell_type": "code",
"execution_count": 4,
"id": "1640aab2",
"metadata": {
"ExecuteTime": {
"end_time": "2022-11-28T04:13:58.305698Z",
"start_time": "2022-11-28T04:13:58.287564Z"
}
},
"outputs": [],
"source": [
"import warnings\n",
"warnings.simplefilter(\"ignore\")\n",
"\n",
"# autoreload import your package\n",
"%load_ext autoreload\n",
"%autoreload 2"
]
},
{
"cell_type": "code",
"execution_count": 9,
"id": "649ddfa9",
"metadata": {
"ExecuteTime": {
"end_time": "2022-11-28T04:14:16.865244Z",
"start_time": "2022-11-28T04:14:16.858147Z"
}
},
"outputs": [],
"source": [
"from loguru import logger\n",
"logger.remove()\n",
"logger.add(os.sys.stdout, level=\"ERROR\", colorize=True, format=\"<level>{time} | {message}</level>\")\n",
"# import_dir(ta_dir, verbose=False)\n",
"warnings.simplefilter(\"ignore\")"
]
},
{
"cell_type": "markdown",
"id": "bc96a9a7",
"metadata": {},
"source": [
"# Args"
]
},
{
"cell_type": "code",
"execution_count": 23,
"id": "7fcd9d4f",
"metadata": {
"ExecuteTime": {
"end_time": "2022-11-28T04:18:43.064910Z",
"start_time": "2022-11-28T04:18:43.050776Z"
}
},
"outputs": [],
"source": [
"from run import set_seed, get_args, Exp_Main"
]
},
{
"cell_type": "code",
"execution_count": 51,
"id": "0602682d",
"metadata": {
"ExecuteTime": {
"end_time": "2022-11-28T05:00:16.840051Z",
"start_time": "2022-11-28T05:00:16.816918Z"
}
},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"Args in experiment:\n",
"Namespace(K=0, activation='sigmoid', batch_size=32, c_out=12, checkpoints='./checkpoints/', d_ff=2048, d_layers=2, d_model=512, damping_learning_rate=0, data='custom', data_path='OXY_2019.csv.gz', dec_in=12, des=\"'Exp'\", devices='0,1,2,3', dropout=0.2, e_layers=2, embed='timeF', enc_in=12, features='M', freq='h', gpu=0, itr=1, label_len=0, learning_rate=0.001, lradj='exponential_with_warmup', min_lr=1e-30, model='ETSformer', model_id='Exchange', n_heads=8, num_workers=10, optim='adam', output_attention=False, patience=5, pred_len=48, root_path='../dataset/stocks/', seq_len=128, smoothing_learning_rate=0, std=0.2, target='RSMKs_18_144_72_2ref_2ref', train_epochs=15, use_gpu=True, use_multi_gpu=False, warmup_epochs=3)\n"
]
},
{
"data": {
"text/plain": [
"Namespace(K=0, activation='sigmoid', batch_size=32, c_out=12, checkpoints='./checkpoints/', d_ff=2048, d_layers=2, d_model=512, damping_learning_rate=0, data='custom', data_path='OXY_2019.csv.gz', dec_in=12, des=\"'Exp'\", devices='0,1,2,3', dropout=0.2, e_layers=2, embed='timeF', enc_in=12, features='M', freq='h', gpu=0, itr=1, label_len=0, learning_rate=0.001, lradj='exponential_with_warmup', min_lr=1e-30, model='ETSformer', model_id='Exchange', n_heads=8, num_workers=10, optim='adam', output_attention=False, patience=5, pred_len=48, root_path='../dataset/stocks/', seq_len=128, smoothing_learning_rate=0, std=0.2, target='RSMKs_18_144_72_2ref_2ref', train_epochs=15, use_gpu=True, use_multi_gpu=False, warmup_epochs=3)"
]
},
"execution_count": 51,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"# mimic cli args to avoid code duplication\n",
"argv = \"\"\"python -u run.py \\\n",
" --root_path ../dataset/stocks/ \\\n",
" --data_path OXY_2019.csv.gz \\\n",
" --checkpoints ./checkpoints/ \\\n",
" --model_id Exchange \\\n",
" --model ETSformer \\\n",
" --data custom \\\n",
" --features M \\\n",
" --seq_len 128 \\\n",
" --pred_len 48 \\\n",
" --e_layers 2 \\\n",
" --d_layers 2 \\\n",
" --enc_in 12 \\\n",
" --dec_in 12 \\\n",
" --c_out 12 \\\n",
" --des 'Exp' \\\n",
" --K 0 \\\n",
" --learning_rate 1e-3 \\\n",
" --target RSMKs_18_144_72_2ref_2ref \\\n",
" --itr 1\n",
"\"\"\"\n",
"argv = argv.replace(\"\\\\n\", \"\").split()[3:]\n",
"args = get_args(argv)\n",
"args"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "7367a3cb",
"metadata": {
"ExecuteTime": {
"end_time": "2022-11-28T04:24:49.799842Z",
"start_time": "2022-11-28T04:24:49.783939Z"
}
},
"outputs": [],
"source": []
},
{
"cell_type": "code",
"execution_count": null,
"id": "93ed3ca8",
"metadata": {
"ExecuteTime": {
"start_time": "2022-11-28T05:00:17.097Z"
}
},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"0\n",
"Use GPU: cuda:0\n",
">>>>>>>start training : Exchange_ETSformer_custom_ftM_sl128_pl48_dm512_nh8_el2_dl2_df2048_K0_lr0.001_'Exp'_0>>>>>>>>>>>>>>>>>>>>>>>>>>\n",
"train 37203\n",
"val 5294\n",
"test 10632\n",
"\titers: 100, epoch: 1 | loss: 0.4932106\n",
"\tspeed: 0.0565s/iter; left time: 978.9201s\n",
"\titers: 200, epoch: 1 | loss: 0.4972929\n",
"\tspeed: 0.0485s/iter; left time: 835.0221s\n",
"\titers: 300, epoch: 1 | loss: 0.4357589\n",
"\tspeed: 0.0476s/iter; left time: 815.1622s\n",
"\titers: 400, epoch: 1 | loss: 0.5983973\n",
"\tspeed: 0.0477s/iter; left time: 811.9901s\n",
"\titers: 500, epoch: 1 | loss: 0.6287826\n",
"\tspeed: 0.0484s/iter; left time: 820.1333s\n",
"\titers: 600, epoch: 1 | loss: 0.4684563\n",
"\tspeed: 0.0488s/iter; left time: 821.1888s\n",
"\titers: 700, epoch: 1 | loss: 1.0543735\n",
"\tspeed: 0.0484s/iter; left time: 810.2026s\n",
"\titers: 800, epoch: 1 | loss: 0.7420598\n",
"\tspeed: 0.0474s/iter; left time: 788.7177s\n",
"\titers: 900, epoch: 1 | loss: 0.5108218\n",
"\tspeed: 0.0474s/iter; left time: 784.3583s\n",
"\titers: 1000, epoch: 1 | loss: 0.5239242\n",
"\tspeed: 0.0505s/iter; left time: 830.0246s\n",
"\titers: 1100, epoch: 1 | loss: 0.8626059\n",
"\tspeed: 0.0468s/iter; left time: 764.9884s\n",
"Epoch: 1 cost time: 56.83964204788208\n",
"Epoch: 1, Steps: 1162 | Train Loss: 0.6849515 Vali Loss: 0.3407791 Test Loss: 0.4381559\n",
"Validation loss decreased (inf --> 0.340779). Saving model ...\n",
"Updating learning rate to 0.00025\n",
"\titers: 100, epoch: 2 | loss: 0.9436644\n",
"\tspeed: 0.1952s/iter; left time: 3155.9649s\n",
"\titers: 200, epoch: 2 | loss: 0.5133687\n",
"\tspeed: 0.0487s/iter; left time: 783.2851s\n",
"\titers: 300, epoch: 2 | loss: 0.6383107\n",
"\tspeed: 0.0482s/iter; left time: 770.3673s\n",
"\titers: 400, epoch: 2 | loss: 1.1335649\n",
"\tspeed: 0.0492s/iter; left time: 781.1069s\n",
"\titers: 500, epoch: 2 | loss: 0.5905243\n",
"\tspeed: 0.0467s/iter; left time: 736.0656s\n",
"\titers: 600, epoch: 2 | loss: 0.7256416\n",
"\tspeed: 0.0474s/iter; left time: 742.6376s\n",
"\titers: 700, epoch: 2 | loss: 0.5506995\n",
"\tspeed: 0.0478s/iter; left time: 744.2372s\n",
"\titers: 800, epoch: 2 | loss: 0.6441094\n",
"\tspeed: 0.0487s/iter; left time: 753.5581s\n",
"\titers: 900, epoch: 2 | loss: 0.3798619\n",
"\tspeed: 0.0487s/iter; left time: 748.4489s\n",
"\titers: 1000, epoch: 2 | loss: 0.4259658\n",
"\tspeed: 0.0479s/iter; left time: 731.2631s\n",
"\titers: 1100, epoch: 2 | loss: 0.5248989\n",
"\tspeed: 0.0491s/iter; left time: 745.1800s\n",
"Epoch: 2 cost time: 56.406978130340576\n",
"Epoch: 2, Steps: 1162 | Train Loss: 0.6076249 Vali Loss: 0.3033751 Test Loss: 0.3955314\n",
"Validation loss decreased (0.340779 --> 0.303375). Saving model ...\n",
"Updating learning rate to 0.0005\n",
"\titers: 100, epoch: 3 | loss: 0.4865443\n",
"\tspeed: 0.2002s/iter; left time: 3003.6628s\n",
"\titers: 200, epoch: 3 | loss: 0.5826065\n",
"\tspeed: 0.0493s/iter; left time: 734.4833s\n",
"\titers: 300, epoch: 3 | loss: 1.2626467\n",
"\tspeed: 0.0480s/iter; left time: 710.5487s\n",
"\titers: 400, epoch: 3 | loss: 0.4687002\n",
"\tspeed: 0.0472s/iter; left time: 694.2853s\n",
"\titers: 500, epoch: 3 | loss: 0.3959515\n",
"\tspeed: 0.0474s/iter; left time: 693.0790s\n",
"\titers: 600, epoch: 3 | loss: 0.4113244\n",
"\tspeed: 0.0490s/iter; left time: 710.3280s\n",
"\titers: 700, epoch: 3 | loss: 0.4114936\n",
"\tspeed: 0.0476s/iter; left time: 685.9448s\n",
"\titers: 800, epoch: 3 | loss: 0.5290235\n",
"\tspeed: 0.0475s/iter; left time: 679.7407s\n",
"\titers: 900, epoch: 3 | loss: 0.4256401\n",
"\tspeed: 0.0478s/iter; left time: 678.8945s\n",
"\titers: 1000, epoch: 3 | loss: 0.6996748\n",
"\tspeed: 0.0491s/iter; left time: 692.3603s\n",
"\titers: 1100, epoch: 3 | loss: 0.7242295\n",
"\tspeed: 0.0470s/iter; left time: 658.2708s\n",
"Epoch: 3 cost time: 56.2166702747345\n",
"Epoch: 3, Steps: 1162 | Train Loss: 0.5938104 Vali Loss: 0.3029587 Test Loss: 0.3915149\n",
"Validation loss decreased (0.303375 --> 0.302959). Saving model ...\n",
"Updating learning rate to 0.00075\n",
"\titers: 100, epoch: 4 | loss: 0.4237639\n",
"\tspeed: 0.1944s/iter; left time: 2691.7509s\n",
"\titers: 200, epoch: 4 | loss: 0.3838455\n",
"\tspeed: 0.0469s/iter; left time: 644.7742s\n",
"\titers: 300, epoch: 4 | loss: 0.5273813\n",
"\tspeed: 0.0465s/iter; left time: 635.0866s\n",
"\titers: 400, epoch: 4 | loss: 1.2923048\n",
"\tspeed: 0.0482s/iter; left time: 652.7872s\n",
"\titers: 500, epoch: 4 | loss: 0.6386693\n",
"\tspeed: 0.0488s/iter; left time: 656.6116s\n",
"\titers: 600, epoch: 4 | loss: 0.5111505\n",
"\tspeed: 0.0483s/iter; left time: 644.4140s\n",
"\titers: 700, epoch: 4 | loss: 0.5771207\n",
"\tspeed: 0.0470s/iter; left time: 622.1141s\n",
"\titers: 800, epoch: 4 | loss: 0.6522042\n",
"\tspeed: 0.0471s/iter; left time: 619.3240s\n",
"\titers: 900, epoch: 4 | loss: 0.3418834\n",
"\tspeed: 0.0474s/iter; left time: 618.8682s\n",
"\titers: 1000, epoch: 4 | loss: 0.4811396\n",
"\tspeed: 0.0470s/iter; left time: 608.9635s\n",
"\titers: 1100, epoch: 4 | loss: 0.6203826\n",
"\tspeed: 0.0476s/iter; left time: 611.6955s\n",
"Epoch: 4 cost time: 55.756075620651245\n",
"Epoch: 4, Steps: 1162 | Train Loss: 0.5855654 Vali Loss: 0.3044514 Test Loss: 0.3943208\n",
"EarlyStopping counter: 1 out of 5\n",
"Updating learning rate to 0.001\n",
"\titers: 100, epoch: 5 | loss: 0.7008136\n",
"\tspeed: 0.1930s/iter; left time: 2447.9230s\n",
"\titers: 200, epoch: 5 | loss: 0.7495702\n",
"\tspeed: 0.0493s/iter; left time: 620.2588s\n",
"\titers: 300, epoch: 5 | loss: 1.0079670\n",
"\tspeed: 0.0487s/iter; left time: 608.5033s\n",
"\titers: 400, epoch: 5 | loss: 0.3510064\n",
"\tspeed: 0.0482s/iter; left time: 596.2644s\n",
"\titers: 500, epoch: 5 | loss: 0.4754097\n",
"\tspeed: 0.0474s/iter; left time: 581.7234s\n",
"\titers: 600, epoch: 5 | loss: 0.7314481\n",
"\tspeed: 0.0493s/iter; left time: 600.1122s\n",
"\titers: 700, epoch: 5 | loss: 0.8495453\n",
"\tspeed: 0.0483s/iter; left time: 583.0922s\n",
"\titers: 800, epoch: 5 | loss: 0.3931953\n",
"\tspeed: 0.0474s/iter; left time: 567.6057s\n",
"\titers: 900, epoch: 5 | loss: 0.3912098\n",
"\tspeed: 0.0488s/iter; left time: 580.1831s\n",
"\titers: 1000, epoch: 5 | loss: 0.4786185\n",
"\tspeed: 0.0483s/iter; left time: 568.7581s\n",
"\titers: 1100, epoch: 5 | loss: 0.6882768\n",
"\tspeed: 0.0471s/iter; left time: 550.3884s\n",
"Epoch: 5 cost time: 56.28184771537781\n",
"Epoch: 5, Steps: 1162 | Train Loss: 0.5815187 Vali Loss: 0.3013998 Test Loss: 0.4001175\n",
"Validation loss decreased (0.302959 --> 0.301400). Saving model ...\n",
"Updating learning rate to 0.0005\n"
]
}
],
"source": [
"\n",
"Exp = Exp_Main\n",
"\n",
"for ii in range(args.itr):\n",
" print(ii)\n",
" set_seed(ii)\n",
" # setting record of experiments\n",
" setting = '{}_{}_{}_ft{}_sl{}_pl{}_dm{}_nh{}_el{}_dl{}_df{}_K{}_lr{}_{}_{}'.format(\n",
" args.model_id,\n",
" args.model,\n",
" args.data,\n",
" args.features,\n",
" args.seq_len,\n",
" args.pred_len,\n",
" args.d_model,\n",
" args.n_heads,\n",
" args.e_layers,\n",
" args.d_layers,\n",
" args.d_ff,\n",
" args.K,\n",
" args.learning_rate,\n",
" args.des, ii)\n",
"\n",
"# if os.path.exists(os.path.join(args.checkpoints, setting)):\n",
"# print('skipping exists')\n",
"# continue\n",
"\n",
" exp = Exp(args) # set experiments\n",
" print('>>>>>>>start training : {}>>>>>>>>>>>>>>>>>>>>>>>>>>'.format(setting))\n",
" exp.train(setting)\n",
"\n",
" print('>>>>>>>testing : {}<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<'.format(setting))\n",
" exp.test(setting, data='val')\n",
" exp.test(setting, data='test')\n",
"\n",
" torch.cuda.empty_cache()\n"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "74745d73",
"metadata": {},
"outputs": [],
"source": []
},
{
"cell_type": "code",
"execution_count": 49,
"id": "2ed3cbf1",
"metadata": {
"ExecuteTime": {
"end_time": "2022-11-28T04:48:15.407951Z",
"start_time": "2022-11-28T04:48:13.634492Z"
}
},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"> \u001b[0;32m/media/wassname/SGIronWolf/projects5/investing/ETSformer/models/etsformer/encoder.py\u001b[0m(155)\u001b[0;36mforward\u001b[0;34m()\u001b[0m\n",
"\u001b[0;32m 153 \u001b[0;31m \u001b[0mgrowth\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mgrowth\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mview\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mb\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mt\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mc_out\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0;36m1\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n",
"\u001b[0m\u001b[0;32m 154 \u001b[0;31m \u001b[0mseason\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mseason\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mview\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mb\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mt\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mc_out\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0;36m1\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n",
"\u001b[0m\u001b[0;32m--> 155 \u001b[0;31m \u001b[0mlevel\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mlevel\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mview\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mb\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mt\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mc_out\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0;36m1\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n",
"\u001b[0m\u001b[0;32m 156 \u001b[0;31m \u001b[0mout\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mes\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mlevel\u001b[0m \u001b[0;34m-\u001b[0m \u001b[0mseason\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0maux_values\u001b[0m\u001b[0;34m=\u001b[0m\u001b[0mgrowth\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n",
"\u001b[0m\u001b[0;32m 157 \u001b[0;31m \u001b[0mout\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mrearrange\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mout\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0;34m'b t h d -> b t (h d)'\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n",
"\u001b[0m\n",
"ipdb> q\n"
]
}
],
"source": [
"%debug"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "2faf8241",
"metadata": {},
"outputs": [],
"source": []
},
{
"cell_type": "code",
"execution_count": null,
"id": "609a7018",
"metadata": {},
"outputs": [],
"source": []
}
],
"metadata": {
"kernelspec": {
"display_name": "deeptime",
"language": "python",
"name": "deeptime"
},
"language_info": {
"codemirror_mode": {
"name": "ipython",
"version": 3
},
"file_extension": ".py",
"mimetype": "text/x-python",
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.8.13"
},
"toc": {
"base_numbering": 1,
"nav_menu": {},
"number_sections": true,
"sideBar": true,
"skip_h1_title": false,
"title_cell": "Table of Contents",
"title_sidebar": "Contents",
"toc_cell": false,
"toc_position": {},
"toc_section_display": true,
"toc_window_display": false
}
},
"nbformat": 4,
"nbformat_minor": 5
}
+64 -57
View File
@@ -13,85 +13,92 @@ def set_seed(seed):
seed += 1
torch.manual_seed(seed)
parser = argparse.ArgumentParser(description='ETSformer: Exponential Smoothing Transformers for Time-series Forecasting')
# basic config
parser.add_argument('--model_id', type=str, required=True, default='test', help='model id')
parser.add_argument('--model', type=str, required=True, default='ETSformer',
def get_args(argv=None):
parser = argparse.ArgumentParser(description='ETSformer: Exponential Smoothing Transformers for Time-series Forecasting')
# basic config
parser.add_argument('--model_id', type=str, required=True, default='test', help='model id')
parser.add_argument('--model', type=str, required=True, default='ETSformer',
help='model name, options: [ETSformer]')
# data loader
parser.add_argument('--data', type=str, required=True, default='ETTm1', help='dataset type')
parser.add_argument('--root_path', type=str, default='./data/ETT/', help='root path of the data file')
parser.add_argument('--data_path', type=str, default='ETTh1.csv', help='data file')
parser.add_argument('--features', type=str, default='M',
# data loader
parser.add_argument('--data', type=str, required=True, default='ETTm1', help='dataset type')
parser.add_argument('--root_path', type=str, default='./data/ETT/', help='root path of the data file')
parser.add_argument('--data_path', type=str, default='ETTh1.csv', help='data file')
parser.add_argument('--features', type=str, default='M',
help='forecasting task, options:[M, S, MS]; M:multivariate predict multivariate, S:univariate predict univariate, MS:multivariate predict univariate')
parser.add_argument('--target', type=str, default='OT', help='target feature in S or MS task')
parser.add_argument('--freq', type=str, default='h',
parser.add_argument('--target', type=str, default='OT', help='target feature in S or MS task')
parser.add_argument('--freq', type=str, default='h',
help='freq for time features encoding, options:[s:secondly, t:minutely, h:hourly, d:daily, b:business days, w:weekly, m:monthly], you can also use more detailed freq like 15min or 3h')
parser.add_argument('--checkpoints', type=str, default='./checkpoints/', help='location of model checkpoints')
parser.add_argument('--checkpoints', type=str, default='./checkpoints/', help='location of model checkpoints')
# forecasting task
parser.add_argument('--seq_len', type=int, required=True, help='input sequence length')
parser.add_argument('--label_len', type=int, default=0, help='start token length')
parser.add_argument('--pred_len', type=int, required=True, help='prediction sequence length')
# forecasting task
parser.add_argument('--seq_len', type=int, required=True, help='input sequence length')
parser.add_argument('--label_len', type=int, default=0, help='start token length')
parser.add_argument('--pred_len', type=int, required=True, help='prediction sequence length')
# model define
parser.add_argument('--enc_in', type=int, default=7, help='encoder input size')
parser.add_argument('--dec_in', type=int, default=7, help='decoder input size')
parser.add_argument('--c_out', type=int, default=7, help='output size')
parser.add_argument('--d_model', type=int, default=512, help='dimension of model')
parser.add_argument('--n_heads', type=int, default=8, help='num of heads')
parser.add_argument('--e_layers', type=int, default=2, help='num of encoder layers')
parser.add_argument('--d_layers', type=int, default=1, help='num of decoder layers')
parser.add_argument('--d_ff', type=int, default=2048, help='dimension of fcn')
parser.add_argument('--K', type=int, default=1, help='Top-K Fourier bases')
parser.add_argument('--dropout', type=float, default=0.2, help='dropout')
parser.add_argument('--embed', type=str, default='timeF',
# model define
parser.add_argument('--enc_in', type=int, default=7, help='encoder input size')
parser.add_argument('--dec_in', type=int, default=7, help='decoder input size')
parser.add_argument('--c_out', type=int, default=7, help='output size')
parser.add_argument('--d_model', type=int, default=512, help='dimension of model')
parser.add_argument('--n_heads', type=int, default=8, help='num of heads')
parser.add_argument('--e_layers', type=int, default=2, help='num of encoder layers')
parser.add_argument('--d_layers', type=int, default=1, help='num of decoder layers')
parser.add_argument('--d_ff', type=int, default=2048, help='dimension of fcn')
parser.add_argument('--K', type=int, default=1, help='Top-K Fourier bases')
parser.add_argument('--dropout', type=float, default=0.2, help='dropout')
parser.add_argument('--embed', type=str, default='timeF',
help='time features encoding, options:[timeF, fixed, learned]')
parser.add_argument('--activation', type=str, default='sigmoid', help='activation')
parser.add_argument('--activation', type=str, default='sigmoid', help='activation')
parser.add_argument('--min_lr', type=float, default=1e-30)
parser.add_argument('--warmup_epochs', type=int, default=3)
parser.add_argument('--std', type=float, default=0.2)
parser.add_argument('--min_lr', type=float, default=1e-30)
parser.add_argument('--warmup_epochs', type=int, default=3)
parser.add_argument('--std', type=float, default=0.2)
parser.add_argument('--smoothing_learning_rate', type=float, default=0, help='optimizer learning rate')
parser.add_argument('--damping_learning_rate', type=float, default=0, help='optimizer learning rate')
parser.add_argument('--output_attention', type=bool, default=False)
parser.add_argument('--smoothing_learning_rate', type=float, default=0, help='optimizer learning rate')
parser.add_argument('--damping_learning_rate', type=float, default=0, help='optimizer learning rate')
parser.add_argument('--output_attention', type=bool, default=False)
# optimization
parser.add_argument('--optim', type=str, default='adam', help='optimizer')
parser.add_argument('--num_workers', type=int, default=10, help='data loader num workers')
parser.add_argument('--itr', type=int, default=1, help='experiments times')
parser.add_argument('--train_epochs', type=int, default=15, help='train epochs')
parser.add_argument('--batch_size', type=int, default=32, help='batch size of train input data')
parser.add_argument('--patience', type=int, default=5, help='early stopping patience')
parser.add_argument('--learning_rate', type=float, default=1e-4, help='optimizer learning rate')
parser.add_argument('--des', type=str, default='test', help='exp description')
parser.add_argument('--lradj', type=str, default='exponential_with_warmup', help='adjust learning rate')
# optimization
parser.add_argument('--optim', type=str, default='adam', help='optimizer')
parser.add_argument('--num_workers', type=int, default=10, help='data loader num workers')
parser.add_argument('--itr', type=int, default=1, help='experiments times')
parser.add_argument('--train_epochs', type=int, default=15, help='train epochs')
parser.add_argument('--batch_size', type=int, default=32, help='batch size of train input data')
parser.add_argument('--patience', type=int, default=5, help='early stopping patience')
parser.add_argument('--learning_rate', type=float, default=1e-4, help='optimizer learning rate')
parser.add_argument('--des', type=str, default='test', help='exp description')
parser.add_argument('--lradj', type=str, default='exponential_with_warmup', help='adjust learning rate')
# GPU
parser.add_argument('--use_gpu', type=bool, default=True, help='use gpu')
parser.add_argument('--gpu', type=int, default=0, help='gpu')
parser.add_argument('--use_multi_gpu', action='store_true', help='use multiple gpus', default=False)
parser.add_argument('--devices', type=str, default='0,1,2,3', help='device ids of multile gpus')
# GPU
parser.add_argument('--use_gpu', type=bool, default=True, help='use gpu')
parser.add_argument('--gpu', type=int, default=0, help='gpu')
parser.add_argument('--use_multi_gpu', action='store_true', help='use multiple gpus', default=False)
parser.add_argument('--devices', type=str, default='0,1,2,3', help='device ids of multile gpus')
args = parser.parse_args()
args = parser.parse_args(argv)
args.use_gpu = True if torch.cuda.is_available() and args.use_gpu else False
args.use_gpu = True if torch.cuda.is_available() and args.use_gpu else False
if args.use_gpu and args.use_multi_gpu:
if args.use_gpu and args.use_multi_gpu:
args.dvices = args.devices.replace(' ', '')
device_ids = args.devices.split(',')
args.device_ids = [int(id_) for id_ in device_ids]
args.gpu = args.device_ids[0]
print('Args in experiment:')
print(args)
print('Args in experiment:')
print(args)
return args
Exp = Exp_Main
if __name__=="__main__":
for ii in range(args.itr):
args = get_args()
Exp = Exp_Main
for ii in range(args.itr):
set_seed(ii)
# setting record of experiments
setting = '{}_{}_{}_ft{}_sl{}_pl{}_dm{}_nh{}_el{}_dl{}_df{}_K{}_lr{}_{}_{}'.format(
Regular → Executable
View File
Regular → Executable
View File
Regular → Executable
View File
Regular → Executable
View File
Regular → Executable
View File
Regular → Executable
View File
Regular → Executable
View File
Regular → Executable
View File
Regular → Executable
View File