mirror of
https://github.com/wassname/DeepTime.git
synced 2026-06-27 20:19:35 +08:00
serialize
This commit is contained in:
@@ -20,6 +20,7 @@ from utils.checkpoint import Checkpoint
|
||||
from utils.ops import default_device, to_tensor
|
||||
from utils.losses import get_loss_fn
|
||||
from utils.metrics import calc_metrics
|
||||
from utils.serialize import serialize
|
||||
|
||||
|
||||
class ForecastExperiment(Experiment):
|
||||
@@ -33,7 +34,7 @@ class ForecastExperiment(Experiment):
|
||||
test_set, test_loader = get_data(flag='test')
|
||||
|
||||
dim_size=train_set.data_x.shape[1]
|
||||
seq_len = train_set[0][1].shape
|
||||
seq_len = train_set[0][1].shape[0]
|
||||
model = get_model(model_type,
|
||||
dim_size=dim_size,
|
||||
seq_len=seq_len,
|
||||
@@ -47,8 +48,10 @@ class ForecastExperiment(Experiment):
|
||||
val_metrics = validate(model, loader=val_loader, report_metrics=True)
|
||||
test_metrics = validate(model, loader=test_loader, report_metrics=True,
|
||||
save_path=self.root if save_vals else None)
|
||||
metrics = {'val': val_metrics, 'test': test_metrics}
|
||||
# np.save(join(self.root, 'metrics.npy'), {'val': val_metrics, 'test': test_metrics})
|
||||
json.dump({'val': val_metrics, 'test': test_metrics}, open(join(self.root, 'metrics.npy', 'w')))
|
||||
metrics = serialize(metrics)
|
||||
json.dump(metrics, open(join(self.root, 'metrics.npy'), 'w'))
|
||||
|
||||
val_metrics = {f'ValMetric/{k}': v for k, v in val_metrics.items()}
|
||||
test_metrics = {f'TestMetric/{k}': v for k, v in test_metrics.items()}
|
||||
|
||||
+96
-92
@@ -21,8 +21,8 @@
|
||||
"id": "7f9e3d73",
|
||||
"metadata": {
|
||||
"ExecuteTime": {
|
||||
"end_time": "2022-11-23T03:32:18.927915Z",
|
||||
"start_time": "2022-11-23T03:32:18.918871Z"
|
||||
"end_time": "2022-11-23T05:42:46.869262Z",
|
||||
"start_time": "2022-11-23T05:42:46.859832Z"
|
||||
}
|
||||
},
|
||||
"outputs": [],
|
||||
@@ -41,8 +41,8 @@
|
||||
"id": "4e09086b",
|
||||
"metadata": {
|
||||
"ExecuteTime": {
|
||||
"end_time": "2022-11-23T03:32:20.144762Z",
|
||||
"start_time": "2022-11-23T03:32:18.928928Z"
|
||||
"end_time": "2022-11-23T05:42:48.073268Z",
|
||||
"start_time": "2022-11-23T05:42:46.870484Z"
|
||||
},
|
||||
"lines_to_next_cell": 0
|
||||
},
|
||||
@@ -84,8 +84,8 @@
|
||||
"id": "66d7f095",
|
||||
"metadata": {
|
||||
"ExecuteTime": {
|
||||
"end_time": "2022-11-23T03:32:20.171177Z",
|
||||
"start_time": "2022-11-23T03:32:20.146544Z"
|
||||
"end_time": "2022-11-23T05:42:48.097114Z",
|
||||
"start_time": "2022-11-23T05:42:48.074670Z"
|
||||
}
|
||||
},
|
||||
"outputs": [
|
||||
@@ -123,8 +123,8 @@
|
||||
"id": "04499bef",
|
||||
"metadata": {
|
||||
"ExecuteTime": {
|
||||
"end_time": "2022-11-23T03:32:20.191130Z",
|
||||
"start_time": "2022-11-23T03:32:20.172479Z"
|
||||
"end_time": "2022-11-23T05:42:48.115157Z",
|
||||
"start_time": "2022-11-23T05:42:48.098029Z"
|
||||
}
|
||||
},
|
||||
"outputs": [],
|
||||
@@ -179,7 +179,7 @@
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"id": "09dd5ebd",
|
||||
"id": "6615bb09",
|
||||
"metadata": {
|
||||
"ExecuteTime": {
|
||||
"end_time": "2022-11-22T13:28:35.849491Z",
|
||||
@@ -193,11 +193,11 @@
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 5,
|
||||
"id": "dc530891",
|
||||
"id": "ec47ae46",
|
||||
"metadata": {
|
||||
"ExecuteTime": {
|
||||
"end_time": "2022-11-23T03:32:20.213967Z",
|
||||
"start_time": "2022-11-23T03:32:20.192307Z"
|
||||
"end_time": "2022-11-23T05:42:48.138242Z",
|
||||
"start_time": "2022-11-23T05:42:48.115973Z"
|
||||
}
|
||||
},
|
||||
"outputs": [
|
||||
@@ -218,7 +218,7 @@
|
||||
},
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"id": "da8f502f",
|
||||
"id": "0a298ce2",
|
||||
"metadata": {
|
||||
"ExecuteTime": {
|
||||
"end_time": "2022-11-22T13:38:15.786656Z",
|
||||
@@ -231,7 +231,7 @@
|
||||
},
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"id": "ccc9fdb0",
|
||||
"id": "40002390",
|
||||
"metadata": {},
|
||||
"source": [
|
||||
"# run exps"
|
||||
@@ -240,7 +240,7 @@
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"id": "e09f2823",
|
||||
"id": "2e1e58e7",
|
||||
"metadata": {
|
||||
"ExecuteTime": {
|
||||
"end_time": "2022-11-23T03:31:08.583641Z",
|
||||
@@ -253,75 +253,80 @@
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 6,
|
||||
"id": "9d13d779",
|
||||
"id": "91b92e80",
|
||||
"metadata": {
|
||||
"ExecuteTime": {
|
||||
"end_time": "2022-11-23T03:32:20.232318Z",
|
||||
"start_time": "2022-11-23T03:32:20.215075Z"
|
||||
"end_time": "2022-11-23T05:42:48.154921Z",
|
||||
"start_time": "2022-11-23T05:42:48.139190Z"
|
||||
}
|
||||
},
|
||||
"outputs": [
|
||||
{
|
||||
"name": "stdout",
|
||||
"output_type": "stream",
|
||||
"text": [
|
||||
"[Path('storage/experiments/Stocks/96M2S/base_learner=Transformer,inr=INR,encoder=mlp,repeat=0/config.gin'), Path('storage/experiments/Stocks/96M2S/base_learner=Ridge,inr=INRPlus2,encoder=inception,repeat=0/config.gin'), Path('storage/experiments/Stocks/96M2S/base_learner=Ridge,inr=INRPlus2,encoder=none,repeat=0/config.gin'), Path('storage/experiments/Stocks/96M2S/base_learner=None,inr=INRPlus2,encoder=transformer2,repeat=0/config.gin'), Path('storage/experiments/Stocks/96M2S/base_learner=None,inr=INR,encoder=transformer,repeat=0/config.gin'), Path('storage/experiments/Stocks/96M2S/base_learner=Transformer,inr=INR,encoder=lstm,repeat=0/config.gin'), Path('storage/experiments/Stocks/96M2S/base_learner=None,inr=INR,encoder=mlp,repeat=0/config.gin'), Path('storage/experiments/Stocks/96M2S/base_learner=Transformer,inr=INR,encoder=transformer,repeat=0/config.gin'), Path('storage/experiments/Stocks/96M2S/base_learner=None,inr=INR,encoder=transformer2,repeat=0/config.gin'), Path('storage/experiments/Stocks/96M2S/base_learner=Ridge,inr=INR,encoder=lstm,repeat=0/config.gin'), Path('storage/experiments/Stocks/96M2S/base_learner=Ridge,inr=INR,encoder=lstm2,repeat=0/config.gin'), Path('storage/experiments/Stocks/96M2S/base_learner=Ridge,inr=INRPlus2,encoder=mlp,repeat=0/config.gin'), Path('storage/experiments/Stocks/96M2S/base_learner=None,inr=INRPlus2,encoder=transformer,repeat=0/config.gin'), Path('storage/experiments/Stocks/96M2S/base_learner=None,inr=INRPlus2,encoder=inception,repeat=0/config.gin'), Path('storage/experiments/Stocks/96M2S/base_learner=Ridge,inr=INR,encoder=mlp,repeat=0/config.gin'), Path('storage/experiments/Stocks/96M2S/base_learner=Transformer,inr=INRPlus2,encoder=transformer2,repeat=0/config.gin'), Path('storage/experiments/Stocks/96M2S/base_learner=None,inr=INR,encoder=lstm,repeat=0/config.gin'), Path('storage/experiments/Stocks/96M2S/base_learner=Transformer,inr=INR,encoder=transformer2,repeat=0/config.gin'), Path('storage/experiments/Stocks/96M2S/base_learner=None,inr=INRPlus2,encoder=mlp,repeat=0/config.gin'), Path('storage/experiments/Stocks/96M2S/base_learner=Transformer,inr=INR,encoder=lstm2,repeat=0/config.gin'), Path('storage/experiments/Stocks/96M2S/base_learner=Ridge,inr=INR,encoder=transformer,repeat=0/config.gin'), Path('storage/experiments/Stocks/96M2S/base_learner=Ridge,inr=INRPlus2,encoder=lstm2,repeat=0/config.gin'), Path('storage/experiments/Stocks/96M2S/base_learner=Transformer,inr=INRPlus2,encoder=lstm,repeat=0/config.gin'), Path('storage/experiments/Stocks/96M2S/base_learner=Ridge,inr=INRPlus2,encoder=transformer2,repeat=0/config.gin'), Path('storage/experiments/Stocks/96M2S/base_learner=None,inr=INR,encoder=lstm2,repeat=0/config.gin'), Path('storage/experiments/Stocks/96M2S/base_learner=Ridge,inr=INR,encoder=transformer2,repeat=0/config.gin'), Path('storage/experiments/Stocks/96M2S/base_learner=Transformer,inr=INRPlus2,encoder=inception,repeat=0/config.gin'), Path('storage/experiments/Stocks/96M2S/base_learner=None,inr=INRPlus2,encoder=none,repeat=0/config.gin'), Path('storage/experiments/Stocks/96M2S/base_learner=Ridge,inr=INR,encoder=inception,repeat=0/config.gin'), Path('storage/experiments/Stocks/96M2S/base_learner=Transformer,inr=INRPlus2,encoder=mlp,repeat=0/config.gin'), Path('storage/experiments/Stocks/96M2S/base_learner=Transformer,inr=INRPlus2,encoder=transformer,repeat=0/config.gin'), Path('storage/experiments/Stocks/96M2S/base_learner=None,inr=INRPlus2,encoder=lstm,repeat=0/config.gin'), Path('storage/experiments/Stocks/96M2S/base_learner=None,inr=INR,encoder=inception,repeat=0/config.gin'), Path('storage/experiments/Stocks/96M2S/base_learner=None,inr=INR,encoder=none,repeat=0/config.gin'), Path('storage/experiments/Stocks/96M2S/base_learner=Transformer,inr=INR,encoder=none,repeat=0/config.gin'), Path('storage/experiments/Stocks/96M2S/base_learner=Transformer,inr=INR,encoder=inception,repeat=0/config.gin'), Path('storage/experiments/Stocks/96M2S/base_learner=None,inr=INRPlus2,encoder=lstm2,repeat=0/config.gin'), Path('storage/experiments/Stocks/96M2S/base_learner=Transformer,inr=INRPlus2,encoder=lstm2,repeat=0/config.gin'), Path('storage/experiments/Stocks/96M2S/base_learner=Ridge,inr=INR,encoder=none,repeat=0/config.gin'), Path('storage/experiments/Stocks/96M2S/base_learner=Ridge,inr=INRPlus2,encoder=transformer,repeat=0/config.gin'), Path('storage/experiments/Stocks/96M2S/base_learner=Transformer,inr=INRPlus2,encoder=none,repeat=0/config.gin'), Path('storage/experiments/Stocks/96M2S/base_learner=Ridge,inr=INRPlus2,encoder=lstm,repeat=0/config.gin')]\n"
|
||||
]
|
||||
}
|
||||
],
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"# list the models we have run...\n",
|
||||
"configs=sorted(Path(\"storage/experiments/Stocks\").glob(\"**/config.gin\"))\n",
|
||||
"import random\n",
|
||||
"random.shuffle(configs)\n",
|
||||
"print(configs)"
|
||||
"# print(configs)"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 7,
|
||||
"id": "4178d85e",
|
||||
"id": "8bb001f7",
|
||||
"metadata": {
|
||||
"ExecuteTime": {
|
||||
"end_time": "2022-11-23T03:32:20.248919Z",
|
||||
"start_time": "2022-11-23T03:32:20.233964Z"
|
||||
"end_time": "2022-11-23T05:42:48.171079Z",
|
||||
"start_time": "2022-11-23T05:42:48.155764Z"
|
||||
}
|
||||
},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"from experiments.forecast import ForecastExperiment"
|
||||
"from experiments.forecast import ForecastExperiment\n",
|
||||
"from tqdm.auto import tqdm"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"id": "8eadaa48",
|
||||
"id": "029641dd",
|
||||
"metadata": {
|
||||
"ExecuteTime": {
|
||||
"start_time": "2022-11-23T04:00:02.202Z"
|
||||
}
|
||||
"start_time": "2022-11-23T05:42:46.859Z"
|
||||
},
|
||||
"scrolled": true
|
||||
},
|
||||
"outputs": [
|
||||
{
|
||||
"data": {
|
||||
"application/vnd.jupyter.widget-view+json": {
|
||||
"model_id": "49eccaa5eb6942e29ed4943d5b92cf89",
|
||||
"version_major": 2,
|
||||
"version_minor": 0
|
||||
},
|
||||
"text/plain": [
|
||||
" 0%| | 0/42 [00:00<?, ?it/s]"
|
||||
]
|
||||
},
|
||||
"metadata": {},
|
||||
"output_type": "display_data"
|
||||
},
|
||||
{
|
||||
"name": "stdout",
|
||||
"output_type": "stream",
|
||||
"text": [
|
||||
"storage/experiments/Stocks/96M2S/base_learner=Transformer,inr=INR,encoder=mlp,repeat=0/config.gin\n"
|
||||
"storage/experiments/Stocks/96M2S/base_learner=Transformer,inr=INR,encoder=none,repeat=0/config.gin\n"
|
||||
]
|
||||
},
|
||||
{
|
||||
"name": "stderr",
|
||||
"output_type": "stream",
|
||||
"text": [
|
||||
"INFO:root:epochs: 1, iters: 100 | training loss: 2.41\n",
|
||||
"INFO:root:Validation loss decreased (inf --> 1.004). Saving model ...\n",
|
||||
"INFO:root:epochs: 2, iters: 100 | training loss: 0.20\n",
|
||||
"INFO:root:Validation loss decreased (1.004 --> 0.115). Saving model ...\n"
|
||||
"INFO:root:epochs: 1, iters: 100 | training loss: 1.84\n"
|
||||
]
|
||||
}
|
||||
],
|
||||
"source": [
|
||||
"for config in configs:\n",
|
||||
"for config in tqdm(configs):\n",
|
||||
" save_path = config.parent\n",
|
||||
"\n",
|
||||
" exp = ForecastExperiment(config_path=config)\n",
|
||||
@@ -331,7 +336,7 @@
|
||||
" except KeyboardInterrupt:\n",
|
||||
" raise\n",
|
||||
" except Exception as e:\n",
|
||||
"# raise\n",
|
||||
" raise\n",
|
||||
" print(e)\n",
|
||||
" pass"
|
||||
]
|
||||
@@ -339,11 +344,24 @@
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"id": "e7a3c151",
|
||||
"id": "0472a84b",
|
||||
"metadata": {
|
||||
"ExecuteTime": {
|
||||
"end_time": "2022-11-23T03:36:04.312188Z",
|
||||
"start_time": "2022-11-23T03:36:04.312180Z"
|
||||
"start_time": "2022-11-23T05:42:46.875Z"
|
||||
}
|
||||
},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"%debug"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"id": "03a28a90",
|
||||
"metadata": {
|
||||
"ExecuteTime": {
|
||||
"start_time": "2022-11-23T05:42:46.875Z"
|
||||
}
|
||||
},
|
||||
"outputs": [],
|
||||
@@ -354,11 +372,10 @@
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"id": "8b73c64e",
|
||||
"id": "a31ae85a",
|
||||
"metadata": {
|
||||
"ExecuteTime": {
|
||||
"end_time": "2022-11-23T03:36:04.313334Z",
|
||||
"start_time": "2022-11-23T03:36:04.313322Z"
|
||||
"start_time": "2022-11-23T05:42:46.875Z"
|
||||
}
|
||||
},
|
||||
"outputs": [],
|
||||
@@ -369,11 +386,10 @@
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"id": "6af1bb1e",
|
||||
"id": "1a9d823d",
|
||||
"metadata": {
|
||||
"ExecuteTime": {
|
||||
"end_time": "2022-11-23T03:36:04.313962Z",
|
||||
"start_time": "2022-11-23T03:36:04.313954Z"
|
||||
"start_time": "2022-11-23T05:42:46.875Z"
|
||||
}
|
||||
},
|
||||
"outputs": [],
|
||||
@@ -388,11 +404,10 @@
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"id": "50a92a7e",
|
||||
"id": "5e4983e7",
|
||||
"metadata": {
|
||||
"ExecuteTime": {
|
||||
"end_time": "2022-11-23T03:36:04.314594Z",
|
||||
"start_time": "2022-11-23T03:36:04.314588Z"
|
||||
"start_time": "2022-11-23T05:42:46.875Z"
|
||||
}
|
||||
},
|
||||
"outputs": [],
|
||||
@@ -405,11 +420,10 @@
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"id": "9ef76b76",
|
||||
"id": "cdd0fbbf",
|
||||
"metadata": {
|
||||
"ExecuteTime": {
|
||||
"end_time": "2022-11-23T03:36:04.315217Z",
|
||||
"start_time": "2022-11-23T03:36:04.315210Z"
|
||||
"start_time": "2022-11-23T05:42:46.875Z"
|
||||
}
|
||||
},
|
||||
"outputs": [],
|
||||
@@ -433,7 +447,7 @@
|
||||
},
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"id": "3637e87d",
|
||||
"id": "321106ef",
|
||||
"metadata": {},
|
||||
"source": [
|
||||
"# view all"
|
||||
@@ -445,8 +459,7 @@
|
||||
"id": "768530be",
|
||||
"metadata": {
|
||||
"ExecuteTime": {
|
||||
"end_time": "2022-11-23T03:36:04.315891Z",
|
||||
"start_time": "2022-11-23T03:36:04.315884Z"
|
||||
"start_time": "2022-11-23T05:42:46.875Z"
|
||||
}
|
||||
},
|
||||
"outputs": [],
|
||||
@@ -463,7 +476,7 @@
|
||||
" model_name = gin.query_parameter(\"instance.model_type\")\n",
|
||||
"\n",
|
||||
" train_set, train_loader = get_data(flag='test', batch_size=3)\n",
|
||||
" seq_len = train_set[0][1].shape\n",
|
||||
" seq_len = train_set[0][1].shape[0]\n",
|
||||
" model = get_model(model_name,\n",
|
||||
" dim_size=train_set.data_x.shape[1],\n",
|
||||
" seq_len=seq_len,\n",
|
||||
@@ -522,8 +535,7 @@
|
||||
"id": "739ee5e3",
|
||||
"metadata": {
|
||||
"ExecuteTime": {
|
||||
"end_time": "2022-11-23T03:36:04.316469Z",
|
||||
"start_time": "2022-11-23T03:36:04.316463Z"
|
||||
"start_time": "2022-11-23T05:42:46.875Z"
|
||||
}
|
||||
},
|
||||
"outputs": [],
|
||||
@@ -536,11 +548,10 @@
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"id": "48e9175b",
|
||||
"id": "f1ae652d",
|
||||
"metadata": {
|
||||
"ExecuteTime": {
|
||||
"end_time": "2022-11-23T03:36:04.316979Z",
|
||||
"start_time": "2022-11-23T03:36:04.316973Z"
|
||||
"start_time": "2022-11-23T05:42:46.875Z"
|
||||
}
|
||||
},
|
||||
"outputs": [],
|
||||
@@ -556,11 +567,10 @@
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"id": "4b966ef7",
|
||||
"id": "79bda4bf",
|
||||
"metadata": {
|
||||
"ExecuteTime": {
|
||||
"end_time": "2022-11-23T03:36:04.317585Z",
|
||||
"start_time": "2022-11-23T03:36:04.317578Z"
|
||||
"start_time": "2022-11-23T05:42:46.875Z"
|
||||
}
|
||||
},
|
||||
"outputs": [],
|
||||
@@ -572,11 +582,10 @@
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"id": "0b696e85",
|
||||
"id": "b06e53a9",
|
||||
"metadata": {
|
||||
"ExecuteTime": {
|
||||
"end_time": "2022-11-23T03:36:04.318277Z",
|
||||
"start_time": "2022-11-23T03:36:04.318270Z"
|
||||
"start_time": "2022-11-23T05:42:46.892Z"
|
||||
}
|
||||
},
|
||||
"outputs": [],
|
||||
@@ -599,11 +608,10 @@
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"id": "ac9e5759",
|
||||
"id": "25b9b9c5",
|
||||
"metadata": {
|
||||
"ExecuteTime": {
|
||||
"end_time": "2022-11-23T03:36:04.319520Z",
|
||||
"start_time": "2022-11-23T03:36:04.319512Z"
|
||||
"start_time": "2022-11-23T05:42:46.892Z"
|
||||
}
|
||||
},
|
||||
"outputs": [],
|
||||
@@ -620,11 +628,10 @@
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"id": "b2e27d42",
|
||||
"id": "e173ea6f",
|
||||
"metadata": {
|
||||
"ExecuteTime": {
|
||||
"end_time": "2022-11-23T03:36:04.320031Z",
|
||||
"start_time": "2022-11-23T03:36:04.320024Z"
|
||||
"start_time": "2022-11-23T05:42:46.892Z"
|
||||
}
|
||||
},
|
||||
"outputs": [],
|
||||
@@ -635,7 +642,7 @@
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"id": "938f06db",
|
||||
"id": "9ca6f142",
|
||||
"metadata": {
|
||||
"ExecuteTime": {
|
||||
"end_time": "2022-11-22T13:17:31.585029Z",
|
||||
@@ -647,7 +654,7 @@
|
||||
},
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"id": "9b4e817c",
|
||||
"id": "fce6db58",
|
||||
"metadata": {},
|
||||
"source": [
|
||||
"# check positions in dl"
|
||||
@@ -656,11 +663,10 @@
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"id": "df6c9b28",
|
||||
"id": "5d574727",
|
||||
"metadata": {
|
||||
"ExecuteTime": {
|
||||
"end_time": "2022-11-23T03:36:04.320752Z",
|
||||
"start_time": "2022-11-23T03:36:04.320745Z"
|
||||
"start_time": "2022-11-23T05:42:46.892Z"
|
||||
},
|
||||
"scrolled": true
|
||||
},
|
||||
@@ -675,11 +681,10 @@
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"id": "fa54042e",
|
||||
"id": "74ebc120",
|
||||
"metadata": {
|
||||
"ExecuteTime": {
|
||||
"end_time": "2022-11-23T03:36:04.321671Z",
|
||||
"start_time": "2022-11-23T03:36:04.321662Z"
|
||||
"start_time": "2022-11-23T05:42:46.892Z"
|
||||
}
|
||||
},
|
||||
"outputs": [],
|
||||
@@ -691,11 +696,10 @@
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"id": "2482f0f7",
|
||||
"id": "d289f6bd",
|
||||
"metadata": {
|
||||
"ExecuteTime": {
|
||||
"end_time": "2022-11-23T03:36:04.322543Z",
|
||||
"start_time": "2022-11-23T03:36:04.322535Z"
|
||||
"start_time": "2022-11-23T05:42:46.892Z"
|
||||
}
|
||||
},
|
||||
"outputs": [],
|
||||
@@ -710,7 +714,7 @@
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"id": "e8fe355e",
|
||||
"id": "879d9d45",
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": []
|
||||
@@ -718,7 +722,7 @@
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"id": "3a978978",
|
||||
"id": "37d7641b",
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": []
|
||||
@@ -726,7 +730,7 @@
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"id": "045d3fd5",
|
||||
"id": "061b3b77",
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": []
|
||||
|
||||
@@ -0,0 +1,13 @@
|
||||
from json_tricks import dump, dumps, load, loads, strip_comments
|
||||
|
||||
def torch_encode(obj, primitives=False):
|
||||
from torch import Tensor
|
||||
if isinstance(obj, Tensor):
|
||||
if primitives:
|
||||
return obj.numpy().tolist()
|
||||
raise NotImplemented()
|
||||
return obj
|
||||
|
||||
def serialize(o):
|
||||
s = dumps(o, extra_obj_encoders=[torch_encode], primitives=True)
|
||||
return loads(s)
|
||||
Reference in New Issue
Block a user