mirror of
https://github.com/wassname/ETSformer.git
synced 2026-06-27 16:43:49 +08:00
461 lines
11 KiB
Plaintext
461 lines
11 KiB
Plaintext
{
|
|
"cells": [
|
|
{
|
|
"cell_type": "code",
|
|
"execution_count": 1,
|
|
"id": "0c996f31",
|
|
"metadata": {
|
|
"ExecuteTime": {
|
|
"end_time": "2022-11-28T07:35:38.923189Z",
|
|
"start_time": "2022-11-28T07:35:38.921272Z"
|
|
}
|
|
},
|
|
"outputs": [],
|
|
"source": [
|
|
"import os\n",
|
|
"os.sys.path.append('..')"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "code",
|
|
"execution_count": 2,
|
|
"id": "23a27691",
|
|
"metadata": {
|
|
"ExecuteTime": {
|
|
"end_time": "2022-11-28T07:35:38.952520Z",
|
|
"start_time": "2022-11-28T07:35:38.923991Z"
|
|
}
|
|
},
|
|
"outputs": [],
|
|
"source": [
|
|
"import warnings\n",
|
|
"warnings.simplefilter(\"ignore\")\n",
|
|
"\n",
|
|
"# autoreload import your package\n",
|
|
"%load_ext autoreload\n",
|
|
"%autoreload 2"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "code",
|
|
"execution_count": 3,
|
|
"id": "e60bfdb0",
|
|
"metadata": {
|
|
"ExecuteTime": {
|
|
"end_time": "2022-11-28T07:35:38.970942Z",
|
|
"start_time": "2022-11-28T07:35:38.953975Z"
|
|
}
|
|
},
|
|
"outputs": [],
|
|
"source": [
|
|
"from loguru import logger\n",
|
|
"logger.remove()\n",
|
|
"logger.add(os.sys.stdout, level=\"ERROR\", colorize=True, format=\"<level>{time} | {message}</level>\")\n",
|
|
"# import_dir(ta_dir, verbose=False)\n",
|
|
"warnings.simplefilter(\"ignore\")"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "code",
|
|
"execution_count": 4,
|
|
"id": "d52338d2",
|
|
"metadata": {
|
|
"ExecuteTime": {
|
|
"end_time": "2022-11-28T07:35:39.765300Z",
|
|
"start_time": "2022-11-28T07:35:38.972353Z"
|
|
}
|
|
},
|
|
"outputs": [],
|
|
"source": [
|
|
"import torch\n",
|
|
"import pandas as pd\n",
|
|
"from matplotlib import pyplot as plt\n",
|
|
"%matplotlib inline"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "markdown",
|
|
"id": "20d78524",
|
|
"metadata": {},
|
|
"source": [
|
|
"# Args"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "code",
|
|
"execution_count": 5,
|
|
"id": "750cdcc3",
|
|
"metadata": {
|
|
"ExecuteTime": {
|
|
"end_time": "2022-11-28T07:35:40.047980Z",
|
|
"start_time": "2022-11-28T07:35:39.766356Z"
|
|
}
|
|
},
|
|
"outputs": [],
|
|
"source": [
|
|
"from run import set_seed, get_args, Exp_Main"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "code",
|
|
"execution_count": null,
|
|
"id": "b3a40300",
|
|
"metadata": {
|
|
"ExecuteTime": {
|
|
"start_time": "2022-11-28T07:36:20.551Z"
|
|
}
|
|
},
|
|
"outputs": [
|
|
{
|
|
"name": "stdout",
|
|
"output_type": "stream",
|
|
"text": [
|
|
"Args in experiment:\n",
|
|
"Namespace(K=0, activation='sigmoid', batch_size=32, c_out=1, checkpoints='./checkpoints/', d_ff=2048, d_layers=2, d_model=512, damping_learning_rate=0, data='custom', data_path='OXY_2019.csv.gz', dec_in=7, des=\"'Exp'\", devices='0,1,2,3', dropout=0.2, e_layers=2, embed='timeF', enc_in=1, features='S', freq='h', gpu=0, itr=1, label_len=0, learning_rate=0.001, lradj='exponential_with_warmup', min_lr=1e-30, model='ETSformer', model_id='Exchange', n_heads=8, num_workers=0, optim='adam', output_attention=False, patience=5, pred_len=48, root_path='../dataset/stocks/', seq_len=128, smoothing_learning_rate=0, std=0.2, target='RSMKs_18_144_72_2ref_2ref', train_epochs=15, use_gpu=True, use_multi_gpu=False, warmup_epochs=3)\n",
|
|
"Use GPU: cuda:0\n",
|
|
">>>>>>>start training : Exchange_ETSformer_custom_ftS_sl128_pl48_dm512_nh8_el2_dl2_df2048_K0_lr0.001_'Exp'_0>>>>>>>>>>>>>>>>>>>>>>>>>>\n",
|
|
"train 37203\n",
|
|
"val 5294\n",
|
|
"test 10632\n"
|
|
]
|
|
},
|
|
{
|
|
"data": {
|
|
"application/vnd.jupyter.widget-view+json": {
|
|
"model_id": "60de03af10a54c008b02a13e686d7a8e",
|
|
"version_major": 2,
|
|
"version_minor": 0
|
|
},
|
|
"text/plain": [
|
|
" 0%| | 0/15 [00:00<?, ?epoch/s]"
|
|
]
|
|
},
|
|
"metadata": {},
|
|
"output_type": "display_data"
|
|
},
|
|
{
|
|
"data": {
|
|
"application/vnd.jupyter.widget-view+json": {
|
|
"model_id": "",
|
|
"version_major": 2,
|
|
"version_minor": 0
|
|
},
|
|
"text/plain": [
|
|
"train: 0%| | 0/1162 [00:00<?, ?it/s]"
|
|
]
|
|
},
|
|
"metadata": {},
|
|
"output_type": "display_data"
|
|
}
|
|
],
|
|
"source": [
|
|
"# mimic cli args to avoid code duplication\n",
|
|
"argv = \"\"\"python -u run.py \\\n",
|
|
" --root_path ../dataset/stocks/ \\\n",
|
|
" --data_path OXY_2019.csv.gz \\\n",
|
|
" --checkpoints ./checkpoints/ \\\n",
|
|
" --model_id Exchange \\\n",
|
|
" --model ETSformer \\\n",
|
|
" --data custom \\\n",
|
|
" --features S \\\n",
|
|
" --seq_len 128 \\\n",
|
|
" --pred_len 48 \\\n",
|
|
" --e_layers 2 \\\n",
|
|
" --d_layers 2 \\\n",
|
|
" --enc_in 1 \\\n",
|
|
" --c_out 1 \\\n",
|
|
" --num_workers 0 \\\n",
|
|
" --des 'Exp' \\\n",
|
|
" --K 0 \\\n",
|
|
" --learning_rate 1e-3 \\\n",
|
|
" --target RSMKs_18_144_72_2ref_2ref \\\n",
|
|
" --itr 1\n",
|
|
"\"\"\"\n",
|
|
"argv = argv.replace(\"\\\\n\", \"\").split()[3:]\n",
|
|
"args = get_args(argv)\n",
|
|
"args\n",
|
|
"\n",
|
|
"\n",
|
|
"Exp = Exp_Main\n",
|
|
"ii=0\n",
|
|
"set_seed(ii)\n",
|
|
"# setting record of experiments\n",
|
|
"setting = '{}_{}_{}_ft{}_sl{}_pl{}_dm{}_nh{}_el{}_dl{}_df{}_K{}_lr{}_{}_{}'.format(\n",
|
|
" args.model_id,\n",
|
|
" args.model,\n",
|
|
" args.data,\n",
|
|
" args.features,\n",
|
|
" args.seq_len,\n",
|
|
" args.pred_len,\n",
|
|
" args.d_model,\n",
|
|
" args.n_heads,\n",
|
|
" args.e_layers,\n",
|
|
" args.d_layers,\n",
|
|
" args.d_ff,\n",
|
|
" args.K,\n",
|
|
" args.learning_rate,\n",
|
|
" args.des, ii)\n",
|
|
"\n",
|
|
"# if os.path.exists(os.path.join(args.checkpoints, setting)):\n",
|
|
"# print('skipping exists')\n",
|
|
"# continue\n",
|
|
"\n",
|
|
"exp = Exp(args) # set experiments\n",
|
|
"print('>>>>>>>start training : {}>>>>>>>>>>>>>>>>>>>>>>>>>>'.format(setting))\n",
|
|
"exp.train(setting)\n",
|
|
"\n",
|
|
"print('>>>>>>>testing : {}<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<'.format(setting))\n",
|
|
"exp.test(setting, data='val')\n",
|
|
"exp.test(setting, data='test')\n",
|
|
"\n",
|
|
"torch.cuda.empty_cache()"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "markdown",
|
|
"id": "72811788",
|
|
"metadata": {
|
|
"ExecuteTime": {
|
|
"end_time": "2022-11-28T05:28:20.356399Z",
|
|
"start_time": "2022-11-28T05:28:20.339961Z"
|
|
}
|
|
},
|
|
"source": [
|
|
"# Plot"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "code",
|
|
"execution_count": null,
|
|
"id": "92b67ed7",
|
|
"metadata": {
|
|
"ExecuteTime": {
|
|
"start_time": "2022-11-28T07:36:25.268Z"
|
|
}
|
|
},
|
|
"outputs": [],
|
|
"source": [
|
|
"setting"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "code",
|
|
"execution_count": null,
|
|
"id": "ebea79f7",
|
|
"metadata": {
|
|
"ExecuteTime": {
|
|
"start_time": "2022-11-28T07:36:25.268Z"
|
|
}
|
|
},
|
|
"outputs": [],
|
|
"source": [
|
|
"ds, dl = exp._get_data('test')"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "code",
|
|
"execution_count": null,
|
|
"id": "9833f218",
|
|
"metadata": {
|
|
"ExecuteTime": {
|
|
"start_time": "2022-11-28T07:36:25.268Z"
|
|
}
|
|
},
|
|
"outputs": [],
|
|
"source": [
|
|
"preds, trues = exp.test(setting, data='test')"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "code",
|
|
"execution_count": null,
|
|
"id": "5c88dbf7",
|
|
"metadata": {
|
|
"ExecuteTime": {
|
|
"start_time": "2022-11-28T07:36:25.268Z"
|
|
}
|
|
},
|
|
"outputs": [],
|
|
"source": [
|
|
"i=990\n",
|
|
"(batch_x, batch_y, batch_x_mark, batch_y_mark) = ds[i]\n",
|
|
"l1 = batch_x.shape[0]\n",
|
|
"l2 = batch_y.shape[0]\n",
|
|
"plt.plot(range(l1), batch_x[:, -1], label='past')\n",
|
|
"plt.plot(range(l1, l1+l2), batch_y[:, -1], color='blue', ls='--', label='true')\n",
|
|
"plt.plot(range(l1, l1+l2), preds[i], label='pred');\n",
|
|
"plt.legend(loc='lower left')"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "code",
|
|
"execution_count": null,
|
|
"id": "0a25923b",
|
|
"metadata": {
|
|
"ExecuteTime": {
|
|
"end_time": "2022-11-28T07:33:57.703528Z",
|
|
"start_time": "2022-11-28T07:33:57.689258Z"
|
|
}
|
|
},
|
|
"outputs": [],
|
|
"source": []
|
|
},
|
|
{
|
|
"cell_type": "markdown",
|
|
"id": "665b98c3",
|
|
"metadata": {},
|
|
"source": [
|
|
"# TODO check index"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "markdown",
|
|
"id": "fc6ff0cd",
|
|
"metadata": {},
|
|
"source": [
|
|
"- s_end = index + self.seq_len\n",
|
|
"- r_begin = index + self.seq_len - self.label_len\n",
|
|
"- r_end = index + self.seq_len + self.pred_len"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "code",
|
|
"execution_count": null,
|
|
"id": "f482f601",
|
|
"metadata": {
|
|
"ExecuteTime": {
|
|
"start_time": "2022-11-28T07:36:25.268Z"
|
|
},
|
|
"scrolled": true
|
|
},
|
|
"outputs": [],
|
|
"source": [
|
|
"f = ds.root_path + '/' + ds.data_path\n",
|
|
"df = pd.read_csv(f).set_index('date', drop=False)\n",
|
|
"df = df[ds.cols[1:]]\n",
|
|
"# df[:] = ds.scaler.transform(df.values)\n",
|
|
"df\n",
|
|
"# df"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "code",
|
|
"execution_count": null,
|
|
"id": "1893beae",
|
|
"metadata": {
|
|
"ExecuteTime": {
|
|
"start_time": "2022-11-28T07:36:25.268Z"
|
|
}
|
|
},
|
|
"outputs": [],
|
|
"source": [
|
|
"from data_provider.data_loader import Dataset_Custom\n",
|
|
"ds2 = Dataset_Custom(\n",
|
|
" ds.root_path,\n",
|
|
" flag=\"test\",\n",
|
|
"# size=ds.size,\n",
|
|
" size=[ds.seq_len, ds.label_len, ds.pred_len],\n",
|
|
" features=ds.features,\n",
|
|
" data_path=ds.data_path,\n",
|
|
" target=ds.target,\n",
|
|
" scale=False,\n",
|
|
" timeenc=ds.timeenc,\n",
|
|
" freq=ds.freq\n",
|
|
")\n",
|
|
"\n",
|
|
"i=99\n",
|
|
"(batch_x, batch_y, batch_x_mark, batch_y_mark) = ds2[i]\n",
|
|
"batch_x[-1]"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "code",
|
|
"execution_count": null,
|
|
"id": "f0a63562",
|
|
"metadata": {
|
|
"ExecuteTime": {
|
|
"start_time": "2022-11-28T07:36:25.268Z"
|
|
}
|
|
},
|
|
"outputs": [],
|
|
"source": [
|
|
"dt = ds2.index.iloc[i-1]\n",
|
|
"df.loc[dt]"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "code",
|
|
"execution_count": null,
|
|
"id": "65c56ec5",
|
|
"metadata": {
|
|
"ExecuteTime": {
|
|
"start_time": "2022-11-28T07:36:25.268Z"
|
|
}
|
|
},
|
|
"outputs": [],
|
|
"source": [
|
|
"assert df.loc[dt].close == batch_x[-1, 0], 'index should be right'\n",
|
|
"'OK'"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "code",
|
|
"execution_count": null,
|
|
"id": "692aadfa",
|
|
"metadata": {},
|
|
"outputs": [],
|
|
"source": []
|
|
},
|
|
{
|
|
"cell_type": "code",
|
|
"execution_count": null,
|
|
"id": "2dd11010",
|
|
"metadata": {},
|
|
"outputs": [],
|
|
"source": []
|
|
},
|
|
{
|
|
"cell_type": "code",
|
|
"execution_count": null,
|
|
"id": "e4413025",
|
|
"metadata": {},
|
|
"outputs": [],
|
|
"source": []
|
|
}
|
|
],
|
|
"metadata": {
|
|
"kernelspec": {
|
|
"display_name": "deeptime",
|
|
"language": "python",
|
|
"name": "deeptime"
|
|
},
|
|
"language_info": {
|
|
"codemirror_mode": {
|
|
"name": "ipython",
|
|
"version": 3
|
|
},
|
|
"file_extension": ".py",
|
|
"mimetype": "text/x-python",
|
|
"name": "python",
|
|
"nbconvert_exporter": "python",
|
|
"pygments_lexer": "ipython3",
|
|
"version": "3.8.13"
|
|
},
|
|
"toc": {
|
|
"base_numbering": 1,
|
|
"nav_menu": {},
|
|
"number_sections": true,
|
|
"sideBar": true,
|
|
"skip_h1_title": false,
|
|
"title_cell": "Table of Contents",
|
|
"title_sidebar": "Contents",
|
|
"toc_cell": false,
|
|
"toc_position": {},
|
|
"toc_section_display": true,
|
|
"toc_window_display": false
|
|
}
|
|
},
|
|
"nbformat": 4,
|
|
"nbformat_minor": 5
|
|
}
|