formatting

This commit is contained in:
Dr. Kashif Rasul
2020-01-28 14:01:48 +01:00
parent 7bede84f45
commit ad482cb526
+17 -8
View File
@@ -34,7 +34,10 @@ from pts.feature import (
get_lags_for_frequency,
)
from .transformer_network import TransformerTrainingNetwork, TransformerPredictionNetwork
from .transformer_network import (
TransformerTrainingNetwork,
TransformerPredictionNetwork,
)
class TransformerEstimator(PTSEstimator):
@@ -80,8 +83,7 @@ class TransformerEstimator(PTSEstimator):
self.embedding_dimension = embedding_dimension
self.num_parallel_samples = num_parallel_samples
self.lags_seq = (
lags_seq if lags_seq is not None else get_lags_for_frequency(
freq_str=freq)
lags_seq if lags_seq is not None else get_lags_for_frequency(freq_str=freq)
)
self.time_features = (
time_features
@@ -112,14 +114,16 @@ class TransformerEstimator(PTSEstimator):
[SetField(output_field=FieldName.FEAT_STATIC_CAT, value=[0])]
if not self.use_feat_static_cat
else []
) + (
)
+ (
[SetField(output_field=FieldName.FEAT_STATIC_REAL, value=[0.0])]
if not self.use_feat_static_real
else []
)
+ [
AsNumpyArray(field=FieldName.FEAT_STATIC_CAT,
expected_ndim=1, dtype=np.long),
AsNumpyArray(
field=FieldName.FEAT_STATIC_CAT, expected_ndim=1, dtype=np.long
),
AsNumpyArray(
field=FieldName.FEAT_STATIC_REAL, expected_ndim=1, dtype=self.dtype,
),
@@ -170,7 +174,9 @@ class TransformerEstimator(PTSEstimator):
]
)
def create_training_network(self, device: torch.device) -> TransformerTrainingNetwork:
def create_training_network(
self, device: torch.device
) -> TransformerTrainingNetwork:
training_network = TransformerTrainingNetwork(
input_size=self.input_size,
@@ -194,7 +200,10 @@ class TransformerEstimator(PTSEstimator):
return training_network
def create_predictor(
self, transformation: Transformation, trained_network: nn.Module, device: torch.device,
self,
transformation: Transformation,
trained_network: nn.Module,
device: torch.device,
) -> Predictor:
prediction_network = TransformerPredictionNetwork(