misc bugfixes

This commit is contained in:
wassname
2022-11-28 15:38:43 +08:00
parent 0bf1606713
commit 04f19e5649
7 changed files with 946 additions and 1700 deletions
+3
View File
@@ -1,5 +1,8 @@
/notebooks/checkpoints/
/checkpoints/
/notebooks/results/
/results/
*.npy
# Byte-compiled / optimized / DLL files
__pycache__/
-13
View File
@@ -129,7 +129,6 @@ class Exp_Main(Exp_Basic):
train_loss = []
self.model.train()
epoch_time = time.time()
for i, (batch_x, batch_y, batch_x_mark, batch_y_mark) in enumerate(tqdm(train_loader, leave=False, desc='train')):
iter_count += 1
model_optim.zero_grad()
@@ -151,22 +150,10 @@ class Exp_Main(Exp_Basic):
loss = criterion(outputs, batch_y)
train_loss.append(loss.item())
# if (i + 1) % 10 == 0:
# print("\titers: {0}, epoch: {1} | loss: {2:.7f}".format(i + 1, epoch + 1, loss.item()))
# speed = (time.time() - time_now) / iter_count
# left_time = speed * ((self.args.train_epochs - epoch) * train_steps - i)
# print('\tspeed: {:.4f}s/iter; left time: {:.4f}s'.format(speed, left_time))
# iter_count = 0
# time_now = time.time()
# p.desc = f'loss: {loss.item():.7f}'
loss.backward()
torch.nn.utils.clip_grad_norm(self.model.parameters(), 1.0)
model_optim.step()
print("Epoch: {} cost time: {}".format(epoch + 1, time.time() - epoch_time))
train_loss = np.average(train_loss)
vali_loss = self.vali(vali_data, vali_loader, criterion)
test_loss = self.vali(test_data, test_loader, criterion)
+460
View File
@@ -0,0 +1,460 @@
{
"cells": [
{
"cell_type": "code",
"execution_count": 1,
"id": "0c996f31",
"metadata": {
"ExecuteTime": {
"end_time": "2022-11-28T07:35:38.923189Z",
"start_time": "2022-11-28T07:35:38.921272Z"
}
},
"outputs": [],
"source": [
"import os\n",
"os.sys.path.append('..')"
]
},
{
"cell_type": "code",
"execution_count": 2,
"id": "23a27691",
"metadata": {
"ExecuteTime": {
"end_time": "2022-11-28T07:35:38.952520Z",
"start_time": "2022-11-28T07:35:38.923991Z"
}
},
"outputs": [],
"source": [
"import warnings\n",
"warnings.simplefilter(\"ignore\")\n",
"\n",
"# autoreload import your package\n",
"%load_ext autoreload\n",
"%autoreload 2"
]
},
{
"cell_type": "code",
"execution_count": 3,
"id": "e60bfdb0",
"metadata": {
"ExecuteTime": {
"end_time": "2022-11-28T07:35:38.970942Z",
"start_time": "2022-11-28T07:35:38.953975Z"
}
},
"outputs": [],
"source": [
"from loguru import logger\n",
"logger.remove()\n",
"logger.add(os.sys.stdout, level=\"ERROR\", colorize=True, format=\"<level>{time} | {message}</level>\")\n",
"# import_dir(ta_dir, verbose=False)\n",
"warnings.simplefilter(\"ignore\")"
]
},
{
"cell_type": "code",
"execution_count": 4,
"id": "d52338d2",
"metadata": {
"ExecuteTime": {
"end_time": "2022-11-28T07:35:39.765300Z",
"start_time": "2022-11-28T07:35:38.972353Z"
}
},
"outputs": [],
"source": [
"import torch\n",
"import pandas as pd\n",
"from matplotlib import pyplot as plt\n",
"%matplotlib inline"
]
},
{
"cell_type": "markdown",
"id": "20d78524",
"metadata": {},
"source": [
"# Args"
]
},
{
"cell_type": "code",
"execution_count": 5,
"id": "750cdcc3",
"metadata": {
"ExecuteTime": {
"end_time": "2022-11-28T07:35:40.047980Z",
"start_time": "2022-11-28T07:35:39.766356Z"
}
},
"outputs": [],
"source": [
"from run import set_seed, get_args, Exp_Main"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "b3a40300",
"metadata": {
"ExecuteTime": {
"start_time": "2022-11-28T07:36:20.551Z"
}
},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"Args in experiment:\n",
"Namespace(K=0, activation='sigmoid', batch_size=32, c_out=1, checkpoints='./checkpoints/', d_ff=2048, d_layers=2, d_model=512, damping_learning_rate=0, data='custom', data_path='OXY_2019.csv.gz', dec_in=7, des=\"'Exp'\", devices='0,1,2,3', dropout=0.2, e_layers=2, embed='timeF', enc_in=1, features='S', freq='h', gpu=0, itr=1, label_len=0, learning_rate=0.001, lradj='exponential_with_warmup', min_lr=1e-30, model='ETSformer', model_id='Exchange', n_heads=8, num_workers=0, optim='adam', output_attention=False, patience=5, pred_len=48, root_path='../dataset/stocks/', seq_len=128, smoothing_learning_rate=0, std=0.2, target='RSMKs_18_144_72_2ref_2ref', train_epochs=15, use_gpu=True, use_multi_gpu=False, warmup_epochs=3)\n",
"Use GPU: cuda:0\n",
">>>>>>>start training : Exchange_ETSformer_custom_ftS_sl128_pl48_dm512_nh8_el2_dl2_df2048_K0_lr0.001_'Exp'_0>>>>>>>>>>>>>>>>>>>>>>>>>>\n",
"train 37203\n",
"val 5294\n",
"test 10632\n"
]
},
{
"data": {
"application/vnd.jupyter.widget-view+json": {
"model_id": "60de03af10a54c008b02a13e686d7a8e",
"version_major": 2,
"version_minor": 0
},
"text/plain": [
" 0%| | 0/15 [00:00<?, ?epoch/s]"
]
},
"metadata": {},
"output_type": "display_data"
},
{
"data": {
"application/vnd.jupyter.widget-view+json": {
"model_id": "",
"version_major": 2,
"version_minor": 0
},
"text/plain": [
"train: 0%| | 0/1162 [00:00<?, ?it/s]"
]
},
"metadata": {},
"output_type": "display_data"
}
],
"source": [
"# mimic cli args to avoid code duplication\n",
"argv = \"\"\"python -u run.py \\\n",
" --root_path ../dataset/stocks/ \\\n",
" --data_path OXY_2019.csv.gz \\\n",
" --checkpoints ./checkpoints/ \\\n",
" --model_id Exchange \\\n",
" --model ETSformer \\\n",
" --data custom \\\n",
" --features S \\\n",
" --seq_len 128 \\\n",
" --pred_len 48 \\\n",
" --e_layers 2 \\\n",
" --d_layers 2 \\\n",
" --enc_in 1 \\\n",
" --c_out 1 \\\n",
" --num_workers 0 \\\n",
" --des 'Exp' \\\n",
" --K 0 \\\n",
" --learning_rate 1e-3 \\\n",
" --target RSMKs_18_144_72_2ref_2ref \\\n",
" --itr 1\n",
"\"\"\"\n",
"argv = argv.replace(\"\\\\n\", \"\").split()[3:]\n",
"args = get_args(argv)\n",
"args\n",
"\n",
"\n",
"Exp = Exp_Main\n",
"ii=0\n",
"set_seed(ii)\n",
"# setting record of experiments\n",
"setting = '{}_{}_{}_ft{}_sl{}_pl{}_dm{}_nh{}_el{}_dl{}_df{}_K{}_lr{}_{}_{}'.format(\n",
" args.model_id,\n",
" args.model,\n",
" args.data,\n",
" args.features,\n",
" args.seq_len,\n",
" args.pred_len,\n",
" args.d_model,\n",
" args.n_heads,\n",
" args.e_layers,\n",
" args.d_layers,\n",
" args.d_ff,\n",
" args.K,\n",
" args.learning_rate,\n",
" args.des, ii)\n",
"\n",
"# if os.path.exists(os.path.join(args.checkpoints, setting)):\n",
"# print('skipping exists')\n",
"# continue\n",
"\n",
"exp = Exp(args) # set experiments\n",
"print('>>>>>>>start training : {}>>>>>>>>>>>>>>>>>>>>>>>>>>'.format(setting))\n",
"exp.train(setting)\n",
"\n",
"print('>>>>>>>testing : {}<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<'.format(setting))\n",
"exp.test(setting, data='val')\n",
"exp.test(setting, data='test')\n",
"\n",
"torch.cuda.empty_cache()"
]
},
{
"cell_type": "markdown",
"id": "72811788",
"metadata": {
"ExecuteTime": {
"end_time": "2022-11-28T05:28:20.356399Z",
"start_time": "2022-11-28T05:28:20.339961Z"
}
},
"source": [
"# Plot"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "92b67ed7",
"metadata": {
"ExecuteTime": {
"start_time": "2022-11-28T07:36:25.268Z"
}
},
"outputs": [],
"source": [
"setting"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "ebea79f7",
"metadata": {
"ExecuteTime": {
"start_time": "2022-11-28T07:36:25.268Z"
}
},
"outputs": [],
"source": [
"ds, dl = exp._get_data('test')"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "9833f218",
"metadata": {
"ExecuteTime": {
"start_time": "2022-11-28T07:36:25.268Z"
}
},
"outputs": [],
"source": [
"preds, trues = exp.test(setting, data='test')"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "5c88dbf7",
"metadata": {
"ExecuteTime": {
"start_time": "2022-11-28T07:36:25.268Z"
}
},
"outputs": [],
"source": [
"i=990\n",
"(batch_x, batch_y, batch_x_mark, batch_y_mark) = ds[i]\n",
"l1 = batch_x.shape[0]\n",
"l2 = batch_y.shape[0]\n",
"plt.plot(range(l1), batch_x[:, -1], label='past')\n",
"plt.plot(range(l1, l1+l2), batch_y[:, -1], color='blue', ls='--', label='true')\n",
"plt.plot(range(l1, l1+l2), preds[i], label='pred');\n",
"plt.legend(loc='lower left')"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "0a25923b",
"metadata": {
"ExecuteTime": {
"end_time": "2022-11-28T07:33:57.703528Z",
"start_time": "2022-11-28T07:33:57.689258Z"
}
},
"outputs": [],
"source": []
},
{
"cell_type": "markdown",
"id": "665b98c3",
"metadata": {},
"source": [
"# TODO check index"
]
},
{
"cell_type": "markdown",
"id": "fc6ff0cd",
"metadata": {},
"source": [
"- s_end = index + self.seq_len\n",
"- r_begin = index + self.seq_len - self.label_len\n",
"- r_end = index + self.seq_len + self.pred_len"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "f482f601",
"metadata": {
"ExecuteTime": {
"start_time": "2022-11-28T07:36:25.268Z"
},
"scrolled": true
},
"outputs": [],
"source": [
"f = ds.root_path + '/' + ds.data_path\n",
"df = pd.read_csv(f).set_index('date', drop=False)\n",
"df = df[ds.cols[1:]]\n",
"# df[:] = ds.scaler.transform(df.values)\n",
"df\n",
"# df"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "1893beae",
"metadata": {
"ExecuteTime": {
"start_time": "2022-11-28T07:36:25.268Z"
}
},
"outputs": [],
"source": [
"from data_provider.data_loader import Dataset_Custom\n",
"ds2 = Dataset_Custom(\n",
" ds.root_path,\n",
" flag=\"test\",\n",
"# size=ds.size,\n",
" size=[ds.seq_len, ds.label_len, ds.pred_len],\n",
" features=ds.features,\n",
" data_path=ds.data_path,\n",
" target=ds.target,\n",
" scale=False,\n",
" timeenc=ds.timeenc,\n",
" freq=ds.freq\n",
")\n",
"\n",
"i=99\n",
"(batch_x, batch_y, batch_x_mark, batch_y_mark) = ds2[i]\n",
"batch_x[-1]"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "f0a63562",
"metadata": {
"ExecuteTime": {
"start_time": "2022-11-28T07:36:25.268Z"
}
},
"outputs": [],
"source": [
"dt = ds2.index.iloc[i-1]\n",
"df.loc[dt]"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "65c56ec5",
"metadata": {
"ExecuteTime": {
"start_time": "2022-11-28T07:36:25.268Z"
}
},
"outputs": [],
"source": [
"assert df.loc[dt].close == batch_x[-1, 0], 'index should be right'\n",
"'OK'"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "692aadfa",
"metadata": {},
"outputs": [],
"source": []
},
{
"cell_type": "code",
"execution_count": null,
"id": "2dd11010",
"metadata": {},
"outputs": [],
"source": []
},
{
"cell_type": "code",
"execution_count": null,
"id": "e4413025",
"metadata": {},
"outputs": [],
"source": []
}
],
"metadata": {
"kernelspec": {
"display_name": "deeptime",
"language": "python",
"name": "deeptime"
},
"language_info": {
"codemirror_mode": {
"name": "ipython",
"version": 3
},
"file_extension": ".py",
"mimetype": "text/x-python",
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.8.13"
},
"toc": {
"base_numbering": 1,
"nav_menu": {},
"number_sections": true,
"sideBar": true,
"skip_h1_title": false,
"title_cell": "Table of Contents",
"title_sidebar": "Contents",
"toc_cell": false,
"toc_position": {},
"toc_section_display": true,
"toc_window_display": false
}
},
"nbformat": 4,
"nbformat_minor": 5
}
File diff suppressed because one or more lines are too long
+483
View File
@@ -0,0 +1,483 @@
{
"cells": [
{
"cell_type": "code",
"execution_count": 1,
"id": "0658b47c",
"metadata": {
"ExecuteTime": {
"end_time": "2022-11-28T07:36:13.547954Z",
"start_time": "2022-11-28T07:36:13.546254Z"
}
},
"outputs": [],
"source": [
"import os\n",
"os.sys.path.append('..')"
]
},
{
"cell_type": "code",
"execution_count": 2,
"id": "6257d783",
"metadata": {
"ExecuteTime": {
"end_time": "2022-11-28T07:36:13.578925Z",
"start_time": "2022-11-28T07:36:13.548651Z"
}
},
"outputs": [],
"source": [
"import warnings\n",
"warnings.simplefilter(\"ignore\")\n",
"\n",
"# autoreload import your package\n",
"%load_ext autoreload\n",
"%autoreload 2"
]
},
{
"cell_type": "code",
"execution_count": 3,
"id": "9e215049",
"metadata": {
"ExecuteTime": {
"end_time": "2022-11-28T07:36:13.597024Z",
"start_time": "2022-11-28T07:36:13.580545Z"
}
},
"outputs": [],
"source": [
"from loguru import logger\n",
"logger.remove()\n",
"logger.add(os.sys.stdout, level=\"ERROR\", colorize=True, format=\"<level>{time} | {message}</level>\")\n",
"# import_dir(ta_dir, verbose=False)\n",
"warnings.simplefilter(\"ignore\")"
]
},
{
"cell_type": "code",
"execution_count": 4,
"id": "d12e975b",
"metadata": {
"ExecuteTime": {
"end_time": "2022-11-28T07:36:14.269826Z",
"start_time": "2022-11-28T07:36:13.597868Z"
}
},
"outputs": [],
"source": [
"import torch\n",
"import pandas as pd\n",
"from matplotlib import pyplot as plt\n",
"%matplotlib inline"
]
},
{
"cell_type": "markdown",
"id": "26c0133f",
"metadata": {},
"source": [
"# Args"
]
},
{
"cell_type": "code",
"execution_count": 5,
"id": "13445106",
"metadata": {
"ExecuteTime": {
"end_time": "2022-11-28T07:36:14.530041Z",
"start_time": "2022-11-28T07:36:14.270844Z"
}
},
"outputs": [],
"source": [
"from run import set_seed, get_args, Exp_Main"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "7c9a419a",
"metadata": {
"ExecuteTime": {
"start_time": "2022-11-28T07:36:13.551Z"
}
},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"Args in experiment:\n",
"Namespace(K=0, activation='sigmoid', batch_size=32, c_out=1, checkpoints='./checkpoints/', d_ff=2048, d_layers=2, d_model=512, damping_learning_rate=0, data='custom', data_path='OXY_2019.csv.gz', dec_in=7, des=\"'Exp'\", devices='0,1,2,3', dropout=0.2, e_layers=2, embed='timeF', enc_in=12, features='M', freq='h', gpu=0, itr=1, label_len=0, learning_rate=0.0001, lradj='exponential_with_warmup', min_lr=1e-30, model='ETSformer', model_id='Exchange', n_heads=8, num_workers=0, optim='adam', output_attention=False, patience=5, pred_len=48, root_path='../dataset/stocks/', seq_len=128, smoothing_learning_rate=0, std=0.2, target='RSMKs_18_144_72_2ref_2ref', train_epochs=15, use_gpu=True, use_multi_gpu=False, warmup_epochs=3)\n",
"Use GPU: cuda:0\n",
">>>>>>>start training : Exchange_ETSformer_custom_ftM_sl128_pl48_dm512_nh8_el2_dl2_df2048_K0_lr0.0001_'Exp'_0>>>>>>>>>>>>>>>>>>>>>>>>>>\n",
"train 37203\n",
"val 5294\n",
"test 10632\n"
]
},
{
"data": {
"application/vnd.jupyter.widget-view+json": {
"model_id": "f7c396e54a674bb5abfb80976d10fe3c",
"version_major": 2,
"version_minor": 0
},
"text/plain": [
" 0%| | 0/15 [00:00<?, ?epoch/s]"
]
},
"metadata": {},
"output_type": "display_data"
},
{
"data": {
"application/vnd.jupyter.widget-view+json": {
"model_id": "",
"version_major": 2,
"version_minor": 0
},
"text/plain": [
"train: 0%| | 0/1162 [00:00<?, ?it/s]"
]
},
"metadata": {},
"output_type": "display_data"
},
{
"name": "stdout",
"output_type": "stream",
"text": [
"Epoch: 1, Steps: 1162 | Train Loss: 0.2852713 Vali Loss: 0.0847221 Test Loss: 0.1228571\n",
"Validation loss decreased (inf --> 0.084722). Saving model ...\n",
"Updating learning rate to 2.5e-05\n"
]
},
{
"data": {
"application/vnd.jupyter.widget-view+json": {
"model_id": "e7d71c496b224619b4bae1c73e47c68d",
"version_major": 2,
"version_minor": 0
},
"text/plain": [
"train: 0%| | 0/1162 [00:00<?, ?it/s]"
]
},
"metadata": {},
"output_type": "display_data"
}
],
"source": [
"# mimic cli args to avoid code duplication\n",
"argv = \"\"\"python -u run.py \\\n",
" --root_path ../dataset/stocks/ \\\n",
" --data_path OXY_2019.csv.gz \\\n",
" --checkpoints ./checkpoints/ \\\n",
" --model_id Exchange \\\n",
" --model ETSformer \\\n",
" --data custom \\\n",
" --features M \\\n",
" --seq_len 128 \\\n",
" --pred_len 48 \\\n",
" --e_layers 2 \\\n",
" --d_layers 2 \\\n",
" --enc_in 12 \\\n",
" --c_out 1 \\\n",
" --num_workers 0 \\\n",
" --des 'Exp' \\\n",
" --K 0 \\\n",
" --learning_rate 1e-4 \\\n",
" --target RSMKs_18_144_72_2ref_2ref \\\n",
" --itr 1\n",
"\"\"\"\n",
"argv = argv.replace(\"\\\\n\", \"\").split()[3:]\n",
"args = get_args(argv)\n",
"args\n",
"\n",
"\n",
"Exp = Exp_Main\n",
"ii=0\n",
"set_seed(ii)\n",
"# setting record of experiments\n",
"setting = '{}_{}_{}_ft{}_sl{}_pl{}_dm{}_nh{}_el{}_dl{}_df{}_K{}_lr{}_{}_{}'.format(\n",
" args.model_id,\n",
" args.model,\n",
" args.data,\n",
" args.features,\n",
" args.seq_len,\n",
" args.pred_len,\n",
" args.d_model,\n",
" args.n_heads,\n",
" args.e_layers,\n",
" args.d_layers,\n",
" args.d_ff,\n",
" args.K,\n",
" args.learning_rate,\n",
" args.des, ii)\n",
"\n",
"# if os.path.exists(os.path.join(args.checkpoints, setting)):\n",
"# print('skipping exists')\n",
"# continue\n",
"\n",
"exp = Exp(args) # set experiments\n",
"print('>>>>>>>start training : {}>>>>>>>>>>>>>>>>>>>>>>>>>>'.format(setting))\n",
"exp.train(setting)\n",
"\n",
"print('>>>>>>>testing : {}<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<'.format(setting))\n",
"exp.test(setting, data='val')\n",
"exp.test(setting, data='test')\n",
"\n",
"torch.cuda.empty_cache()\n"
]
},
{
"cell_type": "markdown",
"id": "c9142164",
"metadata": {
"ExecuteTime": {
"end_time": "2022-11-28T05:28:20.356399Z",
"start_time": "2022-11-28T05:28:20.339961Z"
}
},
"source": [
"# Plot"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "7636db93",
"metadata": {
"ExecuteTime": {
"start_time": "2022-11-28T07:36:13.551Z"
}
},
"outputs": [],
"source": [
"setting"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "0f0476ef",
"metadata": {
"ExecuteTime": {
"start_time": "2022-11-28T07:36:13.551Z"
}
},
"outputs": [],
"source": [
"ds, dl = exp._get_data('test')"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "5d8e2c75",
"metadata": {
"ExecuteTime": {
"start_time": "2022-11-28T07:36:13.551Z"
}
},
"outputs": [],
"source": [
"preds, trues = exp.test(setting, data='test')"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "6fe2952f",
"metadata": {
"ExecuteTime": {
"start_time": "2022-11-28T07:36:13.551Z"
}
},
"outputs": [],
"source": [
"i=990\n",
"(batch_x, batch_y, batch_x_mark, batch_y_mark) = ds[i]\n",
"l1 = batch_x.shape[0]\n",
"l2 = batch_y.shape[0]\n",
"plt.plot(range(l1), batch_x[:, -1], label='past')\n",
"plt.plot(range(l1, l1+l2), batch_y[:, -1], color='blue', ls='--', label='true')\n",
"plt.plot(range(l1, l1+l2), preds[i], label='pred');\n",
"plt.legend(loc='lower left')"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "4e501b7a",
"metadata": {
"ExecuteTime": {
"end_time": "2022-11-28T07:33:57.703528Z",
"start_time": "2022-11-28T07:33:57.689258Z"
}
},
"outputs": [],
"source": []
},
{
"cell_type": "markdown",
"id": "6e6f02ba",
"metadata": {},
"source": [
"# TODO check index"
]
},
{
"cell_type": "markdown",
"id": "54e1b101",
"metadata": {},
"source": [
"- s_end = index + self.seq_len\n",
"- r_begin = index + self.seq_len - self.label_len\n",
"- r_end = index + self.seq_len + self.pred_len"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "8075058e",
"metadata": {
"ExecuteTime": {
"start_time": "2022-11-28T07:36:13.568Z"
},
"scrolled": true
},
"outputs": [],
"source": [
"f = ds.root_path + '/' + ds.data_path\n",
"df = pd.read_csv(f).set_index('date', drop=False)\n",
"df = df[ds.cols[1:]]\n",
"# df[:] = ds.scaler.transform(df.values)\n",
"df\n",
"# df"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "08f288f6",
"metadata": {
"ExecuteTime": {
"start_time": "2022-11-28T07:36:13.568Z"
}
},
"outputs": [],
"source": [
"from data_provider.data_loader import Dataset_Custom\n",
"ds2 = Dataset_Custom(\n",
" ds.root_path,\n",
" flag=\"test\",\n",
"# size=ds.size,\n",
" size=[ds.seq_len, ds.label_len, ds.pred_len],\n",
" features=ds.features,\n",
" data_path=ds.data_path,\n",
" target=ds.target,\n",
" scale=False,\n",
" timeenc=ds.timeenc,\n",
" freq=ds.freq\n",
")\n",
"\n",
"i=99\n",
"(batch_x, batch_y, batch_x_mark, batch_y_mark) = ds2[i]\n",
"batch_x[-1]"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "8ea2be5f",
"metadata": {
"ExecuteTime": {
"start_time": "2022-11-28T07:36:13.568Z"
}
},
"outputs": [],
"source": [
"dt = ds2.index.iloc[i-1]\n",
"df.loc[dt]"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "07e898c8",
"metadata": {
"ExecuteTime": {
"start_time": "2022-11-28T07:36:13.568Z"
}
},
"outputs": [],
"source": [
"assert df.loc[dt].close == batch_x[-1, 0], 'index should be right'\n",
"'OK'"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "bf1b0a6d",
"metadata": {},
"outputs": [],
"source": []
},
{
"cell_type": "code",
"execution_count": null,
"id": "6039c3f0",
"metadata": {},
"outputs": [],
"source": []
},
{
"cell_type": "code",
"execution_count": null,
"id": "368b0ace",
"metadata": {},
"outputs": [],
"source": []
}
],
"metadata": {
"kernelspec": {
"display_name": "deeptime",
"language": "python",
"name": "deeptime"
},
"language_info": {
"codemirror_mode": {
"name": "ipython",
"version": 3
},
"file_extension": ".py",
"mimetype": "text/x-python",
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.8.13"
},
"toc": {
"base_numbering": 1,
"nav_menu": {},
"number_sections": true,
"sideBar": true,
"skip_h1_title": false,
"title_cell": "Table of Contents",
"title_sidebar": "Contents",
"toc_cell": false,
"toc_position": {},
"toc_section_display": true,
"toc_window_display": false
}
},
"nbformat": 4,
"nbformat_minor": 5
}