mirror of
https://github.com/wassname/pytorch-ts.git
synced 2026-06-27 18:06:19 +08:00
TFT fixes (#55)
* fixies for running tft * fix test set configs * added AgeFeature back to the pipeline * fixing AgeFeature * revert setups changes Co-authored-by: alex sliz-nagy <alex.sliz.nagy@blackswan.com>
This commit is contained in:
+1
-1
@@ -9,4 +9,4 @@ __path__ = extend_path(__path__, __name__) # type: ignore
|
||||
try:
|
||||
__version__ = get_distribution(__name__).version
|
||||
except DistributionNotFound:
|
||||
__version__ = "0.0.0-unknown"
|
||||
__version__ = "0.0.0-unknown"
|
||||
@@ -1,22 +1,18 @@
|
||||
from typing import List, Optional, Dict
|
||||
from itertools import chain
|
||||
from typing import List, Optional, Dict
|
||||
|
||||
import numpy as np
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
|
||||
from gluonts.core.component import validated
|
||||
from gluonts.dataset.field_names import FieldName
|
||||
from gluonts.model.forecast_generator import QuantileForecastGenerator
|
||||
from gluonts.model.predictor import Predictor
|
||||
from gluonts.time_feature import (
|
||||
TimeFeature,
|
||||
get_lags_for_frequency,
|
||||
time_features_from_frequency_str,
|
||||
)
|
||||
from gluonts.torch.modules.distribution_output import DistributionOutput
|
||||
from gluonts.torch.support.util import copy_parameters
|
||||
from gluonts.torch.model.predictor import PyTorchPredictor
|
||||
from gluonts.model.predictor import Predictor
|
||||
from gluonts.model.forecast_generator import QuantileForecastGenerator
|
||||
from gluonts.torch.support.util import copy_parameters
|
||||
from gluonts.transform import (
|
||||
Transformation,
|
||||
Chain,
|
||||
@@ -32,9 +28,8 @@ from gluonts.transform import (
|
||||
)
|
||||
|
||||
from pts import Trainer
|
||||
from pts.model.utils import get_module_forward_input_names
|
||||
from pts.model import PyTorchEstimator
|
||||
|
||||
from pts.model.utils import get_module_forward_input_names
|
||||
from .tft_network import (
|
||||
TemporalFusionTransformerPredictionNetwork,
|
||||
TemporalFusionTransformerTrainingNetwork,
|
||||
@@ -51,22 +46,22 @@ def _default_feat_args(dims_or_cardinalities: List[int]):
|
||||
class TemporalFusionTransformerEstimator(PyTorchEstimator):
|
||||
@validated()
|
||||
def __init__(
|
||||
self,
|
||||
freq: str,
|
||||
prediction_length: int,
|
||||
context_length: Optional[int] = None,
|
||||
dropout_rate: float = 0.1,
|
||||
embed_dim: int = 32,
|
||||
num_heads: int = 4,
|
||||
num_outputs: int = 3,
|
||||
variable_dim: Optional[int] = None,
|
||||
time_features: List[TimeFeature] = [],
|
||||
static_cardinalities: Dict[str, int] = {},
|
||||
dynamic_cardinalities: Dict[str, int] = {},
|
||||
static_feature_dims: Dict[str, int] = {},
|
||||
dynamic_feature_dims: Dict[str, int] = {},
|
||||
past_dynamic_features: List[str] = [],
|
||||
trainer: Trainer = Trainer(),
|
||||
self,
|
||||
freq: str,
|
||||
prediction_length: int,
|
||||
context_length: Optional[int] = None,
|
||||
dropout_rate: float = 0.1,
|
||||
embed_dim: int = 32,
|
||||
num_heads: int = 4,
|
||||
num_outputs: int = 3,
|
||||
variable_dim: Optional[int] = None,
|
||||
time_features: List[TimeFeature] = [],
|
||||
static_cardinalities: Dict[str, int] = {},
|
||||
dynamic_cardinalities: Dict[str, int] = {},
|
||||
static_feature_dims: Dict[str, int] = {},
|
||||
dynamic_feature_dims: Dict[str, int] = {},
|
||||
past_dynamic_features: List[str] = [],
|
||||
trainer: Trainer = Trainer(),
|
||||
) -> None:
|
||||
super().__init__(trainer=trainer)
|
||||
|
||||
@@ -116,53 +111,56 @@ class TemporalFusionTransformerEstimator(PyTorchEstimator):
|
||||
|
||||
def create_transformation(self) -> Transformation:
|
||||
transforms = (
|
||||
[AsNumpyArray(field=FieldName.TARGET, expected_ndim=1)]
|
||||
+ (
|
||||
[
|
||||
AsNumpyArray(field=name, expected_ndim=1)
|
||||
for name in self.static_cardinalities.keys()
|
||||
]
|
||||
)
|
||||
+ [
|
||||
AsNumpyArray(field=name, expected_ndim=1)
|
||||
for name in chain(
|
||||
self.static_feature_dims.keys(),
|
||||
self.dynamic_cardinalities.keys(),
|
||||
[AsNumpyArray(field=FieldName.TARGET, expected_ndim=1)]
|
||||
+ (
|
||||
[
|
||||
AsNumpyArray(field=name, expected_ndim=1)
|
||||
for name in self.static_cardinalities.keys()
|
||||
]
|
||||
)
|
||||
]
|
||||
+ [
|
||||
AsNumpyArray(field=name, expected_ndim=2)
|
||||
for name in self.dynamic_feature_dims.keys()
|
||||
]
|
||||
+ [
|
||||
AddObservedValuesIndicator(
|
||||
target_field=FieldName.TARGET,
|
||||
output_field=FieldName.OBSERVED_VALUES,
|
||||
),
|
||||
AddTimeFeatures(
|
||||
start_field=FieldName.START,
|
||||
target_field=FieldName.TARGET,
|
||||
output_field=FieldName.FEAT_TIME,
|
||||
time_features=self.time_features,
|
||||
pred_length=self.prediction_length,
|
||||
),
|
||||
AddAgeFeature(
|
||||
target_field=FieldName.TARGET,
|
||||
output_field=FieldName.FEAT_AGE,
|
||||
pred_length=self.prediction_length,
|
||||
log_scale=True,
|
||||
),
|
||||
]
|
||||
+ [
|
||||
AsNumpyArray(field=name, expected_ndim=1)
|
||||
for name in chain(
|
||||
self.static_feature_dims.keys(),
|
||||
self.dynamic_cardinalities.keys(),
|
||||
)
|
||||
]
|
||||
+ [
|
||||
AsNumpyArray(field=name, expected_ndim=2)
|
||||
for name in self.dynamic_feature_dims.keys()
|
||||
]
|
||||
+ [
|
||||
AddObservedValuesIndicator(
|
||||
target_field=FieldName.TARGET,
|
||||
output_field=FieldName.OBSERVED_VALUES,
|
||||
),
|
||||
AddTimeFeatures(
|
||||
start_field=FieldName.START,
|
||||
target_field=FieldName.TARGET,
|
||||
output_field=FieldName.FEAT_TIME,
|
||||
time_features=self.time_features,
|
||||
pred_length=self.prediction_length,
|
||||
),
|
||||
AddAgeFeature(
|
||||
target_field=FieldName.TARGET,
|
||||
output_field=FieldName.FEAT_AGE,
|
||||
pred_length=self.prediction_length,
|
||||
log_scale=True,
|
||||
),
|
||||
]
|
||||
)
|
||||
|
||||
if self.static_cardinalities:
|
||||
transforms.append(
|
||||
transforms.extend([
|
||||
VstackFeatures(
|
||||
output_field=FieldName.FEAT_STATIC_CAT,
|
||||
input_fields=list(self.static_cardinalities.keys()),
|
||||
h_stack=True,
|
||||
),
|
||||
AsNumpyArray(
|
||||
field=FieldName.FEAT_STATIC_CAT, expected_ndim=1, dtype=np.long
|
||||
)
|
||||
)
|
||||
])
|
||||
else:
|
||||
transforms.extend(
|
||||
[
|
||||
@@ -196,12 +194,17 @@ class TemporalFusionTransformerEstimator(PyTorchEstimator):
|
||||
)
|
||||
|
||||
if self.dynamic_cardinalities:
|
||||
transforms.append(
|
||||
transforms.extend([
|
||||
VstackFeatures(
|
||||
output_field=FieldName.FEAT_DYNAMIC_CAT,
|
||||
input_fields=list(self.dynamic_cardinalities.keys()),
|
||||
),
|
||||
AsNumpyArray(
|
||||
field=FieldName.FEAT_DYNAMIC_CAT,
|
||||
expected_ndim=2,
|
||||
dtype=np.long,
|
||||
)
|
||||
)
|
||||
])
|
||||
else:
|
||||
transforms.extend(
|
||||
[
|
||||
@@ -232,12 +235,17 @@ class TemporalFusionTransformerEstimator(PyTorchEstimator):
|
||||
)
|
||||
|
||||
if self.past_dynamic_cardinalities:
|
||||
transforms.append(
|
||||
transforms.extend([
|
||||
VstackFeatures(
|
||||
output_field=FieldName.PAST_FEAT_DYNAMIC + "_cat",
|
||||
input_fields=list(self.past_dynamic_cardinalities.keys()),
|
||||
),
|
||||
AsNumpyArray(
|
||||
field=FieldName.PAST_FEAT_DYNAMIC + "_cat",
|
||||
expected_ndim=2,
|
||||
dtype=np.long,
|
||||
)
|
||||
)
|
||||
])
|
||||
else:
|
||||
transforms.extend(
|
||||
[
|
||||
@@ -301,7 +309,7 @@ class TemporalFusionTransformerEstimator(PyTorchEstimator):
|
||||
)
|
||||
|
||||
def create_training_network(
|
||||
self, device: torch.device
|
||||
self, device: torch.device
|
||||
) -> TemporalFusionTransformerTrainingNetwork:
|
||||
network = TemporalFusionTransformerTrainingNetwork(
|
||||
context_length=self.context_length,
|
||||
@@ -317,8 +325,9 @@ class TemporalFusionTransformerEstimator(PyTorchEstimator):
|
||||
c_past_feat_dynamic_cat=_default_feat_args(
|
||||
list(self.past_dynamic_cardinalities.values())
|
||||
),
|
||||
# +1 is for Age Feature
|
||||
d_feat_dynamic_real=_default_feat_args(
|
||||
[1] * len(self.time_features) + list(self.dynamic_feature_dims.values())
|
||||
[1] * (len(self.time_features) + 1) + list(self.dynamic_feature_dims.values())
|
||||
),
|
||||
c_feat_dynamic_cat=_default_feat_args(
|
||||
list(self.dynamic_cardinalities.values())
|
||||
@@ -333,10 +342,10 @@ class TemporalFusionTransformerEstimator(PyTorchEstimator):
|
||||
return network.to(device)
|
||||
|
||||
def create_predictor(
|
||||
self,
|
||||
transformation: Transformation,
|
||||
trained_network: TemporalFusionTransformerTrainingNetwork,
|
||||
device: torch.device,
|
||||
self,
|
||||
transformation: Transformation,
|
||||
trained_network: TemporalFusionTransformerTrainingNetwork,
|
||||
device: torch.device,
|
||||
) -> Predictor:
|
||||
|
||||
prediction_network = TemporalFusionTransformerPredictionNetwork(
|
||||
@@ -353,8 +362,9 @@ class TemporalFusionTransformerEstimator(PyTorchEstimator):
|
||||
c_past_feat_dynamic_cat=_default_feat_args(
|
||||
list(self.past_dynamic_cardinalities.values())
|
||||
),
|
||||
# +1 is for Age Feature
|
||||
d_feat_dynamic_real=_default_feat_args(
|
||||
[1] * len(self.time_features) + list(self.dynamic_feature_dims.values())
|
||||
[1] * (len(self.time_features) + 1) + list(self.dynamic_feature_dims.values())
|
||||
),
|
||||
c_feat_dynamic_cat=_default_feat_args(
|
||||
list(self.dynamic_cardinalities.values())
|
||||
|
||||
@@ -18,11 +18,12 @@ class FeatureProjector(nn.Module):
|
||||
|
||||
self.__num_features = len(feature_dims)
|
||||
if self.__num_features > 1:
|
||||
self.feature_dims = (
|
||||
self.feature_slices = (
|
||||
feature_dims[0:1] + np.cumsum(feature_dims)[:-1].tolist()
|
||||
)
|
||||
else:
|
||||
self.feature_dims = feature_dims
|
||||
self.feature_slices = feature_dims
|
||||
self.feature_dims = feature_dims
|
||||
|
||||
self._projector = nn.ModuleList(
|
||||
[
|
||||
@@ -34,7 +35,7 @@ class FeatureProjector(nn.Module):
|
||||
def forward(self, features: torch.Tensor) -> List[torch.Tensor]:
|
||||
if self.__num_features > 1:
|
||||
real_feature_slices = torch.tensor_split(
|
||||
features, self.feature_dims[1:], dim=-1
|
||||
features, self.feature_slices[1:], dim=-1
|
||||
)
|
||||
else:
|
||||
real_feature_slices = [features]
|
||||
|
||||
Reference in New Issue
Block a user