mirror of
https://github.com/wassname/ETSformer.git
synced 2026-06-27 15:13:42 +08:00
nbs
This commit is contained in:
@@ -163,6 +163,9 @@ class Exp_Main(Exp_Basic):
|
||||
print("Epoch: {0}, Steps: {1} | Train Loss: {2:.7f} Vali Loss: {3:.7f} Test Loss: {4:.7f}".format(
|
||||
epoch + 1, train_steps, train_loss, vali_loss, test_loss))
|
||||
early_stopping(vali_loss, self.model, path)
|
||||
|
||||
|
||||
open(path + '/metrics.csv', 'a').write("{train_loss},{vali_loss},{test_loss}\n")
|
||||
if early_stopping.early_stop:
|
||||
print("Early stopping")
|
||||
break
|
||||
|
||||
@@ -0,0 +1,680 @@
|
||||
{
|
||||
"cells": [
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 1,
|
||||
"id": "0c996f31",
|
||||
"metadata": {
|
||||
"ExecuteTime": {
|
||||
"end_time": "2022-11-28T10:34:55.581023Z",
|
||||
"start_time": "2022-11-28T10:34:55.579242Z"
|
||||
}
|
||||
},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"import os\n",
|
||||
"os.sys.path.append('..')"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 2,
|
||||
"id": "23a27691",
|
||||
"metadata": {
|
||||
"ExecuteTime": {
|
||||
"end_time": "2022-11-28T10:34:55.611369Z",
|
||||
"start_time": "2022-11-28T10:34:55.581779Z"
|
||||
}
|
||||
},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"import warnings\n",
|
||||
"warnings.simplefilter(\"ignore\")\n",
|
||||
"\n",
|
||||
"# autoreload import your package\n",
|
||||
"%load_ext autoreload\n",
|
||||
"%autoreload 2"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 3,
|
||||
"id": "e60bfdb0",
|
||||
"metadata": {
|
||||
"ExecuteTime": {
|
||||
"end_time": "2022-11-28T10:34:55.633532Z",
|
||||
"start_time": "2022-11-28T10:34:55.612568Z"
|
||||
}
|
||||
},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"from loguru import logger\n",
|
||||
"logger.remove()\n",
|
||||
"logger.add(os.sys.stdout, level=\"ERROR\", colorize=True, format=\"<level>{time} | {message}</level>\")\n",
|
||||
"# import_dir(ta_dir, verbose=False)\n",
|
||||
"warnings.simplefilter(\"ignore\")"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 4,
|
||||
"id": "d52338d2",
|
||||
"metadata": {
|
||||
"ExecuteTime": {
|
||||
"end_time": "2022-11-28T10:34:56.339147Z",
|
||||
"start_time": "2022-11-28T10:34:55.634385Z"
|
||||
}
|
||||
},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"import torch\n",
|
||||
"import pandas as pd\n",
|
||||
"from matplotlib import pyplot as plt\n",
|
||||
"%matplotlib inline"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"id": "20d78524",
|
||||
"metadata": {},
|
||||
"source": [
|
||||
"# Args"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 5,
|
||||
"id": "750cdcc3",
|
||||
"metadata": {
|
||||
"ExecuteTime": {
|
||||
"end_time": "2022-11-28T10:34:56.600001Z",
|
||||
"start_time": "2022-11-28T10:34:56.340104Z"
|
||||
}
|
||||
},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"from run import set_seed, get_args, Exp_Main"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"id": "b3a40300",
|
||||
"metadata": {
|
||||
"ExecuteTime": {
|
||||
"start_time": "2022-11-28T10:35:02.166Z"
|
||||
}
|
||||
},
|
||||
"outputs": [
|
||||
{
|
||||
"name": "stdout",
|
||||
"output_type": "stream",
|
||||
"text": [
|
||||
"Args in experiment:\n",
|
||||
"Namespace(K=0, activation='sigmoid', batch_size=32, c_out=1, checkpoints='./checkpoints/', d_ff=2048, d_layers=2, d_model=512, damping_learning_rate=0, data='custom', data_path='exchange_rate.csv', dec_in=7, des=\"'Exp'\", devices='0,1,2,3', dropout=0.2, e_layers=2, embed='timeF', enc_in=1, features='S', freq='h', gpu=0, itr=1, label_len=0, learning_rate=0.0003, lradj='exponential_with_warmup', min_lr=1e-30, model='ETSformer', model_id='Exchange', n_heads=16, num_workers=0, optim='adam', output_attention=False, patience=5, pred_len=96, root_path='../dataset/exchange_rate/', seq_len=512, smoothing_learning_rate=0, std=0.2, target='OT', train_epochs=15, use_gpu=True, use_multi_gpu=False, warmup_epochs=3)\n",
|
||||
"Use GPU: cuda:0\n",
|
||||
">>>>>>>start training : Exchange_ETSformer_custom_ftS_sl512_pl96_dm512_nh16_el2_dl2_df2048_K0_lr0.0003_'Exp'_0>>>>>>>>>>>>>>>>>>>>>>>>>>\n",
|
||||
"train 4704\n",
|
||||
"val 665\n",
|
||||
"test 1422\n"
|
||||
]
|
||||
},
|
||||
{
|
||||
"data": {
|
||||
"application/vnd.jupyter.widget-view+json": {
|
||||
"model_id": "db46d457567743ffb1475b58fba74214",
|
||||
"version_major": 2,
|
||||
"version_minor": 0
|
||||
},
|
||||
"text/plain": [
|
||||
" 0%| | 0/15 [00:00<?, ?epoch/s]"
|
||||
]
|
||||
},
|
||||
"metadata": {},
|
||||
"output_type": "display_data"
|
||||
},
|
||||
{
|
||||
"data": {
|
||||
"application/vnd.jupyter.widget-view+json": {
|
||||
"model_id": "",
|
||||
"version_major": 2,
|
||||
"version_minor": 0
|
||||
},
|
||||
"text/plain": [
|
||||
"train: 0%| | 0/147 [00:00<?, ?it/s]"
|
||||
]
|
||||
},
|
||||
"metadata": {},
|
||||
"output_type": "display_data"
|
||||
},
|
||||
{
|
||||
"name": "stdout",
|
||||
"output_type": "stream",
|
||||
"text": [
|
||||
"Epoch: 1, Steps: 147 | Train Loss: 0.3748532 Vali Loss: 0.1615645 Test Loss: 0.0973958\n",
|
||||
"Validation loss decreased (inf --> 0.161564). Saving model ...\n",
|
||||
"Updating learning rate to 7.5e-05\n"
|
||||
]
|
||||
},
|
||||
{
|
||||
"data": {
|
||||
"application/vnd.jupyter.widget-view+json": {
|
||||
"model_id": "",
|
||||
"version_major": 2,
|
||||
"version_minor": 0
|
||||
},
|
||||
"text/plain": [
|
||||
"train: 0%| | 0/147 [00:00<?, ?it/s]"
|
||||
]
|
||||
},
|
||||
"metadata": {},
|
||||
"output_type": "display_data"
|
||||
},
|
||||
{
|
||||
"name": "stdout",
|
||||
"output_type": "stream",
|
||||
"text": [
|
||||
"Epoch: 2, Steps: 147 | Train Loss: 0.2502263 Vali Loss: 0.1444328 Test Loss: 0.0987149\n",
|
||||
"Validation loss decreased (0.161564 --> 0.144433). Saving model ...\n",
|
||||
"Updating learning rate to 0.00015\n"
|
||||
]
|
||||
},
|
||||
{
|
||||
"data": {
|
||||
"application/vnd.jupyter.widget-view+json": {
|
||||
"model_id": "",
|
||||
"version_major": 2,
|
||||
"version_minor": 0
|
||||
},
|
||||
"text/plain": [
|
||||
"train: 0%| | 0/147 [00:00<?, ?it/s]"
|
||||
]
|
||||
},
|
||||
"metadata": {},
|
||||
"output_type": "display_data"
|
||||
},
|
||||
{
|
||||
"name": "stdout",
|
||||
"output_type": "stream",
|
||||
"text": [
|
||||
"Epoch: 3, Steps: 147 | Train Loss: 0.2464321 Vali Loss: 0.1438236 Test Loss: 0.0948109\n",
|
||||
"Validation loss decreased (0.144433 --> 0.143824). Saving model ...\n",
|
||||
"Updating learning rate to 0.000225\n"
|
||||
]
|
||||
},
|
||||
{
|
||||
"data": {
|
||||
"application/vnd.jupyter.widget-view+json": {
|
||||
"model_id": "",
|
||||
"version_major": 2,
|
||||
"version_minor": 0
|
||||
},
|
||||
"text/plain": [
|
||||
"train: 0%| | 0/147 [00:00<?, ?it/s]"
|
||||
]
|
||||
},
|
||||
"metadata": {},
|
||||
"output_type": "display_data"
|
||||
},
|
||||
{
|
||||
"name": "stdout",
|
||||
"output_type": "stream",
|
||||
"text": [
|
||||
"Epoch: 4, Steps: 147 | Train Loss: 0.2674658 Vali Loss: 0.1382710 Test Loss: 0.1000110\n",
|
||||
"Validation loss decreased (0.143824 --> 0.138271). Saving model ...\n",
|
||||
"Updating learning rate to 0.0003\n"
|
||||
]
|
||||
},
|
||||
{
|
||||
"data": {
|
||||
"application/vnd.jupyter.widget-view+json": {
|
||||
"model_id": "",
|
||||
"version_major": 2,
|
||||
"version_minor": 0
|
||||
},
|
||||
"text/plain": [
|
||||
"train: 0%| | 0/147 [00:00<?, ?it/s]"
|
||||
]
|
||||
},
|
||||
"metadata": {},
|
||||
"output_type": "display_data"
|
||||
},
|
||||
{
|
||||
"name": "stdout",
|
||||
"output_type": "stream",
|
||||
"text": [
|
||||
"Epoch: 5, Steps: 147 | Train Loss: 0.2459421 Vali Loss: 0.1538529 Test Loss: 0.0921839\n",
|
||||
"EarlyStopping counter: 1 out of 5\n",
|
||||
"Updating learning rate to 0.00015\n"
|
||||
]
|
||||
},
|
||||
{
|
||||
"data": {
|
||||
"application/vnd.jupyter.widget-view+json": {
|
||||
"model_id": "",
|
||||
"version_major": 2,
|
||||
"version_minor": 0
|
||||
},
|
||||
"text/plain": [
|
||||
"train: 0%| | 0/147 [00:00<?, ?it/s]"
|
||||
]
|
||||
},
|
||||
"metadata": {},
|
||||
"output_type": "display_data"
|
||||
},
|
||||
{
|
||||
"name": "stdout",
|
||||
"output_type": "stream",
|
||||
"text": [
|
||||
"Epoch: 6, Steps: 147 | Train Loss: 0.2457922 Vali Loss: 0.1407998 Test Loss: 0.0951631\n",
|
||||
"EarlyStopping counter: 2 out of 5\n",
|
||||
"Updating learning rate to 7.5e-05\n"
|
||||
]
|
||||
},
|
||||
{
|
||||
"data": {
|
||||
"application/vnd.jupyter.widget-view+json": {
|
||||
"model_id": "",
|
||||
"version_major": 2,
|
||||
"version_minor": 0
|
||||
},
|
||||
"text/plain": [
|
||||
"train: 0%| | 0/147 [00:00<?, ?it/s]"
|
||||
]
|
||||
},
|
||||
"metadata": {},
|
||||
"output_type": "display_data"
|
||||
},
|
||||
{
|
||||
"name": "stdout",
|
||||
"output_type": "stream",
|
||||
"text": [
|
||||
"Epoch: 7, Steps: 147 | Train Loss: 0.2461400 Vali Loss: 0.1385264 Test Loss: 0.0990372\n",
|
||||
"EarlyStopping counter: 3 out of 5\n",
|
||||
"Updating learning rate to 3.75e-05\n"
|
||||
]
|
||||
},
|
||||
{
|
||||
"data": {
|
||||
"application/vnd.jupyter.widget-view+json": {
|
||||
"model_id": "",
|
||||
"version_major": 2,
|
||||
"version_minor": 0
|
||||
},
|
||||
"text/plain": [
|
||||
"train: 0%| | 0/147 [00:00<?, ?it/s]"
|
||||
]
|
||||
},
|
||||
"metadata": {},
|
||||
"output_type": "display_data"
|
||||
},
|
||||
{
|
||||
"name": "stdout",
|
||||
"output_type": "stream",
|
||||
"text": [
|
||||
"Epoch: 8, Steps: 147 | Train Loss: 0.2535853 Vali Loss: 0.1401822 Test Loss: 0.0943378\n",
|
||||
"EarlyStopping counter: 4 out of 5\n",
|
||||
"Updating learning rate to 1.875e-05\n"
|
||||
]
|
||||
},
|
||||
{
|
||||
"data": {
|
||||
"application/vnd.jupyter.widget-view+json": {
|
||||
"model_id": "b99b57bf3d3844c7a693f364d6d25a7b",
|
||||
"version_major": 2,
|
||||
"version_minor": 0
|
||||
},
|
||||
"text/plain": [
|
||||
"train: 0%| | 0/147 [00:00<?, ?it/s]"
|
||||
]
|
||||
},
|
||||
"metadata": {},
|
||||
"output_type": "display_data"
|
||||
}
|
||||
],
|
||||
"source": [
|
||||
"# mimic cli args to avoid code duplication\n",
|
||||
"argv = \"\"\"python -u run.py \\\n",
|
||||
" --root_path ../dataset/exchange_rate/ \\\n",
|
||||
" --data_path exchange_rate.csv \\\n",
|
||||
" --data custom \\\n",
|
||||
" --checkpoints ./checkpoints/ \\\n",
|
||||
" --model_id Exchange \\\n",
|
||||
" --model ETSformer \\\n",
|
||||
" --data custom \\\n",
|
||||
" --features S \\\n",
|
||||
" --seq_len 512 \\\n",
|
||||
" --pred_len 96 \\\n",
|
||||
" --e_layers 2 \\\n",
|
||||
" --d_layers 2 \\\n",
|
||||
" --d_model 512 \\\n",
|
||||
" --d_ff 2048 \\\n",
|
||||
" --n_heads 16 \\\n",
|
||||
" --dropout 0.2 \\\n",
|
||||
" --enc_in 1 \\\n",
|
||||
" --c_out 1 \\\n",
|
||||
" --num_workers 0 \\\n",
|
||||
" --des 'Exp' \\\n",
|
||||
" --K 0 \\\n",
|
||||
" --learning_rate 3e-4 \\\n",
|
||||
" --itr 1\n",
|
||||
"\"\"\"\n",
|
||||
"argv = argv.replace(\"\\\\n\", \"\").split()[3:]\n",
|
||||
"args = get_args(argv)\n",
|
||||
"args\n",
|
||||
"\n",
|
||||
"\n",
|
||||
"Exp = Exp_Main\n",
|
||||
"ii=0\n",
|
||||
"set_seed(ii)\n",
|
||||
"# setting record of experiments\n",
|
||||
"setting = '{}_{}_{}_ft{}_sl{}_pl{}_dm{}_nh{}_el{}_dl{}_df{}_K{}_lr{}_{}_{}'.format(\n",
|
||||
" args.model_id,\n",
|
||||
" args.model,\n",
|
||||
" args.data,\n",
|
||||
" args.features,\n",
|
||||
" args.seq_len,\n",
|
||||
" args.pred_len,\n",
|
||||
" args.d_model,\n",
|
||||
" args.n_heads,\n",
|
||||
" args.e_layers,\n",
|
||||
" args.d_layers,\n",
|
||||
" args.d_ff,\n",
|
||||
" args.K,\n",
|
||||
" args.learning_rate,\n",
|
||||
" args.des, ii)\n",
|
||||
"\n",
|
||||
"# if os.path.exists(os.path.join(args.checkpoints, setting)):\n",
|
||||
"# print('skipping exists')\n",
|
||||
"# continue\n",
|
||||
"\n",
|
||||
"exp = Exp(args) # set experiments\n",
|
||||
"print('>>>>>>>start training : {}>>>>>>>>>>>>>>>>>>>>>>>>>>'.format(setting))\n",
|
||||
"exp.train(setting)\n",
|
||||
"\n",
|
||||
"print('>>>>>>>testing : {}<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<'.format(setting))\n",
|
||||
"exp.test(setting, data='val')\n",
|
||||
"exp.test(setting, data='test')\n",
|
||||
"\n",
|
||||
"torch.cuda.empty_cache()"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"id": "563c1c4c",
|
||||
"metadata": {
|
||||
"ExecuteTime": {
|
||||
"start_time": "2022-11-28T10:35:02.332Z"
|
||||
}
|
||||
},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"exp.model"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"id": "72811788",
|
||||
"metadata": {
|
||||
"ExecuteTime": {
|
||||
"end_time": "2022-11-28T05:28:20.356399Z",
|
||||
"start_time": "2022-11-28T05:28:20.339961Z"
|
||||
}
|
||||
},
|
||||
"source": [
|
||||
"# Plot"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"id": "92b67ed7",
|
||||
"metadata": {
|
||||
"ExecuteTime": {
|
||||
"start_time": "2022-11-28T10:35:02.599Z"
|
||||
}
|
||||
},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"setting"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"id": "ebea79f7",
|
||||
"metadata": {
|
||||
"ExecuteTime": {
|
||||
"start_time": "2022-11-28T10:35:02.716Z"
|
||||
}
|
||||
},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"ds, dl = exp._get_data('test')"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"id": "9833f218",
|
||||
"metadata": {
|
||||
"ExecuteTime": {
|
||||
"start_time": "2022-11-28T10:35:02.866Z"
|
||||
}
|
||||
},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"preds, trues = exp.test(setting, data='test')"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"id": "658546f2",
|
||||
"metadata": {
|
||||
"ExecuteTime": {
|
||||
"start_time": "2022-11-28T10:35:02.999Z"
|
||||
}
|
||||
},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"%matplotlib inline"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"id": "5c88dbf7",
|
||||
"metadata": {
|
||||
"ExecuteTime": {
|
||||
"start_time": "2022-11-28T10:35:03.149Z"
|
||||
}
|
||||
},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"i=990\n",
|
||||
"(batch_x, batch_y, batch_x_mark, batch_y_mark) = ds[i]\n",
|
||||
"l1 = batch_x.shape[0]\n",
|
||||
"l2 = batch_y.shape[0]\n",
|
||||
"plt.plot(range(l1), batch_x[:, -1], label='past')\n",
|
||||
"plt.plot(range(l1, l1+l2), batch_y[:, -1], color='blue', ls='--', label='true')\n",
|
||||
"plt.plot(range(l1, l1+l2), preds[i], label='pred');\n",
|
||||
"plt.legend(loc='lower left')"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"id": "0a25923b",
|
||||
"metadata": {
|
||||
"ExecuteTime": {
|
||||
"end_time": "2022-11-28T07:33:57.703528Z",
|
||||
"start_time": "2022-11-28T07:33:57.689258Z"
|
||||
}
|
||||
},
|
||||
"outputs": [],
|
||||
"source": []
|
||||
},
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"id": "665b98c3",
|
||||
"metadata": {},
|
||||
"source": [
|
||||
"# TODO check index"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"id": "fc6ff0cd",
|
||||
"metadata": {},
|
||||
"source": [
|
||||
"- s_end = index + self.seq_len\n",
|
||||
"- r_begin = index + self.seq_len - self.label_len\n",
|
||||
"- r_end = index + self.seq_len + self.pred_len"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"id": "f482f601",
|
||||
"metadata": {
|
||||
"ExecuteTime": {
|
||||
"end_time": "2022-11-28T10:34:56.644236Z",
|
||||
"start_time": "2022-11-28T10:34:56.644230Z"
|
||||
},
|
||||
"scrolled": true
|
||||
},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"f = ds.root_path + '/' + ds.data_path\n",
|
||||
"df = pd.read_csv(f).set_index('date', drop=False)\n",
|
||||
"df = df[ds.cols[1:]]\n",
|
||||
"# df[:] = ds.scaler.transform(df.values)\n",
|
||||
"df\n",
|
||||
"# df"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"id": "1893beae",
|
||||
"metadata": {
|
||||
"ExecuteTime": {
|
||||
"end_time": "2022-11-28T10:34:56.644919Z",
|
||||
"start_time": "2022-11-28T10:34:56.644913Z"
|
||||
}
|
||||
},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"from data_provider.data_loader import Dataset_Custom\n",
|
||||
"ds2 = Dataset_Custom(\n",
|
||||
" ds.root_path,\n",
|
||||
" flag=\"test\",\n",
|
||||
"# size=ds.size,\n",
|
||||
" size=[ds.seq_len, ds.label_len, ds.pred_len],\n",
|
||||
" features=ds.features,\n",
|
||||
" data_path=ds.data_path,\n",
|
||||
" target=ds.target,\n",
|
||||
" scale=False,\n",
|
||||
" timeenc=ds.timeenc,\n",
|
||||
" freq=ds.freq\n",
|
||||
")\n",
|
||||
"\n",
|
||||
"i=99\n",
|
||||
"(batch_x, batch_y, batch_x_mark, batch_y_mark) = ds2[i]\n",
|
||||
"batch_x[-1]"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"id": "f0a63562",
|
||||
"metadata": {
|
||||
"ExecuteTime": {
|
||||
"end_time": "2022-11-28T10:34:56.645465Z",
|
||||
"start_time": "2022-11-28T10:34:56.645459Z"
|
||||
}
|
||||
},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"dt = ds2.index.iloc[i-1]\n",
|
||||
"df.loc[dt]"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"id": "65c56ec5",
|
||||
"metadata": {
|
||||
"ExecuteTime": {
|
||||
"end_time": "2022-11-28T10:34:56.646111Z",
|
||||
"start_time": "2022-11-28T10:34:56.646104Z"
|
||||
}
|
||||
},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"assert df.loc[dt].close == batch_x[-1, 0], 'index should be right'\n",
|
||||
"'OK'"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"id": "692aadfa",
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": []
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"id": "2dd11010",
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": []
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"id": "e4413025",
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": []
|
||||
}
|
||||
],
|
||||
"metadata": {
|
||||
"kernelspec": {
|
||||
"display_name": "deeptime",
|
||||
"language": "python",
|
||||
"name": "deeptime"
|
||||
},
|
||||
"language_info": {
|
||||
"codemirror_mode": {
|
||||
"name": "ipython",
|
||||
"version": 3
|
||||
},
|
||||
"file_extension": ".py",
|
||||
"mimetype": "text/x-python",
|
||||
"name": "python",
|
||||
"nbconvert_exporter": "python",
|
||||
"pygments_lexer": "ipython3",
|
||||
"version": "3.8.13"
|
||||
},
|
||||
"toc": {
|
||||
"base_numbering": 1,
|
||||
"nav_menu": {},
|
||||
"number_sections": true,
|
||||
"sideBar": true,
|
||||
"skip_h1_title": false,
|
||||
"title_cell": "Table of Contents",
|
||||
"title_sidebar": "Contents",
|
||||
"toc_cell": false,
|
||||
"toc_position": {},
|
||||
"toc_section_display": true,
|
||||
"toc_window_display": false
|
||||
}
|
||||
},
|
||||
"nbformat": 4,
|
||||
"nbformat_minor": 5
|
||||
}
|
||||
+864
-45
File diff suppressed because one or more lines are too long
File diff suppressed because one or more lines are too long
Reference in New Issue
Block a user