mirror of
https://github.com/wassname/pytorch-ts.git
synced 2026-07-05 03:20:55 +08:00
formatting
This commit is contained in:
@@ -34,7 +34,10 @@ from pts.feature import (
|
||||
get_lags_for_frequency,
|
||||
)
|
||||
|
||||
from .transformer_network import TransformerTrainingNetwork, TransformerPredictionNetwork
|
||||
from .transformer_network import (
|
||||
TransformerTrainingNetwork,
|
||||
TransformerPredictionNetwork,
|
||||
)
|
||||
|
||||
|
||||
class TransformerEstimator(PTSEstimator):
|
||||
@@ -80,8 +83,7 @@ class TransformerEstimator(PTSEstimator):
|
||||
self.embedding_dimension = embedding_dimension
|
||||
self.num_parallel_samples = num_parallel_samples
|
||||
self.lags_seq = (
|
||||
lags_seq if lags_seq is not None else get_lags_for_frequency(
|
||||
freq_str=freq)
|
||||
lags_seq if lags_seq is not None else get_lags_for_frequency(freq_str=freq)
|
||||
)
|
||||
self.time_features = (
|
||||
time_features
|
||||
@@ -112,14 +114,16 @@ class TransformerEstimator(PTSEstimator):
|
||||
[SetField(output_field=FieldName.FEAT_STATIC_CAT, value=[0])]
|
||||
if not self.use_feat_static_cat
|
||||
else []
|
||||
) + (
|
||||
)
|
||||
+ (
|
||||
[SetField(output_field=FieldName.FEAT_STATIC_REAL, value=[0.0])]
|
||||
if not self.use_feat_static_real
|
||||
else []
|
||||
)
|
||||
+ [
|
||||
AsNumpyArray(field=FieldName.FEAT_STATIC_CAT,
|
||||
expected_ndim=1, dtype=np.long),
|
||||
AsNumpyArray(
|
||||
field=FieldName.FEAT_STATIC_CAT, expected_ndim=1, dtype=np.long
|
||||
),
|
||||
AsNumpyArray(
|
||||
field=FieldName.FEAT_STATIC_REAL, expected_ndim=1, dtype=self.dtype,
|
||||
),
|
||||
@@ -170,7 +174,9 @@ class TransformerEstimator(PTSEstimator):
|
||||
]
|
||||
)
|
||||
|
||||
def create_training_network(self, device: torch.device) -> TransformerTrainingNetwork:
|
||||
def create_training_network(
|
||||
self, device: torch.device
|
||||
) -> TransformerTrainingNetwork:
|
||||
|
||||
training_network = TransformerTrainingNetwork(
|
||||
input_size=self.input_size,
|
||||
@@ -194,7 +200,10 @@ class TransformerEstimator(PTSEstimator):
|
||||
return training_network
|
||||
|
||||
def create_predictor(
|
||||
self, transformation: Transformation, trained_network: nn.Module, device: torch.device,
|
||||
self,
|
||||
transformation: Transformation,
|
||||
trained_network: nn.Module,
|
||||
device: torch.device,
|
||||
) -> Predictor:
|
||||
|
||||
prediction_network = TransformerPredictionNetwork(
|
||||
|
||||
Reference in New Issue
Block a user