mirror of
https://github.com/wassname/pytorch-transformer-ts.git
synced 2026-06-27 16:46:28 +08:00
initial std scaler
This commit is contained in:
@@ -0,0 +1,9 @@
|
||||
from .estimator import TransformerEstimator
|
||||
from .lightning_module import TransformerLightningModule
|
||||
from .module import TransformerModel
|
||||
|
||||
__all__ = [
|
||||
"TransformerModel",
|
||||
"TransformerLightningModule",
|
||||
"TransformerEstimator",
|
||||
]
|
||||
@@ -0,0 +1,309 @@
|
||||
from typing import Any, Dict, Iterable, List, Optional
|
||||
|
||||
# +
|
||||
import torch
|
||||
from torch.utils.data import DataLoader
|
||||
|
||||
from gluonts.core.component import validated
|
||||
from gluonts.dataset.common import Dataset
|
||||
from gluonts.dataset.field_names import FieldName
|
||||
from gluonts.itertools import Cyclic, IterableSlice, PseudoShuffled
|
||||
from gluonts.time_feature import TimeFeature, time_features_from_frequency_str
|
||||
from gluonts.torch.model.estimator import PyTorchLightningEstimator
|
||||
from gluonts.torch.model.predictor import PyTorchPredictor
|
||||
from gluonts.torch.distributions import DistributionOutput, StudentTOutput
|
||||
from gluonts.torch.modules.loss import DistributionLoss, NegativeLogLikelihood
|
||||
from gluonts.torch.util import IterableDataset
|
||||
from gluonts.transform import (
|
||||
AddAgeFeature,
|
||||
AddObservedValuesIndicator,
|
||||
AddTimeFeatures,
|
||||
AsNumpyArray,
|
||||
Chain,
|
||||
ExpectedNumInstanceSampler,
|
||||
InstanceSplitter,
|
||||
RemoveFields,
|
||||
SelectFields,
|
||||
SetField,
|
||||
TestSplitSampler,
|
||||
Transformation,
|
||||
ValidationSplitSampler,
|
||||
VstackFeatures,
|
||||
)
|
||||
from gluonts.transform.sampler import InstanceSampler
|
||||
|
||||
from lightning_module import NSTransformerLightningModule
|
||||
from module import NSTransformerModel
|
||||
# -
|
||||
|
||||
PREDICTION_INPUT_NAMES = [
|
||||
"feat_static_cat",
|
||||
"feat_static_real",
|
||||
"past_time_feat",
|
||||
"past_target",
|
||||
"past_observed_values",
|
||||
"future_time_feat",
|
||||
]
|
||||
|
||||
TRAINING_INPUT_NAMES = PREDICTION_INPUT_NAMES + [
|
||||
"future_target",
|
||||
"future_observed_values",
|
||||
]
|
||||
|
||||
|
||||
class NSTransformerEstimator(PyTorchLightningEstimator):
|
||||
@validated()
|
||||
def __init__(
|
||||
self,
|
||||
prediction_length: int,
|
||||
# Transformer arguments
|
||||
nhead: int,
|
||||
num_encoder_layers: int,
|
||||
num_decoder_layers: int,
|
||||
dim_feedforward: int,
|
||||
freq: Optional[str] = None,
|
||||
input_size: int = 1,
|
||||
activation: str = "gelu",
|
||||
dropout: float = 0.1,
|
||||
context_length: Optional[int] = None,
|
||||
num_feat_dynamic_real: int = 0,
|
||||
num_feat_static_cat: int = 0,
|
||||
num_feat_static_real: int = 0,
|
||||
cardinality: Optional[List[int]] = None,
|
||||
embedding_dimension: Optional[List[int]] = None,
|
||||
distr_output: DistributionOutput = StudentTOutput(),
|
||||
loss: DistributionLoss = NegativeLogLikelihood(),
|
||||
lags_seq: Optional[List[int]] = None,
|
||||
time_features: Optional[List[TimeFeature]] = None,
|
||||
num_parallel_samples: int = 100,
|
||||
batch_size: int = 32,
|
||||
num_batches_per_epoch: int = 50,
|
||||
trainer_kwargs: Optional[Dict[str, Any]] = dict(),
|
||||
train_sampler: Optional[InstanceSampler] = None,
|
||||
validation_sampler: Optional[InstanceSampler] = None,
|
||||
) -> None:
|
||||
trainer_kwargs = {
|
||||
"max_epochs": 100,
|
||||
**trainer_kwargs,
|
||||
}
|
||||
super().__init__(trainer_kwargs=trainer_kwargs)
|
||||
|
||||
self.freq = freq
|
||||
self.context_length = (
|
||||
context_length if context_length is not None else prediction_length
|
||||
)
|
||||
self.prediction_length = prediction_length
|
||||
self.distr_output = distr_output
|
||||
self.loss = loss
|
||||
|
||||
self.input_size = input_size
|
||||
self.nhead = nhead
|
||||
self.num_encoder_layers = num_encoder_layers
|
||||
self.num_decoder_layers = num_decoder_layers
|
||||
self.activation = activation
|
||||
self.dim_feedforward = dim_feedforward
|
||||
self.dropout = dropout
|
||||
|
||||
self.num_feat_dynamic_real = num_feat_dynamic_real
|
||||
self.num_feat_static_cat = num_feat_static_cat
|
||||
self.num_feat_static_real = num_feat_static_real
|
||||
self.cardinality = (
|
||||
cardinality if cardinality and num_feat_static_cat > 0 else [1]
|
||||
)
|
||||
self.embedding_dimension = embedding_dimension
|
||||
self.lags_seq = lags_seq
|
||||
self.time_features = (
|
||||
time_features
|
||||
if time_features is not None
|
||||
else time_features_from_frequency_str(self.freq)
|
||||
)
|
||||
|
||||
self.num_parallel_samples = num_parallel_samples
|
||||
self.batch_size = batch_size
|
||||
self.num_batches_per_epoch = num_batches_per_epoch
|
||||
|
||||
self.train_sampler = train_sampler or ExpectedNumInstanceSampler(
|
||||
num_instances=1.0, min_future=prediction_length
|
||||
)
|
||||
self.validation_sampler = validation_sampler or ValidationSplitSampler(
|
||||
min_future=prediction_length
|
||||
)
|
||||
|
||||
def create_transformation(self) -> Transformation:
|
||||
remove_field_names = []
|
||||
if self.num_feat_static_real == 0:
|
||||
remove_field_names.append(FieldName.FEAT_STATIC_REAL)
|
||||
if self.num_feat_dynamic_real == 0:
|
||||
remove_field_names.append(FieldName.FEAT_DYNAMIC_REAL)
|
||||
|
||||
return Chain(
|
||||
[RemoveFields(field_names=remove_field_names)]
|
||||
+ (
|
||||
[SetField(output_field=FieldName.FEAT_STATIC_CAT, value=[0])]
|
||||
if not self.num_feat_static_cat > 0
|
||||
else []
|
||||
)
|
||||
+ (
|
||||
[SetField(output_field=FieldName.FEAT_STATIC_REAL, value=[0.0])]
|
||||
if not self.num_feat_static_real > 0
|
||||
else []
|
||||
)
|
||||
+ [
|
||||
AsNumpyArray(
|
||||
field=FieldName.FEAT_STATIC_CAT,
|
||||
expected_ndim=1,
|
||||
dtype=int,
|
||||
),
|
||||
AsNumpyArray(
|
||||
field=FieldName.FEAT_STATIC_REAL,
|
||||
expected_ndim=1,
|
||||
),
|
||||
AsNumpyArray(
|
||||
field=FieldName.TARGET,
|
||||
# in the following line, we add 1 for the time dimension
|
||||
expected_ndim=1 + len(self.distr_output.event_shape),
|
||||
),
|
||||
AddObservedValuesIndicator(
|
||||
target_field=FieldName.TARGET,
|
||||
output_field=FieldName.OBSERVED_VALUES,
|
||||
),
|
||||
AddTimeFeatures(
|
||||
start_field=FieldName.START,
|
||||
target_field=FieldName.TARGET,
|
||||
output_field=FieldName.FEAT_TIME,
|
||||
time_features=self.time_features,
|
||||
pred_length=self.prediction_length,
|
||||
),
|
||||
AddAgeFeature(
|
||||
target_field=FieldName.TARGET,
|
||||
output_field=FieldName.FEAT_AGE,
|
||||
pred_length=self.prediction_length,
|
||||
log_scale=True,
|
||||
),
|
||||
VstackFeatures(
|
||||
output_field=FieldName.FEAT_TIME,
|
||||
input_fields=[FieldName.FEAT_TIME, FieldName.FEAT_AGE]
|
||||
+ (
|
||||
[FieldName.FEAT_DYNAMIC_REAL]
|
||||
if self.num_feat_dynamic_real > 0
|
||||
else []
|
||||
),
|
||||
),
|
||||
]
|
||||
)
|
||||
|
||||
def _create_instance_splitter(self, module: NSTransformerLightningModule, mode: str):
|
||||
assert mode in ["training", "validation", "test"]
|
||||
|
||||
instance_sampler = {
|
||||
"training": self.train_sampler,
|
||||
"validation": self.validation_sampler,
|
||||
"test": TestSplitSampler(),
|
||||
}[mode]
|
||||
|
||||
return InstanceSplitter(
|
||||
target_field=FieldName.TARGET,
|
||||
is_pad_field=FieldName.IS_PAD,
|
||||
start_field=FieldName.START,
|
||||
forecast_start_field=FieldName.FORECAST_START,
|
||||
instance_sampler=instance_sampler,
|
||||
past_length=module.model._past_length,
|
||||
future_length=self.prediction_length,
|
||||
time_series_fields=[
|
||||
FieldName.FEAT_TIME,
|
||||
FieldName.OBSERVED_VALUES,
|
||||
],
|
||||
dummy_value=self.distr_output.value_in_support,
|
||||
)
|
||||
|
||||
def create_training_data_loader(
|
||||
self,
|
||||
data: Dataset,
|
||||
module: NSTransformerLightningModule,
|
||||
shuffle_buffer_length: Optional[int] = None,
|
||||
**kwargs,
|
||||
) -> Iterable:
|
||||
transformation = self._create_instance_splitter(
|
||||
module, "training"
|
||||
) + SelectFields(TRAINING_INPUT_NAMES)
|
||||
|
||||
training_instances = transformation.apply(
|
||||
Cyclic(data)
|
||||
if shuffle_buffer_length is None
|
||||
else PseudoShuffled(
|
||||
Cyclic(data), shuffle_buffer_length=shuffle_buffer_length
|
||||
)
|
||||
)
|
||||
|
||||
return IterableSlice(
|
||||
iter(
|
||||
DataLoader(
|
||||
IterableDataset(training_instances),
|
||||
batch_size=self.batch_size,
|
||||
**kwargs,
|
||||
)
|
||||
),
|
||||
self.num_batches_per_epoch,
|
||||
)
|
||||
|
||||
def create_validation_data_loader(
|
||||
self,
|
||||
data: Dataset,
|
||||
module: NSTransformerLightningModule,
|
||||
**kwargs,
|
||||
) -> Iterable:
|
||||
transformation = self._create_instance_splitter(
|
||||
module, "validation"
|
||||
) + SelectFields(TRAINING_INPUT_NAMES)
|
||||
|
||||
validation_instances = transformation.apply(data)
|
||||
|
||||
return DataLoader(
|
||||
IterableDataset(validation_instances),
|
||||
batch_size=self.batch_size,
|
||||
**kwargs,
|
||||
)
|
||||
|
||||
def create_predictor(
|
||||
self,
|
||||
transformation: Transformation,
|
||||
module: NSTransformerLightningModule,
|
||||
) -> PyTorchPredictor:
|
||||
prediction_splitter = self._create_instance_splitter(module, "test")
|
||||
|
||||
return PyTorchPredictor(
|
||||
input_transform=transformation + prediction_splitter,
|
||||
input_names=PREDICTION_INPUT_NAMES,
|
||||
prediction_net=module.model,
|
||||
batch_size=self.batch_size,
|
||||
prediction_length=self.prediction_length,
|
||||
device=torch.device("cuda" if torch.cuda.is_available() else "cpu"),
|
||||
)
|
||||
|
||||
def create_lightning_module(self) -> NSTransformerLightningModule:
|
||||
model = NSTransformerModel(
|
||||
freq=self.freq,
|
||||
context_length=self.context_length,
|
||||
prediction_length=self.prediction_length,
|
||||
num_feat_dynamic_real=1
|
||||
+ self.num_feat_dynamic_real
|
||||
+ len(self.time_features),
|
||||
num_feat_static_real=max(1, self.num_feat_static_real),
|
||||
num_feat_static_cat=max(1, self.num_feat_static_cat),
|
||||
cardinality=self.cardinality,
|
||||
embedding_dimension=self.embedding_dimension,
|
||||
# transformer arguments
|
||||
nhead=self.nhead,
|
||||
num_encoder_layers=self.num_encoder_layers,
|
||||
num_decoder_layers=self.num_decoder_layers,
|
||||
activation=self.activation,
|
||||
dropout=self.dropout,
|
||||
dim_feedforward=self.dim_feedforward,
|
||||
# univariate input
|
||||
input_size=self.input_size,
|
||||
distr_output=self.distr_output,
|
||||
lags_seq=self.lags_seq,
|
||||
num_parallel_samples=self.num_parallel_samples,
|
||||
)
|
||||
|
||||
return NSTransformerLightningModule(model=model, loss=self.loss)
|
||||
@@ -0,0 +1,79 @@
|
||||
import pytorch_lightning as pl
|
||||
import torch
|
||||
from gluonts.torch.modules.loss import DistributionLoss, NegativeLogLikelihood
|
||||
from gluonts.torch.util import weighted_average
|
||||
from module import NSTransformerModel
|
||||
|
||||
|
||||
class NSTransformerLightningModule(pl.LightningModule):
|
||||
def __init__(
|
||||
self,
|
||||
model: NSTransformerModel,
|
||||
loss: DistributionLoss = NegativeLogLikelihood(),
|
||||
lr: float = 1e-3,
|
||||
weight_decay: float = 1e-8,
|
||||
) -> None:
|
||||
super().__init__()
|
||||
self.save_hyperparameters()
|
||||
self.model = model
|
||||
self.loss = loss
|
||||
self.lr = lr
|
||||
self.weight_decay = weight_decay
|
||||
|
||||
def training_step(self, batch, batch_idx: int):
|
||||
"""Execute training step"""
|
||||
train_loss = self(batch)
|
||||
self.log(
|
||||
"train_loss",
|
||||
train_loss,
|
||||
on_epoch=True,
|
||||
on_step=False,
|
||||
prog_bar=True,
|
||||
)
|
||||
return train_loss
|
||||
|
||||
def validation_step(self, batch, batch_idx: int):
|
||||
"""Execute validation step"""
|
||||
with torch.inference_mode():
|
||||
val_loss = self(batch)
|
||||
self.log("val_loss", val_loss, on_epoch=True, on_step=False, prog_bar=True)
|
||||
return val_loss
|
||||
|
||||
def configure_optimizers(self):
|
||||
"""Returns the optimizer to use"""
|
||||
return torch.optim.Adam(
|
||||
self.model.parameters(),
|
||||
lr=self.lr,
|
||||
weight_decay=self.weight_decay,
|
||||
)
|
||||
|
||||
def forward(self, batch):
|
||||
feat_static_cat = batch["feat_static_cat"]
|
||||
feat_static_real = batch["feat_static_real"]
|
||||
past_time_feat = batch["past_time_feat"]
|
||||
past_target = batch["past_target"]
|
||||
future_time_feat = batch["future_time_feat"]
|
||||
future_target = batch["future_target"]
|
||||
past_observed_values = batch["past_observed_values"]
|
||||
future_observed_values = batch["future_observed_values"]
|
||||
|
||||
transformer_inputs, loc, scale, _, _, _ = self.model.create_network_inputs(
|
||||
feat_static_cat,
|
||||
feat_static_real,
|
||||
past_time_feat,
|
||||
past_target,
|
||||
past_observed_values,
|
||||
future_time_feat,
|
||||
future_target,
|
||||
)
|
||||
params = self.model.output_params(transformer_inputs)
|
||||
distr = self.model.output_distribution(params, loc=loc, scale=scale)
|
||||
|
||||
loss_values = self.loss(distr, future_target)
|
||||
|
||||
if len(self.model.target_shape) == 0:
|
||||
loss_weights = future_observed_values
|
||||
else:
|
||||
loss_weights, _ = future_observed_values.min(dim=-1, keepdim=False)
|
||||
|
||||
return weighted_average(loss_values, weights=loss_weights)
|
||||
@@ -0,0 +1,675 @@
|
||||
from typing import List, Optional, Tuple
|
||||
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
from gluonts.core.component import validated
|
||||
from gluonts.time_feature import get_lags_for_frequency
|
||||
from gluonts.torch.distributions import DistributionOutput, StudentTOutput
|
||||
from gluonts.torch.modules.feature import FeatureEmbedder
|
||||
|
||||
|
||||
class StdScaler(nn.Module):
|
||||
"""
|
||||
Computes a std scaling value along
|
||||
dimension ``dim``, and scales the data accordingly.
|
||||
|
||||
Parameters
|
||||
----------
|
||||
dim
|
||||
dimension along which to compute the scale
|
||||
keepdim
|
||||
controls whether to retain dimension ``dim`` (of length 1) in the
|
||||
scale tensor, or suppress it.
|
||||
minimum_scale
|
||||
default scale that is used for elements that are constantly zero
|
||||
along dimension ``dim``.
|
||||
"""
|
||||
|
||||
@validated()
|
||||
def __init__(
|
||||
self, dim: int, keepdim: bool = False, minimum_scale: float = 1e-10
|
||||
):
|
||||
super().__init__()
|
||||
assert dim > 0, (
|
||||
"Cannot compute scale along dim = 0 (batch dimension), please"
|
||||
" provide dim > 0"
|
||||
)
|
||||
self.dim = dim
|
||||
self.keepdim = keepdim
|
||||
self.register_buffer("minimum_scale", torch.tensor(minimum_scale))
|
||||
|
||||
def forward(
|
||||
self, data: torch.Tensor, weights: torch.Tensor
|
||||
) -> Tuple[torch.Tensor, torch.Tensor]:
|
||||
|
||||
mean_data = data.mean(self.dim, keepdim=self.keepdim).detach()
|
||||
|
||||
std_data = torch.sqrt(torch.var(data - mean_data, dim=self.dim, keepdim=self.keepdim, unbiased=False) + self.minimum_scale).detach()
|
||||
|
||||
return (data - mean_data) / std_data, mean_data if self.keepdim else mean_data.squeeze(dim=self.dim), std_data if self.keepdim else scale.squeeze(
|
||||
dim=self.dim
|
||||
)
|
||||
|
||||
|
||||
class Projector(nn.Module):
|
||||
'''
|
||||
MLP to learn the De-stationary factors
|
||||
'''
|
||||
def __init__(self, enc_in, seq_len, hidden_dims, hidden_layers, output_dim, kernel_size=3):
|
||||
super(Projector, self).__init__()
|
||||
|
||||
self.series_conv = nn.Conv1d(in_channels=seq_len, out_channels=1, kernel_size=kernel_size,
|
||||
padding=1, padding_mode='circular', bias=False)
|
||||
|
||||
layers = [nn.Linear(2 * enc_in, hidden_dims[0]), nn.ReLU()]
|
||||
for i in range(hidden_layers-1):
|
||||
layers += [nn.Linear(hidden_dims[i], hidden_dims[i+1]), nn.ReLU()]
|
||||
|
||||
layers += [nn.Linear(hidden_dims[-1], output_dim, bias=False)]
|
||||
self.backbone = nn.Sequential(*layers)
|
||||
|
||||
def forward(self, x, stats):
|
||||
# x: B x S x E
|
||||
# stats: B x 1 x E
|
||||
# y: B x O
|
||||
batch_size = x.shape[0]
|
||||
|
||||
x = self.series_conv(x) # B x 1 x E
|
||||
x = torch.cat([x, stats], dim=1) # B x 2 x E
|
||||
x = x.view(batch_size, -1) # B x 2E
|
||||
y = self.backbone(x) # B x O
|
||||
|
||||
return y
|
||||
|
||||
|
||||
class NSMultiheadAttention(nn.MultiheadAttention):
|
||||
def forward(self,
|
||||
query: Tensor,
|
||||
key: Tensor,
|
||||
value: Tensor,
|
||||
tau=None, delta=None,
|
||||
key_padding_mask: Optional[Tensor] = None,
|
||||
need_weights: bool = True,
|
||||
attn_mask: Optional[Tensor] = None,
|
||||
average_attn_weights: bool = True) -> Tuple[Tensor, Optional[Tensor]]:
|
||||
r"""
|
||||
Note::
|
||||
Please, refer to :func:`~torch.nn.MultiheadAttention.forward` for more
|
||||
information
|
||||
Args:
|
||||
query, key, value: map a query and a set of key-value pairs to an output.
|
||||
See "Attention Is All You Need" for more details.
|
||||
key_padding_mask: if provided, specified padding elements in the key will
|
||||
be ignored by the attention. When given a binary mask and a value is True,
|
||||
the corresponding value on the attention layer will be ignored. When given
|
||||
a byte mask and a value is non-zero, the corresponding value on the attention
|
||||
layer will be ignored
|
||||
need_weights: output attn_output_weights.
|
||||
attn_mask: 2D or 3D mask that prevents attention to certain positions. A 2D mask will be broadcasted for all
|
||||
the batches while a 3D mask allows to specify a different mask for the entries of each batch.
|
||||
Shape:
|
||||
- Inputs:
|
||||
- query: :math:`(L, N, E)` where L is the target sequence length, N is the batch size, E is
|
||||
the embedding dimension. :math:`(N, L, E)` if ``batch_first`` is ``True``.
|
||||
- key: :math:`(S, N, E)`, where S is the source sequence length, N is the batch size, E is
|
||||
the embedding dimension. :math:`(N, S, E)` if ``batch_first`` is ``True``.
|
||||
- value: :math:`(S, N, E)` where S is the source sequence length, N is the batch size, E is
|
||||
the embedding dimension. :math:`(N, S, E)` if ``batch_first`` is ``True``.
|
||||
- key_padding_mask: :math:`(N, S)` where N is the batch size, S is the source sequence length.
|
||||
If a ByteTensor is provided, the non-zero positions will be ignored while the position
|
||||
with the zero positions will be unchanged. If a BoolTensor is provided, the positions with the
|
||||
value of ``True`` will be ignored while the position with the value of ``False`` will be unchanged.
|
||||
- attn_mask: 2D mask :math:`(L, S)` where L is the target sequence length, S is the source sequence length.
|
||||
3D mask :math:`(N*num_heads, L, S)` where N is the batch size, L is the target sequence length,
|
||||
S is the source sequence length. attn_mask ensure that position i is allowed to attend the unmasked
|
||||
positions. If a ByteTensor is provided, the non-zero positions are not allowed to attend
|
||||
while the zero positions will be unchanged. If a BoolTensor is provided, positions with ``True``
|
||||
is not allowed to attend while ``False`` values will be unchanged. If a FloatTensor
|
||||
is provided, it will be added to the attention weight.
|
||||
- average_attn_weights: If true, indicates that the returned ``attn_weights`` should be averaged across
|
||||
heads. Otherwise, ``attn_weights`` are provided separately per head. Note that this flag only has an
|
||||
effect when ``need_weights=True.``. Default: True (i.e. average weights across heads)
|
||||
- Outputs:
|
||||
- attn_output: :math:`(L, N, E)` where L is the target sequence length, N is the batch size,
|
||||
E is the embedding dimension. :math:`(N, L, E)` if ``batch_first`` is ``True``.
|
||||
- attn_output_weights: If ``average_attn_weights=True``, returns attention weights averaged
|
||||
across heads of shape :math:`(N, L, S)`, where N is the batch size, L is the target sequence length,
|
||||
S is the source sequence length. If ``average_attn_weights=False``, returns attention weights per
|
||||
head of shape :math:`(N, num_heads, L, S)`.
|
||||
"""
|
||||
return self._forward_impl(query, key, value, key_padding_mask=key_padding_mask,
|
||||
need_weights=need_weights, attn_mask=attn_mask,
|
||||
average_attn_weights=average_attn_weights, tau=tau, delta=delta)
|
||||
|
||||
def _forward_impl(self,
|
||||
query: Tensor,
|
||||
key: Tensor,
|
||||
value: Tensor,
|
||||
tau=None, delta=None,
|
||||
key_padding_mask: Optional[Tensor] = None,
|
||||
need_weights: bool = True,
|
||||
attn_mask: Optional[Tensor] = None,
|
||||
average_attn_weights: bool = True) -> Tuple[Tensor, Optional[Tensor]]:
|
||||
# This version will not deal with the static key/value pairs.
|
||||
# Keeping it here for future changes.
|
||||
#
|
||||
# TODO: This method has some duplicate lines with the
|
||||
# `torch.nn.functional.multi_head_attention`. Will need to refactor.
|
||||
static_k = None
|
||||
static_v = None
|
||||
|
||||
if self.batch_first:
|
||||
query, key, value = [x.transpose(0, 1) for x in (query, key, value)]
|
||||
|
||||
tgt_len, bsz, embed_dim_to_check = query.size()
|
||||
assert self.embed_dim == embed_dim_to_check
|
||||
# allow MHA to have different sizes for the feature dimension
|
||||
assert key.size(0) == value.size(0) and key.size(1) == value.size(1)
|
||||
|
||||
head_dim = self.embed_dim // self.num_heads
|
||||
assert head_dim * self.num_heads == self.embed_dim, "embed_dim must be divisible by num_heads"
|
||||
scaling = float(head_dim) ** -0.5
|
||||
|
||||
q = self.linear_Q(query)
|
||||
k = self.linear_K(key)
|
||||
v = self.linear_V(value)
|
||||
|
||||
q = self.q_scaling_product.mul_scalar(q, scaling)
|
||||
|
||||
if attn_mask is not None:
|
||||
assert attn_mask.dtype == torch.float32 or attn_mask.dtype == torch.float64 or \
|
||||
attn_mask.dtype == torch.float16 or attn_mask.dtype == torch.uint8 or attn_mask.dtype == torch.bool, \
|
||||
'Only float, byte, and bool types are supported for attn_mask, not {}'.format(attn_mask.dtype)
|
||||
if attn_mask.dtype == torch.uint8:
|
||||
warnings.warn("Byte tensor for attn_mask in nn.MultiheadAttention is deprecated. Use bool tensor instead.")
|
||||
attn_mask = attn_mask.to(torch.bool)
|
||||
|
||||
if attn_mask.dim() == 2:
|
||||
attn_mask = attn_mask.unsqueeze(0)
|
||||
if list(attn_mask.size()) != [1, query.size(0), key.size(0)]:
|
||||
raise RuntimeError('The size of the 2D attn_mask is not correct.')
|
||||
elif attn_mask.dim() == 3:
|
||||
if list(attn_mask.size()) != [bsz * self.num_heads, query.size(0), key.size(0)]:
|
||||
raise RuntimeError('The size of the 3D attn_mask is not correct.')
|
||||
else:
|
||||
raise RuntimeError("attn_mask's dimension {} is not supported".format(attn_mask.dim()))
|
||||
# attn_mask's dim is 3 now.
|
||||
|
||||
# convert ByteTensor key_padding_mask to bool
|
||||
if key_padding_mask is not None and key_padding_mask.dtype == torch.uint8:
|
||||
warnings.warn("Byte tensor for key_padding_mask in nn.MultiheadAttention is deprecated. Use bool tensor instead.")
|
||||
key_padding_mask = key_padding_mask.to(torch.bool)
|
||||
if self.bias_k is not None and self.bias_v is not None:
|
||||
if static_k is None and static_v is None:
|
||||
|
||||
# Explicitly assert that bias_k and bias_v are not None
|
||||
# in a way that TorchScript can understand.
|
||||
bias_k = self.bias_k
|
||||
assert bias_k is not None
|
||||
bias_v = self.bias_v
|
||||
assert bias_v is not None
|
||||
|
||||
k = torch.cat([k, bias_k.repeat(1, bsz, 1)])
|
||||
v = torch.cat([v, bias_v.repeat(1, bsz, 1)])
|
||||
if attn_mask is not None:
|
||||
attn_mask = nnF.pad(attn_mask, (0, 1))
|
||||
if key_padding_mask is not None:
|
||||
key_padding_mask = nnF.pad(key_padding_mask, (0, 1))
|
||||
else:
|
||||
assert static_k is None, "bias cannot be added to static key."
|
||||
assert static_v is None, "bias cannot be added to static value."
|
||||
else:
|
||||
assert self.bias_k is None
|
||||
assert self.bias_v is None
|
||||
|
||||
q = q.contiguous().view(tgt_len, bsz * self.num_heads, head_dim).transpose(0, 1)
|
||||
if k is not None:
|
||||
k = k.contiguous().view(-1, bsz * self.num_heads, head_dim).transpose(0, 1)
|
||||
if v is not None:
|
||||
v = v.contiguous().view(-1, bsz * self.num_heads, head_dim).transpose(0, 1)
|
||||
|
||||
if static_k is not None:
|
||||
assert static_k.size(0) == bsz * self.num_heads
|
||||
assert static_k.size(2) == head_dim
|
||||
k = static_k
|
||||
|
||||
if static_v is not None:
|
||||
assert static_v.size(0) == bsz * self.num_heads
|
||||
assert static_v.size(2) == head_dim
|
||||
v = static_v
|
||||
|
||||
src_len = k.size(1)
|
||||
|
||||
if key_padding_mask is not None:
|
||||
assert key_padding_mask.size(0) == bsz
|
||||
assert key_padding_mask.size(1) == src_len
|
||||
|
||||
if self.add_zero_attn:
|
||||
src_len += 1
|
||||
k_zeros = torch.zeros((k.size(0), 1) + k.size()[2:])
|
||||
if k.is_quantized:
|
||||
k_zeros = torch.quantize_per_tensor(k_zeros, k.q_scale(), k.q_zero_point(), k.dtype)
|
||||
k = torch.cat([k, k_zeros], dim=1)
|
||||
v_zeros = torch.zeros((v.size(0), 1) + k.size()[2:])
|
||||
if v.is_quantized:
|
||||
v_zeros = torch.quantize_per_tensor(v_zeros, v.q_scale(), v.q_zero_point(), v.dtype)
|
||||
v = torch.cat([v, v_zeros], dim=1)
|
||||
|
||||
if attn_mask is not None:
|
||||
attn_mask = nnF.pad(attn_mask, (0, 1))
|
||||
if key_padding_mask is not None:
|
||||
key_padding_mask = nnF.pad(key_padding_mask, (0, 1))
|
||||
|
||||
# Leaving the quantized zone here
|
||||
q = self.dequant_q(q)
|
||||
k = self.dequant_k(k)
|
||||
v = self.dequant_v(v)
|
||||
attn_output_weights = torch.bmm(q, k.transpose(1, 2))
|
||||
assert list(attn_output_weights.size()) == [bsz * self.num_heads, tgt_len, src_len]
|
||||
|
||||
if attn_mask is not None:
|
||||
if attn_mask.dtype == torch.bool:
|
||||
attn_output_weights.masked_fill_(attn_mask, float('-inf'))
|
||||
else:
|
||||
attn_output_weights += attn_mask
|
||||
|
||||
if key_padding_mask is not None:
|
||||
attn_output_weights = attn_output_weights.view(bsz, self.num_heads, tgt_len, src_len)
|
||||
attn_output_weights = attn_output_weights.masked_fill(
|
||||
key_padding_mask.unsqueeze(1).unsqueeze(2),
|
||||
float('-inf'),
|
||||
)
|
||||
attn_output_weights = attn_output_weights.view(bsz * self.num_heads, tgt_len, src_len)
|
||||
|
||||
attn_output_weights = nnF.softmax(
|
||||
attn_output_weights, dim=-1)
|
||||
attn_output_weights = nnF.dropout(attn_output_weights, p=self.dropout, training=self.training)
|
||||
|
||||
|
||||
|
||||
tau = 1.0 if tau is None else tau.unsqueeze(1).unsqueeze(1) # B x 1 x 1 x 1
|
||||
delta = 0.0 if delta is None else delta.unsqueeze(1).unsqueeze(1) # B x 1 x 1 x S
|
||||
|
||||
attn_output = torch.bmm(attn_output_weights, v)* tau + delta
|
||||
|
||||
|
||||
assert list(attn_output.size()) == [bsz * self.num_heads, tgt_len, head_dim]
|
||||
if self.batch_first:
|
||||
attn_output = attn_output.view(bsz, tgt_len, self.embed_dim)
|
||||
else:
|
||||
attn_output = attn_output.transpose(0, 1).contiguous().view(tgt_len, bsz, self.embed_dim)
|
||||
|
||||
# Reentering the quantized zone
|
||||
attn_output = self.quant_attn_output(attn_output)
|
||||
# for the type: ignore[has-type], see https://github.com/pytorch/pytorch/issues/58969
|
||||
attn_output = self.out_proj(attn_output) # type: ignore[has-type]
|
||||
attn_output_weights = self.quant_attn_output_weights(attn_output_weights)
|
||||
|
||||
if need_weights:
|
||||
# average attention weights over heads
|
||||
attn_output_weights = attn_output_weights.view(bsz, self.num_heads, tgt_len, src_len)
|
||||
if average_attn_weights:
|
||||
attn_output_weights = attn_output_weights.mean(dim=1)
|
||||
return attn_output, attn_output_weights
|
||||
else:
|
||||
return attn_output, None
|
||||
|
||||
|
||||
|
||||
class NSTransformerModel(nn.Module):
|
||||
@validated()
|
||||
def __init__(
|
||||
self,
|
||||
context_length: int,
|
||||
prediction_length: int,
|
||||
num_feat_dynamic_real: int,
|
||||
num_feat_static_real: int,
|
||||
num_feat_static_cat: int,
|
||||
cardinality: List[int],
|
||||
# transformer arguments
|
||||
nhead: int,
|
||||
num_encoder_layers: int,
|
||||
num_decoder_layers: int,
|
||||
dim_feedforward: int,
|
||||
activation: str = "gelu",
|
||||
dropout: float = 0.1,
|
||||
# univariate input
|
||||
input_size: int = 1,
|
||||
embedding_dimension: Optional[List[int]] = None,
|
||||
distr_output: DistributionOutput = StudentTOutput(),
|
||||
lags_seq: Optional[List[int]] = None,
|
||||
freq: Optional[str] = None,
|
||||
num_parallel_samples: int = 100,
|
||||
) -> None:
|
||||
super().__init__()
|
||||
|
||||
self.input_size = input_size
|
||||
|
||||
self.target_shape = distr_output.event_shape
|
||||
self.num_feat_dynamic_real = num_feat_dynamic_real
|
||||
self.num_feat_static_cat = num_feat_static_cat
|
||||
self.num_feat_static_real = num_feat_static_real
|
||||
self.embedding_dimension = (
|
||||
embedding_dimension
|
||||
if embedding_dimension is not None or cardinality is None
|
||||
else [min(50, (cat + 1) // 2) for cat in cardinality]
|
||||
)
|
||||
self.lags_seq = lags_seq or get_lags_for_frequency(freq_str=freq)
|
||||
self.num_parallel_samples = num_parallel_samples
|
||||
self.history_length = context_length + max(self.lags_seq)
|
||||
self.embedder = FeatureEmbedder(
|
||||
cardinalities=cardinality,
|
||||
embedding_dims=self.embedding_dimension,
|
||||
)
|
||||
|
||||
self.scaler = StdScaler(dim=1, keepdim=True)
|
||||
|
||||
|
||||
# total feature size
|
||||
d_model = self.input_size * len(self.lags_seq) + self._number_of_features
|
||||
|
||||
|
||||
self.tau_learner = Projector(
|
||||
enc_in=input_size,
|
||||
seq_len=context_length,
|
||||
hidden_dims=[64, 64],
|
||||
hidden_layers=2,
|
||||
output_dim=1,
|
||||
)
|
||||
self.delta_learner = Projector(
|
||||
enc_in=input_size,
|
||||
seq_len=context_length,
|
||||
hidden_dims=[64, 64],
|
||||
hidden_layers=2,
|
||||
output_dim=context_length,
|
||||
)
|
||||
|
||||
self.context_length = context_length
|
||||
self.prediction_length = prediction_length
|
||||
self.distr_output = distr_output
|
||||
self.param_proj = distr_output.get_args_proj(d_model)
|
||||
|
||||
# transformer enc-decoder and mask initializer
|
||||
self.transformer = nn.Transformer(
|
||||
d_model=d_model,
|
||||
nhead=nhead,
|
||||
num_encoder_layers=num_encoder_layers,
|
||||
num_decoder_layers=num_decoder_layers,
|
||||
dim_feedforward=dim_feedforward,
|
||||
dropout=dropout,
|
||||
activation=activation,
|
||||
batch_first=True,
|
||||
norm_first=True,
|
||||
)
|
||||
|
||||
# causal decoder tgt mask
|
||||
self.register_buffer(
|
||||
"tgt_mask",
|
||||
self.transformer.generate_square_subsequent_mask(prediction_length),
|
||||
)
|
||||
|
||||
@property
|
||||
def _number_of_features(self) -> int:
|
||||
return (
|
||||
sum(self.embedding_dimension)
|
||||
+ self.num_feat_dynamic_real
|
||||
+ self.num_feat_static_real
|
||||
+ self.input_size * 2 # the log(scale) and log(loc)
|
||||
)
|
||||
|
||||
@property
|
||||
def _past_length(self) -> int:
|
||||
return self.context_length + max(self.lags_seq)
|
||||
|
||||
def get_lagged_subsequences(
|
||||
self, sequence: torch.Tensor, subsequences_length: int, shift: int = 0
|
||||
) -> torch.Tensor:
|
||||
"""
|
||||
Returns lagged subsequences of a given sequence.
|
||||
Parameters
|
||||
----------
|
||||
sequence : Tensor
|
||||
the sequence from which lagged subsequences should be extracted.
|
||||
Shape: (N, T, C).
|
||||
subsequences_length : int
|
||||
length of the subsequences to be extracted.
|
||||
shift: int
|
||||
shift the lags by this amount back.
|
||||
Returns
|
||||
--------
|
||||
lagged : Tensor
|
||||
a tensor of shape (N, S, C, I), where S = subsequences_length and
|
||||
I = len(indices), containing lagged subsequences. Specifically,
|
||||
lagged[i, j, :, k] = sequence[i, -indices[k]-S+j, :].
|
||||
"""
|
||||
sequence_length = sequence.shape[1]
|
||||
indices = [lag - shift for lag in self.lags_seq]
|
||||
|
||||
assert max(indices) + subsequences_length <= sequence_length, (
|
||||
f"lags cannot go further than history length, found lag {max(indices)} "
|
||||
f"while history length is only {sequence_length}"
|
||||
)
|
||||
|
||||
lagged_values = []
|
||||
for lag_index in indices:
|
||||
begin_index = -lag_index - subsequences_length
|
||||
end_index = -lag_index if lag_index > 0 else None
|
||||
lagged_values.append(sequence[:, begin_index:end_index, ...])
|
||||
return torch.stack(lagged_values, dim=-1)
|
||||
|
||||
def _check_shapes(
|
||||
self,
|
||||
prior_input: torch.Tensor,
|
||||
inputs: torch.Tensor,
|
||||
features: Optional[torch.Tensor],
|
||||
) -> None:
|
||||
assert len(prior_input.shape) == len(inputs.shape)
|
||||
assert (
|
||||
len(prior_input.shape) == 2 and self.input_size == 1
|
||||
) or prior_input.shape[2] == self.input_size
|
||||
assert (len(inputs.shape) == 2 and self.input_size == 1) or inputs.shape[
|
||||
-1
|
||||
] == self.input_size
|
||||
assert (
|
||||
features is None or features.shape[2] == self._number_of_features
|
||||
), f"{features.shape[2]}, expected {self._number_of_features}"
|
||||
|
||||
def create_network_inputs(
|
||||
self,
|
||||
feat_static_cat: torch.Tensor,
|
||||
feat_static_real: torch.Tensor,
|
||||
past_time_feat: torch.Tensor,
|
||||
past_target: torch.Tensor,
|
||||
past_observed_values: torch.Tensor,
|
||||
future_time_feat: Optional[torch.Tensor] = None,
|
||||
future_target: Optional[torch.Tensor] = None,
|
||||
):
|
||||
# time feature
|
||||
time_feat = (
|
||||
torch.cat(
|
||||
(
|
||||
past_time_feat[:, self._past_length - self.context_length :, ...],
|
||||
future_time_feat,
|
||||
),
|
||||
dim=1,
|
||||
)
|
||||
if future_target is not None
|
||||
else past_time_feat[:, self._past_length - self.context_length :, ...]
|
||||
)
|
||||
|
||||
# target
|
||||
context = past_target[:, -self.context_length :]
|
||||
observed_context = past_observed_values[:, -self.context_length :]
|
||||
_, loc, scale = self.scaler(context, observed_context)
|
||||
|
||||
|
||||
# B x S x E, B x 1 x E -> B x 1, positive scalar
|
||||
tau = self.tau_learner(
|
||||
context.unsqueeze(-1) if self.input_size == 1 else context,
|
||||
scale.unsqueeze(1) if self.input_size == 1 else scale
|
||||
).exp()
|
||||
|
||||
# B x S x E, B x 1 x E -> B x S
|
||||
delta = self.delta_learner(
|
||||
context.unsqueeze(-1) if self.input_size == 1 else context,
|
||||
loc.unsqueeze(1) if self.input_size == 1 else loc
|
||||
)
|
||||
|
||||
inputs = (
|
||||
(torch.cat((past_target, future_target), dim=1) - loc )/ scale
|
||||
if future_target is not None
|
||||
else (past_target - loc) / scale
|
||||
)
|
||||
|
||||
inputs_length = (
|
||||
self._past_length + self.prediction_length
|
||||
if future_target is not None
|
||||
else self._past_length
|
||||
)
|
||||
assert inputs.shape[1] == inputs_length
|
||||
|
||||
subsequences_length = (
|
||||
self.context_length + self.prediction_length
|
||||
if future_target is not None
|
||||
else self.context_length
|
||||
)
|
||||
|
||||
# embeddings
|
||||
embedded_cat = self.embedder(feat_static_cat)
|
||||
log_scale = scale.log1p() if self.input_size == 1 else scale.squeeze(1).log1p()
|
||||
log_loc = loc.log1p() if self.input_size == 1 else loc.scale.squeeze(1).log1p()
|
||||
|
||||
static_feat = torch.cat(
|
||||
(embedded_cat, feat_static_real, log_scale, log_loc),
|
||||
dim=1,
|
||||
)
|
||||
expanded_static_feat = static_feat.unsqueeze(1).expand(
|
||||
-1, time_feat.shape[1], -1
|
||||
)
|
||||
|
||||
features = torch.cat((expanded_static_feat, time_feat), dim=-1)
|
||||
|
||||
# self._check_shapes(prior_input, inputs, features)
|
||||
|
||||
# sequence = torch.cat((prior_input, inputs), dim=1)
|
||||
lagged_sequence = self.get_lagged_subsequences(
|
||||
sequence=inputs,
|
||||
subsequences_length=subsequences_length,
|
||||
)
|
||||
|
||||
lags_shape = lagged_sequence.shape
|
||||
reshaped_lagged_sequence = lagged_sequence.reshape(
|
||||
lags_shape[0], lags_shape[1], -1
|
||||
)
|
||||
|
||||
transformer_inputs = torch.cat((reshaped_lagged_sequence, features), dim=-1)
|
||||
|
||||
return transformer_inputs, loc, scale, static_feat, tau, delta
|
||||
|
||||
def output_params(self, transformer_inputs):
|
||||
enc_input = transformer_inputs[:, : self.context_length, ...]
|
||||
dec_input = transformer_inputs[:, self.context_length :, ...]
|
||||
|
||||
enc_out = self.transformer.encoder(enc_input)
|
||||
dec_output = self.transformer.decoder(
|
||||
dec_input, enc_out, tgt_mask=self.tgt_mask
|
||||
)
|
||||
|
||||
return self.param_proj(dec_output)
|
||||
|
||||
@torch.jit.ignore
|
||||
def output_distribution(
|
||||
self, params, loc=None, scale=None, trailing_n=None
|
||||
) -> torch.distributions.Distribution:
|
||||
sliced_params = params
|
||||
if trailing_n is not None:
|
||||
sliced_params = [p[:, -trailing_n:] for p in params]
|
||||
return self.distr_output.distribution(sliced_params, loc=loc, scale=scale)
|
||||
|
||||
# for prediction
|
||||
def forward(
|
||||
self,
|
||||
feat_static_cat: torch.Tensor,
|
||||
feat_static_real: torch.Tensor,
|
||||
past_time_feat: torch.Tensor,
|
||||
past_target: torch.Tensor,
|
||||
past_observed_values: torch.Tensor,
|
||||
future_time_feat: torch.Tensor,
|
||||
num_parallel_samples: Optional[int] = None,
|
||||
) -> torch.Tensor:
|
||||
|
||||
if num_parallel_samples is None:
|
||||
num_parallel_samples = self.num_parallel_samples
|
||||
|
||||
encoder_inputs, loc, scale, static_feat, tau, delta = self.create_network_inputs(
|
||||
feat_static_cat,
|
||||
feat_static_real,
|
||||
past_time_feat,
|
||||
past_target,
|
||||
past_observed_values,
|
||||
)
|
||||
|
||||
enc_out = self.transformer.encoder(encoder_inputs)
|
||||
|
||||
repeated_loc = loc.repeat_interleave(
|
||||
repeats=self.num_parallel_samples, dim=0
|
||||
)
|
||||
|
||||
repeated_scale = scale.repeat_interleave(
|
||||
repeats=self.num_parallel_samples, dim=0
|
||||
)
|
||||
|
||||
repeated_past_target = (
|
||||
(past_target.repeat_interleave(repeats=self.num_parallel_samples, dim=0) -repeated_loc)
|
||||
/ repeated_scale
|
||||
)
|
||||
|
||||
expanded_static_feat = static_feat.unsqueeze(1).expand(
|
||||
-1, future_time_feat.shape[1], -1
|
||||
)
|
||||
features = torch.cat((expanded_static_feat, future_time_feat), dim=-1)
|
||||
repeated_features = features.repeat_interleave(
|
||||
repeats=self.num_parallel_samples, dim=0
|
||||
)
|
||||
|
||||
repeated_enc_out = enc_out.repeat_interleave(
|
||||
repeats=self.num_parallel_samples, dim=0
|
||||
)
|
||||
|
||||
future_samples = []
|
||||
|
||||
# greedy decoding
|
||||
for k in range(self.prediction_length):
|
||||
# self._check_shapes(repeated_past_target, next_sample, next_features)
|
||||
# sequence = torch.cat((repeated_past_target, next_sample), dim=1)
|
||||
|
||||
lagged_sequence = self.get_lagged_subsequences(
|
||||
sequence=repeated_past_target,
|
||||
subsequences_length=1 + k,
|
||||
shift=1,
|
||||
)
|
||||
|
||||
lags_shape = lagged_sequence.shape
|
||||
reshaped_lagged_sequence = lagged_sequence.reshape(
|
||||
lags_shape[0], lags_shape[1], -1
|
||||
)
|
||||
|
||||
decoder_input = torch.cat(
|
||||
(reshaped_lagged_sequence, repeated_features[:, : k + 1]), dim=-1
|
||||
)
|
||||
|
||||
output = self.transformer.decoder(decoder_input, repeated_enc_out)
|
||||
|
||||
params = self.param_proj(output[:, -1:])
|
||||
distr = self.output_distribution(params, loc=repeated_loc, scale=repeated_scale)
|
||||
next_sample = distr.sample()
|
||||
|
||||
repeated_past_target = torch.cat(
|
||||
(repeated_past_target, (next_sample - repeated_loc) / repeated_scale), dim=1
|
||||
)
|
||||
future_samples.append(next_sample)
|
||||
|
||||
concat_future_samples = torch.cat(future_samples, dim=1)
|
||||
return concat_future_samples.reshape(
|
||||
(-1, self.num_parallel_samples, self.prediction_length) + self.target_shape,
|
||||
)
|
||||
File diff suppressed because one or more lines are too long
Reference in New Issue
Block a user