diff --git a/pts/__init__.py b/pts/__init__.py index ce92987..c26d56c 100644 --- a/pts/__init__.py +++ b/pts/__init__.py @@ -9,4 +9,4 @@ __path__ = extend_path(__path__, __name__) # type: ignore try: __version__ = get_distribution(__name__).version except DistributionNotFound: - __version__ = "0.0.0-unknown" + __version__ = "0.0.0-unknown" \ No newline at end of file diff --git a/pts/model/tft/tft_estimator.py b/pts/model/tft/tft_estimator.py index 7ecb9bf..66d1c52 100644 --- a/pts/model/tft/tft_estimator.py +++ b/pts/model/tft/tft_estimator.py @@ -1,22 +1,18 @@ -from typing import List, Optional, Dict from itertools import chain +from typing import List, Optional, Dict import numpy as np import torch -import torch.nn as nn - from gluonts.core.component import validated from gluonts.dataset.field_names import FieldName +from gluonts.model.forecast_generator import QuantileForecastGenerator +from gluonts.model.predictor import Predictor from gluonts.time_feature import ( TimeFeature, - get_lags_for_frequency, time_features_from_frequency_str, ) -from gluonts.torch.modules.distribution_output import DistributionOutput -from gluonts.torch.support.util import copy_parameters from gluonts.torch.model.predictor import PyTorchPredictor -from gluonts.model.predictor import Predictor -from gluonts.model.forecast_generator import QuantileForecastGenerator +from gluonts.torch.support.util import copy_parameters from gluonts.transform import ( Transformation, Chain, @@ -32,9 +28,8 @@ from gluonts.transform import ( ) from pts import Trainer -from pts.model.utils import get_module_forward_input_names from pts.model import PyTorchEstimator - +from pts.model.utils import get_module_forward_input_names from .tft_network import ( TemporalFusionTransformerPredictionNetwork, TemporalFusionTransformerTrainingNetwork, @@ -51,22 +46,22 @@ def _default_feat_args(dims_or_cardinalities: List[int]): class TemporalFusionTransformerEstimator(PyTorchEstimator): @validated() def __init__( - self, - freq: str, - prediction_length: int, - context_length: Optional[int] = None, - dropout_rate: float = 0.1, - embed_dim: int = 32, - num_heads: int = 4, - num_outputs: int = 3, - variable_dim: Optional[int] = None, - time_features: List[TimeFeature] = [], - static_cardinalities: Dict[str, int] = {}, - dynamic_cardinalities: Dict[str, int] = {}, - static_feature_dims: Dict[str, int] = {}, - dynamic_feature_dims: Dict[str, int] = {}, - past_dynamic_features: List[str] = [], - trainer: Trainer = Trainer(), + self, + freq: str, + prediction_length: int, + context_length: Optional[int] = None, + dropout_rate: float = 0.1, + embed_dim: int = 32, + num_heads: int = 4, + num_outputs: int = 3, + variable_dim: Optional[int] = None, + time_features: List[TimeFeature] = [], + static_cardinalities: Dict[str, int] = {}, + dynamic_cardinalities: Dict[str, int] = {}, + static_feature_dims: Dict[str, int] = {}, + dynamic_feature_dims: Dict[str, int] = {}, + past_dynamic_features: List[str] = [], + trainer: Trainer = Trainer(), ) -> None: super().__init__(trainer=trainer) @@ -116,53 +111,56 @@ class TemporalFusionTransformerEstimator(PyTorchEstimator): def create_transformation(self) -> Transformation: transforms = ( - [AsNumpyArray(field=FieldName.TARGET, expected_ndim=1)] - + ( - [ - AsNumpyArray(field=name, expected_ndim=1) - for name in self.static_cardinalities.keys() - ] - ) - + [ - AsNumpyArray(field=name, expected_ndim=1) - for name in chain( - self.static_feature_dims.keys(), - self.dynamic_cardinalities.keys(), + [AsNumpyArray(field=FieldName.TARGET, expected_ndim=1)] + + ( + [ + AsNumpyArray(field=name, expected_ndim=1) + for name in self.static_cardinalities.keys() + ] ) - ] - + [ - AsNumpyArray(field=name, expected_ndim=2) - for name in self.dynamic_feature_dims.keys() - ] - + [ - AddObservedValuesIndicator( - target_field=FieldName.TARGET, - output_field=FieldName.OBSERVED_VALUES, - ), - AddTimeFeatures( - start_field=FieldName.START, - target_field=FieldName.TARGET, - output_field=FieldName.FEAT_TIME, - time_features=self.time_features, - pred_length=self.prediction_length, - ), - AddAgeFeature( - target_field=FieldName.TARGET, - output_field=FieldName.FEAT_AGE, - pred_length=self.prediction_length, - log_scale=True, - ), - ] + + [ + AsNumpyArray(field=name, expected_ndim=1) + for name in chain( + self.static_feature_dims.keys(), + self.dynamic_cardinalities.keys(), + ) + ] + + [ + AsNumpyArray(field=name, expected_ndim=2) + for name in self.dynamic_feature_dims.keys() + ] + + [ + AddObservedValuesIndicator( + target_field=FieldName.TARGET, + output_field=FieldName.OBSERVED_VALUES, + ), + AddTimeFeatures( + start_field=FieldName.START, + target_field=FieldName.TARGET, + output_field=FieldName.FEAT_TIME, + time_features=self.time_features, + pred_length=self.prediction_length, + ), + AddAgeFeature( + target_field=FieldName.TARGET, + output_field=FieldName.FEAT_AGE, + pred_length=self.prediction_length, + log_scale=True, + ), + ] ) if self.static_cardinalities: - transforms.append( + transforms.extend([ VstackFeatures( output_field=FieldName.FEAT_STATIC_CAT, input_fields=list(self.static_cardinalities.keys()), h_stack=True, + ), + AsNumpyArray( + field=FieldName.FEAT_STATIC_CAT, expected_ndim=1, dtype=np.long ) - ) + ]) else: transforms.extend( [ @@ -196,12 +194,17 @@ class TemporalFusionTransformerEstimator(PyTorchEstimator): ) if self.dynamic_cardinalities: - transforms.append( + transforms.extend([ VstackFeatures( output_field=FieldName.FEAT_DYNAMIC_CAT, input_fields=list(self.dynamic_cardinalities.keys()), + ), + AsNumpyArray( + field=FieldName.FEAT_DYNAMIC_CAT, + expected_ndim=2, + dtype=np.long, ) - ) + ]) else: transforms.extend( [ @@ -232,12 +235,17 @@ class TemporalFusionTransformerEstimator(PyTorchEstimator): ) if self.past_dynamic_cardinalities: - transforms.append( + transforms.extend([ VstackFeatures( output_field=FieldName.PAST_FEAT_DYNAMIC + "_cat", input_fields=list(self.past_dynamic_cardinalities.keys()), + ), + AsNumpyArray( + field=FieldName.PAST_FEAT_DYNAMIC + "_cat", + expected_ndim=2, + dtype=np.long, ) - ) + ]) else: transforms.extend( [ @@ -301,7 +309,7 @@ class TemporalFusionTransformerEstimator(PyTorchEstimator): ) def create_training_network( - self, device: torch.device + self, device: torch.device ) -> TemporalFusionTransformerTrainingNetwork: network = TemporalFusionTransformerTrainingNetwork( context_length=self.context_length, @@ -317,8 +325,9 @@ class TemporalFusionTransformerEstimator(PyTorchEstimator): c_past_feat_dynamic_cat=_default_feat_args( list(self.past_dynamic_cardinalities.values()) ), + # +1 is for Age Feature d_feat_dynamic_real=_default_feat_args( - [1] * len(self.time_features) + list(self.dynamic_feature_dims.values()) + [1] * (len(self.time_features) + 1) + list(self.dynamic_feature_dims.values()) ), c_feat_dynamic_cat=_default_feat_args( list(self.dynamic_cardinalities.values()) @@ -333,10 +342,10 @@ class TemporalFusionTransformerEstimator(PyTorchEstimator): return network.to(device) def create_predictor( - self, - transformation: Transformation, - trained_network: TemporalFusionTransformerTrainingNetwork, - device: torch.device, + self, + transformation: Transformation, + trained_network: TemporalFusionTransformerTrainingNetwork, + device: torch.device, ) -> Predictor: prediction_network = TemporalFusionTransformerPredictionNetwork( @@ -353,8 +362,9 @@ class TemporalFusionTransformerEstimator(PyTorchEstimator): c_past_feat_dynamic_cat=_default_feat_args( list(self.past_dynamic_cardinalities.values()) ), + # +1 is for Age Feature d_feat_dynamic_real=_default_feat_args( - [1] * len(self.time_features) + list(self.dynamic_feature_dims.values()) + [1] * (len(self.time_features) + 1) + list(self.dynamic_feature_dims.values()) ), c_feat_dynamic_cat=_default_feat_args( list(self.dynamic_cardinalities.values()) diff --git a/pts/model/tft/tft_modules.py b/pts/model/tft/tft_modules.py index abec6d7..9caf759 100644 --- a/pts/model/tft/tft_modules.py +++ b/pts/model/tft/tft_modules.py @@ -18,11 +18,12 @@ class FeatureProjector(nn.Module): self.__num_features = len(feature_dims) if self.__num_features > 1: - self.feature_dims = ( + self.feature_slices = ( feature_dims[0:1] + np.cumsum(feature_dims)[:-1].tolist() ) else: - self.feature_dims = feature_dims + self.feature_slices = feature_dims + self.feature_dims = feature_dims self._projector = nn.ModuleList( [ @@ -34,7 +35,7 @@ class FeatureProjector(nn.Module): def forward(self, features: torch.Tensor) -> List[torch.Tensor]: if self.__num_features > 1: real_feature_slices = torch.tensor_split( - features, self.feature_dims[1:], dim=-1 + features, self.feature_slices[1:], dim=-1 ) else: real_feature_slices = [features]