TFT fixes (#55)

* fixies for running tft

* fix test set configs

* added AgeFeature back to the pipeline

* fixing AgeFeature

* revert setups changes

Co-authored-by: alex sliz-nagy <alex.sliz.nagy@blackswan.com>
This commit is contained in:
aslinagy
2021-06-08 13:23:22 +02:00
committed by GitHub
parent 7d82355470
commit 5039dcf6ea
3 changed files with 90 additions and 79 deletions
+1 -1
View File
@@ -9,4 +9,4 @@ __path__ = extend_path(__path__, __name__) # type: ignore
try:
__version__ = get_distribution(__name__).version
except DistributionNotFound:
__version__ = "0.0.0-unknown"
__version__ = "0.0.0-unknown"
+85 -75
View File
@@ -1,22 +1,18 @@
from typing import List, Optional, Dict
from itertools import chain
from typing import List, Optional, Dict
import numpy as np
import torch
import torch.nn as nn
from gluonts.core.component import validated
from gluonts.dataset.field_names import FieldName
from gluonts.model.forecast_generator import QuantileForecastGenerator
from gluonts.model.predictor import Predictor
from gluonts.time_feature import (
TimeFeature,
get_lags_for_frequency,
time_features_from_frequency_str,
)
from gluonts.torch.modules.distribution_output import DistributionOutput
from gluonts.torch.support.util import copy_parameters
from gluonts.torch.model.predictor import PyTorchPredictor
from gluonts.model.predictor import Predictor
from gluonts.model.forecast_generator import QuantileForecastGenerator
from gluonts.torch.support.util import copy_parameters
from gluonts.transform import (
Transformation,
Chain,
@@ -32,9 +28,8 @@ from gluonts.transform import (
)
from pts import Trainer
from pts.model.utils import get_module_forward_input_names
from pts.model import PyTorchEstimator
from pts.model.utils import get_module_forward_input_names
from .tft_network import (
TemporalFusionTransformerPredictionNetwork,
TemporalFusionTransformerTrainingNetwork,
@@ -51,22 +46,22 @@ def _default_feat_args(dims_or_cardinalities: List[int]):
class TemporalFusionTransformerEstimator(PyTorchEstimator):
@validated()
def __init__(
self,
freq: str,
prediction_length: int,
context_length: Optional[int] = None,
dropout_rate: float = 0.1,
embed_dim: int = 32,
num_heads: int = 4,
num_outputs: int = 3,
variable_dim: Optional[int] = None,
time_features: List[TimeFeature] = [],
static_cardinalities: Dict[str, int] = {},
dynamic_cardinalities: Dict[str, int] = {},
static_feature_dims: Dict[str, int] = {},
dynamic_feature_dims: Dict[str, int] = {},
past_dynamic_features: List[str] = [],
trainer: Trainer = Trainer(),
self,
freq: str,
prediction_length: int,
context_length: Optional[int] = None,
dropout_rate: float = 0.1,
embed_dim: int = 32,
num_heads: int = 4,
num_outputs: int = 3,
variable_dim: Optional[int] = None,
time_features: List[TimeFeature] = [],
static_cardinalities: Dict[str, int] = {},
dynamic_cardinalities: Dict[str, int] = {},
static_feature_dims: Dict[str, int] = {},
dynamic_feature_dims: Dict[str, int] = {},
past_dynamic_features: List[str] = [],
trainer: Trainer = Trainer(),
) -> None:
super().__init__(trainer=trainer)
@@ -116,53 +111,56 @@ class TemporalFusionTransformerEstimator(PyTorchEstimator):
def create_transformation(self) -> Transformation:
transforms = (
[AsNumpyArray(field=FieldName.TARGET, expected_ndim=1)]
+ (
[
AsNumpyArray(field=name, expected_ndim=1)
for name in self.static_cardinalities.keys()
]
)
+ [
AsNumpyArray(field=name, expected_ndim=1)
for name in chain(
self.static_feature_dims.keys(),
self.dynamic_cardinalities.keys(),
[AsNumpyArray(field=FieldName.TARGET, expected_ndim=1)]
+ (
[
AsNumpyArray(field=name, expected_ndim=1)
for name in self.static_cardinalities.keys()
]
)
]
+ [
AsNumpyArray(field=name, expected_ndim=2)
for name in self.dynamic_feature_dims.keys()
]
+ [
AddObservedValuesIndicator(
target_field=FieldName.TARGET,
output_field=FieldName.OBSERVED_VALUES,
),
AddTimeFeatures(
start_field=FieldName.START,
target_field=FieldName.TARGET,
output_field=FieldName.FEAT_TIME,
time_features=self.time_features,
pred_length=self.prediction_length,
),
AddAgeFeature(
target_field=FieldName.TARGET,
output_field=FieldName.FEAT_AGE,
pred_length=self.prediction_length,
log_scale=True,
),
]
+ [
AsNumpyArray(field=name, expected_ndim=1)
for name in chain(
self.static_feature_dims.keys(),
self.dynamic_cardinalities.keys(),
)
]
+ [
AsNumpyArray(field=name, expected_ndim=2)
for name in self.dynamic_feature_dims.keys()
]
+ [
AddObservedValuesIndicator(
target_field=FieldName.TARGET,
output_field=FieldName.OBSERVED_VALUES,
),
AddTimeFeatures(
start_field=FieldName.START,
target_field=FieldName.TARGET,
output_field=FieldName.FEAT_TIME,
time_features=self.time_features,
pred_length=self.prediction_length,
),
AddAgeFeature(
target_field=FieldName.TARGET,
output_field=FieldName.FEAT_AGE,
pred_length=self.prediction_length,
log_scale=True,
),
]
)
if self.static_cardinalities:
transforms.append(
transforms.extend([
VstackFeatures(
output_field=FieldName.FEAT_STATIC_CAT,
input_fields=list(self.static_cardinalities.keys()),
h_stack=True,
),
AsNumpyArray(
field=FieldName.FEAT_STATIC_CAT, expected_ndim=1, dtype=np.long
)
)
])
else:
transforms.extend(
[
@@ -196,12 +194,17 @@ class TemporalFusionTransformerEstimator(PyTorchEstimator):
)
if self.dynamic_cardinalities:
transforms.append(
transforms.extend([
VstackFeatures(
output_field=FieldName.FEAT_DYNAMIC_CAT,
input_fields=list(self.dynamic_cardinalities.keys()),
),
AsNumpyArray(
field=FieldName.FEAT_DYNAMIC_CAT,
expected_ndim=2,
dtype=np.long,
)
)
])
else:
transforms.extend(
[
@@ -232,12 +235,17 @@ class TemporalFusionTransformerEstimator(PyTorchEstimator):
)
if self.past_dynamic_cardinalities:
transforms.append(
transforms.extend([
VstackFeatures(
output_field=FieldName.PAST_FEAT_DYNAMIC + "_cat",
input_fields=list(self.past_dynamic_cardinalities.keys()),
),
AsNumpyArray(
field=FieldName.PAST_FEAT_DYNAMIC + "_cat",
expected_ndim=2,
dtype=np.long,
)
)
])
else:
transforms.extend(
[
@@ -301,7 +309,7 @@ class TemporalFusionTransformerEstimator(PyTorchEstimator):
)
def create_training_network(
self, device: torch.device
self, device: torch.device
) -> TemporalFusionTransformerTrainingNetwork:
network = TemporalFusionTransformerTrainingNetwork(
context_length=self.context_length,
@@ -317,8 +325,9 @@ class TemporalFusionTransformerEstimator(PyTorchEstimator):
c_past_feat_dynamic_cat=_default_feat_args(
list(self.past_dynamic_cardinalities.values())
),
# +1 is for Age Feature
d_feat_dynamic_real=_default_feat_args(
[1] * len(self.time_features) + list(self.dynamic_feature_dims.values())
[1] * (len(self.time_features) + 1) + list(self.dynamic_feature_dims.values())
),
c_feat_dynamic_cat=_default_feat_args(
list(self.dynamic_cardinalities.values())
@@ -333,10 +342,10 @@ class TemporalFusionTransformerEstimator(PyTorchEstimator):
return network.to(device)
def create_predictor(
self,
transformation: Transformation,
trained_network: TemporalFusionTransformerTrainingNetwork,
device: torch.device,
self,
transformation: Transformation,
trained_network: TemporalFusionTransformerTrainingNetwork,
device: torch.device,
) -> Predictor:
prediction_network = TemporalFusionTransformerPredictionNetwork(
@@ -353,8 +362,9 @@ class TemporalFusionTransformerEstimator(PyTorchEstimator):
c_past_feat_dynamic_cat=_default_feat_args(
list(self.past_dynamic_cardinalities.values())
),
# +1 is for Age Feature
d_feat_dynamic_real=_default_feat_args(
[1] * len(self.time_features) + list(self.dynamic_feature_dims.values())
[1] * (len(self.time_features) + 1) + list(self.dynamic_feature_dims.values())
),
c_feat_dynamic_cat=_default_feat_args(
list(self.dynamic_cardinalities.values())
+4 -3
View File
@@ -18,11 +18,12 @@ class FeatureProjector(nn.Module):
self.__num_features = len(feature_dims)
if self.__num_features > 1:
self.feature_dims = (
self.feature_slices = (
feature_dims[0:1] + np.cumsum(feature_dims)[:-1].tolist()
)
else:
self.feature_dims = feature_dims
self.feature_slices = feature_dims
self.feature_dims = feature_dims
self._projector = nn.ModuleList(
[
@@ -34,7 +35,7 @@ class FeatureProjector(nn.Module):
def forward(self, features: torch.Tensor) -> List[torch.Tensor]:
if self.__num_features > 1:
real_feature_slices = torch.tensor_split(
features, self.feature_dims[1:], dim=-1
features, self.feature_slices[1:], dim=-1
)
else:
real_feature_slices = [features]