use history_length - context_length for past_time_feat in deepvar

This commit is contained in:
Dr. Kashif Rasul
2020-02-03 14:51:50 +01:00
parent 85c9f61988
commit 637a6fb36e
+9 -4
View File
@@ -238,13 +238,18 @@ class DeepVARTrainingNetwork(nn.Module):
)
if future_time_feat is None or future_target_cdf is None:
time_feat = past_time_feat[:, -self.context_length :, ...]
time_feat = past_time_feat[
:, self.history_length - self.context_length :, ...
]
sequence = past_target_cdf
sequence_length = self.history_length
subsequences_length = self.context_length
else:
time_feat = torch.cat(
(past_time_feat[:, -self.context_length :, ...], future_time_feat),
(
past_time_feat[:, self.history_length - self.context_length :, ...],
future_time_feat,
),
dim=1,
)
sequence = torch.cat((past_target_cdf, future_target_cdf), dim=1)
@@ -401,7 +406,7 @@ class DeepVARTrainingNetwork(nn.Module):
# mask the loss at one time step if one or more observations is missing
# in the target dimensions (batch_size, subseq_length, 1)
loss_weights,_ = observed_values.min(dim=-1, keepdim=True)
loss_weights, _ = observed_values.min(dim=-1, keepdim=True)
# assert_shape(loss_weights, (-1, seq_len, 1))
@@ -430,7 +435,7 @@ class DeepVARPredictionNetwork(DeepVARTrainingNetwork):
target_dimension_indicator: torch.Tensor,
time_feat: torch.Tensor,
scale: torch.Tensor,
begin_states: Union[List[torch.Tensor], torch.Tensor]
begin_states: Union[List[torch.Tensor], torch.Tensor],
) -> torch.Tensor:
"""
Computes sample paths by unrolling the RNN starting with a initial