diff --git a/experiments/forecast.py b/experiments/forecast.py index 9b9afe2..04ca761 100644 --- a/experiments/forecast.py +++ b/experiments/forecast.py @@ -20,6 +20,7 @@ from utils.checkpoint import Checkpoint from utils.ops import default_device, to_tensor from utils.losses import get_loss_fn from utils.metrics import calc_metrics +from utils.serialize import serialize class ForecastExperiment(Experiment): @@ -33,7 +34,7 @@ class ForecastExperiment(Experiment): test_set, test_loader = get_data(flag='test') dim_size=train_set.data_x.shape[1] - seq_len = train_set[0][1].shape + seq_len = train_set[0][1].shape[0] model = get_model(model_type, dim_size=dim_size, seq_len=seq_len, @@ -47,8 +48,10 @@ class ForecastExperiment(Experiment): val_metrics = validate(model, loader=val_loader, report_metrics=True) test_metrics = validate(model, loader=test_loader, report_metrics=True, save_path=self.root if save_vals else None) + metrics = {'val': val_metrics, 'test': test_metrics} # np.save(join(self.root, 'metrics.npy'), {'val': val_metrics, 'test': test_metrics}) - json.dump({'val': val_metrics, 'test': test_metrics}, open(join(self.root, 'metrics.npy', 'w'))) + metrics = serialize(metrics) + json.dump(metrics, open(join(self.root, 'metrics.npy'), 'w')) val_metrics = {f'ValMetric/{k}': v for k, v in val_metrics.items()} test_metrics = {f'TestMetric/{k}': v for k, v in test_metrics.items()} diff --git a/scratch-run_exp.ipynb b/scratch-run_exp.ipynb index a81baa6..bb7fa66 100644 --- a/scratch-run_exp.ipynb +++ b/scratch-run_exp.ipynb @@ -21,8 +21,8 @@ "id": "7f9e3d73", "metadata": { "ExecuteTime": { - "end_time": "2022-11-23T03:32:18.927915Z", - "start_time": "2022-11-23T03:32:18.918871Z" + "end_time": "2022-11-23T05:42:46.869262Z", + "start_time": "2022-11-23T05:42:46.859832Z" } }, "outputs": [], @@ -41,8 +41,8 @@ "id": "4e09086b", "metadata": { "ExecuteTime": { - "end_time": "2022-11-23T03:32:20.144762Z", - "start_time": "2022-11-23T03:32:18.928928Z" + "end_time": "2022-11-23T05:42:48.073268Z", + "start_time": "2022-11-23T05:42:46.870484Z" }, "lines_to_next_cell": 0 }, @@ -84,8 +84,8 @@ "id": "66d7f095", "metadata": { "ExecuteTime": { - "end_time": "2022-11-23T03:32:20.171177Z", - "start_time": "2022-11-23T03:32:20.146544Z" + "end_time": "2022-11-23T05:42:48.097114Z", + "start_time": "2022-11-23T05:42:48.074670Z" } }, "outputs": [ @@ -123,8 +123,8 @@ "id": "04499bef", "metadata": { "ExecuteTime": { - "end_time": "2022-11-23T03:32:20.191130Z", - "start_time": "2022-11-23T03:32:20.172479Z" + "end_time": "2022-11-23T05:42:48.115157Z", + "start_time": "2022-11-23T05:42:48.098029Z" } }, "outputs": [], @@ -179,7 +179,7 @@ { "cell_type": "code", "execution_count": null, - "id": "09dd5ebd", + "id": "6615bb09", "metadata": { "ExecuteTime": { "end_time": "2022-11-22T13:28:35.849491Z", @@ -193,11 +193,11 @@ { "cell_type": "code", "execution_count": 5, - "id": "dc530891", + "id": "ec47ae46", "metadata": { "ExecuteTime": { - "end_time": "2022-11-23T03:32:20.213967Z", - "start_time": "2022-11-23T03:32:20.192307Z" + "end_time": "2022-11-23T05:42:48.138242Z", + "start_time": "2022-11-23T05:42:48.115973Z" } }, "outputs": [ @@ -218,7 +218,7 @@ }, { "cell_type": "markdown", - "id": "da8f502f", + "id": "0a298ce2", "metadata": { "ExecuteTime": { "end_time": "2022-11-22T13:38:15.786656Z", @@ -231,7 +231,7 @@ }, { "cell_type": "markdown", - "id": "ccc9fdb0", + "id": "40002390", "metadata": {}, "source": [ "# run exps" @@ -240,7 +240,7 @@ { "cell_type": "code", "execution_count": null, - "id": "e09f2823", + "id": "2e1e58e7", "metadata": { "ExecuteTime": { "end_time": "2022-11-23T03:31:08.583641Z", @@ -253,75 +253,80 @@ { "cell_type": "code", "execution_count": 6, - "id": "9d13d779", + "id": "91b92e80", "metadata": { "ExecuteTime": { - "end_time": "2022-11-23T03:32:20.232318Z", - "start_time": "2022-11-23T03:32:20.215075Z" + "end_time": "2022-11-23T05:42:48.154921Z", + "start_time": "2022-11-23T05:42:48.139190Z" } }, - "outputs": [ - { - "name": "stdout", - "output_type": "stream", - "text": [ - "[Path('storage/experiments/Stocks/96M2S/base_learner=Transformer,inr=INR,encoder=mlp,repeat=0/config.gin'), Path('storage/experiments/Stocks/96M2S/base_learner=Ridge,inr=INRPlus2,encoder=inception,repeat=0/config.gin'), Path('storage/experiments/Stocks/96M2S/base_learner=Ridge,inr=INRPlus2,encoder=none,repeat=0/config.gin'), Path('storage/experiments/Stocks/96M2S/base_learner=None,inr=INRPlus2,encoder=transformer2,repeat=0/config.gin'), Path('storage/experiments/Stocks/96M2S/base_learner=None,inr=INR,encoder=transformer,repeat=0/config.gin'), Path('storage/experiments/Stocks/96M2S/base_learner=Transformer,inr=INR,encoder=lstm,repeat=0/config.gin'), Path('storage/experiments/Stocks/96M2S/base_learner=None,inr=INR,encoder=mlp,repeat=0/config.gin'), Path('storage/experiments/Stocks/96M2S/base_learner=Transformer,inr=INR,encoder=transformer,repeat=0/config.gin'), Path('storage/experiments/Stocks/96M2S/base_learner=None,inr=INR,encoder=transformer2,repeat=0/config.gin'), Path('storage/experiments/Stocks/96M2S/base_learner=Ridge,inr=INR,encoder=lstm,repeat=0/config.gin'), Path('storage/experiments/Stocks/96M2S/base_learner=Ridge,inr=INR,encoder=lstm2,repeat=0/config.gin'), Path('storage/experiments/Stocks/96M2S/base_learner=Ridge,inr=INRPlus2,encoder=mlp,repeat=0/config.gin'), Path('storage/experiments/Stocks/96M2S/base_learner=None,inr=INRPlus2,encoder=transformer,repeat=0/config.gin'), Path('storage/experiments/Stocks/96M2S/base_learner=None,inr=INRPlus2,encoder=inception,repeat=0/config.gin'), Path('storage/experiments/Stocks/96M2S/base_learner=Ridge,inr=INR,encoder=mlp,repeat=0/config.gin'), Path('storage/experiments/Stocks/96M2S/base_learner=Transformer,inr=INRPlus2,encoder=transformer2,repeat=0/config.gin'), Path('storage/experiments/Stocks/96M2S/base_learner=None,inr=INR,encoder=lstm,repeat=0/config.gin'), Path('storage/experiments/Stocks/96M2S/base_learner=Transformer,inr=INR,encoder=transformer2,repeat=0/config.gin'), Path('storage/experiments/Stocks/96M2S/base_learner=None,inr=INRPlus2,encoder=mlp,repeat=0/config.gin'), Path('storage/experiments/Stocks/96M2S/base_learner=Transformer,inr=INR,encoder=lstm2,repeat=0/config.gin'), Path('storage/experiments/Stocks/96M2S/base_learner=Ridge,inr=INR,encoder=transformer,repeat=0/config.gin'), Path('storage/experiments/Stocks/96M2S/base_learner=Ridge,inr=INRPlus2,encoder=lstm2,repeat=0/config.gin'), Path('storage/experiments/Stocks/96M2S/base_learner=Transformer,inr=INRPlus2,encoder=lstm,repeat=0/config.gin'), Path('storage/experiments/Stocks/96M2S/base_learner=Ridge,inr=INRPlus2,encoder=transformer2,repeat=0/config.gin'), Path('storage/experiments/Stocks/96M2S/base_learner=None,inr=INR,encoder=lstm2,repeat=0/config.gin'), Path('storage/experiments/Stocks/96M2S/base_learner=Ridge,inr=INR,encoder=transformer2,repeat=0/config.gin'), Path('storage/experiments/Stocks/96M2S/base_learner=Transformer,inr=INRPlus2,encoder=inception,repeat=0/config.gin'), Path('storage/experiments/Stocks/96M2S/base_learner=None,inr=INRPlus2,encoder=none,repeat=0/config.gin'), Path('storage/experiments/Stocks/96M2S/base_learner=Ridge,inr=INR,encoder=inception,repeat=0/config.gin'), Path('storage/experiments/Stocks/96M2S/base_learner=Transformer,inr=INRPlus2,encoder=mlp,repeat=0/config.gin'), Path('storage/experiments/Stocks/96M2S/base_learner=Transformer,inr=INRPlus2,encoder=transformer,repeat=0/config.gin'), Path('storage/experiments/Stocks/96M2S/base_learner=None,inr=INRPlus2,encoder=lstm,repeat=0/config.gin'), Path('storage/experiments/Stocks/96M2S/base_learner=None,inr=INR,encoder=inception,repeat=0/config.gin'), Path('storage/experiments/Stocks/96M2S/base_learner=None,inr=INR,encoder=none,repeat=0/config.gin'), Path('storage/experiments/Stocks/96M2S/base_learner=Transformer,inr=INR,encoder=none,repeat=0/config.gin'), Path('storage/experiments/Stocks/96M2S/base_learner=Transformer,inr=INR,encoder=inception,repeat=0/config.gin'), Path('storage/experiments/Stocks/96M2S/base_learner=None,inr=INRPlus2,encoder=lstm2,repeat=0/config.gin'), Path('storage/experiments/Stocks/96M2S/base_learner=Transformer,inr=INRPlus2,encoder=lstm2,repeat=0/config.gin'), Path('storage/experiments/Stocks/96M2S/base_learner=Ridge,inr=INR,encoder=none,repeat=0/config.gin'), Path('storage/experiments/Stocks/96M2S/base_learner=Ridge,inr=INRPlus2,encoder=transformer,repeat=0/config.gin'), Path('storage/experiments/Stocks/96M2S/base_learner=Transformer,inr=INRPlus2,encoder=none,repeat=0/config.gin'), Path('storage/experiments/Stocks/96M2S/base_learner=Ridge,inr=INRPlus2,encoder=lstm,repeat=0/config.gin')]\n" - ] - } - ], + "outputs": [], "source": [ "# list the models we have run...\n", "configs=sorted(Path(\"storage/experiments/Stocks\").glob(\"**/config.gin\"))\n", "import random\n", "random.shuffle(configs)\n", - "print(configs)" + "# print(configs)" ] }, { "cell_type": "code", "execution_count": 7, - "id": "4178d85e", + "id": "8bb001f7", "metadata": { "ExecuteTime": { - "end_time": "2022-11-23T03:32:20.248919Z", - "start_time": "2022-11-23T03:32:20.233964Z" + "end_time": "2022-11-23T05:42:48.171079Z", + "start_time": "2022-11-23T05:42:48.155764Z" } }, "outputs": [], "source": [ - "from experiments.forecast import ForecastExperiment" + "from experiments.forecast import ForecastExperiment\n", + "from tqdm.auto import tqdm" ] }, { "cell_type": "code", "execution_count": null, - "id": "8eadaa48", + "id": "029641dd", "metadata": { "ExecuteTime": { - "start_time": "2022-11-23T04:00:02.202Z" - } + "start_time": "2022-11-23T05:42:46.859Z" + }, + "scrolled": true }, "outputs": [ + { + "data": { + "application/vnd.jupyter.widget-view+json": { + "model_id": "49eccaa5eb6942e29ed4943d5b92cf89", + "version_major": 2, + "version_minor": 0 + }, + "text/plain": [ + " 0%| | 0/42 [00:00 1.004). Saving model ...\n", - "INFO:root:epochs: 2, iters: 100 | training loss: 0.20\n", - "INFO:root:Validation loss decreased (1.004 --> 0.115). Saving model ...\n" + "INFO:root:epochs: 1, iters: 100 | training loss: 1.84\n" ] } ], "source": [ - "for config in configs:\n", + "for config in tqdm(configs):\n", " save_path = config.parent\n", "\n", " exp = ForecastExperiment(config_path=config)\n", @@ -331,7 +336,7 @@ " except KeyboardInterrupt:\n", " raise\n", " except Exception as e:\n", - "# raise\n", + " raise\n", " print(e)\n", " pass" ] @@ -339,11 +344,24 @@ { "cell_type": "code", "execution_count": null, - "id": "e7a3c151", + "id": "0472a84b", "metadata": { "ExecuteTime": { - "end_time": "2022-11-23T03:36:04.312188Z", - "start_time": "2022-11-23T03:36:04.312180Z" + "start_time": "2022-11-23T05:42:46.875Z" + } + }, + "outputs": [], + "source": [ + "%debug" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "03a28a90", + "metadata": { + "ExecuteTime": { + "start_time": "2022-11-23T05:42:46.875Z" } }, "outputs": [], @@ -354,11 +372,10 @@ { "cell_type": "code", "execution_count": null, - "id": "8b73c64e", + "id": "a31ae85a", "metadata": { "ExecuteTime": { - "end_time": "2022-11-23T03:36:04.313334Z", - "start_time": "2022-11-23T03:36:04.313322Z" + "start_time": "2022-11-23T05:42:46.875Z" } }, "outputs": [], @@ -369,11 +386,10 @@ { "cell_type": "code", "execution_count": null, - "id": "6af1bb1e", + "id": "1a9d823d", "metadata": { "ExecuteTime": { - "end_time": "2022-11-23T03:36:04.313962Z", - "start_time": "2022-11-23T03:36:04.313954Z" + "start_time": "2022-11-23T05:42:46.875Z" } }, "outputs": [], @@ -388,11 +404,10 @@ { "cell_type": "code", "execution_count": null, - "id": "50a92a7e", + "id": "5e4983e7", "metadata": { "ExecuteTime": { - "end_time": "2022-11-23T03:36:04.314594Z", - "start_time": "2022-11-23T03:36:04.314588Z" + "start_time": "2022-11-23T05:42:46.875Z" } }, "outputs": [], @@ -405,11 +420,10 @@ { "cell_type": "code", "execution_count": null, - "id": "9ef76b76", + "id": "cdd0fbbf", "metadata": { "ExecuteTime": { - "end_time": "2022-11-23T03:36:04.315217Z", - "start_time": "2022-11-23T03:36:04.315210Z" + "start_time": "2022-11-23T05:42:46.875Z" } }, "outputs": [], @@ -433,7 +447,7 @@ }, { "cell_type": "markdown", - "id": "3637e87d", + "id": "321106ef", "metadata": {}, "source": [ "# view all" @@ -445,8 +459,7 @@ "id": "768530be", "metadata": { "ExecuteTime": { - "end_time": "2022-11-23T03:36:04.315891Z", - "start_time": "2022-11-23T03:36:04.315884Z" + "start_time": "2022-11-23T05:42:46.875Z" } }, "outputs": [], @@ -463,7 +476,7 @@ " model_name = gin.query_parameter(\"instance.model_type\")\n", "\n", " train_set, train_loader = get_data(flag='test', batch_size=3)\n", - " seq_len = train_set[0][1].shape\n", + " seq_len = train_set[0][1].shape[0]\n", " model = get_model(model_name,\n", " dim_size=train_set.data_x.shape[1],\n", " seq_len=seq_len,\n", @@ -522,8 +535,7 @@ "id": "739ee5e3", "metadata": { "ExecuteTime": { - "end_time": "2022-11-23T03:36:04.316469Z", - "start_time": "2022-11-23T03:36:04.316463Z" + "start_time": "2022-11-23T05:42:46.875Z" } }, "outputs": [], @@ -536,11 +548,10 @@ { "cell_type": "code", "execution_count": null, - "id": "48e9175b", + "id": "f1ae652d", "metadata": { "ExecuteTime": { - "end_time": "2022-11-23T03:36:04.316979Z", - "start_time": "2022-11-23T03:36:04.316973Z" + "start_time": "2022-11-23T05:42:46.875Z" } }, "outputs": [], @@ -556,11 +567,10 @@ { "cell_type": "code", "execution_count": null, - "id": "4b966ef7", + "id": "79bda4bf", "metadata": { "ExecuteTime": { - "end_time": "2022-11-23T03:36:04.317585Z", - "start_time": "2022-11-23T03:36:04.317578Z" + "start_time": "2022-11-23T05:42:46.875Z" } }, "outputs": [], @@ -572,11 +582,10 @@ { "cell_type": "code", "execution_count": null, - "id": "0b696e85", + "id": "b06e53a9", "metadata": { "ExecuteTime": { - "end_time": "2022-11-23T03:36:04.318277Z", - "start_time": "2022-11-23T03:36:04.318270Z" + "start_time": "2022-11-23T05:42:46.892Z" } }, "outputs": [], @@ -599,11 +608,10 @@ { "cell_type": "code", "execution_count": null, - "id": "ac9e5759", + "id": "25b9b9c5", "metadata": { "ExecuteTime": { - "end_time": "2022-11-23T03:36:04.319520Z", - "start_time": "2022-11-23T03:36:04.319512Z" + "start_time": "2022-11-23T05:42:46.892Z" } }, "outputs": [], @@ -620,11 +628,10 @@ { "cell_type": "code", "execution_count": null, - "id": "b2e27d42", + "id": "e173ea6f", "metadata": { "ExecuteTime": { - "end_time": "2022-11-23T03:36:04.320031Z", - "start_time": "2022-11-23T03:36:04.320024Z" + "start_time": "2022-11-23T05:42:46.892Z" } }, "outputs": [], @@ -635,7 +642,7 @@ { "cell_type": "code", "execution_count": null, - "id": "938f06db", + "id": "9ca6f142", "metadata": { "ExecuteTime": { "end_time": "2022-11-22T13:17:31.585029Z", @@ -647,7 +654,7 @@ }, { "cell_type": "markdown", - "id": "9b4e817c", + "id": "fce6db58", "metadata": {}, "source": [ "# check positions in dl" @@ -656,11 +663,10 @@ { "cell_type": "code", "execution_count": null, - "id": "df6c9b28", + "id": "5d574727", "metadata": { "ExecuteTime": { - "end_time": "2022-11-23T03:36:04.320752Z", - "start_time": "2022-11-23T03:36:04.320745Z" + "start_time": "2022-11-23T05:42:46.892Z" }, "scrolled": true }, @@ -675,11 +681,10 @@ { "cell_type": "code", "execution_count": null, - "id": "fa54042e", + "id": "74ebc120", "metadata": { "ExecuteTime": { - "end_time": "2022-11-23T03:36:04.321671Z", - "start_time": "2022-11-23T03:36:04.321662Z" + "start_time": "2022-11-23T05:42:46.892Z" } }, "outputs": [], @@ -691,11 +696,10 @@ { "cell_type": "code", "execution_count": null, - "id": "2482f0f7", + "id": "d289f6bd", "metadata": { "ExecuteTime": { - "end_time": "2022-11-23T03:36:04.322543Z", - "start_time": "2022-11-23T03:36:04.322535Z" + "start_time": "2022-11-23T05:42:46.892Z" } }, "outputs": [], @@ -710,7 +714,7 @@ { "cell_type": "code", "execution_count": null, - "id": "e8fe355e", + "id": "879d9d45", "metadata": {}, "outputs": [], "source": [] @@ -718,7 +722,7 @@ { "cell_type": "code", "execution_count": null, - "id": "3a978978", + "id": "37d7641b", "metadata": {}, "outputs": [], "source": [] @@ -726,7 +730,7 @@ { "cell_type": "code", "execution_count": null, - "id": "045d3fd5", + "id": "061b3b77", "metadata": {}, "outputs": [], "source": [] diff --git a/utils/serialize.py b/utils/serialize.py new file mode 100644 index 0000000..4d00e89 --- /dev/null +++ b/utils/serialize.py @@ -0,0 +1,13 @@ +from json_tricks import dump, dumps, load, loads, strip_comments + +def torch_encode(obj, primitives=False): + from torch import Tensor + if isinstance(obj, Tensor): + if primitives: + return obj.numpy().tolist() + raise NotImplemented() + return obj + +def serialize(o): + s = dumps(o, extra_obj_encoders=[torch_encode], primitives=True) + return loads(s)