serialize

This commit is contained in:
wassname
2022-11-23 13:43:15 +08:00
parent cb27d2ded3
commit b0fbc10dfd
3 changed files with 114 additions and 94 deletions
+5 -2
View File
@@ -20,6 +20,7 @@ from utils.checkpoint import Checkpoint
from utils.ops import default_device, to_tensor
from utils.losses import get_loss_fn
from utils.metrics import calc_metrics
from utils.serialize import serialize
class ForecastExperiment(Experiment):
@@ -33,7 +34,7 @@ class ForecastExperiment(Experiment):
test_set, test_loader = get_data(flag='test')
dim_size=train_set.data_x.shape[1]
seq_len = train_set[0][1].shape
seq_len = train_set[0][1].shape[0]
model = get_model(model_type,
dim_size=dim_size,
seq_len=seq_len,
@@ -47,8 +48,10 @@ class ForecastExperiment(Experiment):
val_metrics = validate(model, loader=val_loader, report_metrics=True)
test_metrics = validate(model, loader=test_loader, report_metrics=True,
save_path=self.root if save_vals else None)
metrics = {'val': val_metrics, 'test': test_metrics}
# np.save(join(self.root, 'metrics.npy'), {'val': val_metrics, 'test': test_metrics})
json.dump({'val': val_metrics, 'test': test_metrics}, open(join(self.root, 'metrics.npy', 'w')))
metrics = serialize(metrics)
json.dump(metrics, open(join(self.root, 'metrics.npy'), 'w'))
val_metrics = {f'ValMetric/{k}': v for k, v in val_metrics.items()}
test_metrics = {f'TestMetric/{k}': v for k, v in test_metrics.items()}
+96 -92
View File
@@ -21,8 +21,8 @@
"id": "7f9e3d73",
"metadata": {
"ExecuteTime": {
"end_time": "2022-11-23T03:32:18.927915Z",
"start_time": "2022-11-23T03:32:18.918871Z"
"end_time": "2022-11-23T05:42:46.869262Z",
"start_time": "2022-11-23T05:42:46.859832Z"
}
},
"outputs": [],
@@ -41,8 +41,8 @@
"id": "4e09086b",
"metadata": {
"ExecuteTime": {
"end_time": "2022-11-23T03:32:20.144762Z",
"start_time": "2022-11-23T03:32:18.928928Z"
"end_time": "2022-11-23T05:42:48.073268Z",
"start_time": "2022-11-23T05:42:46.870484Z"
},
"lines_to_next_cell": 0
},
@@ -84,8 +84,8 @@
"id": "66d7f095",
"metadata": {
"ExecuteTime": {
"end_time": "2022-11-23T03:32:20.171177Z",
"start_time": "2022-11-23T03:32:20.146544Z"
"end_time": "2022-11-23T05:42:48.097114Z",
"start_time": "2022-11-23T05:42:48.074670Z"
}
},
"outputs": [
@@ -123,8 +123,8 @@
"id": "04499bef",
"metadata": {
"ExecuteTime": {
"end_time": "2022-11-23T03:32:20.191130Z",
"start_time": "2022-11-23T03:32:20.172479Z"
"end_time": "2022-11-23T05:42:48.115157Z",
"start_time": "2022-11-23T05:42:48.098029Z"
}
},
"outputs": [],
@@ -179,7 +179,7 @@
{
"cell_type": "code",
"execution_count": null,
"id": "09dd5ebd",
"id": "6615bb09",
"metadata": {
"ExecuteTime": {
"end_time": "2022-11-22T13:28:35.849491Z",
@@ -193,11 +193,11 @@
{
"cell_type": "code",
"execution_count": 5,
"id": "dc530891",
"id": "ec47ae46",
"metadata": {
"ExecuteTime": {
"end_time": "2022-11-23T03:32:20.213967Z",
"start_time": "2022-11-23T03:32:20.192307Z"
"end_time": "2022-11-23T05:42:48.138242Z",
"start_time": "2022-11-23T05:42:48.115973Z"
}
},
"outputs": [
@@ -218,7 +218,7 @@
},
{
"cell_type": "markdown",
"id": "da8f502f",
"id": "0a298ce2",
"metadata": {
"ExecuteTime": {
"end_time": "2022-11-22T13:38:15.786656Z",
@@ -231,7 +231,7 @@
},
{
"cell_type": "markdown",
"id": "ccc9fdb0",
"id": "40002390",
"metadata": {},
"source": [
"# run exps"
@@ -240,7 +240,7 @@
{
"cell_type": "code",
"execution_count": null,
"id": "e09f2823",
"id": "2e1e58e7",
"metadata": {
"ExecuteTime": {
"end_time": "2022-11-23T03:31:08.583641Z",
@@ -253,75 +253,80 @@
{
"cell_type": "code",
"execution_count": 6,
"id": "9d13d779",
"id": "91b92e80",
"metadata": {
"ExecuteTime": {
"end_time": "2022-11-23T03:32:20.232318Z",
"start_time": "2022-11-23T03:32:20.215075Z"
"end_time": "2022-11-23T05:42:48.154921Z",
"start_time": "2022-11-23T05:42:48.139190Z"
}
},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"[Path('storage/experiments/Stocks/96M2S/base_learner=Transformer,inr=INR,encoder=mlp,repeat=0/config.gin'), Path('storage/experiments/Stocks/96M2S/base_learner=Ridge,inr=INRPlus2,encoder=inception,repeat=0/config.gin'), Path('storage/experiments/Stocks/96M2S/base_learner=Ridge,inr=INRPlus2,encoder=none,repeat=0/config.gin'), Path('storage/experiments/Stocks/96M2S/base_learner=None,inr=INRPlus2,encoder=transformer2,repeat=0/config.gin'), Path('storage/experiments/Stocks/96M2S/base_learner=None,inr=INR,encoder=transformer,repeat=0/config.gin'), Path('storage/experiments/Stocks/96M2S/base_learner=Transformer,inr=INR,encoder=lstm,repeat=0/config.gin'), Path('storage/experiments/Stocks/96M2S/base_learner=None,inr=INR,encoder=mlp,repeat=0/config.gin'), Path('storage/experiments/Stocks/96M2S/base_learner=Transformer,inr=INR,encoder=transformer,repeat=0/config.gin'), Path('storage/experiments/Stocks/96M2S/base_learner=None,inr=INR,encoder=transformer2,repeat=0/config.gin'), Path('storage/experiments/Stocks/96M2S/base_learner=Ridge,inr=INR,encoder=lstm,repeat=0/config.gin'), Path('storage/experiments/Stocks/96M2S/base_learner=Ridge,inr=INR,encoder=lstm2,repeat=0/config.gin'), Path('storage/experiments/Stocks/96M2S/base_learner=Ridge,inr=INRPlus2,encoder=mlp,repeat=0/config.gin'), Path('storage/experiments/Stocks/96M2S/base_learner=None,inr=INRPlus2,encoder=transformer,repeat=0/config.gin'), Path('storage/experiments/Stocks/96M2S/base_learner=None,inr=INRPlus2,encoder=inception,repeat=0/config.gin'), Path('storage/experiments/Stocks/96M2S/base_learner=Ridge,inr=INR,encoder=mlp,repeat=0/config.gin'), Path('storage/experiments/Stocks/96M2S/base_learner=Transformer,inr=INRPlus2,encoder=transformer2,repeat=0/config.gin'), Path('storage/experiments/Stocks/96M2S/base_learner=None,inr=INR,encoder=lstm,repeat=0/config.gin'), Path('storage/experiments/Stocks/96M2S/base_learner=Transformer,inr=INR,encoder=transformer2,repeat=0/config.gin'), Path('storage/experiments/Stocks/96M2S/base_learner=None,inr=INRPlus2,encoder=mlp,repeat=0/config.gin'), Path('storage/experiments/Stocks/96M2S/base_learner=Transformer,inr=INR,encoder=lstm2,repeat=0/config.gin'), Path('storage/experiments/Stocks/96M2S/base_learner=Ridge,inr=INR,encoder=transformer,repeat=0/config.gin'), Path('storage/experiments/Stocks/96M2S/base_learner=Ridge,inr=INRPlus2,encoder=lstm2,repeat=0/config.gin'), Path('storage/experiments/Stocks/96M2S/base_learner=Transformer,inr=INRPlus2,encoder=lstm,repeat=0/config.gin'), Path('storage/experiments/Stocks/96M2S/base_learner=Ridge,inr=INRPlus2,encoder=transformer2,repeat=0/config.gin'), Path('storage/experiments/Stocks/96M2S/base_learner=None,inr=INR,encoder=lstm2,repeat=0/config.gin'), Path('storage/experiments/Stocks/96M2S/base_learner=Ridge,inr=INR,encoder=transformer2,repeat=0/config.gin'), Path('storage/experiments/Stocks/96M2S/base_learner=Transformer,inr=INRPlus2,encoder=inception,repeat=0/config.gin'), Path('storage/experiments/Stocks/96M2S/base_learner=None,inr=INRPlus2,encoder=none,repeat=0/config.gin'), Path('storage/experiments/Stocks/96M2S/base_learner=Ridge,inr=INR,encoder=inception,repeat=0/config.gin'), Path('storage/experiments/Stocks/96M2S/base_learner=Transformer,inr=INRPlus2,encoder=mlp,repeat=0/config.gin'), Path('storage/experiments/Stocks/96M2S/base_learner=Transformer,inr=INRPlus2,encoder=transformer,repeat=0/config.gin'), Path('storage/experiments/Stocks/96M2S/base_learner=None,inr=INRPlus2,encoder=lstm,repeat=0/config.gin'), Path('storage/experiments/Stocks/96M2S/base_learner=None,inr=INR,encoder=inception,repeat=0/config.gin'), Path('storage/experiments/Stocks/96M2S/base_learner=None,inr=INR,encoder=none,repeat=0/config.gin'), Path('storage/experiments/Stocks/96M2S/base_learner=Transformer,inr=INR,encoder=none,repeat=0/config.gin'), Path('storage/experiments/Stocks/96M2S/base_learner=Transformer,inr=INR,encoder=inception,repeat=0/config.gin'), Path('storage/experiments/Stocks/96M2S/base_learner=None,inr=INRPlus2,encoder=lstm2,repeat=0/config.gin'), Path('storage/experiments/Stocks/96M2S/base_learner=Transformer,inr=INRPlus2,encoder=lstm2,repeat=0/config.gin'), Path('storage/experiments/Stocks/96M2S/base_learner=Ridge,inr=INR,encoder=none,repeat=0/config.gin'), Path('storage/experiments/Stocks/96M2S/base_learner=Ridge,inr=INRPlus2,encoder=transformer,repeat=0/config.gin'), Path('storage/experiments/Stocks/96M2S/base_learner=Transformer,inr=INRPlus2,encoder=none,repeat=0/config.gin'), Path('storage/experiments/Stocks/96M2S/base_learner=Ridge,inr=INRPlus2,encoder=lstm,repeat=0/config.gin')]\n"
]
}
],
"outputs": [],
"source": [
"# list the models we have run...\n",
"configs=sorted(Path(\"storage/experiments/Stocks\").glob(\"**/config.gin\"))\n",
"import random\n",
"random.shuffle(configs)\n",
"print(configs)"
"# print(configs)"
]
},
{
"cell_type": "code",
"execution_count": 7,
"id": "4178d85e",
"id": "8bb001f7",
"metadata": {
"ExecuteTime": {
"end_time": "2022-11-23T03:32:20.248919Z",
"start_time": "2022-11-23T03:32:20.233964Z"
"end_time": "2022-11-23T05:42:48.171079Z",
"start_time": "2022-11-23T05:42:48.155764Z"
}
},
"outputs": [],
"source": [
"from experiments.forecast import ForecastExperiment"
"from experiments.forecast import ForecastExperiment\n",
"from tqdm.auto import tqdm"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "8eadaa48",
"id": "029641dd",
"metadata": {
"ExecuteTime": {
"start_time": "2022-11-23T04:00:02.202Z"
}
"start_time": "2022-11-23T05:42:46.859Z"
},
"scrolled": true
},
"outputs": [
{
"data": {
"application/vnd.jupyter.widget-view+json": {
"model_id": "49eccaa5eb6942e29ed4943d5b92cf89",
"version_major": 2,
"version_minor": 0
},
"text/plain": [
" 0%| | 0/42 [00:00<?, ?it/s]"
]
},
"metadata": {},
"output_type": "display_data"
},
{
"name": "stdout",
"output_type": "stream",
"text": [
"storage/experiments/Stocks/96M2S/base_learner=Transformer,inr=INR,encoder=mlp,repeat=0/config.gin\n"
"storage/experiments/Stocks/96M2S/base_learner=Transformer,inr=INR,encoder=none,repeat=0/config.gin\n"
]
},
{
"name": "stderr",
"output_type": "stream",
"text": [
"INFO:root:epochs: 1, iters: 100 | training loss: 2.41\n",
"INFO:root:Validation loss decreased (inf --> 1.004). Saving model ...\n",
"INFO:root:epochs: 2, iters: 100 | training loss: 0.20\n",
"INFO:root:Validation loss decreased (1.004 --> 0.115). Saving model ...\n"
"INFO:root:epochs: 1, iters: 100 | training loss: 1.84\n"
]
}
],
"source": [
"for config in configs:\n",
"for config in tqdm(configs):\n",
" save_path = config.parent\n",
"\n",
" exp = ForecastExperiment(config_path=config)\n",
@@ -331,7 +336,7 @@
" except KeyboardInterrupt:\n",
" raise\n",
" except Exception as e:\n",
"# raise\n",
" raise\n",
" print(e)\n",
" pass"
]
@@ -339,11 +344,24 @@
{
"cell_type": "code",
"execution_count": null,
"id": "e7a3c151",
"id": "0472a84b",
"metadata": {
"ExecuteTime": {
"end_time": "2022-11-23T03:36:04.312188Z",
"start_time": "2022-11-23T03:36:04.312180Z"
"start_time": "2022-11-23T05:42:46.875Z"
}
},
"outputs": [],
"source": [
"%debug"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "03a28a90",
"metadata": {
"ExecuteTime": {
"start_time": "2022-11-23T05:42:46.875Z"
}
},
"outputs": [],
@@ -354,11 +372,10 @@
{
"cell_type": "code",
"execution_count": null,
"id": "8b73c64e",
"id": "a31ae85a",
"metadata": {
"ExecuteTime": {
"end_time": "2022-11-23T03:36:04.313334Z",
"start_time": "2022-11-23T03:36:04.313322Z"
"start_time": "2022-11-23T05:42:46.875Z"
}
},
"outputs": [],
@@ -369,11 +386,10 @@
{
"cell_type": "code",
"execution_count": null,
"id": "6af1bb1e",
"id": "1a9d823d",
"metadata": {
"ExecuteTime": {
"end_time": "2022-11-23T03:36:04.313962Z",
"start_time": "2022-11-23T03:36:04.313954Z"
"start_time": "2022-11-23T05:42:46.875Z"
}
},
"outputs": [],
@@ -388,11 +404,10 @@
{
"cell_type": "code",
"execution_count": null,
"id": "50a92a7e",
"id": "5e4983e7",
"metadata": {
"ExecuteTime": {
"end_time": "2022-11-23T03:36:04.314594Z",
"start_time": "2022-11-23T03:36:04.314588Z"
"start_time": "2022-11-23T05:42:46.875Z"
}
},
"outputs": [],
@@ -405,11 +420,10 @@
{
"cell_type": "code",
"execution_count": null,
"id": "9ef76b76",
"id": "cdd0fbbf",
"metadata": {
"ExecuteTime": {
"end_time": "2022-11-23T03:36:04.315217Z",
"start_time": "2022-11-23T03:36:04.315210Z"
"start_time": "2022-11-23T05:42:46.875Z"
}
},
"outputs": [],
@@ -433,7 +447,7 @@
},
{
"cell_type": "markdown",
"id": "3637e87d",
"id": "321106ef",
"metadata": {},
"source": [
"# view all"
@@ -445,8 +459,7 @@
"id": "768530be",
"metadata": {
"ExecuteTime": {
"end_time": "2022-11-23T03:36:04.315891Z",
"start_time": "2022-11-23T03:36:04.315884Z"
"start_time": "2022-11-23T05:42:46.875Z"
}
},
"outputs": [],
@@ -463,7 +476,7 @@
" model_name = gin.query_parameter(\"instance.model_type\")\n",
"\n",
" train_set, train_loader = get_data(flag='test', batch_size=3)\n",
" seq_len = train_set[0][1].shape\n",
" seq_len = train_set[0][1].shape[0]\n",
" model = get_model(model_name,\n",
" dim_size=train_set.data_x.shape[1],\n",
" seq_len=seq_len,\n",
@@ -522,8 +535,7 @@
"id": "739ee5e3",
"metadata": {
"ExecuteTime": {
"end_time": "2022-11-23T03:36:04.316469Z",
"start_time": "2022-11-23T03:36:04.316463Z"
"start_time": "2022-11-23T05:42:46.875Z"
}
},
"outputs": [],
@@ -536,11 +548,10 @@
{
"cell_type": "code",
"execution_count": null,
"id": "48e9175b",
"id": "f1ae652d",
"metadata": {
"ExecuteTime": {
"end_time": "2022-11-23T03:36:04.316979Z",
"start_time": "2022-11-23T03:36:04.316973Z"
"start_time": "2022-11-23T05:42:46.875Z"
}
},
"outputs": [],
@@ -556,11 +567,10 @@
{
"cell_type": "code",
"execution_count": null,
"id": "4b966ef7",
"id": "79bda4bf",
"metadata": {
"ExecuteTime": {
"end_time": "2022-11-23T03:36:04.317585Z",
"start_time": "2022-11-23T03:36:04.317578Z"
"start_time": "2022-11-23T05:42:46.875Z"
}
},
"outputs": [],
@@ -572,11 +582,10 @@
{
"cell_type": "code",
"execution_count": null,
"id": "0b696e85",
"id": "b06e53a9",
"metadata": {
"ExecuteTime": {
"end_time": "2022-11-23T03:36:04.318277Z",
"start_time": "2022-11-23T03:36:04.318270Z"
"start_time": "2022-11-23T05:42:46.892Z"
}
},
"outputs": [],
@@ -599,11 +608,10 @@
{
"cell_type": "code",
"execution_count": null,
"id": "ac9e5759",
"id": "25b9b9c5",
"metadata": {
"ExecuteTime": {
"end_time": "2022-11-23T03:36:04.319520Z",
"start_time": "2022-11-23T03:36:04.319512Z"
"start_time": "2022-11-23T05:42:46.892Z"
}
},
"outputs": [],
@@ -620,11 +628,10 @@
{
"cell_type": "code",
"execution_count": null,
"id": "b2e27d42",
"id": "e173ea6f",
"metadata": {
"ExecuteTime": {
"end_time": "2022-11-23T03:36:04.320031Z",
"start_time": "2022-11-23T03:36:04.320024Z"
"start_time": "2022-11-23T05:42:46.892Z"
}
},
"outputs": [],
@@ -635,7 +642,7 @@
{
"cell_type": "code",
"execution_count": null,
"id": "938f06db",
"id": "9ca6f142",
"metadata": {
"ExecuteTime": {
"end_time": "2022-11-22T13:17:31.585029Z",
@@ -647,7 +654,7 @@
},
{
"cell_type": "markdown",
"id": "9b4e817c",
"id": "fce6db58",
"metadata": {},
"source": [
"# check positions in dl"
@@ -656,11 +663,10 @@
{
"cell_type": "code",
"execution_count": null,
"id": "df6c9b28",
"id": "5d574727",
"metadata": {
"ExecuteTime": {
"end_time": "2022-11-23T03:36:04.320752Z",
"start_time": "2022-11-23T03:36:04.320745Z"
"start_time": "2022-11-23T05:42:46.892Z"
},
"scrolled": true
},
@@ -675,11 +681,10 @@
{
"cell_type": "code",
"execution_count": null,
"id": "fa54042e",
"id": "74ebc120",
"metadata": {
"ExecuteTime": {
"end_time": "2022-11-23T03:36:04.321671Z",
"start_time": "2022-11-23T03:36:04.321662Z"
"start_time": "2022-11-23T05:42:46.892Z"
}
},
"outputs": [],
@@ -691,11 +696,10 @@
{
"cell_type": "code",
"execution_count": null,
"id": "2482f0f7",
"id": "d289f6bd",
"metadata": {
"ExecuteTime": {
"end_time": "2022-11-23T03:36:04.322543Z",
"start_time": "2022-11-23T03:36:04.322535Z"
"start_time": "2022-11-23T05:42:46.892Z"
}
},
"outputs": [],
@@ -710,7 +714,7 @@
{
"cell_type": "code",
"execution_count": null,
"id": "e8fe355e",
"id": "879d9d45",
"metadata": {},
"outputs": [],
"source": []
@@ -718,7 +722,7 @@
{
"cell_type": "code",
"execution_count": null,
"id": "3a978978",
"id": "37d7641b",
"metadata": {},
"outputs": [],
"source": []
@@ -726,7 +730,7 @@
{
"cell_type": "code",
"execution_count": null,
"id": "045d3fd5",
"id": "061b3b77",
"metadata": {},
"outputs": [],
"source": []
+13
View File
@@ -0,0 +1,13 @@
from json_tricks import dump, dumps, load, loads, strip_comments
def torch_encode(obj, primitives=False):
from torch import Tensor
if isinstance(obj, Tensor):
if primitives:
return obj.numpy().tolist()
raise NotImplemented()
return obj
def serialize(o):
s = dumps(o, extra_obj_encoders=[torch_encode], primitives=True)
return loads(s)