add estimator args

This commit is contained in:
Kashif Rasul
2022-04-04 22:21:04 +02:00
parent bd0165f720
commit 68d33b7fd3
2 changed files with 12 additions and 6 deletions
+10 -4
View File
@@ -54,13 +54,15 @@ class AutoformerEstimator(PyTorchLightningEstimator):
freq: str,
prediction_length: int,
# Autoformer arguments
nhead: int,
n_heads: int,
num_encoder_layers: int,
num_decoder_layers: int,
dim_feedforward: int,
input_size: int = 1,
activation: str = "gelu",
dropout: float = 0.1,
factor: int = 1,
moving_avg: int = 25,
context_length: Optional[int] = None,
num_feat_dynamic_real: int = 0,
num_feat_static_cat: int = 0,
@@ -94,12 +96,14 @@ class AutoformerEstimator(PyTorchLightningEstimator):
self.loss = loss
self.input_size = input_size
self.nhead = nhead
self.n_heads = n_heads
self.num_encoder_layers = num_encoder_layers
self.num_decoder_layers = num_decoder_layers
self.activation = activation
self.dim_feedforward = dim_feedforward
self.dropout = dropout
self.factor = factor
self.moving_avg = moving_avg
self.num_feat_dynamic_real = num_feat_dynamic_real
self.num_feat_static_cat = num_feat_static_cat
@@ -291,13 +295,15 @@ class AutoformerEstimator(PyTorchLightningEstimator):
num_feat_static_cat=max(1, self.num_feat_static_cat),
cardinality=self.cardinality,
embedding_dimension=self.embedding_dimension,
# transformer arguments
nhead=self.nhead,
# autoformer arguments
n_heads=self.n_heads,
num_encoder_layers=self.num_encoder_layers,
num_decoder_layers=self.num_decoder_layers,
activation=self.activation,
dropout=self.dropout,
dim_feedforward=self.dim_feedforward,
factor=self.factor,
moving_avg=self.moving_avg,
# univariate input
input_size=self.input_size,
distr_output=self.distr_output,
+2 -2
View File
@@ -57,7 +57,7 @@ class AutoformerLightningModule(pl.LightningModule):
past_observed_values = batch["past_observed_values"]
future_observed_values = batch["future_observed_values"]
transformer_inputs, scale, _ = self.model.create_network_inputs(
autoformer_inputs, scale, _ = self.model.create_network_inputs(
feat_static_cat,
feat_static_real,
past_time_feat,
@@ -66,7 +66,7 @@ class AutoformerLightningModule(pl.LightningModule):
future_time_feat,
future_target,
)
params = self.model.output_params(transformer_inputs)
params = self.model.output_params(autoformer_inputs)
distr = self.model.output_distribution(params, scale)
loss_values = self.loss(distr, future_target)