Fixed TFT transform

This commit is contained in:
Kashif Rasul
2021-07-07 11:22:20 +02:00
parent 1a94965d59
commit cc7dea9f2f
3 changed files with 28 additions and 19 deletions
+7 -2
View File
@@ -8,11 +8,13 @@ import torch.nn as nn
from torch.utils import data
from torch.utils.data import DataLoader
from gluonts.env import env
from gluonts.core.component import validated
from gluonts.dataset.common import Dataset
from gluonts.model.estimator import Estimator
from gluonts.torch.model.predictor import PyTorchPredictor
from gluonts.transform import SelectFields, Transformation
from gluonts.support.util import maybe_len
from pts import Trainer
from pts.model import get_module_forward_input_names
@@ -101,7 +103,9 @@ class PyTorchEstimator(Estimator):
trained_net = self.create_training_network(self.trainer.device)
input_names = get_module_forward_input_names(trained_net)
training_instance_splitter = self.create_instance_splitter("training")
with env._let(max_idle_transforms=maybe_len(training_data) or 0):
training_instance_splitter = self.create_instance_splitter("training")
training_iter_dataset = TransformedIterableDataset(
dataset=training_data,
transform=transformation
@@ -124,7 +128,8 @@ class PyTorchEstimator(Estimator):
validation_data_loader = None
if validation_data is not None:
validation_instance_splitter = self.create_instance_splitter("validation")
with env._let(max_idle_transforms=maybe_len(validation_data) or 0):
validation_instance_splitter = self.create_instance_splitter("validation")
validation_iter_dataset = TransformedIterableDataset(
dataset=validation_data,
transform=transformation
+2
View File
@@ -3,6 +3,7 @@ from typing import List, Optional, Dict
import numpy as np
import torch
from gluonts.core.component import validated
from gluonts.dataset.field_names import FieldName
from gluonts.model.forecast_generator import QuantileForecastGenerator
@@ -30,6 +31,7 @@ from gluonts.transform import (
from pts import Trainer
from pts.model import PyTorchEstimator
from pts.model.utils import get_module_forward_input_names
from .tft_network import (
TemporalFusionTransformerPredictionNetwork,
TemporalFusionTransformerTrainingNetwork,
+19 -17
View File
@@ -25,6 +25,7 @@ from gluonts.transform import (
shift_timestamp,
target_transformation_length,
)
from gluonts.transform.sampler import InstanceSampler
class BroadcastTo(MapTransformation):
@@ -54,7 +55,7 @@ class TFTInstanceSplitter(InstanceSplitter):
@validated()
def __init__(
self,
instance_sampler,
instance_sampler: InstanceSampler,
past_length: int,
future_length: int,
target_field: str = FieldName.TARGET,
@@ -64,29 +65,30 @@ class TFTInstanceSplitter(InstanceSplitter):
observed_value_field: str = FieldName.OBSERVED_VALUES,
lead_time: int = 0,
output_NTC: bool = True,
time_series_fields: Optional[List[str]] = None,
past_time_series_fields: Optional[List[str]] = None,
time_series_fields: List[str] = [],
past_time_series_fields: List[str] = [],
dummy_value: float = 0.0,
) -> None:
super().__init__(
target_field=target_field,
is_pad_field=is_pad_field,
start_field=start_field,
forecast_start_field=forecast_start_field,
instance_sampler=instance_sampler,
past_length=past_length,
future_length=future_length,
lead_time=lead_time,
output_NTC=output_NTC,
time_series_fields=time_series_fields,
dummy_value=dummy_value,
)
assert past_length > 0, "The value of `past_length` should be > 0"
assert future_length > 0, "The value of `future_length` should be > 0"
self.instance_sampler = instance_sampler
self.past_length = past_length
self.future_length = future_length
self.lead_time = lead_time
self.output_NTC = output_NTC
self.dummy_value = dummy_value
self.target_field = target_field
self.is_pad_field = is_pad_field
self.start_field = start_field
self.forecast_start_field = forecast_start_field
self.observed_value_field = observed_value_field
self.ts_fields = time_series_fields or []
self.past_ts_fields = past_time_series_fields or []
self.past_ts_fields = past_time_series_fields
def flatmap_transform(self, data: DataEntry, is_train: bool) -> Iterator[DataEntry]:
pl = self.future_length