diff --git a/.gitignore b/.gitignore
new file mode 100644
index 0000000..ebfb806
--- /dev/null
+++ b/.gitignore
@@ -0,0 +1,89 @@
+/notebooks/checkpoints/
+/checkpoints/
+
+# Byte-compiled / optimized / DLL files
+__pycache__/
+*.py[cod]
+
+# C extensions
+*.so
+
+# Distribution / packaging
+.Python
+env/
+build/
+develop-eggs/
+dist/
+downloads/
+eggs/
+.eggs/
+lib/
+lib64/
+parts/
+sdist/
+var/
+*.egg-info/
+.installed.cfg
+*.egg
+
+# PyInstaller
+# Usually these files are written by a python script from a template
+# before PyInstaller builds the exe, so as to inject date/other infos into it.
+*.manifest
+*.spec
+
+# Installer logs
+pip-log.txt
+pip-delete-this-directory.txt
+
+# Unit test / coverage reports
+htmlcov/
+.tox/
+.coverage
+.coverage.*
+.cache
+nosetests.xml
+coverage.xml
+*.cover
+
+# Translations
+*.mo
+*.pot
+
+# Django stuff:
+*.log
+
+# Sphinx documentation
+docs/_build/
+
+# PyBuilder
+target/
+
+# DotEnv configuration
+.env
+
+# Database
+*.db
+*.rdb
+
+# Pycharm
+.idea
+
+# VS Code
+.vscode/
+
+# Spyder
+.spyproject/
+
+# Jupyter NB Checkpoints
+.ipynb_checkpoints/
+
+# exclude data from source control by default
+/data/
+
+# Mac OS-specific storage files
+.DS_Store
+
+# vim
+*.swp
+*.swo
diff --git a/README.md b/README.md
index d130438..02cedee 100644
--- a/README.md
+++ b/README.md
@@ -1,3 +1,6 @@
+Fork info: trying out and plotting the results on stock data.----------
+
+
# ETSformer: Exponential Smoothing Transformers for Time-series Forecasting
@@ -8,6 +11,7 @@
Official PyTorch code repository for the [ETSformer paper](https://arxiv.org/abs/2202.01381). Check out our [blog post](https://blog.salesforceairesearch.com/etsformer-time-series-forecasting/)!
+
* ETSformer is a novel time-series Transformer architecture which exploits the principle of exponential smoothing in improving
Transformers for timeseries forecasting.
* ETSformer is inspired by the classical exponential smoothing methods in
@@ -51,4 +55,4 @@ Please consider citing if you find this code useful to your research.
author={Gerald Woo and Chenghao Liu and Doyen Sahoo and Akshat Kumar and Steven C. H. Hoi},
year={2022},
url={https://arxiv.org/abs/2202.01381},
-}
\ No newline at end of file
+}
diff --git a/dataset/.gitignore b/dataset/.gitignore
index 86d0cb2..b5fb85e 100644
--- a/dataset/.gitignore
+++ b/dataset/.gitignore
@@ -1,4 +1,6 @@
+
# Ignore everything in this directory
*
# Except this file
-!.gitignore
\ No newline at end of file
+!.gitignore
+
diff --git a/mjc_notes.md b/mjc_notes.md
index 9200fea..045a745 100644
--- a/mjc_notes.md
+++ b/mjc_notes.md
@@ -1 +1,7 @@
`conda activate deeptime`
+
+
+TODO:
+- [ ] run on stocks data
+- [ ] multivariate in, univariate out
+- [ ] graph
diff --git a/notebooks/001_001_run.ipynb b/notebooks/001_001_run.ipynb
new file mode 100644
index 0000000..c5e842d
--- /dev/null
+++ b/notebooks/001_001_run.ipynb
@@ -0,0 +1,433 @@
+{
+ "cells": [
+ {
+ "cell_type": "code",
+ "execution_count": 3,
+ "id": "cbdce1e4",
+ "metadata": {
+ "ExecuteTime": {
+ "end_time": "2022-11-28T04:13:58.128994Z",
+ "start_time": "2022-11-28T04:13:58.127086Z"
+ }
+ },
+ "outputs": [],
+ "source": [
+ "import os\n",
+ "os.sys.path.append('..')"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 4,
+ "id": "1640aab2",
+ "metadata": {
+ "ExecuteTime": {
+ "end_time": "2022-11-28T04:13:58.305698Z",
+ "start_time": "2022-11-28T04:13:58.287564Z"
+ }
+ },
+ "outputs": [],
+ "source": [
+ "import warnings\n",
+ "warnings.simplefilter(\"ignore\")\n",
+ "\n",
+ "# autoreload import your package\n",
+ "%load_ext autoreload\n",
+ "%autoreload 2"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 9,
+ "id": "649ddfa9",
+ "metadata": {
+ "ExecuteTime": {
+ "end_time": "2022-11-28T04:14:16.865244Z",
+ "start_time": "2022-11-28T04:14:16.858147Z"
+ }
+ },
+ "outputs": [],
+ "source": [
+ "from loguru import logger\n",
+ "logger.remove()\n",
+ "logger.add(os.sys.stdout, level=\"ERROR\", colorize=True, format=\"{time} | {message}\")\n",
+ "# import_dir(ta_dir, verbose=False)\n",
+ "warnings.simplefilter(\"ignore\")"
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "id": "bc96a9a7",
+ "metadata": {},
+ "source": [
+ "# Args"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 23,
+ "id": "7fcd9d4f",
+ "metadata": {
+ "ExecuteTime": {
+ "end_time": "2022-11-28T04:18:43.064910Z",
+ "start_time": "2022-11-28T04:18:43.050776Z"
+ }
+ },
+ "outputs": [],
+ "source": [
+ "from run import set_seed, get_args, Exp_Main"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 51,
+ "id": "0602682d",
+ "metadata": {
+ "ExecuteTime": {
+ "end_time": "2022-11-28T05:00:16.840051Z",
+ "start_time": "2022-11-28T05:00:16.816918Z"
+ }
+ },
+ "outputs": [
+ {
+ "name": "stdout",
+ "output_type": "stream",
+ "text": [
+ "Args in experiment:\n",
+ "Namespace(K=0, activation='sigmoid', batch_size=32, c_out=12, checkpoints='./checkpoints/', d_ff=2048, d_layers=2, d_model=512, damping_learning_rate=0, data='custom', data_path='OXY_2019.csv.gz', dec_in=12, des=\"'Exp'\", devices='0,1,2,3', dropout=0.2, e_layers=2, embed='timeF', enc_in=12, features='M', freq='h', gpu=0, itr=1, label_len=0, learning_rate=0.001, lradj='exponential_with_warmup', min_lr=1e-30, model='ETSformer', model_id='Exchange', n_heads=8, num_workers=10, optim='adam', output_attention=False, patience=5, pred_len=48, root_path='../dataset/stocks/', seq_len=128, smoothing_learning_rate=0, std=0.2, target='RSMKs_18_144_72_2ref_2ref', train_epochs=15, use_gpu=True, use_multi_gpu=False, warmup_epochs=3)\n"
+ ]
+ },
+ {
+ "data": {
+ "text/plain": [
+ "Namespace(K=0, activation='sigmoid', batch_size=32, c_out=12, checkpoints='./checkpoints/', d_ff=2048, d_layers=2, d_model=512, damping_learning_rate=0, data='custom', data_path='OXY_2019.csv.gz', dec_in=12, des=\"'Exp'\", devices='0,1,2,3', dropout=0.2, e_layers=2, embed='timeF', enc_in=12, features='M', freq='h', gpu=0, itr=1, label_len=0, learning_rate=0.001, lradj='exponential_with_warmup', min_lr=1e-30, model='ETSformer', model_id='Exchange', n_heads=8, num_workers=10, optim='adam', output_attention=False, patience=5, pred_len=48, root_path='../dataset/stocks/', seq_len=128, smoothing_learning_rate=0, std=0.2, target='RSMKs_18_144_72_2ref_2ref', train_epochs=15, use_gpu=True, use_multi_gpu=False, warmup_epochs=3)"
+ ]
+ },
+ "execution_count": 51,
+ "metadata": {},
+ "output_type": "execute_result"
+ }
+ ],
+ "source": [
+ "# mimic cli args to avoid code duplication\n",
+ "argv = \"\"\"python -u run.py \\\n",
+ " --root_path ../dataset/stocks/ \\\n",
+ " --data_path OXY_2019.csv.gz \\\n",
+ " --checkpoints ./checkpoints/ \\\n",
+ " --model_id Exchange \\\n",
+ " --model ETSformer \\\n",
+ " --data custom \\\n",
+ " --features M \\\n",
+ " --seq_len 128 \\\n",
+ " --pred_len 48 \\\n",
+ " --e_layers 2 \\\n",
+ " --d_layers 2 \\\n",
+ " --enc_in 12 \\\n",
+ " --dec_in 12 \\\n",
+ " --c_out 12 \\\n",
+ " --des 'Exp' \\\n",
+ " --K 0 \\\n",
+ " --learning_rate 1e-3 \\\n",
+ " --target RSMKs_18_144_72_2ref_2ref \\\n",
+ " --itr 1\n",
+ "\"\"\"\n",
+ "argv = argv.replace(\"\\\\n\", \"\").split()[3:]\n",
+ "args = get_args(argv)\n",
+ "args"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": null,
+ "id": "7367a3cb",
+ "metadata": {
+ "ExecuteTime": {
+ "end_time": "2022-11-28T04:24:49.799842Z",
+ "start_time": "2022-11-28T04:24:49.783939Z"
+ }
+ },
+ "outputs": [],
+ "source": []
+ },
+ {
+ "cell_type": "code",
+ "execution_count": null,
+ "id": "93ed3ca8",
+ "metadata": {
+ "ExecuteTime": {
+ "start_time": "2022-11-28T05:00:17.097Z"
+ }
+ },
+ "outputs": [
+ {
+ "name": "stdout",
+ "output_type": "stream",
+ "text": [
+ "0\n",
+ "Use GPU: cuda:0\n",
+ ">>>>>>>start training : Exchange_ETSformer_custom_ftM_sl128_pl48_dm512_nh8_el2_dl2_df2048_K0_lr0.001_'Exp'_0>>>>>>>>>>>>>>>>>>>>>>>>>>\n",
+ "train 37203\n",
+ "val 5294\n",
+ "test 10632\n",
+ "\titers: 100, epoch: 1 | loss: 0.4932106\n",
+ "\tspeed: 0.0565s/iter; left time: 978.9201s\n",
+ "\titers: 200, epoch: 1 | loss: 0.4972929\n",
+ "\tspeed: 0.0485s/iter; left time: 835.0221s\n",
+ "\titers: 300, epoch: 1 | loss: 0.4357589\n",
+ "\tspeed: 0.0476s/iter; left time: 815.1622s\n",
+ "\titers: 400, epoch: 1 | loss: 0.5983973\n",
+ "\tspeed: 0.0477s/iter; left time: 811.9901s\n",
+ "\titers: 500, epoch: 1 | loss: 0.6287826\n",
+ "\tspeed: 0.0484s/iter; left time: 820.1333s\n",
+ "\titers: 600, epoch: 1 | loss: 0.4684563\n",
+ "\tspeed: 0.0488s/iter; left time: 821.1888s\n",
+ "\titers: 700, epoch: 1 | loss: 1.0543735\n",
+ "\tspeed: 0.0484s/iter; left time: 810.2026s\n",
+ "\titers: 800, epoch: 1 | loss: 0.7420598\n",
+ "\tspeed: 0.0474s/iter; left time: 788.7177s\n",
+ "\titers: 900, epoch: 1 | loss: 0.5108218\n",
+ "\tspeed: 0.0474s/iter; left time: 784.3583s\n",
+ "\titers: 1000, epoch: 1 | loss: 0.5239242\n",
+ "\tspeed: 0.0505s/iter; left time: 830.0246s\n",
+ "\titers: 1100, epoch: 1 | loss: 0.8626059\n",
+ "\tspeed: 0.0468s/iter; left time: 764.9884s\n",
+ "Epoch: 1 cost time: 56.83964204788208\n",
+ "Epoch: 1, Steps: 1162 | Train Loss: 0.6849515 Vali Loss: 0.3407791 Test Loss: 0.4381559\n",
+ "Validation loss decreased (inf --> 0.340779). Saving model ...\n",
+ "Updating learning rate to 0.00025\n",
+ "\titers: 100, epoch: 2 | loss: 0.9436644\n",
+ "\tspeed: 0.1952s/iter; left time: 3155.9649s\n",
+ "\titers: 200, epoch: 2 | loss: 0.5133687\n",
+ "\tspeed: 0.0487s/iter; left time: 783.2851s\n",
+ "\titers: 300, epoch: 2 | loss: 0.6383107\n",
+ "\tspeed: 0.0482s/iter; left time: 770.3673s\n",
+ "\titers: 400, epoch: 2 | loss: 1.1335649\n",
+ "\tspeed: 0.0492s/iter; left time: 781.1069s\n",
+ "\titers: 500, epoch: 2 | loss: 0.5905243\n",
+ "\tspeed: 0.0467s/iter; left time: 736.0656s\n",
+ "\titers: 600, epoch: 2 | loss: 0.7256416\n",
+ "\tspeed: 0.0474s/iter; left time: 742.6376s\n",
+ "\titers: 700, epoch: 2 | loss: 0.5506995\n",
+ "\tspeed: 0.0478s/iter; left time: 744.2372s\n",
+ "\titers: 800, epoch: 2 | loss: 0.6441094\n",
+ "\tspeed: 0.0487s/iter; left time: 753.5581s\n",
+ "\titers: 900, epoch: 2 | loss: 0.3798619\n",
+ "\tspeed: 0.0487s/iter; left time: 748.4489s\n",
+ "\titers: 1000, epoch: 2 | loss: 0.4259658\n",
+ "\tspeed: 0.0479s/iter; left time: 731.2631s\n",
+ "\titers: 1100, epoch: 2 | loss: 0.5248989\n",
+ "\tspeed: 0.0491s/iter; left time: 745.1800s\n",
+ "Epoch: 2 cost time: 56.406978130340576\n",
+ "Epoch: 2, Steps: 1162 | Train Loss: 0.6076249 Vali Loss: 0.3033751 Test Loss: 0.3955314\n",
+ "Validation loss decreased (0.340779 --> 0.303375). Saving model ...\n",
+ "Updating learning rate to 0.0005\n",
+ "\titers: 100, epoch: 3 | loss: 0.4865443\n",
+ "\tspeed: 0.2002s/iter; left time: 3003.6628s\n",
+ "\titers: 200, epoch: 3 | loss: 0.5826065\n",
+ "\tspeed: 0.0493s/iter; left time: 734.4833s\n",
+ "\titers: 300, epoch: 3 | loss: 1.2626467\n",
+ "\tspeed: 0.0480s/iter; left time: 710.5487s\n",
+ "\titers: 400, epoch: 3 | loss: 0.4687002\n",
+ "\tspeed: 0.0472s/iter; left time: 694.2853s\n",
+ "\titers: 500, epoch: 3 | loss: 0.3959515\n",
+ "\tspeed: 0.0474s/iter; left time: 693.0790s\n",
+ "\titers: 600, epoch: 3 | loss: 0.4113244\n",
+ "\tspeed: 0.0490s/iter; left time: 710.3280s\n",
+ "\titers: 700, epoch: 3 | loss: 0.4114936\n",
+ "\tspeed: 0.0476s/iter; left time: 685.9448s\n",
+ "\titers: 800, epoch: 3 | loss: 0.5290235\n",
+ "\tspeed: 0.0475s/iter; left time: 679.7407s\n",
+ "\titers: 900, epoch: 3 | loss: 0.4256401\n",
+ "\tspeed: 0.0478s/iter; left time: 678.8945s\n",
+ "\titers: 1000, epoch: 3 | loss: 0.6996748\n",
+ "\tspeed: 0.0491s/iter; left time: 692.3603s\n",
+ "\titers: 1100, epoch: 3 | loss: 0.7242295\n",
+ "\tspeed: 0.0470s/iter; left time: 658.2708s\n",
+ "Epoch: 3 cost time: 56.2166702747345\n",
+ "Epoch: 3, Steps: 1162 | Train Loss: 0.5938104 Vali Loss: 0.3029587 Test Loss: 0.3915149\n",
+ "Validation loss decreased (0.303375 --> 0.302959). Saving model ...\n",
+ "Updating learning rate to 0.00075\n",
+ "\titers: 100, epoch: 4 | loss: 0.4237639\n",
+ "\tspeed: 0.1944s/iter; left time: 2691.7509s\n",
+ "\titers: 200, epoch: 4 | loss: 0.3838455\n",
+ "\tspeed: 0.0469s/iter; left time: 644.7742s\n",
+ "\titers: 300, epoch: 4 | loss: 0.5273813\n",
+ "\tspeed: 0.0465s/iter; left time: 635.0866s\n",
+ "\titers: 400, epoch: 4 | loss: 1.2923048\n",
+ "\tspeed: 0.0482s/iter; left time: 652.7872s\n",
+ "\titers: 500, epoch: 4 | loss: 0.6386693\n",
+ "\tspeed: 0.0488s/iter; left time: 656.6116s\n",
+ "\titers: 600, epoch: 4 | loss: 0.5111505\n",
+ "\tspeed: 0.0483s/iter; left time: 644.4140s\n",
+ "\titers: 700, epoch: 4 | loss: 0.5771207\n",
+ "\tspeed: 0.0470s/iter; left time: 622.1141s\n",
+ "\titers: 800, epoch: 4 | loss: 0.6522042\n",
+ "\tspeed: 0.0471s/iter; left time: 619.3240s\n",
+ "\titers: 900, epoch: 4 | loss: 0.3418834\n",
+ "\tspeed: 0.0474s/iter; left time: 618.8682s\n",
+ "\titers: 1000, epoch: 4 | loss: 0.4811396\n",
+ "\tspeed: 0.0470s/iter; left time: 608.9635s\n",
+ "\titers: 1100, epoch: 4 | loss: 0.6203826\n",
+ "\tspeed: 0.0476s/iter; left time: 611.6955s\n",
+ "Epoch: 4 cost time: 55.756075620651245\n",
+ "Epoch: 4, Steps: 1162 | Train Loss: 0.5855654 Vali Loss: 0.3044514 Test Loss: 0.3943208\n",
+ "EarlyStopping counter: 1 out of 5\n",
+ "Updating learning rate to 0.001\n",
+ "\titers: 100, epoch: 5 | loss: 0.7008136\n",
+ "\tspeed: 0.1930s/iter; left time: 2447.9230s\n",
+ "\titers: 200, epoch: 5 | loss: 0.7495702\n",
+ "\tspeed: 0.0493s/iter; left time: 620.2588s\n",
+ "\titers: 300, epoch: 5 | loss: 1.0079670\n",
+ "\tspeed: 0.0487s/iter; left time: 608.5033s\n",
+ "\titers: 400, epoch: 5 | loss: 0.3510064\n",
+ "\tspeed: 0.0482s/iter; left time: 596.2644s\n",
+ "\titers: 500, epoch: 5 | loss: 0.4754097\n",
+ "\tspeed: 0.0474s/iter; left time: 581.7234s\n",
+ "\titers: 600, epoch: 5 | loss: 0.7314481\n",
+ "\tspeed: 0.0493s/iter; left time: 600.1122s\n",
+ "\titers: 700, epoch: 5 | loss: 0.8495453\n",
+ "\tspeed: 0.0483s/iter; left time: 583.0922s\n",
+ "\titers: 800, epoch: 5 | loss: 0.3931953\n",
+ "\tspeed: 0.0474s/iter; left time: 567.6057s\n",
+ "\titers: 900, epoch: 5 | loss: 0.3912098\n",
+ "\tspeed: 0.0488s/iter; left time: 580.1831s\n",
+ "\titers: 1000, epoch: 5 | loss: 0.4786185\n",
+ "\tspeed: 0.0483s/iter; left time: 568.7581s\n",
+ "\titers: 1100, epoch: 5 | loss: 0.6882768\n",
+ "\tspeed: 0.0471s/iter; left time: 550.3884s\n",
+ "Epoch: 5 cost time: 56.28184771537781\n",
+ "Epoch: 5, Steps: 1162 | Train Loss: 0.5815187 Vali Loss: 0.3013998 Test Loss: 0.4001175\n",
+ "Validation loss decreased (0.302959 --> 0.301400). Saving model ...\n",
+ "Updating learning rate to 0.0005\n"
+ ]
+ }
+ ],
+ "source": [
+ "\n",
+ "Exp = Exp_Main\n",
+ "\n",
+ "for ii in range(args.itr):\n",
+ " print(ii)\n",
+ " set_seed(ii)\n",
+ " # setting record of experiments\n",
+ " setting = '{}_{}_{}_ft{}_sl{}_pl{}_dm{}_nh{}_el{}_dl{}_df{}_K{}_lr{}_{}_{}'.format(\n",
+ " args.model_id,\n",
+ " args.model,\n",
+ " args.data,\n",
+ " args.features,\n",
+ " args.seq_len,\n",
+ " args.pred_len,\n",
+ " args.d_model,\n",
+ " args.n_heads,\n",
+ " args.e_layers,\n",
+ " args.d_layers,\n",
+ " args.d_ff,\n",
+ " args.K,\n",
+ " args.learning_rate,\n",
+ " args.des, ii)\n",
+ "\n",
+ "# if os.path.exists(os.path.join(args.checkpoints, setting)):\n",
+ "# print('skipping exists')\n",
+ "# continue\n",
+ "\n",
+ " exp = Exp(args) # set experiments\n",
+ " print('>>>>>>>start training : {}>>>>>>>>>>>>>>>>>>>>>>>>>>'.format(setting))\n",
+ " exp.train(setting)\n",
+ "\n",
+ " print('>>>>>>>testing : {}<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<'.format(setting))\n",
+ " exp.test(setting, data='val')\n",
+ " exp.test(setting, data='test')\n",
+ "\n",
+ " torch.cuda.empty_cache()\n"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": null,
+ "id": "74745d73",
+ "metadata": {},
+ "outputs": [],
+ "source": []
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 49,
+ "id": "2ed3cbf1",
+ "metadata": {
+ "ExecuteTime": {
+ "end_time": "2022-11-28T04:48:15.407951Z",
+ "start_time": "2022-11-28T04:48:13.634492Z"
+ }
+ },
+ "outputs": [
+ {
+ "name": "stdout",
+ "output_type": "stream",
+ "text": [
+ "> \u001b[0;32m/media/wassname/SGIronWolf/projects5/investing/ETSformer/models/etsformer/encoder.py\u001b[0m(155)\u001b[0;36mforward\u001b[0;34m()\u001b[0m\n",
+ "\u001b[0;32m 153 \u001b[0;31m \u001b[0mgrowth\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mgrowth\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mview\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mb\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mt\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mc_out\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0;36m1\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n",
+ "\u001b[0m\u001b[0;32m 154 \u001b[0;31m \u001b[0mseason\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mseason\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mview\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mb\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mt\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mc_out\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0;36m1\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n",
+ "\u001b[0m\u001b[0;32m--> 155 \u001b[0;31m \u001b[0mlevel\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mlevel\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mview\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mb\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mt\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mc_out\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0;36m1\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n",
+ "\u001b[0m\u001b[0;32m 156 \u001b[0;31m \u001b[0mout\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mes\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mlevel\u001b[0m \u001b[0;34m-\u001b[0m \u001b[0mseason\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0maux_values\u001b[0m\u001b[0;34m=\u001b[0m\u001b[0mgrowth\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n",
+ "\u001b[0m\u001b[0;32m 157 \u001b[0;31m \u001b[0mout\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mrearrange\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mout\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0;34m'b t h d -> b t (h d)'\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n",
+ "\u001b[0m\n",
+ "ipdb> q\n"
+ ]
+ }
+ ],
+ "source": [
+ "%debug"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": null,
+ "id": "2faf8241",
+ "metadata": {},
+ "outputs": [],
+ "source": []
+ },
+ {
+ "cell_type": "code",
+ "execution_count": null,
+ "id": "609a7018",
+ "metadata": {},
+ "outputs": [],
+ "source": []
+ }
+ ],
+ "metadata": {
+ "kernelspec": {
+ "display_name": "deeptime",
+ "language": "python",
+ "name": "deeptime"
+ },
+ "language_info": {
+ "codemirror_mode": {
+ "name": "ipython",
+ "version": 3
+ },
+ "file_extension": ".py",
+ "mimetype": "text/x-python",
+ "name": "python",
+ "nbconvert_exporter": "python",
+ "pygments_lexer": "ipython3",
+ "version": "3.8.13"
+ },
+ "toc": {
+ "base_numbering": 1,
+ "nav_menu": {},
+ "number_sections": true,
+ "sideBar": true,
+ "skip_h1_title": false,
+ "title_cell": "Table of Contents",
+ "title_sidebar": "Contents",
+ "toc_cell": false,
+ "toc_position": {},
+ "toc_section_display": true,
+ "toc_window_display": false
+ }
+ },
+ "nbformat": 4,
+ "nbformat_minor": 5
+}
diff --git a/run.py b/run.py
index a4da50f..fb17e73 100644
--- a/run.py
+++ b/run.py
@@ -13,112 +13,119 @@ def set_seed(seed):
seed += 1
torch.manual_seed(seed)
-parser = argparse.ArgumentParser(description='ETSformer: Exponential Smoothing Transformers for Time-series Forecasting')
-# basic config
-parser.add_argument('--model_id', type=str, required=True, default='test', help='model id')
-parser.add_argument('--model', type=str, required=True, default='ETSformer',
- help='model name, options: [ETSformer]')
+def get_args(argv=None):
+ parser = argparse.ArgumentParser(description='ETSformer: Exponential Smoothing Transformers for Time-series Forecasting')
-# data loader
-parser.add_argument('--data', type=str, required=True, default='ETTm1', help='dataset type')
-parser.add_argument('--root_path', type=str, default='./data/ETT/', help='root path of the data file')
-parser.add_argument('--data_path', type=str, default='ETTh1.csv', help='data file')
-parser.add_argument('--features', type=str, default='M',
- help='forecasting task, options:[M, S, MS]; M:multivariate predict multivariate, S:univariate predict univariate, MS:multivariate predict univariate')
-parser.add_argument('--target', type=str, default='OT', help='target feature in S or MS task')
-parser.add_argument('--freq', type=str, default='h',
- help='freq for time features encoding, options:[s:secondly, t:minutely, h:hourly, d:daily, b:business days, w:weekly, m:monthly], you can also use more detailed freq like 15min or 3h')
-parser.add_argument('--checkpoints', type=str, default='./checkpoints/', help='location of model checkpoints')
+ # basic config
+ parser.add_argument('--model_id', type=str, required=True, default='test', help='model id')
+ parser.add_argument('--model', type=str, required=True, default='ETSformer',
+ help='model name, options: [ETSformer]')
-# forecasting task
-parser.add_argument('--seq_len', type=int, required=True, help='input sequence length')
-parser.add_argument('--label_len', type=int, default=0, help='start token length')
-parser.add_argument('--pred_len', type=int, required=True, help='prediction sequence length')
+ # data loader
+ parser.add_argument('--data', type=str, required=True, default='ETTm1', help='dataset type')
+ parser.add_argument('--root_path', type=str, default='./data/ETT/', help='root path of the data file')
+ parser.add_argument('--data_path', type=str, default='ETTh1.csv', help='data file')
+ parser.add_argument('--features', type=str, default='M',
+ help='forecasting task, options:[M, S, MS]; M:multivariate predict multivariate, S:univariate predict univariate, MS:multivariate predict univariate')
+ parser.add_argument('--target', type=str, default='OT', help='target feature in S or MS task')
+ parser.add_argument('--freq', type=str, default='h',
+ help='freq for time features encoding, options:[s:secondly, t:minutely, h:hourly, d:daily, b:business days, w:weekly, m:monthly], you can also use more detailed freq like 15min or 3h')
+ parser.add_argument('--checkpoints', type=str, default='./checkpoints/', help='location of model checkpoints')
-# model define
-parser.add_argument('--enc_in', type=int, default=7, help='encoder input size')
-parser.add_argument('--dec_in', type=int, default=7, help='decoder input size')
-parser.add_argument('--c_out', type=int, default=7, help='output size')
-parser.add_argument('--d_model', type=int, default=512, help='dimension of model')
-parser.add_argument('--n_heads', type=int, default=8, help='num of heads')
-parser.add_argument('--e_layers', type=int, default=2, help='num of encoder layers')
-parser.add_argument('--d_layers', type=int, default=1, help='num of decoder layers')
-parser.add_argument('--d_ff', type=int, default=2048, help='dimension of fcn')
-parser.add_argument('--K', type=int, default=1, help='Top-K Fourier bases')
-parser.add_argument('--dropout', type=float, default=0.2, help='dropout')
-parser.add_argument('--embed', type=str, default='timeF',
- help='time features encoding, options:[timeF, fixed, learned]')
-parser.add_argument('--activation', type=str, default='sigmoid', help='activation')
+ # forecasting task
+ parser.add_argument('--seq_len', type=int, required=True, help='input sequence length')
+ parser.add_argument('--label_len', type=int, default=0, help='start token length')
+ parser.add_argument('--pred_len', type=int, required=True, help='prediction sequence length')
-parser.add_argument('--min_lr', type=float, default=1e-30)
-parser.add_argument('--warmup_epochs', type=int, default=3)
-parser.add_argument('--std', type=float, default=0.2)
+ # model define
+ parser.add_argument('--enc_in', type=int, default=7, help='encoder input size')
+ parser.add_argument('--dec_in', type=int, default=7, help='decoder input size')
+ parser.add_argument('--c_out', type=int, default=7, help='output size')
+ parser.add_argument('--d_model', type=int, default=512, help='dimension of model')
+ parser.add_argument('--n_heads', type=int, default=8, help='num of heads')
+ parser.add_argument('--e_layers', type=int, default=2, help='num of encoder layers')
+ parser.add_argument('--d_layers', type=int, default=1, help='num of decoder layers')
+ parser.add_argument('--d_ff', type=int, default=2048, help='dimension of fcn')
+ parser.add_argument('--K', type=int, default=1, help='Top-K Fourier bases')
+ parser.add_argument('--dropout', type=float, default=0.2, help='dropout')
+ parser.add_argument('--embed', type=str, default='timeF',
+ help='time features encoding, options:[timeF, fixed, learned]')
+ parser.add_argument('--activation', type=str, default='sigmoid', help='activation')
-parser.add_argument('--smoothing_learning_rate', type=float, default=0, help='optimizer learning rate')
-parser.add_argument('--damping_learning_rate', type=float, default=0, help='optimizer learning rate')
-parser.add_argument('--output_attention', type=bool, default=False)
+ parser.add_argument('--min_lr', type=float, default=1e-30)
+ parser.add_argument('--warmup_epochs', type=int, default=3)
+ parser.add_argument('--std', type=float, default=0.2)
-# optimization
-parser.add_argument('--optim', type=str, default='adam', help='optimizer')
-parser.add_argument('--num_workers', type=int, default=10, help='data loader num workers')
-parser.add_argument('--itr', type=int, default=1, help='experiments times')
-parser.add_argument('--train_epochs', type=int, default=15, help='train epochs')
-parser.add_argument('--batch_size', type=int, default=32, help='batch size of train input data')
-parser.add_argument('--patience', type=int, default=5, help='early stopping patience')
-parser.add_argument('--learning_rate', type=float, default=1e-4, help='optimizer learning rate')
-parser.add_argument('--des', type=str, default='test', help='exp description')
-parser.add_argument('--lradj', type=str, default='exponential_with_warmup', help='adjust learning rate')
+ parser.add_argument('--smoothing_learning_rate', type=float, default=0, help='optimizer learning rate')
+ parser.add_argument('--damping_learning_rate', type=float, default=0, help='optimizer learning rate')
+ parser.add_argument('--output_attention', type=bool, default=False)
-# GPU
-parser.add_argument('--use_gpu', type=bool, default=True, help='use gpu')
-parser.add_argument('--gpu', type=int, default=0, help='gpu')
-parser.add_argument('--use_multi_gpu', action='store_true', help='use multiple gpus', default=False)
-parser.add_argument('--devices', type=str, default='0,1,2,3', help='device ids of multile gpus')
+ # optimization
+ parser.add_argument('--optim', type=str, default='adam', help='optimizer')
+ parser.add_argument('--num_workers', type=int, default=10, help='data loader num workers')
+ parser.add_argument('--itr', type=int, default=1, help='experiments times')
+ parser.add_argument('--train_epochs', type=int, default=15, help='train epochs')
+ parser.add_argument('--batch_size', type=int, default=32, help='batch size of train input data')
+ parser.add_argument('--patience', type=int, default=5, help='early stopping patience')
+ parser.add_argument('--learning_rate', type=float, default=1e-4, help='optimizer learning rate')
+ parser.add_argument('--des', type=str, default='test', help='exp description')
+ parser.add_argument('--lradj', type=str, default='exponential_with_warmup', help='adjust learning rate')
-args = parser.parse_args()
+ # GPU
+ parser.add_argument('--use_gpu', type=bool, default=True, help='use gpu')
+ parser.add_argument('--gpu', type=int, default=0, help='gpu')
+ parser.add_argument('--use_multi_gpu', action='store_true', help='use multiple gpus', default=False)
+ parser.add_argument('--devices', type=str, default='0,1,2,3', help='device ids of multile gpus')
-args.use_gpu = True if torch.cuda.is_available() and args.use_gpu else False
+ args = parser.parse_args(argv)
-if args.use_gpu and args.use_multi_gpu:
- args.dvices = args.devices.replace(' ', '')
- device_ids = args.devices.split(',')
- args.device_ids = [int(id_) for id_ in device_ids]
- args.gpu = args.device_ids[0]
+ args.use_gpu = True if torch.cuda.is_available() and args.use_gpu else False
-print('Args in experiment:')
-print(args)
+ if args.use_gpu and args.use_multi_gpu:
+ args.dvices = args.devices.replace(' ', '')
+ device_ids = args.devices.split(',')
+ args.device_ids = [int(id_) for id_ in device_ids]
+ args.gpu = args.device_ids[0]
-Exp = Exp_Main
+ print('Args in experiment:')
+ print(args)
+ return args
-for ii in range(args.itr):
- set_seed(ii)
- # setting record of experiments
- setting = '{}_{}_{}_ft{}_sl{}_pl{}_dm{}_nh{}_el{}_dl{}_df{}_K{}_lr{}_{}_{}'.format(
- args.model_id,
- args.model,
- args.data,
- args.features,
- args.seq_len,
- args.pred_len,
- args.d_model,
- args.n_heads,
- args.e_layers,
- args.d_layers,
- args.d_ff,
- args.K,
- args.learning_rate,
- args.des, ii)
+if __name__=="__main__":
+
+ args = get_args()
- if os.path.exists(os.path.join(args.checkpoints, setting)):
- continue
+ Exp = Exp_Main
- exp = Exp(args) # set experiments
- print('>>>>>>>start training : {}>>>>>>>>>>>>>>>>>>>>>>>>>>'.format(setting))
- exp.train(setting)
+ for ii in range(args.itr):
+ set_seed(ii)
+ # setting record of experiments
+ setting = '{}_{}_{}_ft{}_sl{}_pl{}_dm{}_nh{}_el{}_dl{}_df{}_K{}_lr{}_{}_{}'.format(
+ args.model_id,
+ args.model,
+ args.data,
+ args.features,
+ args.seq_len,
+ args.pred_len,
+ args.d_model,
+ args.n_heads,
+ args.e_layers,
+ args.d_layers,
+ args.d_ff,
+ args.K,
+ args.learning_rate,
+ args.des, ii)
- print('>>>>>>>testing : {}<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<'.format(setting))
- exp.test(setting, data='val')
- exp.test(setting, data='test')
+ if os.path.exists(os.path.join(args.checkpoints, setting)):
+ continue
- torch.cuda.empty_cache()
+ exp = Exp(args) # set experiments
+ print('>>>>>>>start training : {}>>>>>>>>>>>>>>>>>>>>>>>>>>'.format(setting))
+ exp.train(setting)
+
+ print('>>>>>>>testing : {}<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<'.format(setting))
+ exp.test(setting, data='val')
+ exp.test(setting, data='test')
+
+ torch.cuda.empty_cache()
diff --git a/scripts/ECL.sh b/scripts/ECL.sh
old mode 100644
new mode 100755
diff --git a/scripts/ETTm2.sh b/scripts/ETTm2.sh
old mode 100644
new mode 100755
diff --git a/scripts/ETTm2_univar.sh b/scripts/ETTm2_univar.sh
old mode 100644
new mode 100755
diff --git a/scripts/Exchange.sh b/scripts/Exchange.sh
old mode 100644
new mode 100755
diff --git a/scripts/Exchange_univar.sh b/scripts/Exchange_univar.sh
old mode 100644
new mode 100755
diff --git a/scripts/ILI.sh b/scripts/ILI.sh
old mode 100644
new mode 100755
diff --git a/scripts/Traffic.sh b/scripts/Traffic.sh
old mode 100644
new mode 100755
diff --git a/scripts/Weather.sh b/scripts/Weather.sh
old mode 100644
new mode 100755
diff --git a/scripts/grid_search.sh b/scripts/grid_search.sh
old mode 100644
new mode 100755