removed embedding_dimension

This commit is contained in:
Kashif Rasul
2022-03-30 16:51:44 +02:00
parent 7e25ade272
commit b805896fcf
-3
View File
@@ -64,7 +64,6 @@ class TFTEstimator(PyTorchLightningEstimator):
num_feat_static_cat: int = 0,
num_feat_static_real: int = 0,
cardinality: Optional[List[int]] = None,
embedding_dimension: Optional[List[int]] = None,
distr_output: DistributionOutput = StudentTOutput(),
loss: DistributionLoss = NegativeLogLikelihood(),
scaling: bool = True,
@@ -104,7 +103,6 @@ class TFTEstimator(PyTorchLightningEstimator):
self.cardinality = (
cardinality if cardinality and num_feat_static_cat > 0 else [1]
)
self.embedding_dimension = embedding_dimension
self.scaling = scaling
self.lags_seq = lags_seq
self.time_features = (
@@ -287,7 +285,6 @@ class TFTEstimator(PyTorchLightningEstimator):
num_feat_static_real=max(1, self.num_feat_static_real),
num_feat_static_cat=max(1, self.num_feat_static_cat),
cardinality=self.cardinality,
embedding_dimension=self.embedding_dimension,
# transformer arguments
num_heads=self.num_heads,
dropout=self.dropout,