This commit is contained in:
wassname
2022-12-23 19:00:56 +08:00
parent f6d1884c4c
commit 6c84953ac1
6 changed files with 806 additions and 740 deletions
+309 -80
View File
@@ -5,8 +5,24 @@
"execution_count": 1,
"metadata": {
"ExecuteTime": {
"end_time": "2022-12-23T07:08:26.228737Z",
"start_time": "2022-12-23T07:08:24.806396Z"
"end_time": "2022-12-23T09:57:39.738715Z",
"start_time": "2022-12-23T09:57:39.730077Z"
}
},
"outputs": [],
"source": [
"# autoreload import your package\n",
"%load_ext autoreload\n",
"%autoreload 2"
]
},
{
"cell_type": "code",
"execution_count": 2,
"metadata": {
"ExecuteTime": {
"end_time": "2022-12-23T09:57:40.738250Z",
"start_time": "2022-12-23T09:57:39.739677Z"
}
},
"outputs": [],
@@ -22,23 +38,14 @@
},
{
"cell_type": "code",
"execution_count": 2,
"execution_count": 3,
"metadata": {
"ExecuteTime": {
"end_time": "2022-12-23T07:08:26.417012Z",
"start_time": "2022-12-23T07:08:26.230461Z"
"end_time": "2022-12-23T09:57:40.909476Z",
"start_time": "2022-12-23T09:57:40.739741Z"
}
},
"outputs": [
{
"name": "stderr",
"output_type": "stream",
"text": [
"/home/wassname/miniforge3/envs/gluonts10.0/lib/python3.9/site-packages/gluonts/json.py:101: UserWarning: Using `json`-module for json-handling. Consider installing one of `orjson`, `ujson` to speed up serialization and deserialization.\n",
" warnings.warn(\n"
]
}
],
"outputs": [],
"source": [
"from gluonts.dataset.multivariate_grouper import MultivariateGrouper\n",
"from gluonts.dataset.repository.datasets import dataset_recipes, get_dataset\n",
@@ -48,11 +55,11 @@
},
{
"cell_type": "code",
"execution_count": 3,
"execution_count": 4,
"metadata": {
"ExecuteTime": {
"end_time": "2022-12-23T07:08:26.469299Z",
"start_time": "2022-12-23T07:08:26.419051Z"
"end_time": "2022-12-23T09:57:40.966206Z",
"start_time": "2022-12-23T09:57:40.910477Z"
}
},
"outputs": [],
@@ -65,11 +72,11 @@
},
{
"cell_type": "code",
"execution_count": 4,
"execution_count": 5,
"metadata": {
"ExecuteTime": {
"end_time": "2022-12-23T07:08:26.477478Z",
"start_time": "2022-12-23T07:08:26.470628Z"
"end_time": "2022-12-23T09:57:40.980883Z",
"start_time": "2022-12-23T09:57:40.967157Z"
}
},
"outputs": [],
@@ -79,11 +86,11 @@
},
{
"cell_type": "code",
"execution_count": 5,
"execution_count": 6,
"metadata": {
"ExecuteTime": {
"end_time": "2022-12-23T07:08:26.494456Z",
"start_time": "2022-12-23T07:08:26.478496Z"
"end_time": "2022-12-23T09:57:41.003468Z",
"start_time": "2022-12-23T09:57:40.981830Z"
}
},
"outputs": [],
@@ -147,11 +154,11 @@
},
{
"cell_type": "code",
"execution_count": 6,
"execution_count": 7,
"metadata": {
"ExecuteTime": {
"end_time": "2022-12-23T07:08:26.516023Z",
"start_time": "2022-12-23T07:08:26.495676Z"
"end_time": "2022-12-23T09:57:41.025224Z",
"start_time": "2022-12-23T09:57:41.004330Z"
}
},
"outputs": [
@@ -169,11 +176,11 @@
},
{
"cell_type": "code",
"execution_count": 7,
"execution_count": 8,
"metadata": {
"ExecuteTime": {
"end_time": "2022-12-23T07:08:26.532361Z",
"start_time": "2022-12-23T07:08:26.517210Z"
"end_time": "2022-12-23T09:57:41.041326Z",
"start_time": "2022-12-23T09:57:41.026069Z"
},
"scrolled": true
},
@@ -185,11 +192,11 @@
},
{
"cell_type": "code",
"execution_count": 8,
"execution_count": 9,
"metadata": {
"ExecuteTime": {
"end_time": "2022-12-23T07:08:26.554669Z",
"start_time": "2022-12-23T07:08:26.534354Z"
"end_time": "2022-12-23T09:57:41.057694Z",
"start_time": "2022-12-23T09:57:41.042892Z"
}
},
"outputs": [
@@ -199,7 +206,7 @@
"MetaData(freq='H', target=None, feat_static_cat=[CategoricalFeatureInfo(name='feat_static_cat_0', cardinality='370')], feat_static_real=[], feat_dynamic_real=[], feat_dynamic_cat=[], prediction_length=24)"
]
},
"execution_count": 8,
"execution_count": 9,
"metadata": {},
"output_type": "execute_result"
}
@@ -210,67 +217,51 @@
},
{
"cell_type": "code",
"execution_count": 17,
"execution_count": null,
"metadata": {
"ExecuteTime": {
"end_time": "2022-12-23T07:12:03.693420Z",
"start_time": "2022-12-23T07:12:02.845649Z"
}
},
"outputs": [
{
"data": {
"text/plain": [
"7"
]
},
"execution_count": 17,
"metadata": {},
"output_type": "execute_result"
}
],
"outputs": [],
"source": []
},
{
"cell_type": "code",
"execution_count": 18,
"execution_count": 10,
"metadata": {
"ExecuteTime": {
"end_time": "2022-12-23T07:12:11.716051Z",
"start_time": "2022-12-23T07:12:11.712672Z"
"end_time": "2022-12-23T09:57:41.073075Z",
"start_time": "2022-12-23T09:57:41.058526Z"
}
},
"outputs": [],
"source": [
"train_grouper = MultivariateGrouper(max_target_dim=min(2000, int(dataset.metadata.feat_static_cat[0].cardinality)))\n",
"\n",
"test_grouper = MultivariateGrouper(\n",
"# num_test_dates=int(len(dataset.test)/len(dataset.train)*2),\n",
" num_test_dates=7,\n",
" max_target_dim=min(2000, int(dataset.metadata.feat_static_cat[0].cardinality)))"
"train_grouper = MultivariateGrouper(max_target_dim=min(2000, int(dataset.metadata.feat_static_cat[0].cardinality)))\n"
]
},
{
"cell_type": "code",
"execution_count": null,
"execution_count": 11,
"metadata": {
"ExecuteTime": {
"start_time": "2022-12-23T07:12:12.382Z"
"end_time": "2022-12-23T09:57:44.134871Z",
"start_time": "2022-12-23T09:57:41.073880Z"
}
},
"outputs": [],
"source": [
"dataset_train = train_grouper(dataset.train)\n",
"dataset_test = test_grouper(dataset.test)"
"dataset_train = train_grouper(dataset.train)\n"
]
},
{
"cell_type": "code",
"execution_count": null,
"execution_count": 12,
"metadata": {
"ExecuteTime": {
"end_time": "2022-12-23T07:09:49.342107Z",
"start_time": "2022-12-23T07:09:49.342097Z"
"end_time": "2022-12-23T09:57:44.153700Z",
"start_time": "2022-12-23T09:57:44.136099Z"
}
},
"outputs": [],
@@ -300,22 +291,265 @@
"execution_count": null,
"metadata": {
"ExecuteTime": {
"end_time": "2022-12-23T07:09:49.342839Z",
"start_time": "2022-12-23T07:09:49.342830Z"
"start_time": "2022-12-23T09:57:39.746Z"
}
},
"outputs": [],
"outputs": [
{
"data": {
"application/vnd.jupyter.widget-view+json": {
"model_id": "744b148440674ef3902ca56622cb001b",
"version_major": 2,
"version_minor": 0
},
"text/plain": [
" 0%| | 0/99 [00:00<?, ?it/s]"
]
},
"metadata": {},
"output_type": "display_data"
},
{
"data": {
"application/vnd.jupyter.widget-view+json": {
"model_id": "88abd67eb98d4569b276f98dadd73205",
"version_major": 2,
"version_minor": 0
},
"text/plain": [
" 0%| | 0/99 [00:00<?, ?it/s]"
]
},
"metadata": {},
"output_type": "display_data"
},
{
"data": {
"application/vnd.jupyter.widget-view+json": {
"model_id": "f5e08c9306c647cd8589f29d8652177b",
"version_major": 2,
"version_minor": 0
},
"text/plain": [
" 0%| | 0/99 [00:00<?, ?it/s]"
]
},
"metadata": {},
"output_type": "display_data"
},
{
"data": {
"application/vnd.jupyter.widget-view+json": {
"model_id": "10a4ac2006db4f6090dc3f08cc3e3e9e",
"version_major": 2,
"version_minor": 0
},
"text/plain": [
" 0%| | 0/99 [00:00<?, ?it/s]"
]
},
"metadata": {},
"output_type": "display_data"
},
{
"data": {
"application/vnd.jupyter.widget-view+json": {
"model_id": "6458a98361854786bc4fa81da37471d2",
"version_major": 2,
"version_minor": 0
},
"text/plain": [
" 0%| | 0/99 [00:00<?, ?it/s]"
]
},
"metadata": {},
"output_type": "display_data"
},
{
"data": {
"application/vnd.jupyter.widget-view+json": {
"model_id": "676d661e21ec42acbb544a8848264674",
"version_major": 2,
"version_minor": 0
},
"text/plain": [
" 0%| | 0/99 [00:00<?, ?it/s]"
]
},
"metadata": {},
"output_type": "display_data"
},
{
"data": {
"application/vnd.jupyter.widget-view+json": {
"model_id": "36010cebcd884d248de158890d2ba324",
"version_major": 2,
"version_minor": 0
},
"text/plain": [
" 0%| | 0/99 [00:00<?, ?it/s]"
]
},
"metadata": {},
"output_type": "display_data"
},
{
"data": {
"application/vnd.jupyter.widget-view+json": {
"model_id": "f7487fae1ae54dae8c71d23ea5721e7f",
"version_major": 2,
"version_minor": 0
},
"text/plain": [
" 0%| | 0/99 [00:00<?, ?it/s]"
]
},
"metadata": {},
"output_type": "display_data"
},
{
"data": {
"application/vnd.jupyter.widget-view+json": {
"model_id": "6a995bf552bb487cbca2a0f3e45098b3",
"version_major": 2,
"version_minor": 0
},
"text/plain": [
" 0%| | 0/99 [00:00<?, ?it/s]"
]
},
"metadata": {},
"output_type": "display_data"
},
{
"data": {
"application/vnd.jupyter.widget-view+json": {
"model_id": "3ee8b7cdd0b94465a78c43d0917910ae",
"version_major": 2,
"version_minor": 0
},
"text/plain": [
" 0%| | 0/99 [00:00<?, ?it/s]"
]
},
"metadata": {},
"output_type": "display_data"
},
{
"data": {
"application/vnd.jupyter.widget-view+json": {
"model_id": "958283558c6a4e958c5c67188e735ed9",
"version_major": 2,
"version_minor": 0
},
"text/plain": [
" 0%| | 0/99 [00:00<?, ?it/s]"
]
},
"metadata": {},
"output_type": "display_data"
}
],
"source": [
"predictor = estimator.train(dataset_train, num_workers=0)"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": []
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"ExecuteTime": {
"end_time": "2022-12-23T07:09:49.343440Z",
"start_time": "2022-12-23T07:09:49.343431Z"
"start_time": "2022-12-23T09:57:39.748Z"
}
},
"outputs": [],
"source": [
"\n",
"test_grouper = MultivariateGrouper(\n",
" num_test_dates=int(len(dataset.test)/len(dataset.train)*2),\n",
"# num_test_dates=7,\n",
" max_target_dim=min(2000, int(dataset.metadata.feat_static_cat[0].cardinality)))"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"ExecuteTime": {
"start_time": "2022-12-23T09:57:39.749Z"
}
},
"outputs": [],
"source": [
"dataset_test = test_grouper(dataset.test)"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"ExecuteTime": {
"start_time": "2022-12-23T09:57:39.750Z"
}
},
"outputs": [],
"source": [
"next(iter(dataset.test.iter_sequential()))"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"ExecuteTime": {
"start_time": "2022-12-23T09:57:39.751Z"
}
},
"outputs": [],
"source": [
"len(dataset.test)\n",
"x = [x['target'].shape for x in dataset.test.iter_sequential()]"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"ExecuteTime": {
"start_time": "2022-12-23T09:57:39.752Z"
}
},
"outputs": [],
"source": [
"pd.Series(x).value_counts()"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"ExecuteTime": {
"start_time": "2022-12-23T09:57:39.753Z"
}
},
"outputs": [],
"source": [
"%debug"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"ExecuteTime": {
"start_time": "2022-12-23T09:57:39.754Z"
}
},
"outputs": [],
@@ -330,8 +564,7 @@
"execution_count": null,
"metadata": {
"ExecuteTime": {
"end_time": "2022-12-23T07:09:49.344138Z",
"start_time": "2022-12-23T07:09:49.344129Z"
"start_time": "2022-12-23T09:57:39.756Z"
}
},
"outputs": [],
@@ -345,8 +578,7 @@
"execution_count": null,
"metadata": {
"ExecuteTime": {
"end_time": "2022-12-23T07:09:49.344864Z",
"start_time": "2022-12-23T07:09:49.344855Z"
"start_time": "2022-12-23T09:57:39.756Z"
}
},
"outputs": [],
@@ -364,8 +596,7 @@
"execution_count": null,
"metadata": {
"ExecuteTime": {
"end_time": "2022-12-23T07:08:38.402632Z",
"start_time": "2022-12-23T07:08:38.402624Z"
"start_time": "2022-12-23T09:57:39.757Z"
}
},
"outputs": [],
@@ -379,8 +610,7 @@
"execution_count": null,
"metadata": {
"ExecuteTime": {
"end_time": "2022-12-23T07:08:38.403367Z",
"start_time": "2022-12-23T07:08:38.403359Z"
"start_time": "2022-12-23T09:57:39.758Z"
},
"scrolled": true
},
@@ -394,8 +624,7 @@
"execution_count": null,
"metadata": {
"ExecuteTime": {
"end_time": "2022-12-23T07:08:38.404060Z",
"start_time": "2022-12-23T07:08:38.404051Z"
"start_time": "2022-12-23T09:57:39.759Z"
}
},
"outputs": [],
@@ -419,9 +648,9 @@
],
"metadata": {
"kernelspec": {
"display_name": "gluonts10.0",
"display_name": "glounts",
"language": "python",
"name": "gluonts10.0"
"name": "glounts"
},
"language_info": {
"codemirror_mode": {
File diff suppressed because it is too large Load Diff
+1 -1
View File
@@ -251,7 +251,7 @@ class TimeGradEstimator(PyTorchEstimator):
input_names=input_names,
prediction_net=prediction_network,
batch_size=self.trainer.batch_size,
freq=self.freq,
# freq=self.freq,
prediction_length=self.prediction_length,
device=device,
)
+1 -1
View File
@@ -124,7 +124,7 @@ class EpsilonTheta2(nn.Module):
# nn.init.zeros_(self.output_projection.weight)
self.output_projection = nn.Sequential(
nn.Conv1d(residual_channels, residual_channels, 3),
nn.Conv1d(residual_channels, target_dim, 3),
# nn.LeakyReLU(),
# nn.Conv1d(residual_channels, target_dim, 3, padding="same"),
)
+16 -11
View File
@@ -44,13 +44,18 @@ def get_ou(shape, mu=0, theta=0.1, sigma=.1):
return OrnsteinUhlenbeckProcess(shape, mu=mu, theta=theta, sigma=sigma)
def mk_noise(shape, device):
repeats = shape[1]
shape = (shape[0], shape[2])
ou = get_ou(shape)
b = ou.sample()
b = [torch.from_numpy(bb).to(device).float() for bb in b]
C, noise_rand, noise_ou = b
noise_rand = noise_rand[:, None]
noise_ou = noise_ou[:, None]
ns = []
for _ in range(repeats):
ou = get_ou(shape)
b = ou.sample()
b = [torch.from_numpy(bb).to(device).float() for bb in b]
C, noise_rand, noise_ou = b
noise_rand = noise_rand[:, None]
noise_ou = noise_ou[:, None]
ns.append((C, noise_rand, noise_ou))
C, noise_rand, noise_ou = [torch.concat(n, 1) for n in zip(*ns)]
return C, noise_rand, noise_ou
def noise_like(x):
@@ -222,7 +227,7 @@ class GaussianDiffusionOU(nn.Module):
device = self.betas.device
b = shape[0]
img = torch.randn(shape, device=device)
C, noise_rand, img = mk_noise(shape, device=device)
for i in reversed(range(0, self.num_timesteps)):
img = self.p_sample(
@@ -236,10 +241,10 @@ class GaussianDiffusionOU(nn.Module):
shape = cond.shape[:-1] + (self.input_size,)
# TODO reshape cond to (B*T, 1, -1)
B, T, pred_len = cond.shape
shape = (B, 1, pred_len)
shape = (B, self.input_size, pred_len)
else:
shape = sample_shape
x_hat = self.p_sample_loop(shape, cond) # TODO reshape x_hat to (B,T,-1)
x_hat = self.p_sample_loop(shape, cond)
if self.scale is not None:
x_hat *= self.scale
@@ -263,8 +268,8 @@ class GaussianDiffusionOU(nn.Module):
return img
def q_sample(self, x_start, t, noise=None):
noise = default(noise, lambda: noise_like(x_start)[2])
def q_sample(self, x_start, t, noise):
# noise = default(noise, lambda: noise_like(x_start)[2])
return (
extract(self.sqrt_alphas_cumprod, t, x_start.shape) * x_start
+46 -70
View File
@@ -255,19 +255,19 @@ class TimeGradTrainingNetwork2(nn.Module):
past_observed_values, 1 - past_is_pad.unsqueeze(-1)
)
if future_time_feat is None or future_target_cdf is None:
time_feat = past_time_feat[:, -self.context_length :, ...]
sequence = past_target_cdf
sequence_length = self.history_length
subsequences_length = self.context_length
else:
time_feat = torch.cat(
(past_time_feat[:, -self.context_length :, ...], future_time_feat),
dim=1,
)
sequence = torch.cat((past_target_cdf, future_target_cdf), dim=1)
sequence_length = self.history_length + self.prediction_length
subsequences_length = self.context_length + self.prediction_length
# if future_time_feat is None or future_target_cdf is None:
time_feat = past_time_feat[:, -self.context_length:, ...]
sequence = past_target_cdf
sequence_length = self.history_length
subsequences_length = self.context_length
# else:
# time_feat = torch.cat(
# (past_time_feat[:, -self.context_length :, ...], future_time_feat),
# dim=1,
# )
# sequence = torch.cat((past_target_cdf, future_target_cdf), dim=1)
# sequence_length = self.history_length + self.prediction_length
# subsequences_length = self.context_length + self.prediction_length
# (batch_size, sub_seq_len, target_dim, num_lags)
lags = self.get_lagged_subsequences(
@@ -389,16 +389,17 @@ class TimeGradTrainingNetwork2(nn.Module):
# put together target sequence
# (batch_size, seq_len, target_dim)
target = torch.cat(
(past_target_cdf[:, -self.context_length :, ...], future_target_cdf),
dim=1,
)
# target = torch.cat(
# (past_target_cdf[:, -self.context_length :, ...], future_target_cdf),
# dim=1,
# )
target = future_target_cdf
# assert_shape(target, (-1, seq_len, self.target_dim))
distr_args = self.distr_args(rnn_outputs=rnn_outputs)
if self.scaling:
self.diffusion.scale = scale
self.diffusion.scale = scale.permute(0, 2, 1)
# we sum the last axis to have the same shape for all likelihoods
# (batch_size, subseq_length, 1)
@@ -448,11 +449,8 @@ class TimeGradPredictionNetwork2(TimeGradTrainingNetwork2):
def sampling_decoder(
self,
past_target_cdf: torch.Tensor,
target_dimension_indicator: torch.Tensor,
time_feat: torch.Tensor,
scale: torch.Tensor,
begin_states: Union[List[torch.Tensor], torch.Tensor],
rnn_outputs: Union[List[torch.Tensor], torch.Tensor],
) -> torch.Tensor:
"""
Computes sample paths by unrolling the RNN starting with a initial
@@ -470,8 +468,8 @@ class TimeGradPredictionNetwork2(TimeGradTrainingNetwork2):
num_features)
scale
Mean scale for each time series (batch_size, 1, target_dim)
begin_states
List of initial states for the RNN layers (batch_size, num_cells)
rnn_outputs
Outputs of the unrolled RNN (batch_size, seq_len, num_cells)
Returns
--------
sample_paths : Tensor
@@ -482,56 +480,35 @@ class TimeGradPredictionNetwork2(TimeGradTrainingNetwork2):
def repeat(tensor, dim=0):
return tensor.repeat_interleave(repeats=self.num_parallel_samples, dim=dim)
# blows-up the dimension of each tensor to
# batch_size * self.num_sample_paths for increasing parallelism
repeated_past_target_cdf = repeat(past_target_cdf)
repeated_time_feat = repeat(time_feat)
repeated_scale = repeat(scale)
if self.scaling:
self.diffusion.scale = repeated_scale
repeated_target_dimension_indicator = repeat(target_dimension_indicator)
if self.cell_type == "LSTM":
repeated_states = [repeat(s, dim=1) for s in begin_states]
else:
repeated_states = repeat(begin_states, dim=1)
future_samples = []
# for each future time-units we draw new samples for this time-unit
# and update the state
for k in range(self.prediction_length):
lags = self.get_lagged_subsequences(
sequence=repeated_past_target_cdf,
sequence_length=self.history_length + k,
indices=self.shifted_lags,
subsequences_length=1,
)
rnn_outputs, repeated_states, _, _ = self.unroll(
begin_state=repeated_states,
lags=lags,
scale=repeated_scale,
time_feat=repeated_time_feat[:, k : k + 1, ...],
target_dimension_indicator=repeated_target_dimension_indicator,
unroll_length=1,
)
for _ in range(self.num_parallel_samples):
distr_args = self.distr_args(rnn_outputs=rnn_outputs)
# (batch_size, 1, target_dim)
new_samples = self.diffusion.sample(cond=distr_args)
distr_args = distr_args.permute(0, 2, 1)
samples = self.diffusion.sample(cond=distr_args)
future_samples.append(samples)
# import pdb; pdb.set_trace()
# rnn_outputs = torch.Size([7, 24, 40])
# torch.Size([7, 100, 370, 24])
# torch.Size([7, 370, 24])
samples = torch.stack(future_samples, dim=1).permute(0, 1, 3, 2)
# (batch_size, seq_len, target_dim)
future_samples.append(new_samples)
repeated_past_target_cdf = torch.cat(
(repeated_past_target_cdf, new_samples), dim=1
)
# # (batch_size, seq_len, target_dim)
# future_samples.append(new_samples)
# repeated_past_target_cdf = torch.cat(
# (repeated_past_target_cdf, new_samples), dim=1
# )
# (batch_size * num_samples, prediction_length, target_dim)
samples = torch.cat(future_samples, dim=1)
# samples = torch.cat(future_samples, dim=1)
# (batch_size, num_samples, prediction_length, target_dim)
# print('samples.shape', samples.shape)
return samples
return samples.reshape(
(
-1,
@@ -586,21 +563,20 @@ class TimeGradPredictionNetwork2(TimeGradTrainingNetwork2):
past_observed_values, 1 - past_is_pad.unsqueeze(-1)
)
# unroll the decoder in "prediction mode", i.e. with past data only
_, begin_states, scale, _, _ = self.unroll_encoder(
# unroll the decoder in "training mode", i.e. by providing future data
# as well
rnn_outputs, _, scale, _, _ = self.unroll_encoder(
past_time_feat=past_time_feat,
past_target_cdf=past_target_cdf,
past_observed_values=past_observed_values,
past_is_pad=past_is_pad,
future_time_feat=None,
future_time_feat=future_time_feat,
future_target_cdf=None,
target_dimension_indicator=target_dimension_indicator,
)
return self.sampling_decoder(
past_target_cdf=past_target_cdf,
target_dimension_indicator=target_dimension_indicator,
time_feat=future_time_feat,
scale=scale,
begin_states=begin_states,
rnn_outputs=rnn_outputs,
)