diff --git a/pts/model/estimator.py b/pts/model/estimator.py index 5044893..0359683 100644 --- a/pts/model/estimator.py +++ b/pts/model/estimator.py @@ -8,11 +8,13 @@ import torch.nn as nn from torch.utils import data from torch.utils.data import DataLoader +from gluonts.env import env from gluonts.core.component import validated from gluonts.dataset.common import Dataset from gluonts.model.estimator import Estimator from gluonts.torch.model.predictor import PyTorchPredictor from gluonts.transform import SelectFields, Transformation +from gluonts.support.util import maybe_len from pts import Trainer from pts.model import get_module_forward_input_names @@ -101,7 +103,9 @@ class PyTorchEstimator(Estimator): trained_net = self.create_training_network(self.trainer.device) input_names = get_module_forward_input_names(trained_net) - training_instance_splitter = self.create_instance_splitter("training") + + with env._let(max_idle_transforms=maybe_len(training_data) or 0): + training_instance_splitter = self.create_instance_splitter("training") training_iter_dataset = TransformedIterableDataset( dataset=training_data, transform=transformation @@ -124,7 +128,8 @@ class PyTorchEstimator(Estimator): validation_data_loader = None if validation_data is not None: - validation_instance_splitter = self.create_instance_splitter("validation") + with env._let(max_idle_transforms=maybe_len(validation_data) or 0): + validation_instance_splitter = self.create_instance_splitter("validation") validation_iter_dataset = TransformedIterableDataset( dataset=validation_data, transform=transformation diff --git a/pts/model/tft/tft_estimator.py b/pts/model/tft/tft_estimator.py index c8ca91d..53c53af 100644 --- a/pts/model/tft/tft_estimator.py +++ b/pts/model/tft/tft_estimator.py @@ -3,6 +3,7 @@ from typing import List, Optional, Dict import numpy as np import torch + from gluonts.core.component import validated from gluonts.dataset.field_names import FieldName from gluonts.model.forecast_generator import QuantileForecastGenerator @@ -30,6 +31,7 @@ from gluonts.transform import ( from pts import Trainer from pts.model import PyTorchEstimator from pts.model.utils import get_module_forward_input_names + from .tft_network import ( TemporalFusionTransformerPredictionNetwork, TemporalFusionTransformerTrainingNetwork, diff --git a/pts/model/tft/tft_transform.py b/pts/model/tft/tft_transform.py index 60b5f87..3832409 100644 --- a/pts/model/tft/tft_transform.py +++ b/pts/model/tft/tft_transform.py @@ -25,6 +25,7 @@ from gluonts.transform import ( shift_timestamp, target_transformation_length, ) +from gluonts.transform.sampler import InstanceSampler class BroadcastTo(MapTransformation): @@ -54,7 +55,7 @@ class TFTInstanceSplitter(InstanceSplitter): @validated() def __init__( self, - instance_sampler, + instance_sampler: InstanceSampler, past_length: int, future_length: int, target_field: str = FieldName.TARGET, @@ -64,29 +65,30 @@ class TFTInstanceSplitter(InstanceSplitter): observed_value_field: str = FieldName.OBSERVED_VALUES, lead_time: int = 0, output_NTC: bool = True, - time_series_fields: Optional[List[str]] = None, - past_time_series_fields: Optional[List[str]] = None, + time_series_fields: List[str] = [], + past_time_series_fields: List[str] = [], dummy_value: float = 0.0, ) -> None: + super().__init__( + target_field=target_field, + is_pad_field=is_pad_field, + start_field=start_field, + forecast_start_field=forecast_start_field, + instance_sampler=instance_sampler, + past_length=past_length, + future_length=future_length, + lead_time=lead_time, + output_NTC=output_NTC, + time_series_fields=time_series_fields, + dummy_value=dummy_value, + ) + assert past_length > 0, "The value of `past_length` should be > 0" assert future_length > 0, "The value of `future_length` should be > 0" - self.instance_sampler = instance_sampler - self.past_length = past_length - self.future_length = future_length - self.lead_time = lead_time - self.output_NTC = output_NTC - self.dummy_value = dummy_value - - self.target_field = target_field - self.is_pad_field = is_pad_field - self.start_field = start_field - self.forecast_start_field = forecast_start_field self.observed_value_field = observed_value_field - - self.ts_fields = time_series_fields or [] - self.past_ts_fields = past_time_series_fields or [] + self.past_ts_fields = past_time_series_fields def flatmap_transform(self, data: DataEntry, is_train: bool) -> Iterator[DataEntry]: pl = self.future_length