starting on timegrad 2

This commit is contained in:
wassname
2022-12-23 12:18:12 +08:00
parent a4a3e953de
commit f31af603e4
5 changed files with 1449 additions and 0 deletions
+449
View File
@@ -0,0 +1,449 @@
{
"cells": [
{
"cell_type": "code",
"execution_count": 1,
"metadata": {},
"outputs": [],
"source": [
"%matplotlib inline\n",
"import matplotlib.pyplot as plt\n",
"\n",
"import numpy as np\n",
"import pandas as pd\n",
"\n",
"import torch"
]
},
{
"cell_type": "code",
"execution_count": 2,
"metadata": {},
"outputs": [],
"source": [
"from gluonts.dataset.multivariate_grouper import MultivariateGrouper\n",
"from gluonts.dataset.repository.datasets import dataset_recipes, get_dataset\n",
"from gluonts.evaluation.backtest import make_evaluation_predictions\n",
"from gluonts.evaluation import MultivariateEvaluator"
]
},
{
"cell_type": "code",
"execution_count": 3,
"metadata": {},
"outputs": [],
"source": [
"from pts.model.tempflow import TempFlowEstimator\n",
"from pts.model.time_grad import TimeGradEstimator\n",
"from pts.model.transformer_tempflow import TransformerTempFlowEstimator\n",
"from pts import Trainer"
]
},
{
"cell_type": "code",
"execution_count": 4,
"metadata": {},
"outputs": [],
"source": [
"device = torch.device(\"cuda\" if torch.cuda.is_available() else \"cpu\")"
]
},
{
"cell_type": "code",
"execution_count": 5,
"metadata": {},
"outputs": [],
"source": [
"def plot(target, forecast, prediction_length, prediction_intervals=(50.0, 90.0), color='g', fname=None):\n",
" label_prefix = \"\"\n",
" rows = 4\n",
" cols = 4\n",
" fig, axs = plt.subplots(rows, cols, figsize=(24, 24))\n",
" axx = axs.ravel()\n",
" seq_len, target_dim = target.shape\n",
" \n",
" ps = [50.0] + [\n",
" 50.0 + f * c / 2.0 for c in prediction_intervals for f in [-1.0, +1.0]\n",
" ]\n",
" \n",
" percentiles_sorted = sorted(set(ps))\n",
" \n",
" def alpha_for_percentile(p):\n",
" return (p / 100.0) ** 0.3\n",
" \n",
" for dim in range(0, min(rows * cols, target_dim)):\n",
" ax = axx[dim]\n",
"\n",
" target[-2 * prediction_length :][dim].plot(ax=ax)\n",
" \n",
" ps_data = [forecast.quantile(p / 100.0)[:,dim] for p in percentiles_sorted]\n",
" i_p50 = len(percentiles_sorted) // 2\n",
" \n",
" p50_data = ps_data[i_p50]\n",
" p50_series = pd.Series(data=p50_data, index=forecast.index)\n",
" p50_series.plot(color=color, ls=\"-\", label=f\"{label_prefix}median\", ax=ax)\n",
" \n",
" for i in range(len(percentiles_sorted) // 2):\n",
" ptile = percentiles_sorted[i]\n",
" alpha = alpha_for_percentile(ptile)\n",
" ax.fill_between(\n",
" forecast.index,\n",
" ps_data[i],\n",
" ps_data[-i - 1],\n",
" facecolor=color,\n",
" alpha=alpha,\n",
" interpolate=True,\n",
" )\n",
" # Hack to create labels for the error intervals.\n",
" # Doesn't actually plot anything, because we only pass a single data point\n",
" pd.Series(data=p50_data[:1], index=forecast.index[:1]).plot(\n",
" color=color,\n",
" alpha=alpha,\n",
" linewidth=10,\n",
" label=f\"{label_prefix}{100 - ptile * 2}%\",\n",
" ax=ax,\n",
" )\n",
"\n",
" legend = [\"observations\", \"median prediction\"] + [f\"{k}% prediction interval\" for k in prediction_intervals][::-1] \n",
" axx[0].legend(legend, loc=\"upper left\")\n",
" \n",
" if fname is not None:\n",
" plt.savefig(fname, bbox_inches='tight', pad_inches=0.05)"
]
},
{
"cell_type": "code",
"execution_count": 6,
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"Available datasets: ['constant', 'exchange_rate', 'solar-energy', 'electricity', 'traffic', 'exchange_rate_nips', 'electricity_nips', 'traffic_nips', 'solar_nips', 'wiki-rolling_nips', 'taxi_30min', 'kaggle_web_traffic_with_missing', 'kaggle_web_traffic_without_missing', 'kaggle_web_traffic_weekly', 'm1_yearly', 'm1_quarterly', 'm1_monthly', 'nn5_daily_with_missing', 'nn5_daily_without_missing', 'nn5_weekly', 'tourism_monthly', 'tourism_quarterly', 'tourism_yearly', 'cif_2016', 'london_smart_meters_without_missing', 'wind_farms_without_missing', 'car_parts_without_missing', 'dominick', 'fred_md', 'pedestrian_counts', 'hospital', 'covid_deaths', 'kdd_cup_2018_without_missing', 'weather', 'm3_monthly', 'm3_quarterly', 'm3_yearly', 'm3_other', 'm4_hourly', 'm4_daily', 'm4_weekly', 'm4_monthly', 'm4_quarterly', 'm4_yearly', 'm5', 'uber_tlc_daily', 'uber_tlc_hourly', 'airpassengers']\n"
]
}
],
"source": [
"print(f\"Available datasets: {list(dataset_recipes.keys())}\")"
]
},
{
"cell_type": "code",
"execution_count": 7,
"metadata": {
"scrolled": true
},
"outputs": [],
"source": [
"# exchange_rate_nips, electricity_nips, traffic_nips, solar_nips, wiki-rolling_nips, ## taxi_30min is buggy still\n",
"dataset = get_dataset(\"electricity_nips\", regenerate=False)"
]
},
{
"cell_type": "code",
"execution_count": 8,
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
"MetaData(freq='H', target=None, feat_static_cat=[CategoricalFeatureInfo(name='feat_static_cat_0', cardinality='370')], feat_static_real=[], feat_dynamic_real=[], feat_dynamic_cat=[], prediction_length=24)"
]
},
"execution_count": 8,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"dataset.metadata"
]
},
{
"cell_type": "code",
"execution_count": 9,
"metadata": {},
"outputs": [],
"source": [
"train_grouper = MultivariateGrouper(max_target_dim=min(2000, int(dataset.metadata.feat_static_cat[0].cardinality)))\n",
"\n",
"test_grouper = MultivariateGrouper(num_test_dates=int(len(dataset.test)/len(dataset.train)), \n",
" max_target_dim=min(2000, int(dataset.metadata.feat_static_cat[0].cardinality)))"
]
},
{
"cell_type": "code",
"execution_count": 10,
"metadata": {},
"outputs": [
{
"name": "stderr",
"output_type": "stream",
"text": [
"/home/wassname/miniforge3/envs/glounts/lib/python3.9/site-packages/gluonts/dataset/multivariate_grouper.py:191: VisibleDeprecationWarning: Creating an ndarray from ragged nested sequences (which is a list-or-tuple of lists-or-tuples-or ndarrays with different lengths or shapes) is deprecated. If you meant to do this, you must specify 'dtype=object' when creating the ndarray.\n",
" return {FieldName.TARGET: np.array([funcs(data) for data in dataset])}\n"
]
},
{
"ename": "ValueError",
"evalue": "array split does not result in an equal division",
"output_type": "error",
"traceback": [
"\u001b[0;31m---------------------------------------------------------------------------\u001b[0m",
"\u001b[0;31mValueError\u001b[0m Traceback (most recent call last)",
"\u001b[1;32m/media/wassname/SGIronWolf/projects5/timeseries/pytorch-ts/examples/Time-Grad2-Electricity.ipynb Cell 10\u001b[0m in \u001b[0;36m<cell line: 2>\u001b[0;34m()\u001b[0m\n\u001b[1;32m <a href='vscode-notebook-cell:/media/wassname/SGIronWolf/projects5/timeseries/pytorch-ts/examples/Time-Grad2-Electricity.ipynb#X12sZmlsZQ%3D%3D?line=0'>1</a>\u001b[0m dataset_train \u001b[39m=\u001b[39m train_grouper(dataset\u001b[39m.\u001b[39mtrain)\n\u001b[0;32m----> <a href='vscode-notebook-cell:/media/wassname/SGIronWolf/projects5/timeseries/pytorch-ts/examples/Time-Grad2-Electricity.ipynb#X12sZmlsZQ%3D%3D?line=1'>2</a>\u001b[0m dataset_test \u001b[39m=\u001b[39m test_grouper(dataset\u001b[39m.\u001b[39;49mtest)\n",
"File \u001b[0;32m~/miniforge3/envs/glounts/lib/python3.9/site-packages/gluonts/dataset/multivariate_grouper.py:87\u001b[0m, in \u001b[0;36mMultivariateGrouper.__call__\u001b[0;34m(self, dataset)\u001b[0m\n\u001b[1;32m 85\u001b[0m \u001b[39mdef\u001b[39;00m \u001b[39m__call__\u001b[39m(\u001b[39mself\u001b[39m, dataset: Dataset) \u001b[39m-\u001b[39m\u001b[39m>\u001b[39m Dataset:\n\u001b[1;32m 86\u001b[0m \u001b[39mself\u001b[39m\u001b[39m.\u001b[39m_preprocess(dataset)\n\u001b[0;32m---> 87\u001b[0m \u001b[39mreturn\u001b[39;00m \u001b[39mself\u001b[39;49m\u001b[39m.\u001b[39;49m_group_all(dataset)\n",
"File \u001b[0;32m~/miniforge3/envs/glounts/lib/python3.9/site-packages/gluonts/dataset/multivariate_grouper.py:125\u001b[0m, in \u001b[0;36mMultivariateGrouper._group_all\u001b[0;34m(self, dataset)\u001b[0m\n\u001b[1;32m 123\u001b[0m grouped_dataset \u001b[39m=\u001b[39m \u001b[39mself\u001b[39m\u001b[39m.\u001b[39m_prepare_train_data(dataset)\n\u001b[1;32m 124\u001b[0m \u001b[39melse\u001b[39;00m:\n\u001b[0;32m--> 125\u001b[0m grouped_dataset \u001b[39m=\u001b[39m \u001b[39mself\u001b[39;49m\u001b[39m.\u001b[39;49m_prepare_test_data(dataset)\n\u001b[1;32m 126\u001b[0m \u001b[39mreturn\u001b[39;00m grouped_dataset\n",
"File \u001b[0;32m~/miniforge3/envs/glounts/lib/python3.9/site-packages/gluonts/dataset/multivariate_grouper.py:148\u001b[0m, in \u001b[0;36mMultivariateGrouper._prepare_test_data\u001b[0;34m(self, dataset)\u001b[0m\n\u001b[1;32m 145\u001b[0m grouped_data \u001b[39m=\u001b[39m \u001b[39mself\u001b[39m\u001b[39m.\u001b[39m_transform_target(\u001b[39mself\u001b[39m\u001b[39m.\u001b[39m_left_pad_data, dataset)\n\u001b[1;32m 146\u001b[0m \u001b[39m# splits test dataset with rolling date into N R^d time series where\u001b[39;00m\n\u001b[1;32m 147\u001b[0m \u001b[39m# N is the number of rolling evaluation dates\u001b[39;00m\n\u001b[0;32m--> 148\u001b[0m split_dataset \u001b[39m=\u001b[39m np\u001b[39m.\u001b[39;49msplit(\n\u001b[1;32m 149\u001b[0m grouped_data[FieldName\u001b[39m.\u001b[39;49mTARGET], \u001b[39mself\u001b[39;49m\u001b[39m.\u001b[39;49mnum_test_dates\n\u001b[1;32m 150\u001b[0m )\n\u001b[1;32m 152\u001b[0m all_entries \u001b[39m=\u001b[39m \u001b[39mlist\u001b[39m()\n\u001b[1;32m 153\u001b[0m \u001b[39mfor\u001b[39;00m dataset_at_test_date \u001b[39min\u001b[39;00m split_dataset:\n",
"File \u001b[0;32m<__array_function__ internals>:180\u001b[0m, in \u001b[0;36msplit\u001b[0;34m(*args, **kwargs)\u001b[0m\n",
"File \u001b[0;32m~/miniforge3/envs/glounts/lib/python3.9/site-packages/numpy/lib/shape_base.py:872\u001b[0m, in \u001b[0;36msplit\u001b[0;34m(ary, indices_or_sections, axis)\u001b[0m\n\u001b[1;32m 870\u001b[0m N \u001b[39m=\u001b[39m ary\u001b[39m.\u001b[39mshape[axis]\n\u001b[1;32m 871\u001b[0m \u001b[39mif\u001b[39;00m N \u001b[39m%\u001b[39m sections:\n\u001b[0;32m--> 872\u001b[0m \u001b[39mraise\u001b[39;00m \u001b[39mValueError\u001b[39;00m(\n\u001b[1;32m 873\u001b[0m \u001b[39m'\u001b[39m\u001b[39marray split does not result in an equal division\u001b[39m\u001b[39m'\u001b[39m) \u001b[39mfrom\u001b[39;00m \u001b[39mNone\u001b[39m\n\u001b[1;32m 874\u001b[0m \u001b[39mreturn\u001b[39;00m array_split(ary, indices_or_sections, axis)\n",
"\u001b[0;31mValueError\u001b[0m: array split does not result in an equal division"
]
}
],
"source": [
"dataset_train = train_grouper(dataset.train)\n",
"dataset_test = test_grouper(dataset.test)"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [
{
"ename": "",
"evalue": "",
"output_type": "error",
"traceback": [
"\u001b[1;31mRunning cells with 'Python 3.10.6 64-bit' requires ipykernel package.\n",
"\u001b[1;31mRun the following command to install 'ipykernel' into the Python environment. \n",
"\u001b[1;31mCommand: '/bin/python3 -m pip install ipykernel -U --user --force-reinstall'"
]
}
],
"source": [
"estimator = TimeGradEstimator(\n",
" target_dim=int(dataset.metadata.feat_static_cat[0].cardinality),\n",
" prediction_length=dataset.metadata.prediction_length,\n",
" context_length=dataset.metadata.prediction_length,\n",
" cell_type='GRU',\n",
" input_size=1484,\n",
" freq=dataset.metadata.freq,\n",
" loss_type='l2',\n",
" scaling=True,\n",
" diff_steps=100,\n",
" beta_end=0.1,\n",
" beta_schedule=\"linear\",\n",
" trainer=Trainer(device=device,\n",
" epochs=20,\n",
" learning_rate=1e-3,\n",
" num_batches_per_epoch=100,\n",
" batch_size=64,)\n",
")"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [
{
"ename": "",
"evalue": "",
"output_type": "error",
"traceback": [
"\u001b[1;31mRunning cells with 'Python 3.10.6 64-bit' requires ipykernel package.\n",
"\u001b[1;31mRun the following command to install 'ipykernel' into the Python environment. \n",
"\u001b[1;31mCommand: '/bin/python3 -m pip install ipykernel -U --user --force-reinstall'"
]
}
],
"source": [
"predictor = estimator.train(dataset_train, num_workers=8)"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [
{
"ename": "",
"evalue": "",
"output_type": "error",
"traceback": [
"\u001b[1;31mRunning cells with 'Python 3.10.6 64-bit' requires ipykernel package.\n",
"\u001b[1;31mRun the following command to install 'ipykernel' into the Python environment. \n",
"\u001b[1;31mCommand: '/bin/python3 -m pip install ipykernel -U --user --force-reinstall'"
]
}
],
"source": [
"forecast_it, ts_it = make_evaluation_predictions(dataset=dataset_test,\n",
" predictor=predictor,\n",
" num_samples=100)"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [
{
"ename": "",
"evalue": "",
"output_type": "error",
"traceback": [
"\u001b[1;31mRunning cells with 'Python 3.10.6 64-bit' requires ipykernel package.\n",
"\u001b[1;31mRun the following command to install 'ipykernel' into the Python environment. \n",
"\u001b[1;31mCommand: '/bin/python3 -m pip install ipykernel -U --user --force-reinstall'"
]
}
],
"source": [
"forecasts = list(forecast_it)\n",
"targets = list(ts_it)"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [
{
"ename": "",
"evalue": "",
"output_type": "error",
"traceback": [
"\u001b[1;31mRunning cells with 'Python 3.10.6 64-bit' requires ipykernel package.\n",
"\u001b[1;31mRun the following command to install 'ipykernel' into the Python environment. \n",
"\u001b[1;31mCommand: '/bin/python3 -m pip install ipykernel -U --user --force-reinstall'"
]
}
],
"source": [
"plot(\n",
" target=targets[0],\n",
" forecast=forecasts[0],\n",
" prediction_length=dataset.metadata.prediction_length,\n",
")\n",
"plt.show()"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [
{
"ename": "",
"evalue": "",
"output_type": "error",
"traceback": [
"\u001b[1;31mRunning cells with 'Python 3.10.6 64-bit' requires ipykernel package.\n",
"\u001b[1;31mRun the following command to install 'ipykernel' into the Python environment. \n",
"\u001b[1;31mCommand: '/bin/python3 -m pip install ipykernel -U --user --force-reinstall'"
]
}
],
"source": [
"evaluator = MultivariateEvaluator(quantiles=(np.arange(20)/20.0)[1:], \n",
" target_agg_funcs={'sum': np.sum})"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"scrolled": true
},
"outputs": [
{
"ename": "",
"evalue": "",
"output_type": "error",
"traceback": [
"\u001b[1;31mRunning cells with 'Python 3.10.6 64-bit' requires ipykernel package.\n",
"\u001b[1;31mRun the following command to install 'ipykernel' into the Python environment. \n",
"\u001b[1;31mCommand: '/bin/python3 -m pip install ipykernel -U --user --force-reinstall'"
]
}
],
"source": [
"agg_metric, item_metrics = evaluator(targets, forecasts, num_series=len(dataset_test))"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [
{
"ename": "",
"evalue": "",
"output_type": "error",
"traceback": [
"\u001b[1;31mRunning cells with 'Python 3.10.6 64-bit' requires ipykernel package.\n",
"\u001b[1;31mRun the following command to install 'ipykernel' into the Python environment. \n",
"\u001b[1;31mCommand: '/bin/python3 -m pip install ipykernel -U --user --force-reinstall'"
]
}
],
"source": [
"print(\"CRPS:\", agg_metric[\"mean_wQuantileLoss\"])\n",
"print(\"ND:\", agg_metric[\"ND\"])\n",
"print(\"NRMSE:\", agg_metric[\"NRMSE\"])\n",
"print(\"\")\n",
"print(\"CRPS-Sum:\", agg_metric[\"m_sum_mean_wQuantileLoss\"])\n",
"print(\"ND-Sum:\", agg_metric[\"m_sum_ND\"])\n",
"print(\"NRMSE-Sum:\", agg_metric[\"m_sum_NRMSE\"])"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [
{
"ename": "",
"evalue": "",
"output_type": "error",
"traceback": [
"\u001b[1;31mRunning cells with 'Python 3.10.6 64-bit' requires ipykernel package.\n",
"\u001b[1;31mRun the following command to install 'ipykernel' into the Python environment. \n",
"\u001b[1;31mCommand: '/bin/python3 -m pip install ipykernel -U --user --force-reinstall'"
]
}
],
"source": []
}
],
"metadata": {
"kernelspec": {
"display_name": "Python 3.9.15 ('glounts')",
"language": "python",
"name": "python3"
},
"language_info": {
"codemirror_mode": {
"name": "ipython",
"version": 3
},
"file_extension": ".py",
"mimetype": "text/x-python",
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.9.15"
},
"vscode": {
"interpreter": {
"hash": "7f25a1f13147a60511cf6766827402baf95cbe50d53a241197155306ee38fe70"
}
}
},
"nbformat": 4,
"nbformat_minor": 4
}
+3
View File
@@ -0,0 +1,3 @@
from .time_grad_estimator import TimeGradEstimator
from .time_grad_network import TimeGradTrainingNetwork, TimeGradPredictionNetwork
from .epsilon_theta import EpsilonTheta
+136
View File
@@ -0,0 +1,136 @@
import math
import torch
from torch import nn
import torch.nn.functional as F
class DiffusionEmbedding(nn.Module):
def __init__(self, dim, proj_dim, max_steps=500):
super().__init__()
self.register_buffer(
"embedding", self._build_embedding(dim, max_steps), persistent=False
)
self.projection1 = nn.Linear(dim * 2, proj_dim)
self.projection2 = nn.Linear(proj_dim, proj_dim)
def forward(self, diffusion_step):
x = self.embedding[diffusion_step]
x = self.projection1(x)
x = F.silu(x)
x = self.projection2(x)
x = F.silu(x)
return x
def _build_embedding(self, dim, max_steps):
steps = torch.arange(max_steps).unsqueeze(1) # [T,1]
dims = torch.arange(dim).unsqueeze(0) # [1,dim]
table = steps * 10.0 ** (dims * 4.0 / dim) # [T,dim]
table = torch.cat([torch.sin(table), torch.cos(table)], dim=1)
return table
class ResidualBlock(nn.Module):
def __init__(self, hidden_size, residual_channels, dilation):
super().__init__()
self.dilated_conv = nn.Conv1d(
residual_channels,
2 * residual_channels,
3,
padding=dilation,
dilation=dilation,
padding_mode="circular",
)
self.diffusion_projection = nn.Linear(hidden_size, residual_channels)
self.conditioner_projection = nn.Conv1d(
1, 2 * residual_channels, 1, padding=2, padding_mode="circular"
)
self.output_projection = nn.Conv1d(residual_channels, 2 * residual_channels, 1)
nn.init.kaiming_normal_(self.conditioner_projection.weight)
nn.init.kaiming_normal_(self.output_projection.weight)
def forward(self, x, conditioner, diffusion_step):
diffusion_step = self.diffusion_projection(diffusion_step).unsqueeze(-1)
conditioner = self.conditioner_projection(conditioner)
y = x + diffusion_step
y = self.dilated_conv(y) + conditioner
gate, filter = torch.chunk(y, 2, dim=1)
y = torch.sigmoid(gate) * torch.tanh(filter)
y = self.output_projection(y)
y = F.leaky_relu(y, 0.4)
residual, skip = torch.chunk(y, 2, dim=1)
return (x + residual) / math.sqrt(2.0), skip
class CondUpsampler(nn.Module):
def __init__(self, cond_length, target_dim):
super().__init__()
self.linear1 = nn.Linear(cond_length, target_dim // 2)
self.linear2 = nn.Linear(target_dim // 2, target_dim)
def forward(self, x):
x = self.linear1(x)
x = F.leaky_relu(x, 0.4)
x = self.linear2(x)
x = F.leaky_relu(x, 0.4)
return x
class EpsilonTheta(nn.Module):
def __init__(
self,
target_dim,
cond_length,
time_emb_dim=16,
residual_layers=8,
residual_channels=8,
dilation_cycle_length=2,
residual_hidden=64,
):
super().__init__()
self.input_projection = nn.Conv1d(
1, residual_channels, 1, padding=2, padding_mode="circular"
)
self.diffusion_embedding = DiffusionEmbedding(
time_emb_dim, proj_dim=residual_hidden
)
self.cond_upsampler = CondUpsampler(
target_dim=target_dim, cond_length=cond_length
)
self.residual_layers = nn.ModuleList(
[
ResidualBlock(
residual_channels=residual_channels,
dilation=2 ** (i % dilation_cycle_length),
hidden_size=residual_hidden,
)
for i in range(residual_layers)
]
)
self.skip_projection = nn.Conv1d(residual_channels, residual_channels, 3)
self.output_projection = nn.Conv1d(residual_channels, 1, 3)
nn.init.kaiming_normal_(self.input_projection.weight)
nn.init.kaiming_normal_(self.skip_projection.weight)
nn.init.zeros_(self.output_projection.weight)
def forward(self, inputs, time, cond):
x = self.input_projection(inputs)
x = F.leaky_relu(x, 0.4)
diffusion_step = self.diffusion_embedding(time)
cond_up = self.cond_upsampler(cond)
skip = []
for layer in self.residual_layers:
x, skip_connection = layer(x, cond_up, diffusion_step)
skip.append(skip_connection)
x = torch.sum(torch.stack(skip), dim=0) / math.sqrt(len(self.residual_layers))
x = self.skip_projection(x)
x = F.leaky_relu(x, 0.4)
x = self.output_projection(x)
return x
+257
View File
@@ -0,0 +1,257 @@
from typing import List, Optional
import torch
from gluonts.dataset.field_names import FieldName
from gluonts.time_feature import TimeFeature
from gluonts.torch.model.predictor import PyTorchPredictor
from gluonts.torch.util import copy_parameters
from gluonts.model.predictor import Predictor
from gluonts.transform import (
Transformation,
Chain,
InstanceSplitter,
ExpectedNumInstanceSampler,
ValidationSplitSampler,
TestSplitSampler,
RenameFields,
AsNumpyArray,
ExpandDimArray,
AddObservedValuesIndicator,
AddTimeFeatures,
VstackFeatures,
SetFieldIfNotPresent,
TargetDimIndicator,
)
from pts import Trainer
from pts.feature import (
fourier_time_features_from_frequency,
lags_for_fourier_time_features_from_frequency,
)
from pts.model import PyTorchEstimator
from pts.model.utils import get_module_forward_input_names
from .time_grad_network import TimeGradTrainingNetwork, TimeGradPredictionNetwork
class TimeGradEstimator(PyTorchEstimator):
def __init__(
self,
input_size: int,
freq: str,
prediction_length: int,
target_dim: int,
trainer: Trainer = Trainer(),
context_length: Optional[int] = None,
num_layers: int = 2,
num_cells: int = 40,
cell_type: str = "LSTM",
num_parallel_samples: int = 100,
dropout_rate: float = 0.1,
cardinality: List[int] = [1],
embedding_dimension: int = 5,
conditioning_length: int = 100,
diff_steps: int = 100,
loss_type: str = "l2",
beta_end=0.1,
beta_schedule="linear",
residual_layers=8,
residual_channels=8,
dilation_cycle_length=2,
scaling: bool = True,
pick_incomplete: bool = False,
lags_seq: Optional[List[int]] = None,
time_features: Optional[List[TimeFeature]] = None,
**kwargs,
) -> None:
super().__init__(trainer=trainer, **kwargs)
self.freq = freq
self.context_length = (
context_length if context_length is not None else prediction_length
)
self.input_size = input_size
self.prediction_length = prediction_length
self.target_dim = target_dim
self.num_layers = num_layers
self.num_cells = num_cells
self.cell_type = cell_type
self.num_parallel_samples = num_parallel_samples
self.dropout_rate = dropout_rate
self.cardinality = cardinality
self.embedding_dimension = embedding_dimension
self.conditioning_length = conditioning_length
self.diff_steps = diff_steps
self.loss_type = loss_type
self.beta_end = beta_end
self.beta_schedule = beta_schedule
self.residual_layers = residual_layers
self.residual_channels = residual_channels
self.dilation_cycle_length = dilation_cycle_length
self.lags_seq = (
lags_seq
if lags_seq is not None
else lags_for_fourier_time_features_from_frequency(freq_str=freq)
)
self.time_features = (
time_features
if time_features is not None
else fourier_time_features_from_frequency(self.freq)
)
self.history_length = self.context_length + max(self.lags_seq)
self.pick_incomplete = pick_incomplete
self.scaling = scaling
self.train_sampler = ExpectedNumInstanceSampler(
num_instances=1.0,
min_past=0 if pick_incomplete else self.history_length,
min_future=prediction_length,
)
self.validation_sampler = ValidationSplitSampler(
min_past=0 if pick_incomplete else self.history_length,
min_future=prediction_length,
)
def create_transformation(self) -> Transformation:
return Chain(
[
AsNumpyArray(
field=FieldName.TARGET,
expected_ndim=2,
),
# maps the target to (1, T)
# if the target data is uni dimensional
ExpandDimArray(
field=FieldName.TARGET,
axis=None,
),
AddObservedValuesIndicator(
target_field=FieldName.TARGET,
output_field=FieldName.OBSERVED_VALUES,
),
AddTimeFeatures(
start_field=FieldName.START,
target_field=FieldName.TARGET,
output_field=FieldName.FEAT_TIME,
time_features=self.time_features,
pred_length=self.prediction_length,
),
VstackFeatures(
output_field=FieldName.FEAT_TIME,
input_fields=[FieldName.FEAT_TIME],
),
SetFieldIfNotPresent(field=FieldName.FEAT_STATIC_CAT, value=[0]),
TargetDimIndicator(
field_name="target_dimension_indicator",
target_field=FieldName.TARGET,
),
AsNumpyArray(field=FieldName.FEAT_STATIC_CAT, expected_ndim=1),
]
)
def create_instance_splitter(self, mode: str):
assert mode in ["training", "validation", "test"]
instance_sampler = {
"training": self.train_sampler,
"validation": self.validation_sampler,
"test": TestSplitSampler(),
}[mode]
return InstanceSplitter(
target_field=FieldName.TARGET,
is_pad_field=FieldName.IS_PAD,
start_field=FieldName.START,
forecast_start_field=FieldName.FORECAST_START,
instance_sampler=instance_sampler,
past_length=self.history_length,
future_length=self.prediction_length,
time_series_fields=[
FieldName.FEAT_TIME,
FieldName.OBSERVED_VALUES,
],
) + (
RenameFields(
{
f"past_{FieldName.TARGET}": f"past_{FieldName.TARGET}_cdf",
f"future_{FieldName.TARGET}": f"future_{FieldName.TARGET}_cdf",
}
)
)
def create_training_network(self, device: torch.device) -> TimeGradTrainingNetwork:
return TimeGradTrainingNetwork(
input_size=self.input_size,
target_dim=self.target_dim,
num_layers=self.num_layers,
num_cells=self.num_cells,
cell_type=self.cell_type,
history_length=self.history_length,
context_length=self.context_length,
prediction_length=self.prediction_length,
dropout_rate=self.dropout_rate,
cardinality=self.cardinality,
embedding_dimension=self.embedding_dimension,
diff_steps=self.diff_steps,
loss_type=self.loss_type,
beta_end=self.beta_end,
beta_schedule=self.beta_schedule,
residual_layers=self.residual_layers,
residual_channels=self.residual_channels,
dilation_cycle_length=self.dilation_cycle_length,
lags_seq=self.lags_seq,
scaling=self.scaling,
conditioning_length=self.conditioning_length,
).to(device)
def create_predictor(
self,
transformation: Transformation,
trained_network: TimeGradTrainingNetwork,
device: torch.device,
) -> Predictor:
prediction_network = TimeGradPredictionNetwork(
input_size=self.input_size,
target_dim=self.target_dim,
num_layers=self.num_layers,
num_cells=self.num_cells,
cell_type=self.cell_type,
history_length=self.history_length,
context_length=self.context_length,
prediction_length=self.prediction_length,
dropout_rate=self.dropout_rate,
cardinality=self.cardinality,
embedding_dimension=self.embedding_dimension,
diff_steps=self.diff_steps,
loss_type=self.loss_type,
beta_end=self.beta_end,
beta_schedule=self.beta_schedule,
residual_layers=self.residual_layers,
residual_channels=self.residual_channels,
dilation_cycle_length=self.dilation_cycle_length,
lags_seq=self.lags_seq,
scaling=self.scaling,
conditioning_length=self.conditioning_length,
num_parallel_samples=self.num_parallel_samples,
).to(device)
copy_parameters(trained_network, prediction_network)
input_names = get_module_forward_input_names(prediction_network)
prediction_splitter = self.create_instance_splitter("test")
return PyTorchPredictor(
input_transform=transformation + prediction_splitter,
input_names=input_names,
prediction_net=prediction_network,
batch_size=self.trainer.batch_size,
freq=self.freq,
prediction_length=self.prediction_length,
device=device,
)
+604
View File
@@ -0,0 +1,604 @@
from torch.nn.modules import loss
from typing import List, Optional, Tuple, Union
import torch
import torch.nn as nn
from gluonts.core.component import validated
from pts.model import weighted_average
from pts.modules import GaussianDiffusion, DiffusionOutput, MeanScaler, NOPScaler
from .epsilon_theta import EpsilonTheta
class TimeGradTrainingNetwork(nn.Module):
@validated()
def __init__(
self,
input_size: int,
num_layers: int,
num_cells: int,
cell_type: str,
history_length: int,
context_length: int,
prediction_length: int,
dropout_rate: float,
lags_seq: List[int],
target_dim: int,
conditioning_length: int,
diff_steps: int,
loss_type: str,
beta_end: float,
beta_schedule: str,
residual_layers: int,
residual_channels: int,
dilation_cycle_length: int,
cardinality: List[int] = [1],
embedding_dimension: int = 1,
scaling: bool = True,
**kwargs,
) -> None:
super().__init__(**kwargs)
self.target_dim = target_dim
self.prediction_length = prediction_length
self.context_length = context_length
self.history_length = history_length
self.scaling = scaling
assert len(set(lags_seq)) == len(lags_seq), "no duplicated lags allowed!"
lags_seq.sort()
self.lags_seq = lags_seq
self.cell_type = cell_type
rnn_cls = {"LSTM": nn.LSTM, "GRU": nn.GRU}[cell_type]
self.rnn = rnn_cls(
input_size=input_size,
hidden_size=num_cells,
num_layers=num_layers,
dropout=dropout_rate,
batch_first=True,
)
self.denoise_fn = EpsilonTheta(
target_dim=target_dim,
cond_length=conditioning_length,
residual_layers=residual_layers,
residual_channels=residual_channels,
dilation_cycle_length=dilation_cycle_length,
)
self.diffusion = GaussianDiffusion(
self.denoise_fn,
input_size=target_dim,
diff_steps=diff_steps,
loss_type=loss_type,
beta_end=beta_end,
beta_schedule=beta_schedule,
)
self.distr_output = DiffusionOutput(
self.diffusion, input_size=target_dim, cond_size=conditioning_length
)
self.proj_dist_args = self.distr_output.get_args_proj(num_cells)
self.embed_dim = 1
self.embed = nn.Embedding(
num_embeddings=self.target_dim, embedding_dim=self.embed_dim
)
if self.scaling:
self.scaler = MeanScaler(keepdim=True)
else:
self.scaler = NOPScaler(keepdim=True)
@staticmethod
def get_lagged_subsequences(
sequence: torch.Tensor,
sequence_length: int,
indices: List[int],
subsequences_length: int = 1,
) -> torch.Tensor:
"""
Returns lagged subsequences of a given sequence.
Parameters
----------
sequence
the sequence from which lagged subsequences should be extracted.
Shape: (N, T, C).
sequence_length
length of sequence in the T (time) dimension (axis = 1).
indices
list of lag indices to be used.
subsequences_length
length of the subsequences to be extracted.
Returns
--------
lagged : Tensor
a tensor of shape (N, S, C, I),
where S = subsequences_length and I = len(indices),
containing lagged subsequences.
Specifically, lagged[i, :, j, k] = sequence[i, -indices[k]-S+j, :].
"""
# we must have: history_length + begin_index >= 0
# that is: history_length - lag_index - sequence_length >= 0
# hence the following assert
assert max(indices) + subsequences_length <= sequence_length, (
f"lags cannot go further than history length, found lag "
f"{max(indices)} while history length is only {sequence_length}"
)
assert all(lag_index >= 0 for lag_index in indices)
lagged_values = []
for lag_index in indices:
begin_index = -lag_index - subsequences_length
end_index = -lag_index if lag_index > 0 else None
lagged_values.append(sequence[:, begin_index:end_index, ...].unsqueeze(1))
return torch.cat(lagged_values, dim=1).permute(0, 2, 3, 1)
def unroll(
self,
lags: torch.Tensor,
scale: torch.Tensor,
time_feat: torch.Tensor,
target_dimension_indicator: torch.Tensor,
unroll_length: int,
begin_state: Optional[Union[List[torch.Tensor], torch.Tensor]] = None,
) -> Tuple[
torch.Tensor,
Union[List[torch.Tensor], torch.Tensor],
torch.Tensor,
torch.Tensor,
]:
# (batch_size, sub_seq_len, target_dim, num_lags)
lags_scaled = lags / scale.unsqueeze(-1)
# assert_shape(
# lags_scaled, (-1, unroll_length, self.target_dim, len(self.lags_seq)),
# )
input_lags = lags_scaled.reshape(
(-1, unroll_length, len(self.lags_seq) * self.target_dim)
)
# (batch_size, target_dim, embed_dim)
index_embeddings = self.embed(target_dimension_indicator)
# assert_shape(index_embeddings, (-1, self.target_dim, self.embed_dim))
# (batch_size, seq_len, target_dim * embed_dim)
repeated_index_embeddings = (
index_embeddings.unsqueeze(1)
.expand(-1, unroll_length, -1, -1)
.reshape((-1, unroll_length, self.target_dim * self.embed_dim))
)
# (batch_size, sub_seq_len, input_dim)
inputs = torch.cat((input_lags, repeated_index_embeddings, time_feat), dim=-1)
# unroll encoder
outputs, state = self.rnn(inputs, begin_state)
# assert_shape(outputs, (-1, unroll_length, self.num_cells))
# for s in state:
# assert_shape(s, (-1, self.num_cells))
# assert_shape(
# lags_scaled, (-1, unroll_length, self.target_dim, len(self.lags_seq)),
# )
return outputs, state, lags_scaled, inputs
def unroll_encoder(
self,
past_time_feat: torch.Tensor,
past_target_cdf: torch.Tensor,
past_observed_values: torch.Tensor,
past_is_pad: torch.Tensor,
future_time_feat: Optional[torch.Tensor],
future_target_cdf: Optional[torch.Tensor],
target_dimension_indicator: torch.Tensor,
) -> Tuple[
torch.Tensor,
Union[List[torch.Tensor], torch.Tensor],
torch.Tensor,
torch.Tensor,
torch.Tensor,
]:
"""
Unrolls the RNN encoder over past and, if present, future data.
Returns outputs and state of the encoder, plus the scale of
past_target_cdf and a vector of static features that was constructed
and fed as input to the encoder. All tensor arguments should have NTC
layout.
Parameters
----------
past_time_feat
Past time features (batch_size, history_length, num_features)
past_target_cdf
Past marginal CDF transformed target values (batch_size,
history_length, target_dim)
past_observed_values
Indicator whether or not the values were observed (batch_size,
history_length, target_dim)
past_is_pad
Indicator whether the past target values have been padded
(batch_size, history_length)
future_time_feat
Future time features (batch_size, prediction_length, num_features)
future_target_cdf
Future marginal CDF transformed target values (batch_size,
prediction_length, target_dim)
target_dimension_indicator
Dimensionality of the time series (batch_size, target_dim)
Returns
-------
outputs
RNN outputs (batch_size, seq_len, num_cells)
states
RNN states. Nested list with (batch_size, num_cells) tensors with
dimensions target_dim x num_layers x (batch_size, num_cells)
scale
Mean scales for the time series (batch_size, 1, target_dim)
lags_scaled
Scaled lags(batch_size, sub_seq_len, target_dim, num_lags)
inputs
inputs to the RNN
"""
past_observed_values = torch.min(
past_observed_values, 1 - past_is_pad.unsqueeze(-1)
)
if future_time_feat is None or future_target_cdf is None:
time_feat = past_time_feat[:, -self.context_length :, ...]
sequence = past_target_cdf
sequence_length = self.history_length
subsequences_length = self.context_length
else:
time_feat = torch.cat(
(past_time_feat[:, -self.context_length :, ...], future_time_feat),
dim=1,
)
sequence = torch.cat((past_target_cdf, future_target_cdf), dim=1)
sequence_length = self.history_length + self.prediction_length
subsequences_length = self.context_length + self.prediction_length
# (batch_size, sub_seq_len, target_dim, num_lags)
lags = self.get_lagged_subsequences(
sequence=sequence,
sequence_length=sequence_length,
indices=self.lags_seq,
subsequences_length=subsequences_length,
)
# scale is computed on the context length last units of the past target
# scale shape is (batch_size, 1, target_dim)
_, scale = self.scaler(
past_target_cdf[:, -self.context_length :, ...],
past_observed_values[:, -self.context_length :, ...],
)
outputs, states, lags_scaled, inputs = self.unroll(
lags=lags,
scale=scale,
time_feat=time_feat,
target_dimension_indicator=target_dimension_indicator,
unroll_length=subsequences_length,
begin_state=None,
)
return outputs, states, scale, lags_scaled, inputs
def distr_args(self, rnn_outputs: torch.Tensor):
"""
Returns the distribution of DeepVAR with respect to the RNN outputs.
Parameters
----------
rnn_outputs
Outputs of the unrolled RNN (batch_size, seq_len, num_cells)
scale
Mean scale for each time series (batch_size, 1, target_dim)
Returns
-------
distr
Distribution instance
distr_args
Distribution arguments
"""
(distr_args,) = self.proj_dist_args(rnn_outputs)
# # compute likelihood of target given the predicted parameters
# distr = self.distr_output.distribution(distr_args, scale=scale)
# return distr, distr_args
return distr_args
def forward(
self,
target_dimension_indicator: torch.Tensor,
past_time_feat: torch.Tensor,
past_target_cdf: torch.Tensor,
past_observed_values: torch.Tensor,
past_is_pad: torch.Tensor,
future_time_feat: torch.Tensor,
future_target_cdf: torch.Tensor,
future_observed_values: torch.Tensor,
) -> Tuple[torch.Tensor, ...]:
"""
Computes the loss for training DeepVAR, all inputs tensors representing
time series have NTC layout.
Parameters
----------
target_dimension_indicator
Indices of the target dimension (batch_size, target_dim)
past_time_feat
Dynamic features of past time series (batch_size, history_length,
num_features)
past_target_cdf
Past marginal CDF transformed target values (batch_size,
history_length, target_dim)
past_observed_values
Indicator whether or not the values were observed (batch_size,
history_length, target_dim)
past_is_pad
Indicator whether the past target values have been padded
(batch_size, history_length)
future_time_feat
Future time features (batch_size, prediction_length, num_features)
future_target_cdf
Future marginal CDF transformed target values (batch_size,
prediction_length, target_dim)
future_observed_values
Indicator whether or not the future values were observed
(batch_size, prediction_length, target_dim)
Returns
-------
distr
Loss with shape (batch_size, 1)
likelihoods
Likelihoods for each time step
(batch_size, context + prediction_length, 1)
distr_args
Distribution arguments (context + prediction_length,
number_of_arguments)
"""
seq_len = self.context_length + self.prediction_length
# unroll the decoder in "training mode", i.e. by providing future data
# as well
rnn_outputs, _, scale, _, _ = self.unroll_encoder(
past_time_feat=past_time_feat,
past_target_cdf=past_target_cdf,
past_observed_values=past_observed_values,
past_is_pad=past_is_pad,
future_time_feat=future_time_feat,
future_target_cdf=future_target_cdf,
target_dimension_indicator=target_dimension_indicator,
)
# put together target sequence
# (batch_size, seq_len, target_dim)
target = torch.cat(
(past_target_cdf[:, -self.context_length :, ...], future_target_cdf),
dim=1,
)
# assert_shape(target, (-1, seq_len, self.target_dim))
distr_args = self.distr_args(rnn_outputs=rnn_outputs)
if self.scaling:
self.diffusion.scale = scale
# we sum the last axis to have the same shape for all likelihoods
# (batch_size, subseq_length, 1)
likelihoods = self.diffusion.log_prob(target, distr_args).unsqueeze(-1)
# assert_shape(likelihoods, (-1, seq_len, 1))
past_observed_values = torch.min(
past_observed_values, 1 - past_is_pad.unsqueeze(-1)
)
# (batch_size, subseq_length, target_dim)
observed_values = torch.cat(
(
past_observed_values[:, -self.context_length :, ...],
future_observed_values,
),
dim=1,
)
# mask the loss at one time step if one or more observations is missing
# in the target dimensions (batch_size, subseq_length, 1)
loss_weights, _ = observed_values.min(dim=-1, keepdim=True)
# assert_shape(loss_weights, (-1, seq_len, 1))
loss = weighted_average(likelihoods, weights=loss_weights, dim=1)
# assert_shape(loss, (-1, -1, 1))
# self.distribution = distr
return (loss.mean(), likelihoods, distr_args)
class TimeGradPredictionNetwork(TimeGradTrainingNetwork):
def __init__(self, num_parallel_samples: int, **kwargs) -> None:
super().__init__(**kwargs)
self.num_parallel_samples = num_parallel_samples
# for decoding the lags are shifted by one,
# at the first time-step of the decoder a lag of one corresponds to
# the last target value
self.shifted_lags = [l - 1 for l in self.lags_seq]
def sampling_decoder(
self,
past_target_cdf: torch.Tensor,
target_dimension_indicator: torch.Tensor,
time_feat: torch.Tensor,
scale: torch.Tensor,
begin_states: Union[List[torch.Tensor], torch.Tensor],
) -> torch.Tensor:
"""
Computes sample paths by unrolling the RNN starting with a initial
input and state.
Parameters
----------
past_target_cdf
Past marginal CDF transformed target values (batch_size,
history_length, target_dim)
target_dimension_indicator
Indices of the target dimension (batch_size, target_dim)
time_feat
Dynamic features of future time series (batch_size, history_length,
num_features)
scale
Mean scale for each time series (batch_size, 1, target_dim)
begin_states
List of initial states for the RNN layers (batch_size, num_cells)
Returns
--------
sample_paths : Tensor
A tensor containing sampled paths. Shape: (1, num_sample_paths,
prediction_length, target_dim).
"""
def repeat(tensor, dim=0):
return tensor.repeat_interleave(repeats=self.num_parallel_samples, dim=dim)
# blows-up the dimension of each tensor to
# batch_size * self.num_sample_paths for increasing parallelism
repeated_past_target_cdf = repeat(past_target_cdf)
repeated_time_feat = repeat(time_feat)
repeated_scale = repeat(scale)
if self.scaling:
self.diffusion.scale = repeated_scale
repeated_target_dimension_indicator = repeat(target_dimension_indicator)
if self.cell_type == "LSTM":
repeated_states = [repeat(s, dim=1) for s in begin_states]
else:
repeated_states = repeat(begin_states, dim=1)
future_samples = []
# for each future time-units we draw new samples for this time-unit
# and update the state
for k in range(self.prediction_length):
lags = self.get_lagged_subsequences(
sequence=repeated_past_target_cdf,
sequence_length=self.history_length + k,
indices=self.shifted_lags,
subsequences_length=1,
)
rnn_outputs, repeated_states, _, _ = self.unroll(
begin_state=repeated_states,
lags=lags,
scale=repeated_scale,
time_feat=repeated_time_feat[:, k : k + 1, ...],
target_dimension_indicator=repeated_target_dimension_indicator,
unroll_length=1,
)
distr_args = self.distr_args(rnn_outputs=rnn_outputs)
# (batch_size, 1, target_dim)
new_samples = self.diffusion.sample(cond=distr_args)
# (batch_size, seq_len, target_dim)
future_samples.append(new_samples)
repeated_past_target_cdf = torch.cat(
(repeated_past_target_cdf, new_samples), dim=1
)
# (batch_size * num_samples, prediction_length, target_dim)
samples = torch.cat(future_samples, dim=1)
# (batch_size, num_samples, prediction_length, target_dim)
return samples.reshape(
(
-1,
self.num_parallel_samples,
self.prediction_length,
self.target_dim,
)
)
def forward(
self,
target_dimension_indicator: torch.Tensor,
past_time_feat: torch.Tensor,
past_target_cdf: torch.Tensor,
past_observed_values: torch.Tensor,
past_is_pad: torch.Tensor,
future_time_feat: torch.Tensor,
) -> torch.Tensor:
"""
Predicts samples given the trained DeepVAR model.
All tensors should have NTC layout.
Parameters
----------
target_dimension_indicator
Indices of the target dimension (batch_size, target_dim)
past_time_feat
Dynamic features of past time series (batch_size, history_length,
num_features)
past_target_cdf
Past marginal CDF transformed target values (batch_size,
history_length, target_dim)
past_observed_values
Indicator whether or not the values were observed (batch_size,
history_length, target_dim)
past_is_pad
Indicator whether the past target values have been padded
(batch_size, history_length)
future_time_feat
Future time features (batch_size, prediction_length, num_features)
Returns
-------
sample_paths : Tensor
A tensor containing sampled paths (1, num_sample_paths,
prediction_length, target_dim).
"""
# mark padded data as unobserved
# (batch_size, target_dim, seq_len)
past_observed_values = torch.min(
past_observed_values, 1 - past_is_pad.unsqueeze(-1)
)
# unroll the decoder in "prediction mode", i.e. with past data only
_, begin_states, scale, _, _ = self.unroll_encoder(
past_time_feat=past_time_feat,
past_target_cdf=past_target_cdf,
past_observed_values=past_observed_values,
past_is_pad=past_is_pad,
future_time_feat=None,
future_target_cdf=None,
target_dimension_indicator=target_dimension_indicator,
)
return self.sampling_decoder(
past_target_cdf=past_target_cdf,
target_dimension_indicator=target_dimension_indicator,
time_feat=future_time_feat,
scale=scale,
begin_states=begin_states,
)