diff --git a/transformer/transformer.ipynb b/transformer/transformer.ipynb new file mode 100644 index 0000000..0af140d --- /dev/null +++ b/transformer/transformer.ipynb @@ -0,0 +1,1010 @@ +{ + "cells": [ + { + "cell_type": "code", + "execution_count": null, + "id": "b19f0e22", + "metadata": {}, + "outputs": [], + "source": [ + "%matplotlib inline" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "bc1a0f32", + "metadata": {}, + "outputs": [], + "source": [ + "from typing import List, Optional, Iterable, Dict, Any\n", + "from itertools import islice\n", + "\n", + "import numpy as np\n", + "from matplotlib import pyplot as plt\n", + "import matplotlib.dates as mdates\n", + "import tqdm.auto as tqdm\n", + "\n", + "import torch\n", + "import torch.nn as nn\n", + "from torch.utils.data import DataLoader\n", + "\n", + "import pytorch_lightning as pl\n", + "from pytorch_lightning.loggers import CSVLogger\n", + "\n", + "from gluonts.core.component import validated\n", + "from gluonts.dataset.common import Dataset\n", + "from gluonts.dataset.field_names import FieldName\n", + "from gluonts.itertools import Cyclic, PseudoShuffled, IterableSlice\n", + "from gluonts.time_feature import (\n", + " TimeFeature,\n", + " time_features_from_frequency_str,\n", + ")\n", + "from gluonts.torch.modules.loss import DistributionLoss, NegativeLogLikelihood\n", + "from gluonts.transform import (\n", + " Transformation,\n", + " Chain,\n", + " RemoveFields,\n", + " SetField,\n", + " AsNumpyArray,\n", + " AddObservedValuesIndicator,\n", + " AddTimeFeatures,\n", + " AddAgeFeature,\n", + " VstackFeatures,\n", + " InstanceSplitter,\n", + " ValidationSplitSampler,\n", + " TestSplitSampler,\n", + " ExpectedNumInstanceSampler,\n", + " SelectFields,\n", + ")\n", + "from gluonts.torch.util import (\n", + " IterableDataset,\n", + ")\n", + "from gluonts.evaluation import make_evaluation_predictions, Evaluator\n", + "from gluonts.torch.model.estimator import PyTorchLightningEstimator\n", + "from gluonts.torch.model.predictor import PyTorchPredictor\n", + "from gluonts.torch.modules.distribution_output import (\n", + " DistributionOutput,\n", + " StudentTOutput,\n", + ")\n", + "from gluonts.torch.util import weighted_average\n", + "from gluonts.torch.modules.scaler import MeanScaler, NOPScaler\n", + "from gluonts.torch.modules.feature import FeatureEmbedder\n", + "from gluonts.time_feature import get_lags_for_frequency\n", + "from gluonts.dataset.repository.datasets import get_dataset" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "ac78c47a", + "metadata": {}, + "outputs": [], + "source": [ + "class TransformerModel(nn.Module):\n", + " @validated()\n", + " def __init__(\n", + " self,\n", + " freq: str,\n", + " context_length: int,\n", + " prediction_length: int,\n", + " num_feat_dynamic_real: int,\n", + " num_feat_static_real: int,\n", + " num_feat_static_cat: int,\n", + " cardinality: List[int],\n", + " \n", + " # transformer arguments\n", + " nhead: int,\n", + " num_encoder_layers: int,\n", + " num_decoder_layers: int,\n", + " dim_feedforward: int,\n", + " activation: str = \"gelu\",\n", + " dropout: float = 0.1,\n", + "\n", + " # univariate input\n", + " input_size: int = 1,\n", + " embedding_dimension: Optional[List[int]] = None,\n", + " distr_output: DistributionOutput = StudentTOutput(),\n", + " lags_seq: Optional[List[int]] = None,\n", + " scaling: bool = True,\n", + " num_parallel_samples: int = 100,\n", + " ) -> None:\n", + " super().__init__()\n", + " \n", + " self.input_size = input_size\n", + " \n", + " self.target_shape = distr_output.event_shape\n", + " self.num_feat_dynamic_real = num_feat_dynamic_real\n", + " self.num_feat_static_cat = num_feat_static_cat\n", + " self.num_feat_static_real = num_feat_static_real\n", + " self.embedding_dimension = (\n", + " embedding_dimension\n", + " if embedding_dimension is not None or cardinality is None\n", + " else [min(50, (cat + 1) // 2) for cat in cardinality]\n", + " )\n", + " self.lags_seq = lags_seq or get_lags_for_frequency(freq_str=freq)\n", + " self.num_parallel_samples = num_parallel_samples\n", + " self.history_length = context_length + max(self.lags_seq)\n", + " self.embedder = FeatureEmbedder(\n", + " cardinalities=cardinality,\n", + " embedding_dims=self.embedding_dimension,\n", + " )\n", + " if scaling:\n", + " self.scaler = MeanScaler(dim=1, keepdim=True)\n", + " else:\n", + " self.scaler = NOPScaler(dim=1, keepdim=True)\n", + " \n", + " # total feature size\n", + " d_model = self.input_size * len(self.lags_seq) + self._number_of_features\n", + " \n", + " self.context_length = context_length\n", + " self.prediction_length = prediction_length\n", + " self.distr_output = distr_output\n", + " self.param_proj = distr_output.get_args_proj(d_model)\n", + " \n", + " # transformer enc-decoder and mask initializer\n", + " self.transformer = nn.Transformer(\n", + " d_model=d_model,\n", + " nhead=nhead,\n", + " num_encoder_layers=num_encoder_layers,\n", + " num_decoder_layers=num_decoder_layers,\n", + " dim_feedforward=dim_feedforward,\n", + " dropout=dropout,\n", + " activation=activation,\n", + " batch_first=True,\n", + " )\n", + " \n", + " # causal decoder tgt mask\n", + " self.register_buffer(\n", + " \"tgt_mask\",\n", + " self.transformer.generate_square_subsequent_mask(prediction_length),\n", + " )\n", + " \n", + " @property\n", + " def _number_of_features(self) -> int:\n", + " return (\n", + " sum(self.embedding_dimension)\n", + " + self.num_feat_dynamic_real\n", + " + self.num_feat_static_real\n", + " + 1 # the log(scale)\n", + " )\n", + "\n", + " @property\n", + " def _past_length(self) -> int:\n", + " return self.context_length + max(self.lags_seq)\n", + " \n", + " def get_lagged_subsequences(\n", + " self,\n", + " sequence: torch.Tensor,\n", + " subsequences_length: int,\n", + " shift: int = 0\n", + " ) -> torch.Tensor:\n", + " \"\"\"\n", + " Returns lagged subsequences of a given sequence.\n", + " Parameters\n", + " ----------\n", + " sequence : Tensor\n", + " the sequence from which lagged subsequences should be extracted.\n", + " Shape: (N, T, C).\n", + " subsequences_length : int\n", + " length of the subsequences to be extracted.\n", + " shift: int\n", + " shift the lags by this amount back.\n", + " Returns\n", + " --------\n", + " lagged : Tensor\n", + " a tensor of shape (N, S, C, I), where S = subsequences_length and\n", + " I = len(indices), containing lagged subsequences. Specifically,\n", + " lagged[i, j, :, k] = sequence[i, -indices[k]-S+j, :].\n", + " \"\"\"\n", + " sequence_length = sequence.shape[1]\n", + " indices = [l - shift for l in self.lags_seq]\n", + "\n", + " assert max(indices) + subsequences_length <= sequence_length, (\n", + " f\"lags cannot go further than history length, found lag {max(indices)} \"\n", + " f\"while history length is only {sequence_length}\"\n", + " )\n", + "\n", + " lagged_values = []\n", + " for lag_index in indices:\n", + " begin_index = -lag_index - subsequences_length\n", + " end_index = -lag_index if lag_index > 0 else None\n", + " lagged_values.append(sequence[:, begin_index:end_index, ...])\n", + " return torch.stack(lagged_values, dim=-1)\n", + "\n", + " def _check_shapes(\n", + " self,\n", + " prior_input: torch.Tensor,\n", + " inputs: torch.Tensor,\n", + " features: Optional[torch.Tensor],\n", + " ) -> None:\n", + " assert len(prior_input.shape) == len(inputs.shape)\n", + " assert (\n", + " len(prior_input.shape) == 2 and self.input_size == 1\n", + " ) or prior_input.shape[2] == self.input_size\n", + " assert (len(inputs.shape) == 2 and self.input_size == 1) or inputs.shape[\n", + " -1\n", + " ] == self.input_size\n", + " assert (\n", + " features is None or features.shape[2] == self._number_of_features\n", + " ), f\"{features.shape[2]}, expected {self._number_of_features}\"\n", + " \n", + " \n", + " def create_network_inputs(\n", + " self, \n", + " feat_static_cat: torch.Tensor, \n", + " feat_static_real: torch.Tensor,\n", + " past_time_feat: torch.Tensor,\n", + " past_target: torch.Tensor,\n", + " past_observed_values: torch.Tensor,\n", + " future_time_feat: Optional[torch.Tensor] = None,\n", + " future_target: Optional[torch.Tensor] = None,\n", + " ): \n", + " # time feature\n", + " time_feat = (\n", + " past_time_feat[:, self._past_length - self.context_length :, ...]\n", + " if future_time_feat is None or future_target is None\n", + " else torch.cat(\n", + " (\n", + " past_time_feat[:, self._past_length - self.context_length :, ...],\n", + " future_time_feat,\n", + " ),\n", + " dim=1,\n", + " )\n", + " )\n", + "\n", + " # target\n", + " context = past_target[:, -self.context_length :]\n", + " observed_context = past_observed_values[:, -self.context_length :]\n", + " _, scale = self.scaler(context, observed_context)\n", + "\n", + " inputs = (\n", + " torch.cat((past_target, future_target), dim=1) / scale\n", + " if future_target is not None\n", + " else past_target / scale\n", + " )\n", + "\n", + " inputs_length = (\n", + " self._past_length + self.prediction_length\n", + " if future_target is not None\n", + " else self._past_length\n", + " )\n", + " assert inputs.shape[1] == inputs_length\n", + " \n", + " subsequences_length = (\n", + " self.context_length\n", + " if future_time_feat is None or future_target is None\n", + " else self.context_length + self.prediction_length\n", + " )\n", + " \n", + " # embeddings\n", + " embedded_cat = self.embedder(feat_static_cat)\n", + " static_feat = torch.cat(\n", + " (embedded_cat, feat_static_real, scale.log()),\n", + " dim=1,\n", + " )\n", + " expanded_static_feat = static_feat.unsqueeze(1).expand(\n", + " -1, time_feat.shape[1], -1\n", + " )\n", + " \n", + " \n", + " features = torch.cat((expanded_static_feat, time_feat), dim=-1)\n", + " \n", + " \n", + " #self._check_shapes(prior_input, inputs, features)\n", + "\n", + " #sequence = torch.cat((prior_input, inputs), dim=1)\n", + " lagged_sequence = self.get_lagged_subsequences(\n", + " sequence=inputs,\n", + " subsequences_length=subsequences_length,\n", + " )\n", + "\n", + " lags_shape = lagged_sequence.shape\n", + " reshaped_lagged_sequence = lagged_sequence.reshape(\n", + " lags_shape[0], lags_shape[1], -1\n", + " )\n", + "\n", + " if features is None:\n", + " transformer_inputs = reshaped_lagged_sequence\n", + " else:\n", + " transformer_inputs = torch.cat((reshaped_lagged_sequence, features), dim=-1)\n", + " \n", + " return transformer_inputs, scale, static_feat\n", + " \n", + " def output_params(self, transformer_inputs):\n", + " enc_input = transformer_inputs[:, :self.context_length, ...]\n", + " dec_input = transformer_inputs[:, self.context_length:, ...]\n", + " \n", + " enc_out = self.transformer.encoder(\n", + " enc_input\n", + " )\n", + " dec_output = self.transformer.decoder(\n", + " dec_input,\n", + " enc_out,\n", + " tgt_mask=self.tgt_mask\n", + " )\n", + " \n", + " return self.param_proj(dec_output)\n", + "\n", + " @torch.jit.ignore\n", + " def output_distribution(\n", + " self, params, scale=None, trailing_n=None\n", + " ) -> torch.distributions.Distribution:\n", + " sliced_params = params\n", + " if trailing_n is not None:\n", + " sliced_params = [p[:, -trailing_n:] for p in params]\n", + " return self.distr_output.distribution(sliced_params, scale=scale)\n", + " \n", + " # for prediction\n", + " def forward(\n", + " self,\n", + " feat_static_cat: torch.Tensor,\n", + " feat_static_real: torch.Tensor,\n", + " past_time_feat: torch.Tensor,\n", + " past_target: torch.Tensor,\n", + " past_observed_values: torch.Tensor,\n", + " future_time_feat: torch.Tensor,\n", + " num_parallel_samples: Optional[int] = None,\n", + " ) -> torch.Tensor:\n", + " \n", + " \n", + " if num_parallel_samples is None:\n", + " num_parallel_samples = self.num_parallel_samples\n", + " \n", + " encoder_inputs, scale, static_feat = self.create_network_inputs(\n", + " feat_static_cat,\n", + " feat_static_real,\n", + " past_time_feat,\n", + " past_target,\n", + " past_observed_values,\n", + " future_time_feat,\n", + " )\n", + " \n", + " enc_out = self.transformer.encoder(encoder_inputs)\n", + " \n", + " repeated_scale = scale.repeat_interleave(\n", + " repeats=self.num_parallel_samples, dim=0\n", + " )\n", + " repeated_static_feat = static_feat.repeat_interleave(\n", + " repeats=self.num_parallel_samples, dim=0\n", + " ).unsqueeze(dim=1)\n", + " repeated_past_target = (\n", + " past_target.repeat_interleave(\n", + " repeats=self.num_parallel_samples, dim=0\n", + " )\n", + " / repeated_scale\n", + " )\n", + " repeated_time_feat = future_time_feat.repeat_interleave(\n", + " repeats=self.num_parallel_samples, dim=0\n", + " )\n", + " repeated_enc_out = enc_out.repeat_interleave(\n", + " repeats=self.num_parallel_samples, dim=0\n", + " )\n", + "\n", + " future_samples = []\n", + " \n", + " for k in range(self.prediction_length):\n", + " next_features = torch.cat(\n", + " (repeated_static_feat, repeated_time_feat[:, k : k + 1]),\n", + " dim=-1,\n", + " )\n", + " \n", + " #self._check_shapes(repeated_past_target, next_sample, next_features)\n", + "\n", + " #sequence = torch.cat((repeated_past_target, next_sample), dim=1)\n", + " \n", + " lagged_sequence = self.get_lagged_subsequences(\n", + " sequence=repeated_past_target,\n", + " subsequences_length=1,\n", + " shift=1, \n", + " )\n", + "\n", + " lags_shape = lagged_sequence.shape\n", + " reshaped_lagged_sequence = lagged_sequence.reshape(\n", + " lags_shape[0], lags_shape[1], -1\n", + " )\n", + " \n", + " decoder_input = torch.cat((reshaped_lagged_sequence, next_features), dim=-1)\n", + "\n", + " output = self.transformer.decoder(decoder_input, repeated_enc_out)\n", + " \n", + " params = self.param_proj(output)\n", + " distr = self.output_distribution(params)\n", + " next_sample = distr.sample()\n", + " \n", + " repeated_past_target = torch.cat(\n", + " (repeated_past_target, next_sample), dim=1\n", + " )\n", + " future_samples.append(next_sample)\n", + "\n", + " unscaled_future_samples = (\n", + " torch.cat(future_samples, dim=1) * repeated_scale\n", + " )\n", + " return unscaled_future_samples.reshape(\n", + " (-1, self.num_parallel_samples, self.prediction_length)\n", + " + self.target_shape,\n", + " )" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "a8873ae3", + "metadata": {}, + "outputs": [], + "source": [ + "class TransformerLightningModule(pl.LightningModule):\n", + " def __init__(\n", + " self,\n", + " model: TransformerModel,\n", + " loss: DistributionLoss = NegativeLogLikelihood(),\n", + " lr: float = 1e-3,\n", + " weight_decay: float = 1e-8,\n", + " ) -> None:\n", + " super().__init__()\n", + " self.save_hyperparameters()\n", + " self.model = model\n", + " self.loss = loss\n", + " self.lr = lr\n", + " self.weight_decay = weight_decay\n", + " \n", + " def training_step(self, batch, batch_idx: int):\n", + " \"\"\"Execute training step\"\"\"\n", + " train_loss = self(batch)\n", + " self.log(\n", + " \"train_loss\",\n", + " train_loss,\n", + " on_epoch=True,\n", + " on_step=False,\n", + " prog_bar=True,\n", + " )\n", + " return train_loss\n", + "\n", + " def validation_step(self, batch, batch_idx: int):\n", + " \"\"\"Execute validation step\"\"\"\n", + " with torch.inference_mode():\n", + " val_loss = self(batch)\n", + " self.log(\n", + " \"val_loss\", val_loss, on_epoch=True, on_step=False, prog_bar=True\n", + " )\n", + " return val_loss\n", + "\n", + " def configure_optimizers(self):\n", + " \"\"\"Returns the optimizer to use\"\"\"\n", + " return torch.optim.Adam(\n", + " self.model.parameters(),\n", + " lr=self.lr,\n", + " weight_decay=self.weight_decay,\n", + " )\n", + "\n", + " def forward(self, batch):\n", + " feat_static_cat = batch[\"feat_static_cat\"]\n", + " feat_static_real = batch[\"feat_static_real\"]\n", + " past_time_feat = batch[\"past_time_feat\"]\n", + " past_target = batch[\"past_target\"]\n", + " future_time_feat = batch[\"future_time_feat\"]\n", + " future_target = batch[\"future_target\"]\n", + " past_observed_values = batch[\"past_observed_values\"]\n", + " future_observed_values = batch[\"future_observed_values\"]\n", + " \n", + " transformer_inputs, scale, _ = self.model.create_network_inputs(\n", + " feat_static_cat,\n", + " feat_static_real,\n", + " past_time_feat,\n", + " past_target,\n", + " past_observed_values,\n", + " future_time_feat,\n", + " future_target,\n", + " )\n", + " params = self.model.output_params(transformer_inputs)\n", + " distr = self.model.output_distribution(params, scale)\n", + "\n", + " loss_values = self.loss(distr, future_target)\n", + " \n", + " if len(self.model.target_shape) == 0:\n", + " loss_weights = future_observed_values\n", + " else:\n", + " loss_weights = future_observed_values.min(dim=-1, keepdim=False)\n", + "\n", + " return weighted_average(loss_values, weights=loss_weights)" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "99d97334", + "metadata": {}, + "outputs": [], + "source": [ + "PREDICTION_INPUT_NAMES = [\n", + " \"feat_static_cat\",\n", + " \"feat_static_real\",\n", + " \"past_time_feat\",\n", + " \"past_target\",\n", + " \"past_observed_values\",\n", + " \"future_time_feat\",\n", + "]\n", + "\n", + "TRAINING_INPUT_NAMES = PREDICTION_INPUT_NAMES + [\n", + " \"future_target\",\n", + " \"future_observed_values\",\n", + "]" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "bc39c0e9", + "metadata": {}, + "outputs": [], + "source": [ + "class TransformerEstimator(PyTorchLightningEstimator):\n", + " @validated()\n", + " def __init__(\n", + " self,\n", + " freq: str,\n", + " prediction_length: int,\n", + " \n", + " # Transformer arguments\n", + " nhead: int,\n", + " num_encoder_layers: int,\n", + " num_decoder_layers: int,\n", + " dim_feedforward: int,\n", + " input_size: int = 1,\n", + " activation: str = \"gelu\",\n", + " dropout: float = 0.1,\n", + "\n", + " context_length: Optional[int] = None,\n", + "\n", + " num_feat_dynamic_real: int = 0,\n", + " num_feat_static_cat: int = 0,\n", + " num_feat_static_real: int = 0,\n", + " cardinality: Optional[List[int]] = None,\n", + " embedding_dimension: Optional[List[int]] = None,\n", + " distr_output: DistributionOutput = StudentTOutput(),\n", + " loss: DistributionLoss = NegativeLogLikelihood(),\n", + " scaling: bool = True,\n", + " lags_seq: Optional[List[int]] = None,\n", + " time_features: Optional[List[TimeFeature]] = None,\n", + " num_parallel_samples: int = 100,\n", + " batch_size: int = 32,\n", + " num_batches_per_epoch: int = 50,\n", + " trainer_kwargs: Optional[Dict[str, Any]] = dict(),\n", + " ) -> None:\n", + " trainer_kwargs = {\n", + " \"max_epochs\": 100,\n", + " **trainer_kwargs,\n", + " }\n", + " super().__init__(trainer_kwargs=trainer_kwargs)\n", + " \n", + " self.freq = freq\n", + " self.context_length = (\n", + " context_length if context_length is not None else prediction_length\n", + " )\n", + " self.prediction_length = prediction_length\n", + " self.distr_output = distr_output\n", + " self.loss = loss\n", + " \n", + " self.input_size = input_size\n", + " self.nhead = nhead\n", + " self.num_encoder_layers = num_encoder_layers\n", + " self.num_decoder_layers = num_decoder_layers\n", + " self.activation = activation\n", + " self.dim_feedforward = dim_feedforward\n", + " self.dropout = dropout\n", + " \n", + " self.num_feat_dynamic_real = num_feat_dynamic_real\n", + " self.num_feat_static_cat = num_feat_static_cat\n", + " self.num_feat_static_real = num_feat_static_real\n", + " self.cardinality = (\n", + " cardinality if cardinality and num_feat_static_cat > 0 else [1]\n", + " )\n", + " self.embedding_dimension = embedding_dimension\n", + " self.scaling = scaling\n", + " self.lags_seq = lags_seq\n", + " self.time_features = (\n", + " time_features\n", + " if time_features is not None\n", + " else time_features_from_frequency_str(self.freq)\n", + " )\n", + "\n", + " self.num_parallel_samples = num_parallel_samples\n", + " self.batch_size = batch_size\n", + " self.num_batches_per_epoch = num_batches_per_epoch\n", + "\n", + " self.train_sampler = ExpectedNumInstanceSampler(\n", + " num_instances=1.0, min_future=prediction_length\n", + " )\n", + " self.validation_sampler = ValidationSplitSampler(\n", + " min_future=prediction_length\n", + " )\n", + " \n", + " def create_transformation(self) -> Transformation:\n", + " remove_field_names = []\n", + " if self.num_feat_static_real == 0:\n", + " remove_field_names.append(FieldName.FEAT_STATIC_REAL)\n", + " if self.num_feat_dynamic_real == 0:\n", + " remove_field_names.append(FieldName.FEAT_DYNAMIC_REAL)\n", + "\n", + " return Chain(\n", + " [RemoveFields(field_names=remove_field_names)]\n", + " + (\n", + " [SetField(output_field=FieldName.FEAT_STATIC_CAT, value=[0])]\n", + " if not self.num_feat_static_cat > 0\n", + " else []\n", + " )\n", + " + (\n", + " [\n", + " SetField(\n", + " output_field=FieldName.FEAT_STATIC_REAL, value=[0.0]\n", + " )\n", + " ]\n", + " if not self.num_feat_static_real > 0\n", + " else []\n", + " )\n", + " + [\n", + " AsNumpyArray(\n", + " field=FieldName.FEAT_STATIC_CAT,\n", + " expected_ndim=1,\n", + " dtype=np.long,\n", + " ),\n", + " AsNumpyArray(\n", + " field=FieldName.FEAT_STATIC_REAL,\n", + " expected_ndim=1,\n", + " ),\n", + " AsNumpyArray(\n", + " field=FieldName.TARGET,\n", + " # in the following line, we add 1 for the time dimension\n", + " expected_ndim=1 + len(self.distr_output.event_shape),\n", + " ),\n", + " AddObservedValuesIndicator(\n", + " target_field=FieldName.TARGET,\n", + " output_field=FieldName.OBSERVED_VALUES,\n", + " ),\n", + " AddTimeFeatures(\n", + " start_field=FieldName.START,\n", + " target_field=FieldName.TARGET,\n", + " output_field=FieldName.FEAT_TIME,\n", + " time_features=self.time_features,\n", + " pred_length=self.prediction_length,\n", + " ),\n", + " AddAgeFeature(\n", + " target_field=FieldName.TARGET,\n", + " output_field=FieldName.FEAT_AGE,\n", + " pred_length=self.prediction_length,\n", + " log_scale=True,\n", + " ),\n", + " VstackFeatures(\n", + " output_field=FieldName.FEAT_TIME,\n", + " input_fields=[FieldName.FEAT_TIME, FieldName.FEAT_AGE]\n", + " + (\n", + " [FieldName.FEAT_DYNAMIC_REAL]\n", + " if self.num_feat_dynamic_real > 0\n", + " else []\n", + " ),\n", + " ),\n", + " ]\n", + " )\n", + "\n", + " def _create_instance_splitter(\n", + " self, module: TransformerLightningModule, mode: str\n", + " ):\n", + " assert mode in [\"training\", \"validation\", \"test\"]\n", + "\n", + " instance_sampler = {\n", + " \"training\": self.train_sampler,\n", + " \"validation\": self.validation_sampler,\n", + " \"test\": TestSplitSampler(),\n", + " }[mode]\n", + "\n", + " return InstanceSplitter(\n", + " target_field=FieldName.TARGET,\n", + " is_pad_field=FieldName.IS_PAD,\n", + " start_field=FieldName.START,\n", + " forecast_start_field=FieldName.FORECAST_START,\n", + " instance_sampler=instance_sampler,\n", + " past_length=module.model._past_length,\n", + " future_length=self.prediction_length,\n", + " time_series_fields=[\n", + " FieldName.FEAT_TIME,\n", + " FieldName.OBSERVED_VALUES,\n", + " ],\n", + " dummy_value=self.distr_output.value_in_support,\n", + " )\n", + "\n", + " def create_training_data_loader(\n", + " self,\n", + " data: Dataset,\n", + " module: TransformerLightningModule,\n", + " shuffle_buffer_length: Optional[int] = None,\n", + " **kwargs,\n", + " ) -> Iterable:\n", + " transformation = self._create_instance_splitter(\n", + " module, \"training\"\n", + " ) + SelectFields(TRAINING_INPUT_NAMES)\n", + "\n", + " training_instances = transformation.apply(\n", + " Cyclic(data)\n", + " if shuffle_buffer_length is None\n", + " else PseudoShuffled(\n", + " Cyclic(data), shuffle_buffer_length=shuffle_buffer_length\n", + " )\n", + " )\n", + "\n", + " return IterableSlice(\n", + " iter(\n", + " DataLoader(\n", + " IterableDataset(training_instances),\n", + " batch_size=self.batch_size,\n", + " **kwargs,\n", + " )\n", + " ),\n", + " self.num_batches_per_epoch,\n", + " )\n", + "\n", + " def create_validation_data_loader(\n", + " self,\n", + " data: Dataset,\n", + " module: TransformerLightningModule,\n", + " **kwargs,\n", + " ) -> Iterable:\n", + " transformation = self._create_instance_splitter(\n", + " module, \"validation\"\n", + " ) + SelectFields(TRAINING_INPUT_NAMES)\n", + "\n", + " validation_instances = transformation.apply(data)\n", + "\n", + " return DataLoader(\n", + " IterableDataset(validation_instances),\n", + " batch_size=self.batch_size,\n", + " **kwargs,\n", + " )\n", + " \n", + " def create_predictor(\n", + " self,\n", + " transformation: Transformation,\n", + " module: TransformerLightningModule,\n", + " ) -> PyTorchPredictor:\n", + " prediction_splitter = self._create_instance_splitter(module, \"test\")\n", + "\n", + " return PyTorchPredictor(\n", + " input_transform=transformation + prediction_splitter,\n", + " input_names=PREDICTION_INPUT_NAMES,\n", + " prediction_net=module.model,\n", + " batch_size=self.batch_size,\n", + " freq=self.freq,\n", + " prediction_length=self.prediction_length,\n", + " device=torch.device('cuda' if torch.cuda.is_available() else 'cpu'),\n", + " )\n", + "\n", + " def create_lightning_module(self) -> TransformerLightningModule:\n", + " model = TransformerModel(\n", + " freq=self.freq,\n", + " context_length=self.context_length,\n", + " prediction_length=self.prediction_length,\n", + " num_feat_dynamic_real=1 + self.num_feat_dynamic_real + len(self.time_features),\n", + " num_feat_static_real=max(1, self.num_feat_static_real),\n", + " num_feat_static_cat=max(1, self.num_feat_static_cat),\n", + " cardinality=self.cardinality,\n", + " embedding_dimension=self.embedding_dimension,\n", + "\n", + " # transformer arguments\n", + " nhead=self.nhead,\n", + " num_encoder_layers=self.num_encoder_layers,\n", + " num_decoder_layers=self.num_decoder_layers,\n", + " activation=self.activation,\n", + " dropout=self.dropout,\n", + " dim_feedforward=self.dim_feedforward,\n", + "\n", + " # univariate input\n", + " input_size=self.input_size,\n", + " distr_output=self.distr_output,\n", + " lags_seq=self.lags_seq,\n", + " scaling=self.scaling,\n", + " num_parallel_samples=self.num_parallel_samples,\n", + " )\n", + " \n", + " return TransformerLightningModule(model=model, loss=self.loss)" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "f1c38a2a", + "metadata": {}, + "outputs": [], + "source": [ + "dataset = get_dataset(\"electricity\")" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "6e17f04e", + "metadata": {}, + "outputs": [], + "source": [ + "estimator = TransformerEstimator(\n", + " freq=dataset.metadata.freq,\n", + " prediction_length=dataset.metadata.prediction_length,\n", + "\n", + " nhead=2,\n", + " num_encoder_layers=2,\n", + " num_decoder_layers=2,\n", + " dim_feedforward=32,\n", + " activation=\"gelu\",\n", + "\n", + " batch_size=128,\n", + " num_batches_per_epoch=100,\n", + " trainer_kwargs=dict(max_epochs=10, accelerator='auto'),\n", + ")" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "ed0d8504", + "metadata": { + "scrolled": true + }, + "outputs": [], + "source": [ + "predictor = estimator.train(\n", + " training_data=dataset.train,\n", + " validation_data=dataset.train,\n", + " num_workers=16,\n", + " shuffle_buffer_length=1024\n", + ")" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "4f319643", + "metadata": {}, + "outputs": [], + "source": [ + "forecast_it, ts_it = make_evaluation_predictions(\n", + " dataset=dataset.test, \n", + " predictor=predictor\n", + ")" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "c4d84519", + "metadata": { + "scrolled": true + }, + "outputs": [], + "source": [ + "forecasts = list(forecast_it)" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "fcfa0dc3", + "metadata": {}, + "outputs": [], + "source": [ + "tss = list(ts_it)" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "4239bdbb", + "metadata": {}, + "outputs": [], + "source": [ + "evaluator = Evaluator()" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "bf9638c4", + "metadata": {}, + "outputs": [], + "source": [ + "agg_metrics, ts_metrics = evaluator(iter(tss), iter(forecasts))" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "d52b033a", + "metadata": {}, + "outputs": [], + "source": [ + "agg_metrics" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "d61f32ab", + "metadata": {}, + "outputs": [], + "source": [ + "plt.figure(figsize=(20, 15))\n", + "date_formater = mdates.DateFormatter('%b, %d')\n", + "plt.rcParams.update({'font.size': 15})\n", + "\n", + "for idx, (forecast, ts) in islice(enumerate(zip(forecasts, tss)), 9):\n", + " ax = plt.subplot(3, 3, idx+1)\n", + "\n", + " plt.plot(ts[-4 * dataset.metadata.prediction_length:], label=\"target\", )\n", + " forecast.plot( color='g')\n", + " plt.xticks(rotation=60)\n", + " ax.xaxis.set_major_formatter(date_formater)\n", + "\n", + "plt.gcf().tight_layout()\n", + "plt.legend()\n", + "plt.show()" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "d494463f", + "metadata": {}, + "outputs": [], + "source": [ + "def plot_prob_forecasts(ts_entry, forecast_entry):\n", + " plot_length = 70\n", + " prediction_intervals = (50.0, 90.0)\n", + " legend = [\"observations\", \"median prediction\"] + [f\"{k}% prediction interval\" for k in prediction_intervals][::-1]\n", + "\n", + " fig, ax = plt.subplots(1, 1, figsize=(10, 7))\n", + " ts_entry[-plot_length:].plot(ax=ax) # plot the time series\n", + " forecast_entry.plot(prediction_intervals=prediction_intervals, color='g')\n", + " plt.grid(which=\"both\")\n", + " plt.legend(legend, loc=\"best\")\n", + " plt.show()" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "5256fde1", + "metadata": {}, + "outputs": [], + "source": [ + "index = 123\n", + "plot_prob_forecasts(tss[index], forecasts[index])" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "84e15650", + "metadata": {}, + "outputs": [], + "source": [] + } + ], + "metadata": { + "kernelspec": { + "display_name": "Python 3 (ipykernel)", + "language": "python", + "name": "python3" + }, + "language_info": { + "codemirror_mode": { + "name": "ipython", + "version": 3 + }, + "file_extension": ".py", + "mimetype": "text/x-python", + "name": "python", + "nbconvert_exporter": "python", + "pygments_lexer": "ipython3", + "version": "3.9.9" + } + }, + "nbformat": 4, + "nbformat_minor": 5 +}