added ETSFormer

This commit is contained in:
Kashif Rasul
2022-04-07 13:29:50 +02:00
parent bc331ba3a7
commit 38882af8aa
7 changed files with 1177 additions and 0 deletions
View File
+9
View File
@@ -0,0 +1,9 @@
from .estimator import ETSformerEstimator
from .lightning_module import ETSformerLightningModule
from .module import ETSformerModel
__all__ = [
"ETSformerModel",
"ETSformerLightningModule",
"ETSformerEstimator",
]
+307
View File
@@ -0,0 +1,307 @@
from typing import Any, Dict, Iterable, List, Optional
import torch
from gluonts.core.component import validated
from gluonts.dataset.common import Dataset
from gluonts.dataset.field_names import FieldName
from gluonts.itertools import Cyclic, IterableSlice, PseudoShuffled
from gluonts.time_feature import TimeFeature, time_features_from_frequency_str
from gluonts.torch.model.estimator import PyTorchLightningEstimator
from gluonts.torch.model.predictor import PyTorchPredictor
from gluonts.torch.modules.distribution_output import DistributionOutput, StudentTOutput
from gluonts.torch.modules.loss import DistributionLoss, NegativeLogLikelihood
from gluonts.torch.util import IterableDataset
from gluonts.transform import (
AddAgeFeature,
AddObservedValuesIndicator,
AddTimeFeatures,
AsNumpyArray,
Chain,
ExpectedNumInstanceSampler,
InstanceSplitter,
RemoveFields,
SelectFields,
SetField,
TestSplitSampler,
Transformation,
ValidationSplitSampler,
VstackFeatures,
)
from gluonts.transform.sampler import InstanceSampler
from torch.utils.data import DataLoader
from lightning_module import ETSformerLightningModule
from module import ETSformerModel
PREDICTION_INPUT_NAMES = [
"feat_static_cat",
"feat_static_real",
"past_time_feat",
"past_target",
"past_observed_values",
"future_time_feat",
]
TRAINING_INPUT_NAMES = PREDICTION_INPUT_NAMES + [
"future_target",
"future_observed_values",
]
class ETSformerEstimator(PyTorchLightningEstimator):
@validated()
def __init__(
self,
freq: str,
prediction_length: int,
# ETSformer arguments
nhead: int,
num_layers: int = 2,
k_largest_amplitudes: int = 4,
embed_kernel_size: int = 3,
input_size: int = 1,
dropout: float = 0.1,
context_length: Optional[int] = None,
num_feat_dynamic_real: int = 0,
num_feat_static_cat: int = 0,
num_feat_static_real: int = 0,
cardinality: Optional[List[int]] = None,
embedding_dimension: Optional[List[int]] = None,
distr_output: DistributionOutput = StudentTOutput(),
loss: DistributionLoss = NegativeLogLikelihood(),
scaling: bool = True,
lags_seq: Optional[List[int]] = None,
time_features: Optional[List[TimeFeature]] = None,
num_parallel_samples: int = 100,
batch_size: int = 32,
num_batches_per_epoch: int = 50,
trainer_kwargs: Optional[Dict[str, Any]] = dict(),
train_sampler: Optional[InstanceSampler] = None,
validation_sampler: Optional[InstanceSampler] = None,
) -> None:
trainer_kwargs = {
"max_epochs": 100,
**trainer_kwargs,
}
super().__init__(trainer_kwargs=trainer_kwargs)
self.freq = freq
self.context_length = (
context_length if context_length is not None else prediction_length
)
self.prediction_length = prediction_length
self.distr_output = distr_output
self.loss = loss
self.input_size = input_size
self.nhead = nhead
self.num_layers = num_layers
self.k_largest_amplitudes = k_largest_amplitudes
self.dropout = dropout
self.embed_kernel_size = embed_kernel_size
self.num_feat_dynamic_real = num_feat_dynamic_real
self.num_feat_static_cat = num_feat_static_cat
self.num_feat_static_real = num_feat_static_real
self.cardinality = (
cardinality if cardinality and num_feat_static_cat > 0 else [1]
)
self.embedding_dimension = embedding_dimension
self.scaling = scaling
self.lags_seq = lags_seq
self.time_features = (
time_features
if time_features is not None
else time_features_from_frequency_str(self.freq)
)
self.num_parallel_samples = num_parallel_samples
self.batch_size = batch_size
self.num_batches_per_epoch = num_batches_per_epoch
self.train_sampler = train_sampler or ExpectedNumInstanceSampler(
num_instances=1.0, min_future=prediction_length
)
self.validation_sampler = validation_sampler or ValidationSplitSampler(
min_future=prediction_length
)
def create_transformation(self) -> Transformation:
remove_field_names = []
if self.num_feat_static_real == 0:
remove_field_names.append(FieldName.FEAT_STATIC_REAL)
if self.num_feat_dynamic_real == 0:
remove_field_names.append(FieldName.FEAT_DYNAMIC_REAL)
return Chain(
[RemoveFields(field_names=remove_field_names)]
+ (
[SetField(output_field=FieldName.FEAT_STATIC_CAT, value=[0])]
if not self.num_feat_static_cat > 0
else []
)
+ (
[SetField(output_field=FieldName.FEAT_STATIC_REAL, value=[0.0])]
if not self.num_feat_static_real > 0
else []
)
+ [
AsNumpyArray(
field=FieldName.FEAT_STATIC_CAT,
expected_ndim=1,
dtype=int,
),
AsNumpyArray(
field=FieldName.FEAT_STATIC_REAL,
expected_ndim=1,
),
AsNumpyArray(
field=FieldName.TARGET,
# in the following line, we add 1 for the time dimension
expected_ndim=1 + len(self.distr_output.event_shape),
),
AddObservedValuesIndicator(
target_field=FieldName.TARGET,
output_field=FieldName.OBSERVED_VALUES,
),
AddTimeFeatures(
start_field=FieldName.START,
target_field=FieldName.TARGET,
output_field=FieldName.FEAT_TIME,
time_features=self.time_features,
pred_length=self.prediction_length,
),
AddAgeFeature(
target_field=FieldName.TARGET,
output_field=FieldName.FEAT_AGE,
pred_length=self.prediction_length,
log_scale=True,
),
VstackFeatures(
output_field=FieldName.FEAT_TIME,
input_fields=[FieldName.FEAT_TIME, FieldName.FEAT_AGE]
+ (
[FieldName.FEAT_DYNAMIC_REAL]
if self.num_feat_dynamic_real > 0
else []
),
),
]
)
def _create_instance_splitter(self, module: ETSformerLightningModule, mode: str):
assert mode in ["training", "validation", "test"]
instance_sampler = {
"training": self.train_sampler,
"validation": self.validation_sampler,
"test": TestSplitSampler(),
}[mode]
return InstanceSplitter(
target_field=FieldName.TARGET,
is_pad_field=FieldName.IS_PAD,
start_field=FieldName.START,
forecast_start_field=FieldName.FORECAST_START,
instance_sampler=instance_sampler,
past_length=module.model._past_length,
future_length=self.prediction_length,
time_series_fields=[
FieldName.FEAT_TIME,
FieldName.OBSERVED_VALUES,
],
dummy_value=self.distr_output.value_in_support,
)
def create_training_data_loader(
self,
data: Dataset,
module: ETSformerLightningModule,
shuffle_buffer_length: Optional[int] = None,
**kwargs,
) -> Iterable:
transformation = self._create_instance_splitter(
module, "training"
) + SelectFields(TRAINING_INPUT_NAMES)
training_instances = transformation.apply(
Cyclic(data)
if shuffle_buffer_length is None
else PseudoShuffled(
Cyclic(data), shuffle_buffer_length=shuffle_buffer_length
)
)
return IterableSlice(
iter(
DataLoader(
IterableDataset(training_instances),
batch_size=self.batch_size,
**kwargs,
)
),
self.num_batches_per_epoch,
)
def create_validation_data_loader(
self,
data: Dataset,
module: ETSformerLightningModule,
**kwargs,
) -> Iterable:
transformation = self._create_instance_splitter(
module, "validation"
) + SelectFields(TRAINING_INPUT_NAMES)
validation_instances = transformation.apply(data)
return DataLoader(
IterableDataset(validation_instances),
batch_size=self.batch_size,
**kwargs,
)
def create_predictor(
self,
transformation: Transformation,
module: ETSformerLightningModule,
) -> PyTorchPredictor:
prediction_splitter = self._create_instance_splitter(module, "test")
return PyTorchPredictor(
input_transform=transformation + prediction_splitter,
input_names=PREDICTION_INPUT_NAMES,
prediction_net=module.model,
batch_size=self.batch_size,
freq=self.freq,
prediction_length=self.prediction_length,
device=torch.device("cuda" if torch.cuda.is_available() else "cpu"),
)
def create_lightning_module(self) -> ETSformerLightningModule:
model = ETSformerModel(
freq=self.freq,
context_length=self.context_length,
prediction_length=self.prediction_length,
num_feat_dynamic_real=1
+ self.num_feat_dynamic_real
+ len(self.time_features),
num_feat_static_real=max(1, self.num_feat_static_real),
num_feat_static_cat=max(1, self.num_feat_static_cat),
cardinality=self.cardinality,
embedding_dimension=self.embedding_dimension,
# ETSformer arguments
nhead=self.nhead,
num_layers=self.num_layers,
dropout=self.dropout,
k_largest_amplitudes=self.k_largest_amplitudes,
embed_kernel_size=self.embed_kernel_size,
# univariate input
input_size=self.input_size,
distr_output=self.distr_output,
lags_seq=self.lags_seq,
scaling=self.scaling,
num_parallel_samples=self.num_parallel_samples,
)
return ETSformerLightningModule(model=model, loss=self.loss)
File diff suppressed because one or more lines are too long
+80
View File
@@ -0,0 +1,80 @@
import pytorch_lightning as pl
import torch
from gluonts.torch.modules.loss import DistributionLoss, NegativeLogLikelihood
from gluonts.torch.util import weighted_average
from module import ETSformerModel
class ETSformerLightningModule(pl.LightningModule):
def __init__(
self,
model: ETSformerModel,
loss: DistributionLoss = NegativeLogLikelihood(),
lr: float = 1e-3,
weight_decay: float = 1e-8,
) -> None:
super().__init__()
self.save_hyperparameters()
self.model = model
self.loss = loss
self.lr = lr
self.weight_decay = weight_decay
def training_step(self, batch, batch_idx: int):
"""Execute training step"""
train_loss = self(batch)
self.log(
"train_loss",
train_loss,
on_epoch=True,
on_step=False,
prog_bar=True,
)
return train_loss
def validation_step(self, batch, batch_idx: int):
"""Execute validation step"""
with torch.inference_mode():
val_loss = self(batch)
self.log("val_loss", val_loss, on_epoch=True, on_step=False, prog_bar=True)
return val_loss
def configure_optimizers(self):
"""Returns the optimizer to use"""
return torch.optim.Adam(
self.model.parameters(),
lr=self.lr,
weight_decay=self.weight_decay,
)
def forward(self, batch):
feat_static_cat = batch["feat_static_cat"]
feat_static_real = batch["feat_static_real"]
past_time_feat = batch["past_time_feat"]
past_target = batch["past_target"]
future_time_feat = batch["future_time_feat"]
future_target = batch["future_target"]
past_observed_values = batch["past_observed_values"]
future_observed_values = batch["future_observed_values"]
etsformer_inputs, scale, _ = self.model.create_network_inputs(
feat_static_cat,
feat_static_real,
past_time_feat,
past_target,
past_observed_values,
future_time_feat,
future_target,
)
params = self.model.output_params(etsformer_inputs)
distr = self.model.output_distribution(params, scale)
loss_values = self.loss(distr, future_target)
if len(self.model.target_shape) == 0:
loss_weights = future_observed_values
else:
loss_weights = future_observed_values.min(dim=-1, keepdim=False)
return weighted_average(loss_values, weights=loss_weights)
+286
View File
@@ -0,0 +1,286 @@
from typing import List, Optional
import torch
import torch.nn as nn
from etsformer_pytorch import ETSFormer
from gluonts.core.component import validated
from gluonts.time_feature import get_lags_for_frequency
from gluonts.torch.modules.distribution_output import DistributionOutput, StudentTOutput
from gluonts.torch.modules.feature import FeatureEmbedder
from gluonts.torch.modules.scaler import MeanScaler, NOPScaler
class ETSformerModel(nn.Module):
@validated()
def __init__(
self,
freq: str,
context_length: int,
prediction_length: int,
num_feat_dynamic_real: int,
num_feat_static_real: int,
num_feat_static_cat: int,
cardinality: List[int],
# ETSformer arguments
k_largest_amplitudes: int,
embed_kernel_size: int,
nhead: int,
num_layers: int,
dropout: float = 0.1,
# univariate input
input_size: int = 1,
embedding_dimension: Optional[List[int]] = None,
distr_output: DistributionOutput = StudentTOutput(),
lags_seq: Optional[List[int]] = None,
scaling: bool = True,
num_parallel_samples: int = 100,
) -> None:
super().__init__()
self.input_size = input_size
self.target_shape = distr_output.event_shape
self.num_feat_dynamic_real = num_feat_dynamic_real
self.num_feat_static_cat = num_feat_static_cat
self.num_feat_static_real = num_feat_static_real
self.embedding_dimension = (
embedding_dimension
if embedding_dimension is not None or cardinality is None
else [min(50, (cat + 1) // 2) for cat in cardinality]
)
self.lags_seq = lags_seq or get_lags_for_frequency(freq_str=freq)
self.num_parallel_samples = num_parallel_samples
self.history_length = context_length + max(self.lags_seq)
self.embedder = FeatureEmbedder(
cardinalities=cardinality,
embedding_dims=self.embedding_dimension,
)
if scaling:
self.scaler = MeanScaler(dim=1, keepdim=True)
else:
self.scaler = NOPScaler(dim=1, keepdim=True)
# total feature size
d_model = self.input_size * len(self.lags_seq) + self._number_of_features
self.context_length = context_length
self.prediction_length = prediction_length
self.distr_output = distr_output
self.param_proj = distr_output.get_args_proj(d_model)
# ETSformer enc-decoder
self.etsformer = ETSFormer(
time_features=d_model,
model_dim=d_model,
embed_kernel_size=embed_kernel_size,
K=k_largest_amplitudes,
layers=num_layers,
heads=nhead,
dropout=dropout,
)
@property
def _number_of_features(self) -> int:
return (
sum(self.embedding_dimension)
+ self.num_feat_dynamic_real
+ self.num_feat_static_real
+ 1 # the log(scale)
)
@property
def _past_length(self) -> int:
return self.context_length + max(self.lags_seq)
def get_lagged_subsequences(
self, sequence: torch.Tensor, subsequences_length: int, shift: int = 0
) -> torch.Tensor:
"""
Returns lagged subsequences of a given sequence.
Parameters
----------
sequence : Tensor
the sequence from which lagged subsequences should be extracted.
Shape: (N, T, C).
subsequences_length : int
length of the subsequences to be extracted.
shift: int
shift the lags by this amount back.
Returns
--------
lagged : Tensor
a tensor of shape (N, S, C, I), where S = subsequences_length and
I = len(indices), containing lagged subsequences. Specifically,
lagged[i, j, :, k] = sequence[i, -indices[k]-S+j, :].
"""
sequence_length = sequence.shape[1]
indices = [lag - shift for lag in self.lags_seq]
assert max(indices) + subsequences_length <= sequence_length, (
f"lags cannot go further than history length, found lag {max(indices)} "
f"while history length is only {sequence_length}"
)
lagged_values = []
for lag_index in indices:
begin_index = -lag_index - subsequences_length
end_index = -lag_index if lag_index > 0 else None
lagged_values.append(sequence[:, begin_index:end_index, ...])
return torch.stack(lagged_values, dim=-1)
def _check_shapes(
self,
prior_input: torch.Tensor,
inputs: torch.Tensor,
features: Optional[torch.Tensor],
) -> None:
assert len(prior_input.shape) == len(inputs.shape)
assert (
len(prior_input.shape) == 2 and self.input_size == 1
) or prior_input.shape[2] == self.input_size
assert (len(inputs.shape) == 2 and self.input_size == 1) or inputs.shape[
-1
] == self.input_size
assert (
features is None or features.shape[2] == self._number_of_features
), f"{features.shape[2]}, expected {self._number_of_features}"
def create_network_inputs(
self,
feat_static_cat: torch.Tensor,
feat_static_real: torch.Tensor,
past_time_feat: torch.Tensor,
past_target: torch.Tensor,
past_observed_values: torch.Tensor,
future_time_feat: Optional[torch.Tensor] = None,
future_target: Optional[torch.Tensor] = None,
):
# time feature
time_feat = (
torch.cat(
(
past_time_feat[:, self._past_length - self.context_length :, ...],
future_time_feat,
),
dim=1,
)
if future_target is not None
else past_time_feat[:, self._past_length - self.context_length :, ...]
)
# target
context = past_target[:, -self.context_length :]
observed_context = past_observed_values[:, -self.context_length :]
_, scale = self.scaler(context, observed_context)
inputs = (
torch.cat((past_target, future_target), dim=1) / scale
if future_target is not None
else past_target / scale
)
inputs_length = (
self._past_length + self.prediction_length
if future_target is not None
else self._past_length
)
assert inputs.shape[1] == inputs_length
subsequences_length = (
self.context_length + self.prediction_length
if future_target is not None
else self.context_length
)
# embeddings
embedded_cat = self.embedder(feat_static_cat)
static_feat = torch.cat(
(embedded_cat, feat_static_real, scale.log()),
dim=1,
)
expanded_static_feat = static_feat.unsqueeze(1).expand(
-1, time_feat.shape[1], -1
)
features = torch.cat((expanded_static_feat, time_feat), dim=-1)
# self._check_shapes(prior_input, inputs, features)
# sequence = torch.cat((prior_input, inputs), dim=1)
lagged_sequence = self.get_lagged_subsequences(
sequence=inputs,
subsequences_length=subsequences_length,
)
lags_shape = lagged_sequence.shape
reshaped_lagged_sequence = lagged_sequence.reshape(
lags_shape[0], lags_shape[1], -1
)
transformer_inputs = torch.cat((reshaped_lagged_sequence, features), dim=-1)
return transformer_inputs, scale, static_feat
def output_params(self, transformer_inputs):
enc_input = transformer_inputs[:, : self.context_length, ...]
dec_output = self.etsformer(
enc_input, num_steps_forecast=self.prediction_length
)
return self.param_proj(dec_output)
@torch.jit.ignore
def output_distribution(
self, params, scale=None, trailing_n=None
) -> torch.distributions.Distribution:
sliced_params = params
if trailing_n is not None:
sliced_params = [p[:, -trailing_n:] for p in params]
return self.distr_output.distribution(sliced_params, scale=scale)
# for prediction
def forward(
self,
feat_static_cat: torch.Tensor,
feat_static_real: torch.Tensor,
past_time_feat: torch.Tensor,
past_target: torch.Tensor,
past_observed_values: torch.Tensor,
future_time_feat: torch.Tensor,
num_parallel_samples: Optional[int] = None,
) -> torch.Tensor:
if num_parallel_samples is None:
num_parallel_samples = self.num_parallel_samples
encoder_inputs, scale, _ = self.create_network_inputs(
feat_static_cat,
feat_static_real,
past_time_feat,
past_target,
past_observed_values,
)
dec_out = self.etsformer(
encoder_inputs, num_steps_forecast=self.prediction_length
)
params = self.param_proj(dec_out)
repeated_params = [
s.repeat_interleave(repeats=self.num_parallel_samples, dim=0)
for s in params
]
repeated_scale = scale.repeat_interleave(
repeats=self.num_parallel_samples, dim=0
)
distr = self.output_distribution(repeated_params, scale=repeated_scale)
# Future samples
samples = distr.sample()
return samples.reshape(
(-1, self.num_parallel_samples, self.prediction_length) + self.target_shape,
)
+1
View File
@@ -4,3 +4,4 @@ pytorch-lightning
datasets
xformers
https://github.com/ml-jku/hopfield-layers
etsformer-pytorch