diff --git a/pts/model/transformer/transformer_estimator.py b/pts/model/transformer/transformer_estimator.py index 2a86d36..91ef17f 100644 --- a/pts/model/transformer/transformer_estimator.py +++ b/pts/model/transformer/transformer_estimator.py @@ -34,7 +34,10 @@ from pts.feature import ( get_lags_for_frequency, ) -from .transformer_network import TransformerTrainingNetwork, TransformerPredictionNetwork +from .transformer_network import ( + TransformerTrainingNetwork, + TransformerPredictionNetwork, +) class TransformerEstimator(PTSEstimator): @@ -80,8 +83,7 @@ class TransformerEstimator(PTSEstimator): self.embedding_dimension = embedding_dimension self.num_parallel_samples = num_parallel_samples self.lags_seq = ( - lags_seq if lags_seq is not None else get_lags_for_frequency( - freq_str=freq) + lags_seq if lags_seq is not None else get_lags_for_frequency(freq_str=freq) ) self.time_features = ( time_features @@ -112,14 +114,16 @@ class TransformerEstimator(PTSEstimator): [SetField(output_field=FieldName.FEAT_STATIC_CAT, value=[0])] if not self.use_feat_static_cat else [] - ) + ( + ) + + ( [SetField(output_field=FieldName.FEAT_STATIC_REAL, value=[0.0])] if not self.use_feat_static_real else [] ) + [ - AsNumpyArray(field=FieldName.FEAT_STATIC_CAT, - expected_ndim=1, dtype=np.long), + AsNumpyArray( + field=FieldName.FEAT_STATIC_CAT, expected_ndim=1, dtype=np.long + ), AsNumpyArray( field=FieldName.FEAT_STATIC_REAL, expected_ndim=1, dtype=self.dtype, ), @@ -170,7 +174,9 @@ class TransformerEstimator(PTSEstimator): ] ) - def create_training_network(self, device: torch.device) -> TransformerTrainingNetwork: + def create_training_network( + self, device: torch.device + ) -> TransformerTrainingNetwork: training_network = TransformerTrainingNetwork( input_size=self.input_size, @@ -194,7 +200,10 @@ class TransformerEstimator(PTSEstimator): return training_network def create_predictor( - self, transformation: Transformation, trained_network: nn.Module, device: torch.device, + self, + transformation: Transformation, + trained_network: nn.Module, + device: torch.device, ) -> Predictor: prediction_network = TransformerPredictionNetwork(