initial perceiver-AR time series model

This commit is contained in:
Kashif Rasul
2022-08-07 19:06:32 -04:00
parent 2a27ac7794
commit 64c8a1c4a8
6 changed files with 1421 additions and 0 deletions
View File
+22
View File
@@ -0,0 +1,22 @@
# Copyright 2018 Amazon.com, Inc. or its affiliates. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License").
# You may not use this file except in compliance with the License.
# A copy of the License is located at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# or in the "license" file accompanying this file. This file is distributed
# on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either
# express or implied. See the License for the specific language governing
# permissions and limitations under the License.
from .module import PerceiverARModel
from .lightning_module import PerceiverARLightningModule
from .estimator import PerceiverAREstimator
__all__ = [
"PerceiverARModel",
"PerceiverARLightningModule",
"PerceiverAREstimator",
]
+384
View File
@@ -0,0 +1,384 @@
from typing import List, Optional, Iterable, Dict, Any
import torch
from torch.utils.data import DataLoader
from gluonts.core.component import validated
from gluonts.dataset.common import Dataset
from gluonts.dataset.field_names import FieldName
from gluonts.itertools import Cyclic, PseudoShuffled, IterableSlice
from gluonts.time_feature import (
TimeFeature,
time_features_from_frequency_str,
)
from gluonts.torch.modules.loss import DistributionLoss, NegativeLogLikelihood
from gluonts.transform import (
Transformation,
Chain,
RemoveFields,
SetField,
AsNumpyArray,
AddObservedValuesIndicator,
AddTimeFeatures,
AddAgeFeature,
VstackFeatures,
InstanceSplitter,
ValidationSplitSampler,
TestSplitSampler,
ExpectedNumInstanceSampler,
SelectFields,
)
from gluonts.torch.util import (
IterableDataset,
)
from gluonts.torch.model.estimator import PyTorchLightningEstimator
from gluonts.torch.model.predictor import PyTorchPredictor
from gluonts.torch.distributions import (
DistributionOutput,
StudentTOutput,
)
from gluonts.transform.sampler import InstanceSampler
from module import PerceiverARModel
from lightning_module import PerceiverARLightningModule
PREDICTION_INPUT_NAMES = [
"feat_static_cat",
"feat_static_real",
"past_time_feat",
"past_target",
"past_observed_values",
"future_time_feat",
]
TRAINING_INPUT_NAMES = PREDICTION_INPUT_NAMES + [
"future_target",
"future_observed_values",
]
class PerceiverAREstimator(PyTorchLightningEstimator):
"""
Estimator class to train a PerceiverAR model.
This class is uses the model defined in ``PerceiverARModel``, and wraps it
into a ``PerceiverARLightningModule`` for training purposes: training is
performed using PyTorch Lightning's ``pl.Trainer`` class.
Parameters
----------
freq
Frequency of the data to train on and predict.
prediction_length
Length of the prediction horizon.
context_length
Number of steps to unroll the RNN for before computing predictions
(default: None, in which case context_length = prediction_length).
perceive_depth
Number of RNN layers (default: 2).
hidden_size
Number of RNN cells for each layer (default: 40).
dropout_rate
Dropout regularization parameter (default: 0.1).
num_feat_dynamic_real
Number of dynamic real features in the data (default: 0).
num_feat_static_real
Number of static real features in the data (default: 0).
num_feat_static_cat
Number of static categorical features in the data (default: 0).
cardinality
Number of values of each categorical feature.
This must be set if ``num_feat_static_cat > 0`` (default: None).
embedding_dimension
Dimension of the embeddings for categorical features
(default: ``[min(50, (cat+1)//2) for cat in cardinality]``).
distr_output
Distribution to use to evaluate observations and sample predictions
(default: StudentTOutput()).
loss
Loss to be optimized during training
(default: ``NegativeLogLikelihood()``).
scaling
Whether to automatically scale the target values (default: true).
lags_seq
Indices of the lagged target values to use as inputs of the RNN
(default: None, in which case these are automatically determined
based on freq).
time_features
List of time features, from :py:mod:`gluonts.time_feature`, to use as
inputs of the RNN in addition to the provided data (default: None,
in which case these are automatically determined based on freq).
num_parallel_samples
Number of samples per time series to that the resulting predictor
should produce (default: 100).
batch_size
The size of the batches to be used for training (default: 32).
num_batches_per_epoch
Number of batches to be processed in each training epoch
(default: 50).
trainer_kwargs
Additional arguments to provide to ``pl.Trainer`` for construction.
train_sampler
Controls the sampling of windows during training.
validation_sampler
Controls the sampling of windows during validation.
"""
@validated()
def __init__(
self,
freq: str,
prediction_length: int,
depth: int,
context_length: Optional[int] = None,
perceive_depth: int = 1,
heads: int = 2,
hidden_size: int = 32,
dropout_rate: float = 0.1,
cross_attn_dropout: float = 0.1,
perceive_max_heads_process: int = 2,
ff_mult: int = 1,
num_feat_dynamic_real: int = 0,
num_feat_static_cat: int = 0,
num_feat_static_real: int = 0,
cardinality: Optional[List[int]] = None,
embedding_dimension: Optional[List[int]] = None,
distr_output: DistributionOutput = StudentTOutput(),
loss: DistributionLoss = NegativeLogLikelihood(),
scaling: bool = True,
lags_seq: Optional[List[int]] = None,
time_features: Optional[List[TimeFeature]] = None,
num_parallel_samples: int = 100,
batch_size: int = 32,
num_batches_per_epoch: int = 50,
trainer_kwargs: Optional[Dict[str, Any]] = None,
train_sampler: Optional[InstanceSampler] = None,
validation_sampler: Optional[InstanceSampler] = None,
) -> None:
default_trainer_kwargs = {
"max_epochs": 100,
"gradient_clip_val": 10.0,
}
if trainer_kwargs is not None:
default_trainer_kwargs.update(trainer_kwargs)
super().__init__(trainer_kwargs=default_trainer_kwargs)
self.freq = freq
self.context_length = (
context_length if context_length is not None else prediction_length
)
self.prediction_length = prediction_length
self.distr_output = distr_output
self.loss = loss
self.depth = depth
self.perceive_depth = perceive_depth
self.hidden_size = hidden_size
self.dropout_rate = dropout_rate
self.heads = heads
self.perceive_max_heads_process = perceive_max_heads_process
self.ff_mult = ff_mult
self.cross_attn_dropout = cross_attn_dropout
self.num_feat_dynamic_real = num_feat_dynamic_real
self.num_feat_static_cat = num_feat_static_cat
self.num_feat_static_real = num_feat_static_real
self.cardinality = (
cardinality if cardinality and num_feat_static_cat > 0 else [1]
)
self.embedding_dimension = embedding_dimension
self.scaling = scaling
self.lags_seq = lags_seq
self.time_features = (
time_features
if time_features is not None
else time_features_from_frequency_str(self.freq)
)
self.num_parallel_samples = num_parallel_samples
self.batch_size = batch_size
self.num_batches_per_epoch = num_batches_per_epoch
self.train_sampler = train_sampler or ExpectedNumInstanceSampler(
num_instances=1.0, min_future=prediction_length
)
self.validation_sampler = validation_sampler or ValidationSplitSampler(
min_future=prediction_length
)
def create_transformation(self) -> Transformation:
remove_field_names = []
if self.num_feat_static_real == 0:
remove_field_names.append(FieldName.FEAT_STATIC_REAL)
if self.num_feat_dynamic_real == 0:
remove_field_names.append(FieldName.FEAT_DYNAMIC_REAL)
return Chain(
[RemoveFields(field_names=remove_field_names)]
+ (
[SetField(output_field=FieldName.FEAT_STATIC_CAT, value=[0])]
if not self.num_feat_static_cat > 0
else []
)
+ (
[SetField(output_field=FieldName.FEAT_STATIC_REAL, value=[0.0])]
if not self.num_feat_static_real > 0
else []
)
+ [
AsNumpyArray(
field=FieldName.FEAT_STATIC_CAT,
expected_ndim=1,
dtype=int,
),
AsNumpyArray(
field=FieldName.FEAT_STATIC_REAL,
expected_ndim=1,
),
AsNumpyArray(
field=FieldName.TARGET,
# in the following line, we add 1 for the time dimension
expected_ndim=1 + len(self.distr_output.event_shape),
),
AddObservedValuesIndicator(
target_field=FieldName.TARGET,
output_field=FieldName.OBSERVED_VALUES,
),
AddTimeFeatures(
start_field=FieldName.START,
target_field=FieldName.TARGET,
output_field=FieldName.FEAT_TIME,
time_features=self.time_features,
pred_length=self.prediction_length,
),
AddAgeFeature(
target_field=FieldName.TARGET,
output_field=FieldName.FEAT_AGE,
pred_length=self.prediction_length,
log_scale=True,
),
VstackFeatures(
output_field=FieldName.FEAT_TIME,
input_fields=[FieldName.FEAT_TIME, FieldName.FEAT_AGE]
+ (
[FieldName.FEAT_DYNAMIC_REAL]
if self.num_feat_dynamic_real > 0
else []
),
),
]
)
def _create_instance_splitter(self, module: PerceiverARLightningModule, mode: str):
assert mode in ["training", "validation", "test"]
instance_sampler = {
"training": self.train_sampler,
"validation": self.validation_sampler,
"test": TestSplitSampler(),
}[mode]
return InstanceSplitter(
target_field=FieldName.TARGET,
is_pad_field=FieldName.IS_PAD,
start_field=FieldName.START,
forecast_start_field=FieldName.FORECAST_START,
instance_sampler=instance_sampler,
past_length=module.model._past_length,
future_length=self.prediction_length,
time_series_fields=[
FieldName.FEAT_TIME,
FieldName.OBSERVED_VALUES,
],
dummy_value=self.distr_output.value_in_support,
)
def create_training_data_loader(
self,
data: Dataset,
module: PerceiverARLightningModule,
shuffle_buffer_length: Optional[int] = None,
**kwargs,
) -> Iterable:
transformation = self._create_instance_splitter(
module, "training"
) + SelectFields(TRAINING_INPUT_NAMES)
training_instances = transformation.apply(
Cyclic(data)
if shuffle_buffer_length is None
else PseudoShuffled(
Cyclic(data), shuffle_buffer_length=shuffle_buffer_length
)
)
return IterableSlice(
iter(
DataLoader(
IterableDataset(training_instances),
batch_size=self.batch_size,
**kwargs,
)
),
self.num_batches_per_epoch,
)
def create_validation_data_loader(
self,
data: Dataset,
module: PerceiverARLightningModule,
**kwargs,
) -> Iterable:
transformation = self._create_instance_splitter(
module, "validation"
) + SelectFields(TRAINING_INPUT_NAMES)
validation_instances = transformation.apply(data)
return DataLoader(
IterableDataset(validation_instances),
batch_size=self.batch_size,
**kwargs,
)
def create_lightning_module(self) -> PerceiverARLightningModule:
model = PerceiverARModel(
freq=self.freq,
depth=self.depth,
context_length=self.context_length,
prediction_length=self.prediction_length,
num_feat_dynamic_real=(
1 + self.num_feat_dynamic_real + len(self.time_features)
),
num_feat_static_real=max(1, self.num_feat_static_real),
num_feat_static_cat=max(1, self.num_feat_static_cat),
cardinality=self.cardinality,
embedding_dimension=self.embedding_dimension,
perceive_depth=self.perceive_depth,
heads=self.heads,
perceive_max_heads_process=self.perceive_max_heads_process,
ff_mult=self.ff_mult,
cross_attn_dropout=self.cross_attn_dropout,
hidden_size=self.hidden_size,
distr_output=self.distr_output,
dropout_rate=self.dropout_rate,
lags_seq=self.lags_seq,
scaling=self.scaling,
num_parallel_samples=self.num_parallel_samples,
)
return PerceiverARLightningModule(model=model, loss=self.loss)
def create_predictor(
self,
transformation: Transformation,
module: PerceiverARLightningModule,
) -> PyTorchPredictor:
prediction_splitter = self._create_instance_splitter(module, "test")
return PyTorchPredictor(
input_transform=transformation + prediction_splitter,
input_names=PREDICTION_INPUT_NAMES,
prediction_net=module.model,
batch_size=self.batch_size,
prediction_length=self.prediction_length,
device=torch.device("cuda" if torch.cuda.is_available() else "cpu"),
)
+113
View File
@@ -0,0 +1,113 @@
import pytorch_lightning as pl
import torch
from gluonts.torch.modules.loss import DistributionLoss, NegativeLogLikelihood
from gluonts.torch.util import weighted_average
from module import PerceiverARModel
class PerceiverARLightningModule(pl.LightningModule):
"""
A ``pl.LightningModule`` class that can be used to train a
``PerceiverARModel`` with PyTorch Lightning.
This is a thin layer around a (wrapped) ``PerceiverARModel`` object,
that exposes the methods to evaluate training and validation loss.
Parameters
----------
model
``PerceiverARModel`` to be trained.
loss
Loss function to be used for training,
default: ``NegativeLogLikelihood()``.
lr
Learning rate, default: ``1e-3``.
weight_decay
Weight decay regularization parameter, default: ``1e-8``.
"""
def __init__(
self,
model: PerceiverARModel,
loss: DistributionLoss = NegativeLogLikelihood(),
lr: float = 1e-3,
weight_decay: float = 1e-8,
) -> None:
super().__init__()
self.save_hyperparameters()
self.model = model
self.loss = loss
self.lr = lr
self.weight_decay = weight_decay
def _compute_loss(self, batch):
feat_static_cat = batch["feat_static_cat"]
feat_static_real = batch["feat_static_real"]
past_time_feat = batch["past_time_feat"]
past_target = batch["past_target"]
future_time_feat = batch["future_time_feat"]
future_target = batch["future_target"]
past_observed_values = batch["past_observed_values"]
future_observed_values = batch["future_observed_values"]
params, scale, _, _, _ = self.model.lagged_perciever(
feat_static_cat,
feat_static_real,
past_time_feat,
past_target,
past_observed_values,
future_time_feat,
future_target,
)
distr = self.model.output_distribution(params, scale)
# context_target = past_target[:, -self.model.context_length + 1 :]
# target = torch.cat(
# (context_target, future_target),
# dim=1,
# )
loss_values = self.loss(distr, future_target)
# context_observed = past_observed_values[:, -self.model.context_length + 1 :]
# observed_values = torch.cat((context_observed, future_observed_values), dim=1)
if len(self.model.target_shape) == 0:
loss_weights = future_observed_values
else:
loss_weights, _ = future_observed_values.min(dim=-1, keepdim=False)
return weighted_average(loss_values, weights=loss_weights)
def training_step(self, batch, batch_idx: int): # type: ignore
"""
Execute training step.
"""
train_loss = self._compute_loss(batch)
self.log(
"train_loss",
train_loss,
on_epoch=True,
on_step=False,
prog_bar=True,
)
return train_loss
def validation_step(self, batch, batch_idx: int): # type: ignore
"""
Execute validation step.
"""
val_loss = self._compute_loss(batch)
self.log("val_loss", val_loss, on_epoch=True, on_step=False, prog_bar=True)
return val_loss
def configure_optimizers(self):
"""
Returns the optimizer to use.
"""
return torch.optim.Adam(
self.model.parameters(),
lr=self.lr,
weight_decay=self.weight_decay,
)
+568
View File
@@ -0,0 +1,568 @@
from typing import List, Optional, Tuple
import torch
import torch.nn as nn
from torch import einsum
from einops import rearrange, repeat
from gluonts.core.component import validated
from gluonts.time_feature import get_lags_for_frequency
from gluonts.torch.distributions import (
DistributionOutput,
StudentTOutput,
)
from gluonts.torch.modules.scaler import MeanScaler, NOPScaler
from gluonts.torch.modules.feature import FeatureEmbedder
from gluonts.torch.util import lagged_sequence_values
# helper functions
def exists(val):
return val is not None
# feedforward
def FeedForward(dim, mult=4, dropout=0.0):
hidden_dim = int(dim * mult)
return nn.Sequential(
nn.LayerNorm(dim),
nn.Linear(dim, hidden_dim, bias=False),
nn.GELU(),
nn.Dropout(dropout),
nn.Linear(hidden_dim, dim, bias=False),
)
# attention
class CausalAttention(nn.Module):
def __init__(self, *, dim, dim_head=64, heads=8, dropout=0.0):
super().__init__()
self.scale = dim_head**-0.5
self.heads = heads
inner_dim = heads * dim_head
self.norm = nn.LayerNorm(dim)
self.dropout = nn.Dropout(dropout)
self.to_qkv = nn.Linear(dim, inner_dim * 3, bias=False)
self.to_out = nn.Linear(inner_dim, dim, bias=False)
def forward(self, x):
x = self.norm(x)
q, k, v = self.to_qkv(x).chunk(3, dim=-1)
q, k, v = map(
lambda t: rearrange(t, "b n (h d) -> b h n d", h=self.heads), (q, k, v)
)
q = q * self.scale
sim = einsum("b h i d, b h j d -> b h i j", q, k)
i, j = sim.shape[-2:]
causal_mask = torch.ones((i, j), device=x.device, dtype=torch.bool).triu(
j - i + 1
)
sim = sim.masked_fill(causal_mask, -torch.finfo(sim.dtype).max)
attn = sim.softmax(dim=-1)
attn = self.dropout(attn)
out = einsum("b h i j, b h j d -> b h i d", attn, v)
out = rearrange(out, "b h n d -> b n (h d)")
return self.to_out(out)
class CausalPrefixAttention(nn.Module):
def __init__(
self,
*,
dim,
dim_head=64,
heads=8,
max_heads_process=2,
dropout=0.0,
cross_attn_dropout=0.0
):
super().__init__()
self.scale = dim_head**-0.5
self.heads = heads
self.max_heads_process = max_heads_process
inner_dim = heads * dim_head
self.norm = nn.LayerNorm(dim)
self.context_norm = nn.LayerNorm(dim)
self.dropout = nn.Dropout(dropout)
self.cross_attn_dropout = cross_attn_dropout # they drop out a percentage of the prefix during training, shown to help prevent overfitting
self.to_q = nn.Linear(dim, inner_dim, bias=False)
self.to_kv = nn.Linear(dim, inner_dim * 2, bias=False)
self.to_out = nn.Linear(inner_dim, dim)
def forward(self, x, context, context_mask=None):
batch, context_len, device = x.shape[0], context.shape[-2], x.device
# take care of cross attention dropout
if self.training and self.cross_attn_dropout > 0.0:
rand = torch.zeros((batch, context_len), device=device).uniform_()
keep_context_len = context_len - int(context_len * self.cross_attn_dropout)
keep_indices = rand.topk(keep_context_len, dim=-1).indices
keep_mask = torch.zeros_like(rand).scatter_(1, keep_indices, 1).bool()
context = rearrange(context[keep_mask], "(b n) d -> b n d", b=batch)
if exists(context_mask):
context_mask = rearrange(
context_mask[keep_mask], "(b n) -> b n", b=batch
)
# normalization
x = self.norm(x)
context = self.context_norm(context)
# derive queries, keys, values
q = self.to_q(x)
k_input, v_input = self.to_kv(x).chunk(2, dim=-1)
k_context, v_context = self.to_kv(context).chunk(2, dim=-1)
k = torch.cat((k_context, k_input), dim=1)
v = torch.cat((v_context, v_input), dim=1)
q, k, v = map(
lambda t: rearrange(t, "b n (h d) -> b h n d", h=self.heads), (q, k, v)
)
q = q * self.scale
# take care of masking
i, j = q.shape[-2], k.shape[-2]
mask_value = -torch.finfo(q.dtype).max
if exists(context_mask):
mask_len = context_mask.shape[-1]
context_mask = F.pad(context_mask, (0, max(j - mask_len, 0)), value=True)
context_mask = rearrange(context_mask, "b j -> b 1 1 j")
causal_mask = torch.ones((i, j), device=x.device, dtype=torch.bool).triu(
j - i + 1
)
# process in chunks of heads
out = []
max_heads = self.max_heads_process
for q_chunk, k_chunk, v_chunk in zip(
q.split(max_heads, dim=1),
k.split(max_heads, dim=1),
v.split(max_heads, dim=1),
):
sim = einsum("b h i d, b h j d -> b h i j", q_chunk, k_chunk)
if exists(context_mask):
sim = sim.masked_fill(~context_mask, mask_value)
sim = sim.masked_fill(causal_mask, mask_value)
attn = sim.softmax(dim=-1)
attn = self.dropout(attn)
out_chunk = einsum("b h i j, b h j d -> b h i d", attn, v_chunk)
out.append(out_chunk)
# concat all the heads together
out = torch.cat(out, dim=1)
# merge heads and then combine with linear
out = rearrange(out, "b h n d -> b n (h d)")
return self.to_out(out)
class PerceiverARModel(nn.Module):
"""
Module implementing the PerceiverAR model.
Parameters
----------
freq
String indicating the sampling frequency of the data to be processed.
context_length
Length of the RNN unrolling prior to the forecast date.
prediction_length
Number of time points to predict.
num_feat_dynamic_real
Number of dynamic real features that will be provided to ``forward``.
num_feat_static_real
Number of static real features that will be provided to ``forward``.
num_feat_static_cat
Number of static categorical features that will be provided to
``forward``.
cardinality
List of cardinalities, one for each static categorical feature.
embedding_dimension
Dimension of the embedding space, one for each static categorical
feature.
num_layers
Number of layers in the RNN.
hidden_size
Size of the hidden layers in the RNN.
dropout_rate
Dropout rate to be applied at training time.
distr_output
Type of distribution to be output by the model at each time step
lags_seq
Indices of the lagged observations that the RNN takes as input. For
example, ``[1]`` indicates that the RNN only takes the observation at
time ``t-1`` to produce the output for time ``t``; instead,
``[1, 25]`` indicates that the RNN takes observations at times ``t-1``
and ``t-25`` as input.
scaling
Whether to apply mean scaling to the observations (target).
num_parallel_samples
Number of samples to produce when unrolling the RNN in the prediction
time range.
"""
@validated()
def __init__(
self,
freq: str,
depth: int,
context_length: int,
prediction_length: int,
num_feat_dynamic_real: int,
num_feat_static_real: int,
num_feat_static_cat: int,
cardinality: List[int],
embedding_dimension: Optional[List[int]] = None,
perceive_depth: int = 1,
heads: int = 2,
perceive_max_heads_process: int = 2,
ff_mult: int = 1,
hidden_size: int = 32,
dropout_rate: float = 0.1,
cross_attn_dropout: float = 0.1,
distr_output: DistributionOutput = StudentTOutput(),
lags_seq: Optional[List[int]] = None,
scaling: bool = True,
num_parallel_samples: int = 100,
) -> None:
super().__init__()
self.context_length = context_length
self.prediction_length = prediction_length
self.distr_output = distr_output
self.target_shape = distr_output.event_shape
self.num_feat_dynamic_real = num_feat_dynamic_real
self.num_feat_static_cat = num_feat_static_cat
self.num_feat_static_real = num_feat_static_real
self.embedding_dimension = (
embedding_dimension
if embedding_dimension is not None or cardinality is None
else [min(50, (cat + 1) // 2) for cat in cardinality]
)
self.lags_seq = lags_seq or get_lags_for_frequency(freq_str=freq)
self.num_parallel_samples = num_parallel_samples
self.past_length = self.context_length + max(self.lags_seq)
self.embedder = FeatureEmbedder(
cardinalities=cardinality,
embedding_dims=self.embedding_dimension,
)
if scaling:
self.scaler = MeanScaler(dim=1, keepdim=True)
else:
self.scaler = NOPScaler(dim=1, keepdim=True)
dim_head = len(self.lags_seq) + self._number_of_features
self.perceive_layers = nn.ModuleList([])
for _ in range(perceive_depth):
self.perceive_layers.append(
nn.ModuleList(
[
CausalPrefixAttention(
dim=dim_head,
dim_head=hidden_size,
heads=heads,
max_heads_process=perceive_max_heads_process,
dropout=dropout_rate,
cross_attn_dropout=cross_attn_dropout,
),
FeedForward(dim_head, mult=ff_mult, dropout=dropout_rate),
]
)
)
self.layers = nn.ModuleList([])
for _ in range(depth):
self.layers.append(
nn.ModuleList(
[
CausalAttention(
dim=dim_head, dim_head=hidden_size, heads=heads
),
FeedForward(dim_head, mult=ff_mult, dropout=dropout_rate),
]
)
)
self.param_proj = distr_output.get_args_proj(dim_head)
@property
def _number_of_features(self) -> int:
return (
sum(self.embedding_dimension)
+ self.num_feat_dynamic_real
+ self.num_feat_static_real
+ 1 # the log(scale)
)
@property
def _past_length(self) -> int:
return self.context_length + max(self.lags_seq)
def lagged_perciever(
self,
feat_static_cat: torch.Tensor,
feat_static_real: torch.Tensor,
past_time_feat: torch.Tensor,
past_target: torch.Tensor,
past_observed_values: torch.Tensor,
future_time_feat: Optional[torch.Tensor] = None,
future_target: Optional[torch.Tensor] = None,
) -> Tuple[
Tuple[torch.Tensor, ...],
torch.Tensor,
torch.Tensor,
torch.Tensor,
Tuple[torch.Tensor, torch.Tensor],
]:
"""
Applies the underlying RNN to the provided target data and covariates.
Parameters
----------
feat_static_cat
Tensor of static categorical features,
shape: ``(batch_size, num_feat_static_cat)``.
feat_static_real
Tensor of static real features,
shape: ``(batch_size, num_feat_static_real)``.
past_time_feat
Tensor of dynamic real features in the past,
shape: ``(batch_size, past_length, num_feat_dynamic_real)``.
past_target
Tensor of past target values,
shape: ``(batch_size, past_length, *target_shape)``.
past_observed_values
Tensor of observed values indicators,
shape: ``(batch_size, past_length)``.
future_time_feat
(Optional) tensor of dynamic real features in the past,
shape: ``(batch_size, prediction_length, num_feat_dynamic_real)``.
future_target
(Optional) tensor of future target values,
shape: ``(batch_size, prediction_length, *target_shape)``.
Returns
-------
Tuple
A tuple containing, in this order:
- Parameters of the output distribution
- Scaling factor applied to the target
- Raw output of the RNN
- Static input to the RNN
- Output state from the RNN
"""
context = past_target[:, -self.context_length :]
observed_context = past_observed_values[:, -self.context_length :]
_, scale = self.scaler(context, observed_context)
prior_input = past_target[:, : -self.context_length] / scale
input = (
torch.cat((context, future_target[:, :-1]), dim=1) / scale
if future_target is not None
else context / scale
)
embedded_cat = self.embedder(feat_static_cat)
static_feat = torch.cat(
(embedded_cat, feat_static_real, scale.log()),
dim=1,
)
expanded_static_feat = static_feat.unsqueeze(1).expand(-1, input.shape[1], -1)
time_feat = (
torch.cat(
(
past_time_feat[:, -self.context_length + 1 :, ...],
future_time_feat,
),
dim=1,
)
if future_time_feat is not None
else past_time_feat[:, -self.context_length + 1 :, ...]
)
features = torch.cat((expanded_static_feat, time_feat), dim=-1)
lags = lagged_sequence_values(self.lags_seq, prior_input, input)
perciever_input = torch.cat((lags, features), dim=-1)
prefix, x = (
perciever_input[:, : self.context_length - 1, ...],
perciever_input[:, self.context_length - 1 :, ...],
)
# initial perceiver attention and feedforward (one cross attention)
for cross_attn, ff in self.perceive_layers:
x = cross_attn(x, prefix) + x
x = ff(x) + x
# layers
for attn, ff in self.layers:
x = attn(x) + x
x = ff(x) + x
# output
params = self.param_proj(x)
return (
params,
scale,
static_feat,
perciever_input[:, : self.context_length - 1, ...],
perciever_input[:, self.context_length - 1 :, ...],
)
@torch.jit.ignore
def output_distribution(
self, params, scale=None, trailing_n=None
) -> torch.distributions.Distribution:
"""
Instantiate the output distribution
Parameters
----------
params
Tuple of distribution parameters.
scale
(Optional) scale tensor.
trailing_n
If set, the output distribution is created only for the last
``trailing_n`` time points.
Returns
-------
torch.distributions.Distribution
Output distribution from the model.
"""
sliced_params = params
if trailing_n is not None:
sliced_params = [p[:, -trailing_n:] for p in params]
return self.distr_output.distribution(sliced_params, scale=scale)
def forward(
self,
feat_static_cat: torch.Tensor,
feat_static_real: torch.Tensor,
past_time_feat: torch.Tensor,
past_target: torch.Tensor,
past_observed_values: torch.Tensor,
future_time_feat: torch.Tensor,
num_parallel_samples: Optional[int] = None,
) -> torch.Tensor:
"""
Invokes the model on input data, and produce outputs future samples.
Parameters
----------
feat_static_cat
Tensor of static categorical features,
shape: ``(batch_size, num_feat_static_cat)``.
feat_static_real
Tensor of static real features,
shape: ``(batch_size, num_feat_static_real)``.
past_time_feat
Tensor of dynamic real features in the past,
shape: ``(batch_size, past_length, num_feat_dynamic_real)``.
past_target
Tensor of past target values,
shape: ``(batch_size, past_length, *target_shape)``.
past_observed_values
Tensor of observed values indicators,
shape: ``(batch_size, past_length)``.
future_time_feat
(Optional) tensor of dynamic real features in the past,
shape: ``(batch_size, prediction_length, num_feat_dynamic_real)``.
num_parallel_samples
How many future samples to produce.
By default, self.num_parallel_samples is used.
"""
if num_parallel_samples is None:
num_parallel_samples = self.num_parallel_samples
params, scale, static_feat, prefix, x = self.lagged_perciever(
feat_static_cat,
feat_static_real,
past_time_feat,
past_target,
past_observed_values,
future_time_feat[:, :1],
)
repeated_scale = scale.repeat_interleave(repeats=num_parallel_samples, dim=0)
repeated_static_feat = static_feat.repeat_interleave(
repeats=num_parallel_samples, dim=0
).unsqueeze(dim=1)
repeated_past_target = (
past_target.repeat_interleave(repeats=num_parallel_samples, dim=0)
/ repeated_scale
)
repeated_time_feat = future_time_feat.repeat_interleave(
repeats=num_parallel_samples, dim=0
)
repeated_prefix = prefix.repeat_interleave(repeats=num_parallel_samples, dim=0)
repeated_x = x.repeat_interleave(repeats=num_parallel_samples, dim=0)
repeated_params = [
s.repeat_interleave(repeats=num_parallel_samples, dim=0) for s in params
]
distr = self.output_distribution(
repeated_params, trailing_n=1, scale=repeated_scale
)
next_sample = distr.sample()
future_samples = [next_sample]
# greedy sampling
for k in range(1, self.prediction_length):
scaled_next_sample = next_sample / repeated_scale
next_features = torch.cat(
(repeated_static_feat, repeated_time_feat[:, k : k + 1]),
dim=-1,
)
next_lags = lagged_sequence_values(
self.lags_seq,
repeated_past_target,
scaled_next_sample,
)
perciever_input = torch.cat((next_lags, next_features), dim=-1)
repeated_x = torch.cat((repeated_x, perciever_input), dim=1)
x = repeated_x
for cross_attn, ff in self.perceive_layers:
x = cross_attn(x, repeated_prefix) + x
x = ff(x) + x
for attn, ff in self.layers:
x = attn(x) + x
x = ff(x) + x
params = self.param_proj(x[:, -1:])
distr = self.output_distribution(params, scale=repeated_scale)
next_sample = distr.sample()
future_samples.append(next_sample)
future_samples_concat = torch.cat(future_samples, dim=1)
return future_samples_concat.reshape(
(-1, num_parallel_samples, self.prediction_length) + self.target_shape,
)
File diff suppressed because one or more lines are too long