mirror of
https://github.com/wassname/pytorch-ts.git
synced 2026-06-27 18:06:19 +08:00
fixes
This commit is contained in:
@@ -5,8 +5,24 @@
|
||||
"execution_count": 1,
|
||||
"metadata": {
|
||||
"ExecuteTime": {
|
||||
"end_time": "2022-12-23T07:08:26.228737Z",
|
||||
"start_time": "2022-12-23T07:08:24.806396Z"
|
||||
"end_time": "2022-12-23T09:57:39.738715Z",
|
||||
"start_time": "2022-12-23T09:57:39.730077Z"
|
||||
}
|
||||
},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"# autoreload import your package\n",
|
||||
"%load_ext autoreload\n",
|
||||
"%autoreload 2"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 2,
|
||||
"metadata": {
|
||||
"ExecuteTime": {
|
||||
"end_time": "2022-12-23T09:57:40.738250Z",
|
||||
"start_time": "2022-12-23T09:57:39.739677Z"
|
||||
}
|
||||
},
|
||||
"outputs": [],
|
||||
@@ -22,23 +38,14 @@
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 2,
|
||||
"execution_count": 3,
|
||||
"metadata": {
|
||||
"ExecuteTime": {
|
||||
"end_time": "2022-12-23T07:08:26.417012Z",
|
||||
"start_time": "2022-12-23T07:08:26.230461Z"
|
||||
"end_time": "2022-12-23T09:57:40.909476Z",
|
||||
"start_time": "2022-12-23T09:57:40.739741Z"
|
||||
}
|
||||
},
|
||||
"outputs": [
|
||||
{
|
||||
"name": "stderr",
|
||||
"output_type": "stream",
|
||||
"text": [
|
||||
"/home/wassname/miniforge3/envs/gluonts10.0/lib/python3.9/site-packages/gluonts/json.py:101: UserWarning: Using `json`-module for json-handling. Consider installing one of `orjson`, `ujson` to speed up serialization and deserialization.\n",
|
||||
" warnings.warn(\n"
|
||||
]
|
||||
}
|
||||
],
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"from gluonts.dataset.multivariate_grouper import MultivariateGrouper\n",
|
||||
"from gluonts.dataset.repository.datasets import dataset_recipes, get_dataset\n",
|
||||
@@ -48,11 +55,11 @@
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 3,
|
||||
"execution_count": 4,
|
||||
"metadata": {
|
||||
"ExecuteTime": {
|
||||
"end_time": "2022-12-23T07:08:26.469299Z",
|
||||
"start_time": "2022-12-23T07:08:26.419051Z"
|
||||
"end_time": "2022-12-23T09:57:40.966206Z",
|
||||
"start_time": "2022-12-23T09:57:40.910477Z"
|
||||
}
|
||||
},
|
||||
"outputs": [],
|
||||
@@ -65,11 +72,11 @@
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 4,
|
||||
"execution_count": 5,
|
||||
"metadata": {
|
||||
"ExecuteTime": {
|
||||
"end_time": "2022-12-23T07:08:26.477478Z",
|
||||
"start_time": "2022-12-23T07:08:26.470628Z"
|
||||
"end_time": "2022-12-23T09:57:40.980883Z",
|
||||
"start_time": "2022-12-23T09:57:40.967157Z"
|
||||
}
|
||||
},
|
||||
"outputs": [],
|
||||
@@ -79,11 +86,11 @@
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 5,
|
||||
"execution_count": 6,
|
||||
"metadata": {
|
||||
"ExecuteTime": {
|
||||
"end_time": "2022-12-23T07:08:26.494456Z",
|
||||
"start_time": "2022-12-23T07:08:26.478496Z"
|
||||
"end_time": "2022-12-23T09:57:41.003468Z",
|
||||
"start_time": "2022-12-23T09:57:40.981830Z"
|
||||
}
|
||||
},
|
||||
"outputs": [],
|
||||
@@ -147,11 +154,11 @@
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 6,
|
||||
"execution_count": 7,
|
||||
"metadata": {
|
||||
"ExecuteTime": {
|
||||
"end_time": "2022-12-23T07:08:26.516023Z",
|
||||
"start_time": "2022-12-23T07:08:26.495676Z"
|
||||
"end_time": "2022-12-23T09:57:41.025224Z",
|
||||
"start_time": "2022-12-23T09:57:41.004330Z"
|
||||
}
|
||||
},
|
||||
"outputs": [
|
||||
@@ -169,11 +176,11 @@
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 7,
|
||||
"execution_count": 8,
|
||||
"metadata": {
|
||||
"ExecuteTime": {
|
||||
"end_time": "2022-12-23T07:08:26.532361Z",
|
||||
"start_time": "2022-12-23T07:08:26.517210Z"
|
||||
"end_time": "2022-12-23T09:57:41.041326Z",
|
||||
"start_time": "2022-12-23T09:57:41.026069Z"
|
||||
},
|
||||
"scrolled": true
|
||||
},
|
||||
@@ -185,11 +192,11 @@
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 8,
|
||||
"execution_count": 9,
|
||||
"metadata": {
|
||||
"ExecuteTime": {
|
||||
"end_time": "2022-12-23T07:08:26.554669Z",
|
||||
"start_time": "2022-12-23T07:08:26.534354Z"
|
||||
"end_time": "2022-12-23T09:57:41.057694Z",
|
||||
"start_time": "2022-12-23T09:57:41.042892Z"
|
||||
}
|
||||
},
|
||||
"outputs": [
|
||||
@@ -199,7 +206,7 @@
|
||||
"MetaData(freq='H', target=None, feat_static_cat=[CategoricalFeatureInfo(name='feat_static_cat_0', cardinality='370')], feat_static_real=[], feat_dynamic_real=[], feat_dynamic_cat=[], prediction_length=24)"
|
||||
]
|
||||
},
|
||||
"execution_count": 8,
|
||||
"execution_count": 9,
|
||||
"metadata": {},
|
||||
"output_type": "execute_result"
|
||||
}
|
||||
@@ -210,67 +217,51 @@
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 17,
|
||||
"execution_count": null,
|
||||
"metadata": {
|
||||
"ExecuteTime": {
|
||||
"end_time": "2022-12-23T07:12:03.693420Z",
|
||||
"start_time": "2022-12-23T07:12:02.845649Z"
|
||||
}
|
||||
},
|
||||
"outputs": [
|
||||
{
|
||||
"data": {
|
||||
"text/plain": [
|
||||
"7"
|
||||
]
|
||||
},
|
||||
"execution_count": 17,
|
||||
"metadata": {},
|
||||
"output_type": "execute_result"
|
||||
}
|
||||
],
|
||||
"outputs": [],
|
||||
"source": []
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 18,
|
||||
"execution_count": 10,
|
||||
"metadata": {
|
||||
"ExecuteTime": {
|
||||
"end_time": "2022-12-23T07:12:11.716051Z",
|
||||
"start_time": "2022-12-23T07:12:11.712672Z"
|
||||
"end_time": "2022-12-23T09:57:41.073075Z",
|
||||
"start_time": "2022-12-23T09:57:41.058526Z"
|
||||
}
|
||||
},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"train_grouper = MultivariateGrouper(max_target_dim=min(2000, int(dataset.metadata.feat_static_cat[0].cardinality)))\n",
|
||||
"\n",
|
||||
"test_grouper = MultivariateGrouper(\n",
|
||||
"# num_test_dates=int(len(dataset.test)/len(dataset.train)*2),\n",
|
||||
" num_test_dates=7,\n",
|
||||
" max_target_dim=min(2000, int(dataset.metadata.feat_static_cat[0].cardinality)))"
|
||||
"train_grouper = MultivariateGrouper(max_target_dim=min(2000, int(dataset.metadata.feat_static_cat[0].cardinality)))\n"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"execution_count": 11,
|
||||
"metadata": {
|
||||
"ExecuteTime": {
|
||||
"start_time": "2022-12-23T07:12:12.382Z"
|
||||
"end_time": "2022-12-23T09:57:44.134871Z",
|
||||
"start_time": "2022-12-23T09:57:41.073880Z"
|
||||
}
|
||||
},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"dataset_train = train_grouper(dataset.train)\n",
|
||||
"dataset_test = test_grouper(dataset.test)"
|
||||
"dataset_train = train_grouper(dataset.train)\n"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"execution_count": 12,
|
||||
"metadata": {
|
||||
"ExecuteTime": {
|
||||
"end_time": "2022-12-23T07:09:49.342107Z",
|
||||
"start_time": "2022-12-23T07:09:49.342097Z"
|
||||
"end_time": "2022-12-23T09:57:44.153700Z",
|
||||
"start_time": "2022-12-23T09:57:44.136099Z"
|
||||
}
|
||||
},
|
||||
"outputs": [],
|
||||
@@ -300,22 +291,265 @@
|
||||
"execution_count": null,
|
||||
"metadata": {
|
||||
"ExecuteTime": {
|
||||
"end_time": "2022-12-23T07:09:49.342839Z",
|
||||
"start_time": "2022-12-23T07:09:49.342830Z"
|
||||
"start_time": "2022-12-23T09:57:39.746Z"
|
||||
}
|
||||
},
|
||||
"outputs": [],
|
||||
"outputs": [
|
||||
{
|
||||
"data": {
|
||||
"application/vnd.jupyter.widget-view+json": {
|
||||
"model_id": "744b148440674ef3902ca56622cb001b",
|
||||
"version_major": 2,
|
||||
"version_minor": 0
|
||||
},
|
||||
"text/plain": [
|
||||
" 0%| | 0/99 [00:00<?, ?it/s]"
|
||||
]
|
||||
},
|
||||
"metadata": {},
|
||||
"output_type": "display_data"
|
||||
},
|
||||
{
|
||||
"data": {
|
||||
"application/vnd.jupyter.widget-view+json": {
|
||||
"model_id": "88abd67eb98d4569b276f98dadd73205",
|
||||
"version_major": 2,
|
||||
"version_minor": 0
|
||||
},
|
||||
"text/plain": [
|
||||
" 0%| | 0/99 [00:00<?, ?it/s]"
|
||||
]
|
||||
},
|
||||
"metadata": {},
|
||||
"output_type": "display_data"
|
||||
},
|
||||
{
|
||||
"data": {
|
||||
"application/vnd.jupyter.widget-view+json": {
|
||||
"model_id": "f5e08c9306c647cd8589f29d8652177b",
|
||||
"version_major": 2,
|
||||
"version_minor": 0
|
||||
},
|
||||
"text/plain": [
|
||||
" 0%| | 0/99 [00:00<?, ?it/s]"
|
||||
]
|
||||
},
|
||||
"metadata": {},
|
||||
"output_type": "display_data"
|
||||
},
|
||||
{
|
||||
"data": {
|
||||
"application/vnd.jupyter.widget-view+json": {
|
||||
"model_id": "10a4ac2006db4f6090dc3f08cc3e3e9e",
|
||||
"version_major": 2,
|
||||
"version_minor": 0
|
||||
},
|
||||
"text/plain": [
|
||||
" 0%| | 0/99 [00:00<?, ?it/s]"
|
||||
]
|
||||
},
|
||||
"metadata": {},
|
||||
"output_type": "display_data"
|
||||
},
|
||||
{
|
||||
"data": {
|
||||
"application/vnd.jupyter.widget-view+json": {
|
||||
"model_id": "6458a98361854786bc4fa81da37471d2",
|
||||
"version_major": 2,
|
||||
"version_minor": 0
|
||||
},
|
||||
"text/plain": [
|
||||
" 0%| | 0/99 [00:00<?, ?it/s]"
|
||||
]
|
||||
},
|
||||
"metadata": {},
|
||||
"output_type": "display_data"
|
||||
},
|
||||
{
|
||||
"data": {
|
||||
"application/vnd.jupyter.widget-view+json": {
|
||||
"model_id": "676d661e21ec42acbb544a8848264674",
|
||||
"version_major": 2,
|
||||
"version_minor": 0
|
||||
},
|
||||
"text/plain": [
|
||||
" 0%| | 0/99 [00:00<?, ?it/s]"
|
||||
]
|
||||
},
|
||||
"metadata": {},
|
||||
"output_type": "display_data"
|
||||
},
|
||||
{
|
||||
"data": {
|
||||
"application/vnd.jupyter.widget-view+json": {
|
||||
"model_id": "36010cebcd884d248de158890d2ba324",
|
||||
"version_major": 2,
|
||||
"version_minor": 0
|
||||
},
|
||||
"text/plain": [
|
||||
" 0%| | 0/99 [00:00<?, ?it/s]"
|
||||
]
|
||||
},
|
||||
"metadata": {},
|
||||
"output_type": "display_data"
|
||||
},
|
||||
{
|
||||
"data": {
|
||||
"application/vnd.jupyter.widget-view+json": {
|
||||
"model_id": "f7487fae1ae54dae8c71d23ea5721e7f",
|
||||
"version_major": 2,
|
||||
"version_minor": 0
|
||||
},
|
||||
"text/plain": [
|
||||
" 0%| | 0/99 [00:00<?, ?it/s]"
|
||||
]
|
||||
},
|
||||
"metadata": {},
|
||||
"output_type": "display_data"
|
||||
},
|
||||
{
|
||||
"data": {
|
||||
"application/vnd.jupyter.widget-view+json": {
|
||||
"model_id": "6a995bf552bb487cbca2a0f3e45098b3",
|
||||
"version_major": 2,
|
||||
"version_minor": 0
|
||||
},
|
||||
"text/plain": [
|
||||
" 0%| | 0/99 [00:00<?, ?it/s]"
|
||||
]
|
||||
},
|
||||
"metadata": {},
|
||||
"output_type": "display_data"
|
||||
},
|
||||
{
|
||||
"data": {
|
||||
"application/vnd.jupyter.widget-view+json": {
|
||||
"model_id": "3ee8b7cdd0b94465a78c43d0917910ae",
|
||||
"version_major": 2,
|
||||
"version_minor": 0
|
||||
},
|
||||
"text/plain": [
|
||||
" 0%| | 0/99 [00:00<?, ?it/s]"
|
||||
]
|
||||
},
|
||||
"metadata": {},
|
||||
"output_type": "display_data"
|
||||
},
|
||||
{
|
||||
"data": {
|
||||
"application/vnd.jupyter.widget-view+json": {
|
||||
"model_id": "958283558c6a4e958c5c67188e735ed9",
|
||||
"version_major": 2,
|
||||
"version_minor": 0
|
||||
},
|
||||
"text/plain": [
|
||||
" 0%| | 0/99 [00:00<?, ?it/s]"
|
||||
]
|
||||
},
|
||||
"metadata": {},
|
||||
"output_type": "display_data"
|
||||
}
|
||||
],
|
||||
"source": [
|
||||
"predictor = estimator.train(dataset_train, num_workers=0)"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": []
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"metadata": {
|
||||
"ExecuteTime": {
|
||||
"end_time": "2022-12-23T07:09:49.343440Z",
|
||||
"start_time": "2022-12-23T07:09:49.343431Z"
|
||||
"start_time": "2022-12-23T09:57:39.748Z"
|
||||
}
|
||||
},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"\n",
|
||||
"test_grouper = MultivariateGrouper(\n",
|
||||
" num_test_dates=int(len(dataset.test)/len(dataset.train)*2),\n",
|
||||
"# num_test_dates=7,\n",
|
||||
" max_target_dim=min(2000, int(dataset.metadata.feat_static_cat[0].cardinality)))"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"metadata": {
|
||||
"ExecuteTime": {
|
||||
"start_time": "2022-12-23T09:57:39.749Z"
|
||||
}
|
||||
},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"dataset_test = test_grouper(dataset.test)"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"metadata": {
|
||||
"ExecuteTime": {
|
||||
"start_time": "2022-12-23T09:57:39.750Z"
|
||||
}
|
||||
},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"next(iter(dataset.test.iter_sequential()))"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"metadata": {
|
||||
"ExecuteTime": {
|
||||
"start_time": "2022-12-23T09:57:39.751Z"
|
||||
}
|
||||
},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"len(dataset.test)\n",
|
||||
"x = [x['target'].shape for x in dataset.test.iter_sequential()]"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"metadata": {
|
||||
"ExecuteTime": {
|
||||
"start_time": "2022-12-23T09:57:39.752Z"
|
||||
}
|
||||
},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"pd.Series(x).value_counts()"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"metadata": {
|
||||
"ExecuteTime": {
|
||||
"start_time": "2022-12-23T09:57:39.753Z"
|
||||
}
|
||||
},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"%debug"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"metadata": {
|
||||
"ExecuteTime": {
|
||||
"start_time": "2022-12-23T09:57:39.754Z"
|
||||
}
|
||||
},
|
||||
"outputs": [],
|
||||
@@ -330,8 +564,7 @@
|
||||
"execution_count": null,
|
||||
"metadata": {
|
||||
"ExecuteTime": {
|
||||
"end_time": "2022-12-23T07:09:49.344138Z",
|
||||
"start_time": "2022-12-23T07:09:49.344129Z"
|
||||
"start_time": "2022-12-23T09:57:39.756Z"
|
||||
}
|
||||
},
|
||||
"outputs": [],
|
||||
@@ -345,8 +578,7 @@
|
||||
"execution_count": null,
|
||||
"metadata": {
|
||||
"ExecuteTime": {
|
||||
"end_time": "2022-12-23T07:09:49.344864Z",
|
||||
"start_time": "2022-12-23T07:09:49.344855Z"
|
||||
"start_time": "2022-12-23T09:57:39.756Z"
|
||||
}
|
||||
},
|
||||
"outputs": [],
|
||||
@@ -364,8 +596,7 @@
|
||||
"execution_count": null,
|
||||
"metadata": {
|
||||
"ExecuteTime": {
|
||||
"end_time": "2022-12-23T07:08:38.402632Z",
|
||||
"start_time": "2022-12-23T07:08:38.402624Z"
|
||||
"start_time": "2022-12-23T09:57:39.757Z"
|
||||
}
|
||||
},
|
||||
"outputs": [],
|
||||
@@ -379,8 +610,7 @@
|
||||
"execution_count": null,
|
||||
"metadata": {
|
||||
"ExecuteTime": {
|
||||
"end_time": "2022-12-23T07:08:38.403367Z",
|
||||
"start_time": "2022-12-23T07:08:38.403359Z"
|
||||
"start_time": "2022-12-23T09:57:39.758Z"
|
||||
},
|
||||
"scrolled": true
|
||||
},
|
||||
@@ -394,8 +624,7 @@
|
||||
"execution_count": null,
|
||||
"metadata": {
|
||||
"ExecuteTime": {
|
||||
"end_time": "2022-12-23T07:08:38.404060Z",
|
||||
"start_time": "2022-12-23T07:08:38.404051Z"
|
||||
"start_time": "2022-12-23T09:57:39.759Z"
|
||||
}
|
||||
},
|
||||
"outputs": [],
|
||||
@@ -419,9 +648,9 @@
|
||||
],
|
||||
"metadata": {
|
||||
"kernelspec": {
|
||||
"display_name": "gluonts10.0",
|
||||
"display_name": "glounts",
|
||||
"language": "python",
|
||||
"name": "gluonts10.0"
|
||||
"name": "glounts"
|
||||
},
|
||||
"language_info": {
|
||||
"codemirror_mode": {
|
||||
|
||||
File diff suppressed because it is too large
Load Diff
@@ -251,7 +251,7 @@ class TimeGradEstimator(PyTorchEstimator):
|
||||
input_names=input_names,
|
||||
prediction_net=prediction_network,
|
||||
batch_size=self.trainer.batch_size,
|
||||
freq=self.freq,
|
||||
# freq=self.freq,
|
||||
prediction_length=self.prediction_length,
|
||||
device=device,
|
||||
)
|
||||
|
||||
@@ -124,7 +124,7 @@ class EpsilonTheta2(nn.Module):
|
||||
# nn.init.zeros_(self.output_projection.weight)
|
||||
|
||||
self.output_projection = nn.Sequential(
|
||||
nn.Conv1d(residual_channels, residual_channels, 3),
|
||||
nn.Conv1d(residual_channels, target_dim, 3),
|
||||
# nn.LeakyReLU(),
|
||||
# nn.Conv1d(residual_channels, target_dim, 3, padding="same"),
|
||||
)
|
||||
|
||||
@@ -44,13 +44,18 @@ def get_ou(shape, mu=0, theta=0.1, sigma=.1):
|
||||
return OrnsteinUhlenbeckProcess(shape, mu=mu, theta=theta, sigma=sigma)
|
||||
|
||||
def mk_noise(shape, device):
|
||||
repeats = shape[1]
|
||||
shape = (shape[0], shape[2])
|
||||
ou = get_ou(shape)
|
||||
b = ou.sample()
|
||||
b = [torch.from_numpy(bb).to(device).float() for bb in b]
|
||||
C, noise_rand, noise_ou = b
|
||||
noise_rand = noise_rand[:, None]
|
||||
noise_ou = noise_ou[:, None]
|
||||
ns = []
|
||||
for _ in range(repeats):
|
||||
ou = get_ou(shape)
|
||||
b = ou.sample()
|
||||
b = [torch.from_numpy(bb).to(device).float() for bb in b]
|
||||
C, noise_rand, noise_ou = b
|
||||
noise_rand = noise_rand[:, None]
|
||||
noise_ou = noise_ou[:, None]
|
||||
ns.append((C, noise_rand, noise_ou))
|
||||
C, noise_rand, noise_ou = [torch.concat(n, 1) for n in zip(*ns)]
|
||||
return C, noise_rand, noise_ou
|
||||
|
||||
def noise_like(x):
|
||||
@@ -222,7 +227,7 @@ class GaussianDiffusionOU(nn.Module):
|
||||
device = self.betas.device
|
||||
|
||||
b = shape[0]
|
||||
img = torch.randn(shape, device=device)
|
||||
C, noise_rand, img = mk_noise(shape, device=device)
|
||||
|
||||
for i in reversed(range(0, self.num_timesteps)):
|
||||
img = self.p_sample(
|
||||
@@ -236,10 +241,10 @@ class GaussianDiffusionOU(nn.Module):
|
||||
shape = cond.shape[:-1] + (self.input_size,)
|
||||
# TODO reshape cond to (B*T, 1, -1)
|
||||
B, T, pred_len = cond.shape
|
||||
shape = (B, 1, pred_len)
|
||||
shape = (B, self.input_size, pred_len)
|
||||
else:
|
||||
shape = sample_shape
|
||||
x_hat = self.p_sample_loop(shape, cond) # TODO reshape x_hat to (B,T,-1)
|
||||
x_hat = self.p_sample_loop(shape, cond)
|
||||
|
||||
if self.scale is not None:
|
||||
x_hat *= self.scale
|
||||
@@ -263,8 +268,8 @@ class GaussianDiffusionOU(nn.Module):
|
||||
|
||||
return img
|
||||
|
||||
def q_sample(self, x_start, t, noise=None):
|
||||
noise = default(noise, lambda: noise_like(x_start)[2])
|
||||
def q_sample(self, x_start, t, noise):
|
||||
# noise = default(noise, lambda: noise_like(x_start)[2])
|
||||
|
||||
return (
|
||||
extract(self.sqrt_alphas_cumprod, t, x_start.shape) * x_start
|
||||
|
||||
@@ -255,19 +255,19 @@ class TimeGradTrainingNetwork2(nn.Module):
|
||||
past_observed_values, 1 - past_is_pad.unsqueeze(-1)
|
||||
)
|
||||
|
||||
if future_time_feat is None or future_target_cdf is None:
|
||||
time_feat = past_time_feat[:, -self.context_length :, ...]
|
||||
sequence = past_target_cdf
|
||||
sequence_length = self.history_length
|
||||
subsequences_length = self.context_length
|
||||
else:
|
||||
time_feat = torch.cat(
|
||||
(past_time_feat[:, -self.context_length :, ...], future_time_feat),
|
||||
dim=1,
|
||||
)
|
||||
sequence = torch.cat((past_target_cdf, future_target_cdf), dim=1)
|
||||
sequence_length = self.history_length + self.prediction_length
|
||||
subsequences_length = self.context_length + self.prediction_length
|
||||
# if future_time_feat is None or future_target_cdf is None:
|
||||
time_feat = past_time_feat[:, -self.context_length:, ...]
|
||||
sequence = past_target_cdf
|
||||
sequence_length = self.history_length
|
||||
subsequences_length = self.context_length
|
||||
# else:
|
||||
# time_feat = torch.cat(
|
||||
# (past_time_feat[:, -self.context_length :, ...], future_time_feat),
|
||||
# dim=1,
|
||||
# )
|
||||
# sequence = torch.cat((past_target_cdf, future_target_cdf), dim=1)
|
||||
# sequence_length = self.history_length + self.prediction_length
|
||||
# subsequences_length = self.context_length + self.prediction_length
|
||||
|
||||
# (batch_size, sub_seq_len, target_dim, num_lags)
|
||||
lags = self.get_lagged_subsequences(
|
||||
@@ -389,16 +389,17 @@ class TimeGradTrainingNetwork2(nn.Module):
|
||||
|
||||
# put together target sequence
|
||||
# (batch_size, seq_len, target_dim)
|
||||
target = torch.cat(
|
||||
(past_target_cdf[:, -self.context_length :, ...], future_target_cdf),
|
||||
dim=1,
|
||||
)
|
||||
# target = torch.cat(
|
||||
# (past_target_cdf[:, -self.context_length :, ...], future_target_cdf),
|
||||
# dim=1,
|
||||
# )
|
||||
target = future_target_cdf
|
||||
|
||||
# assert_shape(target, (-1, seq_len, self.target_dim))
|
||||
|
||||
distr_args = self.distr_args(rnn_outputs=rnn_outputs)
|
||||
if self.scaling:
|
||||
self.diffusion.scale = scale
|
||||
self.diffusion.scale = scale.permute(0, 2, 1)
|
||||
|
||||
# we sum the last axis to have the same shape for all likelihoods
|
||||
# (batch_size, subseq_length, 1)
|
||||
@@ -448,11 +449,8 @@ class TimeGradPredictionNetwork2(TimeGradTrainingNetwork2):
|
||||
|
||||
def sampling_decoder(
|
||||
self,
|
||||
past_target_cdf: torch.Tensor,
|
||||
target_dimension_indicator: torch.Tensor,
|
||||
time_feat: torch.Tensor,
|
||||
scale: torch.Tensor,
|
||||
begin_states: Union[List[torch.Tensor], torch.Tensor],
|
||||
rnn_outputs: Union[List[torch.Tensor], torch.Tensor],
|
||||
) -> torch.Tensor:
|
||||
"""
|
||||
Computes sample paths by unrolling the RNN starting with a initial
|
||||
@@ -470,8 +468,8 @@ class TimeGradPredictionNetwork2(TimeGradTrainingNetwork2):
|
||||
num_features)
|
||||
scale
|
||||
Mean scale for each time series (batch_size, 1, target_dim)
|
||||
begin_states
|
||||
List of initial states for the RNN layers (batch_size, num_cells)
|
||||
rnn_outputs
|
||||
Outputs of the unrolled RNN (batch_size, seq_len, num_cells)
|
||||
Returns
|
||||
--------
|
||||
sample_paths : Tensor
|
||||
@@ -482,56 +480,35 @@ class TimeGradPredictionNetwork2(TimeGradTrainingNetwork2):
|
||||
def repeat(tensor, dim=0):
|
||||
return tensor.repeat_interleave(repeats=self.num_parallel_samples, dim=dim)
|
||||
|
||||
# blows-up the dimension of each tensor to
|
||||
# batch_size * self.num_sample_paths for increasing parallelism
|
||||
repeated_past_target_cdf = repeat(past_target_cdf)
|
||||
repeated_time_feat = repeat(time_feat)
|
||||
repeated_scale = repeat(scale)
|
||||
if self.scaling:
|
||||
self.diffusion.scale = repeated_scale
|
||||
repeated_target_dimension_indicator = repeat(target_dimension_indicator)
|
||||
|
||||
if self.cell_type == "LSTM":
|
||||
repeated_states = [repeat(s, dim=1) for s in begin_states]
|
||||
else:
|
||||
repeated_states = repeat(begin_states, dim=1)
|
||||
|
||||
|
||||
future_samples = []
|
||||
|
||||
# for each future time-units we draw new samples for this time-unit
|
||||
# and update the state
|
||||
for k in range(self.prediction_length):
|
||||
lags = self.get_lagged_subsequences(
|
||||
sequence=repeated_past_target_cdf,
|
||||
sequence_length=self.history_length + k,
|
||||
indices=self.shifted_lags,
|
||||
subsequences_length=1,
|
||||
)
|
||||
|
||||
rnn_outputs, repeated_states, _, _ = self.unroll(
|
||||
begin_state=repeated_states,
|
||||
lags=lags,
|
||||
scale=repeated_scale,
|
||||
time_feat=repeated_time_feat[:, k : k + 1, ...],
|
||||
target_dimension_indicator=repeated_target_dimension_indicator,
|
||||
unroll_length=1,
|
||||
)
|
||||
|
||||
for _ in range(self.num_parallel_samples):
|
||||
distr_args = self.distr_args(rnn_outputs=rnn_outputs)
|
||||
|
||||
# (batch_size, 1, target_dim)
|
||||
new_samples = self.diffusion.sample(cond=distr_args)
|
||||
distr_args = distr_args.permute(0, 2, 1)
|
||||
samples = self.diffusion.sample(cond=distr_args)
|
||||
future_samples.append(samples)
|
||||
# import pdb; pdb.set_trace()
|
||||
# rnn_outputs = torch.Size([7, 24, 40])
|
||||
# torch.Size([7, 100, 370, 24])
|
||||
# torch.Size([7, 370, 24])
|
||||
samples = torch.stack(future_samples, dim=1).permute(0, 1, 3, 2)
|
||||
|
||||
# (batch_size, seq_len, target_dim)
|
||||
future_samples.append(new_samples)
|
||||
repeated_past_target_cdf = torch.cat(
|
||||
(repeated_past_target_cdf, new_samples), dim=1
|
||||
)
|
||||
# # (batch_size, seq_len, target_dim)
|
||||
# future_samples.append(new_samples)
|
||||
# repeated_past_target_cdf = torch.cat(
|
||||
# (repeated_past_target_cdf, new_samples), dim=1
|
||||
# )
|
||||
|
||||
# (batch_size * num_samples, prediction_length, target_dim)
|
||||
samples = torch.cat(future_samples, dim=1)
|
||||
# samples = torch.cat(future_samples, dim=1)
|
||||
|
||||
# (batch_size, num_samples, prediction_length, target_dim)
|
||||
# print('samples.shape', samples.shape)
|
||||
return samples
|
||||
return samples.reshape(
|
||||
(
|
||||
-1,
|
||||
@@ -586,21 +563,20 @@ class TimeGradPredictionNetwork2(TimeGradTrainingNetwork2):
|
||||
past_observed_values, 1 - past_is_pad.unsqueeze(-1)
|
||||
)
|
||||
|
||||
# unroll the decoder in "prediction mode", i.e. with past data only
|
||||
_, begin_states, scale, _, _ = self.unroll_encoder(
|
||||
|
||||
# unroll the decoder in "training mode", i.e. by providing future data
|
||||
# as well
|
||||
rnn_outputs, _, scale, _, _ = self.unroll_encoder(
|
||||
past_time_feat=past_time_feat,
|
||||
past_target_cdf=past_target_cdf,
|
||||
past_observed_values=past_observed_values,
|
||||
past_is_pad=past_is_pad,
|
||||
future_time_feat=None,
|
||||
future_time_feat=future_time_feat,
|
||||
future_target_cdf=None,
|
||||
target_dimension_indicator=target_dimension_indicator,
|
||||
)
|
||||
|
||||
return self.sampling_decoder(
|
||||
past_target_cdf=past_target_cdf,
|
||||
target_dimension_indicator=target_dimension_indicator,
|
||||
time_feat=future_time_feat,
|
||||
scale=scale,
|
||||
begin_states=begin_states,
|
||||
rnn_outputs=rnn_outputs,
|
||||
)
|
||||
|
||||
Reference in New Issue
Block a user