mirror of
https://github.com/wassname/pytorch-ts.git
synced 2026-06-27 19:32:05 +08:00
use history_length - context_length for past_time_feat in deepvar
This commit is contained in:
@@ -238,13 +238,18 @@ class DeepVARTrainingNetwork(nn.Module):
|
||||
)
|
||||
|
||||
if future_time_feat is None or future_target_cdf is None:
|
||||
time_feat = past_time_feat[:, -self.context_length :, ...]
|
||||
time_feat = past_time_feat[
|
||||
:, self.history_length - self.context_length :, ...
|
||||
]
|
||||
sequence = past_target_cdf
|
||||
sequence_length = self.history_length
|
||||
subsequences_length = self.context_length
|
||||
else:
|
||||
time_feat = torch.cat(
|
||||
(past_time_feat[:, -self.context_length :, ...], future_time_feat),
|
||||
(
|
||||
past_time_feat[:, self.history_length - self.context_length :, ...],
|
||||
future_time_feat,
|
||||
),
|
||||
dim=1,
|
||||
)
|
||||
sequence = torch.cat((past_target_cdf, future_target_cdf), dim=1)
|
||||
@@ -401,7 +406,7 @@ class DeepVARTrainingNetwork(nn.Module):
|
||||
|
||||
# mask the loss at one time step if one or more observations is missing
|
||||
# in the target dimensions (batch_size, subseq_length, 1)
|
||||
loss_weights,_ = observed_values.min(dim=-1, keepdim=True)
|
||||
loss_weights, _ = observed_values.min(dim=-1, keepdim=True)
|
||||
|
||||
# assert_shape(loss_weights, (-1, seq_len, 1))
|
||||
|
||||
@@ -430,7 +435,7 @@ class DeepVARPredictionNetwork(DeepVARTrainingNetwork):
|
||||
target_dimension_indicator: torch.Tensor,
|
||||
time_feat: torch.Tensor,
|
||||
scale: torch.Tensor,
|
||||
begin_states: Union[List[torch.Tensor], torch.Tensor]
|
||||
begin_states: Union[List[torch.Tensor], torch.Tensor],
|
||||
) -> torch.Tensor:
|
||||
"""
|
||||
Computes sample paths by unrolling the RNN starting with a initial
|
||||
|
||||
Reference in New Issue
Block a user