mirror of
https://github.com/wassname/pytorch-transformer-ts.git
synced 2026-06-27 18:06:14 +08:00
add estimator args
This commit is contained in:
+10
-4
@@ -54,13 +54,15 @@ class AutoformerEstimator(PyTorchLightningEstimator):
|
||||
freq: str,
|
||||
prediction_length: int,
|
||||
# Autoformer arguments
|
||||
nhead: int,
|
||||
n_heads: int,
|
||||
num_encoder_layers: int,
|
||||
num_decoder_layers: int,
|
||||
dim_feedforward: int,
|
||||
input_size: int = 1,
|
||||
activation: str = "gelu",
|
||||
dropout: float = 0.1,
|
||||
factor: int = 1,
|
||||
moving_avg: int = 25,
|
||||
context_length: Optional[int] = None,
|
||||
num_feat_dynamic_real: int = 0,
|
||||
num_feat_static_cat: int = 0,
|
||||
@@ -94,12 +96,14 @@ class AutoformerEstimator(PyTorchLightningEstimator):
|
||||
self.loss = loss
|
||||
|
||||
self.input_size = input_size
|
||||
self.nhead = nhead
|
||||
self.n_heads = n_heads
|
||||
self.num_encoder_layers = num_encoder_layers
|
||||
self.num_decoder_layers = num_decoder_layers
|
||||
self.activation = activation
|
||||
self.dim_feedforward = dim_feedforward
|
||||
self.dropout = dropout
|
||||
self.factor = factor
|
||||
self.moving_avg = moving_avg
|
||||
|
||||
self.num_feat_dynamic_real = num_feat_dynamic_real
|
||||
self.num_feat_static_cat = num_feat_static_cat
|
||||
@@ -291,13 +295,15 @@ class AutoformerEstimator(PyTorchLightningEstimator):
|
||||
num_feat_static_cat=max(1, self.num_feat_static_cat),
|
||||
cardinality=self.cardinality,
|
||||
embedding_dimension=self.embedding_dimension,
|
||||
# transformer arguments
|
||||
nhead=self.nhead,
|
||||
# autoformer arguments
|
||||
n_heads=self.n_heads,
|
||||
num_encoder_layers=self.num_encoder_layers,
|
||||
num_decoder_layers=self.num_decoder_layers,
|
||||
activation=self.activation,
|
||||
dropout=self.dropout,
|
||||
dim_feedforward=self.dim_feedforward,
|
||||
factor=self.factor,
|
||||
moving_avg=self.moving_avg,
|
||||
# univariate input
|
||||
input_size=self.input_size,
|
||||
distr_output=self.distr_output,
|
||||
|
||||
@@ -57,7 +57,7 @@ class AutoformerLightningModule(pl.LightningModule):
|
||||
past_observed_values = batch["past_observed_values"]
|
||||
future_observed_values = batch["future_observed_values"]
|
||||
|
||||
transformer_inputs, scale, _ = self.model.create_network_inputs(
|
||||
autoformer_inputs, scale, _ = self.model.create_network_inputs(
|
||||
feat_static_cat,
|
||||
feat_static_real,
|
||||
past_time_feat,
|
||||
@@ -66,7 +66,7 @@ class AutoformerLightningModule(pl.LightningModule):
|
||||
future_time_feat,
|
||||
future_target,
|
||||
)
|
||||
params = self.model.output_params(transformer_inputs)
|
||||
params = self.model.output_params(autoformer_inputs)
|
||||
distr = self.model.output_distribution(params, scale)
|
||||
|
||||
loss_values = self.loss(distr, future_target)
|
||||
|
||||
Reference in New Issue
Block a user