Gluonts dataloader

This commit is contained in:
Kashif Rasul
2022-08-14 01:04:11 +02:00
parent d8e8030994
commit 6f77e3839f
3 changed files with 651 additions and 183 deletions
+14 -14
View File
@@ -51,13 +51,13 @@ class TransformerEstimator(PyTorchLightningEstimator):
@validated()
def __init__(
self,
freq: str,
prediction_length: int,
# Transformer arguments
nhead: int,
num_encoder_layers: int,
num_decoder_layers: int,
dim_feedforward: int,
freq: Optional[str] = None,
input_size: int = 1,
activation: str = "gelu",
dropout: float = 0.1,
@@ -165,19 +165,19 @@ class TransformerEstimator(PyTorchLightningEstimator):
target_field=FieldName.TARGET,
output_field=FieldName.OBSERVED_VALUES,
),
AddTimeFeatures(
start_field=FieldName.START,
target_field=FieldName.TARGET,
output_field=FieldName.FEAT_TIME,
time_features=self.time_features,
pred_length=self.prediction_length,
),
AddAgeFeature(
target_field=FieldName.TARGET,
output_field=FieldName.FEAT_AGE,
pred_length=self.prediction_length,
log_scale=True,
),
# AddTimeFeatures(
# start_field=FieldName.START,
# target_field=FieldName.TARGET,
# output_field=FieldName.FEAT_TIME,
# time_features=self.time_features,
# pred_length=self.prediction_length,
# ),
# AddAgeFeature(
# target_field=FieldName.TARGET,
# output_field=FieldName.FEAT_AGE,
# pred_length=self.prediction_length,
# log_scale=True,
# ),
VstackFeatures(
output_field=FieldName.FEAT_TIME,
input_fields=[FieldName.FEAT_TIME, FieldName.FEAT_AGE]
+2 -1
View File
@@ -13,7 +13,6 @@ class TransformerModel(nn.Module):
@validated()
def __init__(
self,
freq: str,
context_length: int,
prediction_length: int,
num_feat_dynamic_real: int,
@@ -32,6 +31,7 @@ class TransformerModel(nn.Module):
embedding_dimension: Optional[List[int]] = None,
distr_output: DistributionOutput = StudentTOutput(),
lags_seq: Optional[List[int]] = None,
freq: Optional[str] = None,
scaling: bool = True,
num_parallel_samples: int = 100,
) -> None:
@@ -78,6 +78,7 @@ class TransformerModel(nn.Module):
dropout=dropout,
activation=activation,
batch_first=True,
norm_first=True,
)
# causal decoder tgt mask
+635 -168
View File
File diff suppressed because one or more lines are too long