mirror of
https://github.com/wassname/pytorch-ts.git
synced 2026-06-30 01:23:05 +08:00
597 lines
21 KiB
Python
597 lines
21 KiB
Python
from typing import List, Optional, Tuple, Union
|
|
|
|
import torch
|
|
import torch.nn as nn
|
|
|
|
from gluonts.core.component import validated
|
|
|
|
from pts.model import weighted_average
|
|
from pts.modules import RealNVP, MAF, FlowOutput, MeanScaler, NOPScaler
|
|
|
|
|
|
class TempFlowTrainingNetwork(nn.Module):
|
|
@validated()
|
|
def __init__(
|
|
self,
|
|
input_size: int,
|
|
num_layers: int,
|
|
num_cells: int,
|
|
cell_type: str,
|
|
history_length: int,
|
|
context_length: int,
|
|
prediction_length: int,
|
|
dropout_rate: float,
|
|
lags_seq: List[int],
|
|
target_dim: int,
|
|
conditioning_length: int,
|
|
flow_type: str,
|
|
n_blocks: int,
|
|
hidden_size: int,
|
|
n_hidden: int,
|
|
dequantize: bool,
|
|
cardinality: List[int] = [1],
|
|
embedding_dimension: int = 1,
|
|
scaling: bool = True,
|
|
**kwargs,
|
|
) -> None:
|
|
super().__init__(**kwargs)
|
|
self.target_dim = target_dim
|
|
self.prediction_length = prediction_length
|
|
self.context_length = context_length
|
|
self.history_length = history_length
|
|
self.scaling = scaling
|
|
|
|
assert len(set(lags_seq)) == len(lags_seq), "no duplicated lags allowed!"
|
|
lags_seq.sort()
|
|
self.lags_seq = lags_seq
|
|
|
|
self.cell_type = cell_type
|
|
rnn_cls = {"LSTM": nn.LSTM, "GRU": nn.GRU}[cell_type]
|
|
self.rnn = rnn_cls(
|
|
input_size=input_size,
|
|
hidden_size=num_cells,
|
|
num_layers=num_layers,
|
|
dropout=dropout_rate,
|
|
batch_first=True,
|
|
)
|
|
|
|
flow_cls = {
|
|
"RealNVP": RealNVP,
|
|
"MAF": MAF,
|
|
}[flow_type]
|
|
self.flow = flow_cls(
|
|
input_size=target_dim,
|
|
n_blocks=n_blocks,
|
|
n_hidden=n_hidden,
|
|
hidden_size=hidden_size,
|
|
cond_label_size=conditioning_length,
|
|
)
|
|
self.dequantize = dequantize
|
|
|
|
self.distr_output = FlowOutput(
|
|
self.flow, input_size=target_dim, cond_size=conditioning_length
|
|
)
|
|
|
|
self.proj_dist_args = self.distr_output.get_args_proj(num_cells)
|
|
|
|
self.embed_dim = 1
|
|
self.embed = nn.Embedding(
|
|
num_embeddings=self.target_dim, embedding_dim=self.embed_dim
|
|
)
|
|
|
|
if self.scaling:
|
|
self.scaler = MeanScaler(keepdim=True)
|
|
else:
|
|
self.scaler = NOPScaler(keepdim=True)
|
|
|
|
@staticmethod
|
|
def get_lagged_subsequences(
|
|
sequence: torch.Tensor,
|
|
sequence_length: int,
|
|
indices: List[int],
|
|
subsequences_length: int = 1,
|
|
) -> torch.Tensor:
|
|
"""
|
|
Returns lagged subsequences of a given sequence.
|
|
Parameters
|
|
----------
|
|
sequence
|
|
the sequence from which lagged subsequences should be extracted.
|
|
Shape: (N, T, C).
|
|
sequence_length
|
|
length of sequence in the T (time) dimension (axis = 1).
|
|
indices
|
|
list of lag indices to be used.
|
|
subsequences_length
|
|
length of the subsequences to be extracted.
|
|
Returns
|
|
--------
|
|
lagged : Tensor
|
|
a tensor of shape (N, S, C, I),
|
|
where S = subsequences_length and I = len(indices),
|
|
containing lagged subsequences.
|
|
Specifically, lagged[i, :, j, k] = sequence[i, -indices[k]-S+j, :].
|
|
"""
|
|
# we must have: history_length + begin_index >= 0
|
|
# that is: history_length - lag_index - sequence_length >= 0
|
|
# hence the following assert
|
|
assert max(indices) + subsequences_length <= sequence_length, (
|
|
f"lags cannot go further than history length, found lag "
|
|
f"{max(indices)} while history length is only {sequence_length}"
|
|
)
|
|
assert all(lag_index >= 0 for lag_index in indices)
|
|
|
|
lagged_values = []
|
|
for lag_index in indices:
|
|
begin_index = -lag_index - subsequences_length
|
|
end_index = -lag_index if lag_index > 0 else None
|
|
lagged_values.append(sequence[:, begin_index:end_index, ...].unsqueeze(1))
|
|
return torch.cat(lagged_values, dim=1).permute(0, 2, 3, 1)
|
|
|
|
def unroll(
|
|
self,
|
|
lags: torch.Tensor,
|
|
scale: torch.Tensor,
|
|
time_feat: torch.Tensor,
|
|
target_dimension_indicator: torch.Tensor,
|
|
unroll_length: int,
|
|
begin_state: Optional[Union[List[torch.Tensor], torch.Tensor]] = None,
|
|
) -> Tuple[
|
|
torch.Tensor,
|
|
Union[List[torch.Tensor], torch.Tensor],
|
|
torch.Tensor,
|
|
torch.Tensor,
|
|
]:
|
|
|
|
# (batch_size, sub_seq_len, target_dim, num_lags)
|
|
lags_scaled = lags / scale.unsqueeze(-1)
|
|
|
|
# assert_shape(
|
|
# lags_scaled, (-1, unroll_length, self.target_dim, len(self.lags_seq)),
|
|
# )
|
|
|
|
input_lags = lags_scaled.reshape(
|
|
(-1, unroll_length, len(self.lags_seq) * self.target_dim)
|
|
)
|
|
|
|
# (batch_size, target_dim, embed_dim)
|
|
index_embeddings = self.embed(target_dimension_indicator)
|
|
# assert_shape(index_embeddings, (-1, self.target_dim, self.embed_dim))
|
|
|
|
# (batch_size, seq_len, target_dim * embed_dim)
|
|
repeated_index_embeddings = (
|
|
index_embeddings.unsqueeze(1)
|
|
.expand(-1, unroll_length, -1, -1)
|
|
.reshape((-1, unroll_length, self.target_dim * self.embed_dim))
|
|
)
|
|
|
|
# (batch_size, sub_seq_len, input_dim)
|
|
inputs = torch.cat((input_lags, repeated_index_embeddings, time_feat), dim=-1)
|
|
|
|
# unroll encoder
|
|
outputs, state = self.rnn(inputs, begin_state)
|
|
|
|
# assert_shape(outputs, (-1, unroll_length, self.num_cells))
|
|
# for s in state:
|
|
# assert_shape(s, (-1, self.num_cells))
|
|
|
|
# assert_shape(
|
|
# lags_scaled, (-1, unroll_length, self.target_dim, len(self.lags_seq)),
|
|
# )
|
|
|
|
return outputs, state, lags_scaled, inputs
|
|
|
|
def unroll_encoder(
|
|
self,
|
|
past_time_feat: torch.Tensor,
|
|
past_target_cdf: torch.Tensor,
|
|
past_observed_values: torch.Tensor,
|
|
past_is_pad: torch.Tensor,
|
|
future_time_feat: Optional[torch.Tensor],
|
|
future_target_cdf: Optional[torch.Tensor],
|
|
target_dimension_indicator: torch.Tensor,
|
|
) -> Tuple[
|
|
torch.Tensor,
|
|
Union[List[torch.Tensor], torch.Tensor],
|
|
torch.Tensor,
|
|
torch.Tensor,
|
|
torch.Tensor,
|
|
]:
|
|
"""
|
|
Unrolls the RNN encoder over past and, if present, future data.
|
|
Returns outputs and state of the encoder, plus the scale of
|
|
past_target_cdf and a vector of static features that was constructed
|
|
and fed as input to the encoder. All tensor arguments should have NTC
|
|
layout.
|
|
|
|
Parameters
|
|
----------
|
|
past_time_feat
|
|
Past time features (batch_size, history_length, num_features)
|
|
past_target_cdf
|
|
Past marginal CDF transformed target values (batch_size,
|
|
history_length, target_dim)
|
|
past_observed_values
|
|
Indicator whether or not the values were observed (batch_size,
|
|
history_length, target_dim)
|
|
past_is_pad
|
|
Indicator whether the past target values have been padded
|
|
(batch_size, history_length)
|
|
future_time_feat
|
|
Future time features (batch_size, prediction_length, num_features)
|
|
future_target_cdf
|
|
Future marginal CDF transformed target values (batch_size,
|
|
prediction_length, target_dim)
|
|
target_dimension_indicator
|
|
Dimensionality of the time series (batch_size, target_dim)
|
|
|
|
Returns
|
|
-------
|
|
outputs
|
|
RNN outputs (batch_size, seq_len, num_cells)
|
|
states
|
|
RNN states. Nested list with (batch_size, num_cells) tensors with
|
|
dimensions target_dim x num_layers x (batch_size, num_cells)
|
|
scale
|
|
Mean scales for the time series (batch_size, 1, target_dim)
|
|
lags_scaled
|
|
Scaled lags(batch_size, sub_seq_len, target_dim, num_lags)
|
|
inputs
|
|
inputs to the RNN
|
|
|
|
"""
|
|
|
|
past_observed_values = torch.min(
|
|
past_observed_values, 1 - past_is_pad.unsqueeze(-1)
|
|
)
|
|
|
|
if future_time_feat is None or future_target_cdf is None:
|
|
time_feat = past_time_feat[:, -self.context_length :, ...]
|
|
sequence = past_target_cdf
|
|
sequence_length = self.history_length
|
|
subsequences_length = self.context_length
|
|
else:
|
|
time_feat = torch.cat(
|
|
(past_time_feat[:, -self.context_length :, ...], future_time_feat),
|
|
dim=1,
|
|
)
|
|
sequence = torch.cat((past_target_cdf, future_target_cdf), dim=1)
|
|
sequence_length = self.history_length + self.prediction_length
|
|
subsequences_length = self.context_length + self.prediction_length
|
|
|
|
# (batch_size, sub_seq_len, target_dim, num_lags)
|
|
lags = self.get_lagged_subsequences(
|
|
sequence=sequence,
|
|
sequence_length=sequence_length,
|
|
indices=self.lags_seq,
|
|
subsequences_length=subsequences_length,
|
|
)
|
|
|
|
# scale is computed on the context length last units of the past target
|
|
# scale shape is (batch_size, 1, target_dim)
|
|
_, scale = self.scaler(
|
|
past_target_cdf[:, -self.context_length :, ...],
|
|
past_observed_values[:, -self.context_length :, ...],
|
|
)
|
|
|
|
outputs, states, lags_scaled, inputs = self.unroll(
|
|
lags=lags,
|
|
scale=scale,
|
|
time_feat=time_feat,
|
|
target_dimension_indicator=target_dimension_indicator,
|
|
unroll_length=subsequences_length,
|
|
begin_state=None,
|
|
)
|
|
|
|
return outputs, states, scale, lags_scaled, inputs
|
|
|
|
def distr_args(self, rnn_outputs: torch.Tensor):
|
|
"""
|
|
Returns the distribution of DeepVAR with respect to the RNN outputs.
|
|
|
|
Parameters
|
|
----------
|
|
rnn_outputs
|
|
Outputs of the unrolled RNN (batch_size, seq_len, num_cells)
|
|
scale
|
|
Mean scale for each time series (batch_size, 1, target_dim)
|
|
|
|
Returns
|
|
-------
|
|
distr
|
|
Distribution instance
|
|
distr_args
|
|
Distribution arguments
|
|
"""
|
|
(distr_args,) = self.proj_dist_args(rnn_outputs)
|
|
|
|
# # compute likelihood of target given the predicted parameters
|
|
# distr = self.distr_output.distribution(distr_args, scale=scale)
|
|
|
|
# return distr, distr_args
|
|
return distr_args
|
|
|
|
def forward(
|
|
self,
|
|
target_dimension_indicator: torch.Tensor,
|
|
past_time_feat: torch.Tensor,
|
|
past_target_cdf: torch.Tensor,
|
|
past_observed_values: torch.Tensor,
|
|
past_is_pad: torch.Tensor,
|
|
future_time_feat: torch.Tensor,
|
|
future_target_cdf: torch.Tensor,
|
|
future_observed_values: torch.Tensor,
|
|
) -> Tuple[torch.Tensor, ...]:
|
|
"""
|
|
Computes the loss for training DeepVAR, all inputs tensors representing
|
|
time series have NTC layout.
|
|
|
|
Parameters
|
|
----------
|
|
target_dimension_indicator
|
|
Indices of the target dimension (batch_size, target_dim)
|
|
past_time_feat
|
|
Dynamic features of past time series (batch_size, history_length,
|
|
num_features)
|
|
past_target_cdf
|
|
Past marginal CDF transformed target values (batch_size,
|
|
history_length, target_dim)
|
|
past_observed_values
|
|
Indicator whether or not the values were observed (batch_size,
|
|
history_length, target_dim)
|
|
past_is_pad
|
|
Indicator whether the past target values have been padded
|
|
(batch_size, history_length)
|
|
future_time_feat
|
|
Future time features (batch_size, prediction_length, num_features)
|
|
future_target_cdf
|
|
Future marginal CDF transformed target values (batch_size,
|
|
prediction_length, target_dim)
|
|
future_observed_values
|
|
Indicator whether or not the future values were observed
|
|
(batch_size, prediction_length, target_dim)
|
|
|
|
Returns
|
|
-------
|
|
distr
|
|
Loss with shape (batch_size, 1)
|
|
likelihoods
|
|
Likelihoods for each time step
|
|
(batch_size, context + prediction_length, 1)
|
|
distr_args
|
|
Distribution arguments (context + prediction_length,
|
|
number_of_arguments)
|
|
"""
|
|
|
|
seq_len = self.context_length + self.prediction_length
|
|
|
|
# unroll the decoder in "training mode", i.e. by providing future data
|
|
# as well
|
|
rnn_outputs, _, scale, _, _ = self.unroll_encoder(
|
|
past_time_feat=past_time_feat,
|
|
past_target_cdf=past_target_cdf,
|
|
past_observed_values=past_observed_values,
|
|
past_is_pad=past_is_pad,
|
|
future_time_feat=future_time_feat,
|
|
future_target_cdf=future_target_cdf,
|
|
target_dimension_indicator=target_dimension_indicator,
|
|
)
|
|
|
|
# put together target sequence
|
|
# (batch_size, seq_len, target_dim)
|
|
target = torch.cat(
|
|
(past_target_cdf[:, -self.context_length :, ...], future_target_cdf),
|
|
dim=1,
|
|
)
|
|
|
|
# assert_shape(target, (-1, seq_len, self.target_dim))
|
|
|
|
distr_args = self.distr_args(rnn_outputs=rnn_outputs)
|
|
if self.scaling:
|
|
self.flow.scale = scale
|
|
|
|
# we sum the last axis to have the same shape for all likelihoods
|
|
# (batch_size, subseq_length, 1)
|
|
if self.dequantize:
|
|
target += torch.rand_like(target)
|
|
likelihoods = -self.flow.log_prob(target, distr_args).unsqueeze(-1)
|
|
|
|
# assert_shape(likelihoods, (-1, seq_len, 1))
|
|
|
|
past_observed_values = torch.min(
|
|
past_observed_values, 1 - past_is_pad.unsqueeze(-1)
|
|
)
|
|
|
|
# (batch_size, subseq_length, target_dim)
|
|
observed_values = torch.cat(
|
|
(
|
|
past_observed_values[:, -self.context_length :, ...],
|
|
future_observed_values,
|
|
),
|
|
dim=1,
|
|
)
|
|
|
|
# mask the loss at one time step if one or more observations is missing
|
|
# in the target dimensions (batch_size, subseq_length, 1)
|
|
loss_weights, _ = observed_values.min(dim=-1, keepdim=True)
|
|
|
|
# assert_shape(loss_weights, (-1, seq_len, 1))
|
|
|
|
loss = weighted_average(likelihoods, weights=loss_weights, dim=1)
|
|
|
|
# assert_shape(loss, (-1, -1, 1))
|
|
|
|
# self.distribution = distr
|
|
|
|
return (loss.mean(), likelihoods, distr_args)
|
|
|
|
|
|
class TempFlowPredictionNetwork(TempFlowTrainingNetwork):
|
|
def __init__(self, num_parallel_samples: int, **kwargs) -> None:
|
|
super().__init__(**kwargs)
|
|
self.num_parallel_samples = num_parallel_samples
|
|
|
|
# for decoding the lags are shifted by one,
|
|
# at the first time-step of the decoder a lag of one corresponds to
|
|
# the last target value
|
|
self.shifted_lags = [l - 1 for l in self.lags_seq]
|
|
|
|
def sampling_decoder(
|
|
self,
|
|
past_target_cdf: torch.Tensor,
|
|
target_dimension_indicator: torch.Tensor,
|
|
time_feat: torch.Tensor,
|
|
scale: torch.Tensor,
|
|
begin_states: Union[List[torch.Tensor], torch.Tensor],
|
|
) -> torch.Tensor:
|
|
"""
|
|
Computes sample paths by unrolling the RNN starting with a initial
|
|
input and state.
|
|
|
|
Parameters
|
|
----------
|
|
past_target_cdf
|
|
Past marginal CDF transformed target values (batch_size,
|
|
history_length, target_dim)
|
|
target_dimension_indicator
|
|
Indices of the target dimension (batch_size, target_dim)
|
|
time_feat
|
|
Dynamic features of future time series (batch_size, history_length,
|
|
num_features)
|
|
scale
|
|
Mean scale for each time series (batch_size, 1, target_dim)
|
|
begin_states
|
|
List of initial states for the RNN layers (batch_size, num_cells)
|
|
Returns
|
|
--------
|
|
sample_paths : Tensor
|
|
A tensor containing sampled paths. Shape: (1, num_sample_paths,
|
|
prediction_length, target_dim).
|
|
"""
|
|
|
|
def repeat(tensor, dim=0):
|
|
return tensor.repeat_interleave(repeats=self.num_parallel_samples, dim=dim)
|
|
|
|
# blows-up the dimension of each tensor to
|
|
# batch_size * self.num_sample_paths for increasing parallelism
|
|
repeated_past_target_cdf = repeat(past_target_cdf)
|
|
repeated_time_feat = repeat(time_feat)
|
|
repeated_scale = repeat(scale)
|
|
if self.scaling:
|
|
self.flow.scale = repeated_scale
|
|
repeated_target_dimension_indicator = repeat(target_dimension_indicator)
|
|
|
|
if self.cell_type == "LSTM":
|
|
repeated_states = [repeat(s, dim=1) for s in begin_states]
|
|
else:
|
|
repeated_states = repeat(begin_states, dim=1)
|
|
|
|
future_samples = []
|
|
|
|
# for each future time-units we draw new samples for this time-unit
|
|
# and update the state
|
|
for k in range(self.prediction_length):
|
|
lags = self.get_lagged_subsequences(
|
|
sequence=repeated_past_target_cdf,
|
|
sequence_length=self.history_length + k,
|
|
indices=self.shifted_lags,
|
|
subsequences_length=1,
|
|
)
|
|
|
|
rnn_outputs, repeated_states, _, _ = self.unroll(
|
|
begin_state=repeated_states,
|
|
lags=lags,
|
|
scale=repeated_scale,
|
|
time_feat=repeated_time_feat[:, k : k + 1, ...],
|
|
target_dimension_indicator=repeated_target_dimension_indicator,
|
|
unroll_length=1,
|
|
)
|
|
|
|
distr_args = self.distr_args(rnn_outputs=rnn_outputs)
|
|
|
|
# (batch_size, 1, target_dim)
|
|
new_samples = self.flow.sample(cond=distr_args)
|
|
|
|
# (batch_size, seq_len, target_dim)
|
|
future_samples.append(new_samples)
|
|
repeated_past_target_cdf = torch.cat(
|
|
(repeated_past_target_cdf, new_samples), dim=1
|
|
)
|
|
|
|
# (batch_size * num_samples, prediction_length, target_dim)
|
|
samples = torch.cat(future_samples, dim=1)
|
|
|
|
# (batch_size, num_samples, prediction_length, target_dim)
|
|
return samples.reshape(
|
|
(
|
|
-1,
|
|
self.num_parallel_samples,
|
|
self.prediction_length,
|
|
self.target_dim,
|
|
)
|
|
)
|
|
|
|
def forward(
|
|
self,
|
|
target_dimension_indicator: torch.Tensor,
|
|
past_time_feat: torch.Tensor,
|
|
past_target_cdf: torch.Tensor,
|
|
past_observed_values: torch.Tensor,
|
|
past_is_pad: torch.Tensor,
|
|
future_time_feat: torch.Tensor,
|
|
) -> torch.Tensor:
|
|
"""
|
|
Predicts samples given the trained DeepVAR model.
|
|
All tensors should have NTC layout.
|
|
Parameters
|
|
----------
|
|
target_dimension_indicator
|
|
Indices of the target dimension (batch_size, target_dim)
|
|
past_time_feat
|
|
Dynamic features of past time series (batch_size, history_length,
|
|
num_features)
|
|
past_target_cdf
|
|
Past marginal CDF transformed target values (batch_size,
|
|
history_length, target_dim)
|
|
past_observed_values
|
|
Indicator whether or not the values were observed (batch_size,
|
|
history_length, target_dim)
|
|
past_is_pad
|
|
Indicator whether the past target values have been padded
|
|
(batch_size, history_length)
|
|
future_time_feat
|
|
Future time features (batch_size, prediction_length, num_features)
|
|
|
|
Returns
|
|
-------
|
|
sample_paths : Tensor
|
|
A tensor containing sampled paths (1, num_sample_paths,
|
|
prediction_length, target_dim).
|
|
|
|
"""
|
|
|
|
# mark padded data as unobserved
|
|
# (batch_size, target_dim, seq_len)
|
|
past_observed_values = torch.min(
|
|
past_observed_values, 1 - past_is_pad.unsqueeze(-1)
|
|
)
|
|
|
|
# unroll the decoder in "prediction mode", i.e. with past data only
|
|
_, begin_states, scale, _, _ = self.unroll_encoder(
|
|
past_time_feat=past_time_feat,
|
|
past_target_cdf=past_target_cdf,
|
|
past_observed_values=past_observed_values,
|
|
past_is_pad=past_is_pad,
|
|
future_time_feat=None,
|
|
future_target_cdf=None,
|
|
target_dimension_indicator=target_dimension_indicator,
|
|
)
|
|
|
|
return self.sampling_decoder(
|
|
past_target_cdf=past_target_cdf,
|
|
target_dimension_indicator=target_dimension_indicator,
|
|
time_feat=future_time_feat,
|
|
scale=scale,
|
|
begin_states=begin_states,
|
|
)
|