mirror of
https://github.com/wassname/pytorch-transformer-ts.git
synced 2026-06-27 16:46:28 +08:00
Gluonts dataloader
This commit is contained in:
+14
-14
@@ -51,13 +51,13 @@ class TransformerEstimator(PyTorchLightningEstimator):
|
||||
@validated()
|
||||
def __init__(
|
||||
self,
|
||||
freq: str,
|
||||
prediction_length: int,
|
||||
# Transformer arguments
|
||||
nhead: int,
|
||||
num_encoder_layers: int,
|
||||
num_decoder_layers: int,
|
||||
dim_feedforward: int,
|
||||
freq: Optional[str] = None,
|
||||
input_size: int = 1,
|
||||
activation: str = "gelu",
|
||||
dropout: float = 0.1,
|
||||
@@ -165,19 +165,19 @@ class TransformerEstimator(PyTorchLightningEstimator):
|
||||
target_field=FieldName.TARGET,
|
||||
output_field=FieldName.OBSERVED_VALUES,
|
||||
),
|
||||
AddTimeFeatures(
|
||||
start_field=FieldName.START,
|
||||
target_field=FieldName.TARGET,
|
||||
output_field=FieldName.FEAT_TIME,
|
||||
time_features=self.time_features,
|
||||
pred_length=self.prediction_length,
|
||||
),
|
||||
AddAgeFeature(
|
||||
target_field=FieldName.TARGET,
|
||||
output_field=FieldName.FEAT_AGE,
|
||||
pred_length=self.prediction_length,
|
||||
log_scale=True,
|
||||
),
|
||||
# AddTimeFeatures(
|
||||
# start_field=FieldName.START,
|
||||
# target_field=FieldName.TARGET,
|
||||
# output_field=FieldName.FEAT_TIME,
|
||||
# time_features=self.time_features,
|
||||
# pred_length=self.prediction_length,
|
||||
# ),
|
||||
# AddAgeFeature(
|
||||
# target_field=FieldName.TARGET,
|
||||
# output_field=FieldName.FEAT_AGE,
|
||||
# pred_length=self.prediction_length,
|
||||
# log_scale=True,
|
||||
# ),
|
||||
VstackFeatures(
|
||||
output_field=FieldName.FEAT_TIME,
|
||||
input_fields=[FieldName.FEAT_TIME, FieldName.FEAT_AGE]
|
||||
|
||||
@@ -13,7 +13,6 @@ class TransformerModel(nn.Module):
|
||||
@validated()
|
||||
def __init__(
|
||||
self,
|
||||
freq: str,
|
||||
context_length: int,
|
||||
prediction_length: int,
|
||||
num_feat_dynamic_real: int,
|
||||
@@ -32,6 +31,7 @@ class TransformerModel(nn.Module):
|
||||
embedding_dimension: Optional[List[int]] = None,
|
||||
distr_output: DistributionOutput = StudentTOutput(),
|
||||
lags_seq: Optional[List[int]] = None,
|
||||
freq: Optional[str] = None,
|
||||
scaling: bool = True,
|
||||
num_parallel_samples: int = 100,
|
||||
) -> None:
|
||||
@@ -78,6 +78,7 @@ class TransformerModel(nn.Module):
|
||||
dropout=dropout,
|
||||
activation=activation,
|
||||
batch_first=True,
|
||||
norm_first=True,
|
||||
)
|
||||
|
||||
# causal decoder tgt mask
|
||||
|
||||
+635
-168
File diff suppressed because one or more lines are too long
Reference in New Issue
Block a user