mirror of
https://github.com/wassname/pytorch-transformer-ts.git
synced 2026-06-27 18:06:14 +08:00
removed embedding_dimension
This commit is contained in:
@@ -64,7 +64,6 @@ class TFTEstimator(PyTorchLightningEstimator):
|
||||
num_feat_static_cat: int = 0,
|
||||
num_feat_static_real: int = 0,
|
||||
cardinality: Optional[List[int]] = None,
|
||||
embedding_dimension: Optional[List[int]] = None,
|
||||
distr_output: DistributionOutput = StudentTOutput(),
|
||||
loss: DistributionLoss = NegativeLogLikelihood(),
|
||||
scaling: bool = True,
|
||||
@@ -104,7 +103,6 @@ class TFTEstimator(PyTorchLightningEstimator):
|
||||
self.cardinality = (
|
||||
cardinality if cardinality and num_feat_static_cat > 0 else [1]
|
||||
)
|
||||
self.embedding_dimension = embedding_dimension
|
||||
self.scaling = scaling
|
||||
self.lags_seq = lags_seq
|
||||
self.time_features = (
|
||||
@@ -287,7 +285,6 @@ class TFTEstimator(PyTorchLightningEstimator):
|
||||
num_feat_static_real=max(1, self.num_feat_static_real),
|
||||
num_feat_static_cat=max(1, self.num_feat_static_cat),
|
||||
cardinality=self.cardinality,
|
||||
embedding_dimension=self.embedding_dimension,
|
||||
# transformer arguments
|
||||
num_heads=self.num_heads,
|
||||
dropout=self.dropout,
|
||||
|
||||
Reference in New Issue
Block a user