mirror of
https://github.com/wassname/pytorch-ts.git
synced 2026-06-27 18:06:19 +08:00
Fixed TFT transform
This commit is contained in:
@@ -8,11 +8,13 @@ import torch.nn as nn
|
||||
from torch.utils import data
|
||||
from torch.utils.data import DataLoader
|
||||
|
||||
from gluonts.env import env
|
||||
from gluonts.core.component import validated
|
||||
from gluonts.dataset.common import Dataset
|
||||
from gluonts.model.estimator import Estimator
|
||||
from gluonts.torch.model.predictor import PyTorchPredictor
|
||||
from gluonts.transform import SelectFields, Transformation
|
||||
from gluonts.support.util import maybe_len
|
||||
|
||||
from pts import Trainer
|
||||
from pts.model import get_module_forward_input_names
|
||||
@@ -101,7 +103,9 @@ class PyTorchEstimator(Estimator):
|
||||
trained_net = self.create_training_network(self.trainer.device)
|
||||
|
||||
input_names = get_module_forward_input_names(trained_net)
|
||||
training_instance_splitter = self.create_instance_splitter("training")
|
||||
|
||||
with env._let(max_idle_transforms=maybe_len(training_data) or 0):
|
||||
training_instance_splitter = self.create_instance_splitter("training")
|
||||
training_iter_dataset = TransformedIterableDataset(
|
||||
dataset=training_data,
|
||||
transform=transformation
|
||||
@@ -124,7 +128,8 @@ class PyTorchEstimator(Estimator):
|
||||
|
||||
validation_data_loader = None
|
||||
if validation_data is not None:
|
||||
validation_instance_splitter = self.create_instance_splitter("validation")
|
||||
with env._let(max_idle_transforms=maybe_len(validation_data) or 0):
|
||||
validation_instance_splitter = self.create_instance_splitter("validation")
|
||||
validation_iter_dataset = TransformedIterableDataset(
|
||||
dataset=validation_data,
|
||||
transform=transformation
|
||||
|
||||
@@ -3,6 +3,7 @@ from typing import List, Optional, Dict
|
||||
|
||||
import numpy as np
|
||||
import torch
|
||||
|
||||
from gluonts.core.component import validated
|
||||
from gluonts.dataset.field_names import FieldName
|
||||
from gluonts.model.forecast_generator import QuantileForecastGenerator
|
||||
@@ -30,6 +31,7 @@ from gluonts.transform import (
|
||||
from pts import Trainer
|
||||
from pts.model import PyTorchEstimator
|
||||
from pts.model.utils import get_module_forward_input_names
|
||||
|
||||
from .tft_network import (
|
||||
TemporalFusionTransformerPredictionNetwork,
|
||||
TemporalFusionTransformerTrainingNetwork,
|
||||
|
||||
@@ -25,6 +25,7 @@ from gluonts.transform import (
|
||||
shift_timestamp,
|
||||
target_transformation_length,
|
||||
)
|
||||
from gluonts.transform.sampler import InstanceSampler
|
||||
|
||||
|
||||
class BroadcastTo(MapTransformation):
|
||||
@@ -54,7 +55,7 @@ class TFTInstanceSplitter(InstanceSplitter):
|
||||
@validated()
|
||||
def __init__(
|
||||
self,
|
||||
instance_sampler,
|
||||
instance_sampler: InstanceSampler,
|
||||
past_length: int,
|
||||
future_length: int,
|
||||
target_field: str = FieldName.TARGET,
|
||||
@@ -64,29 +65,30 @@ class TFTInstanceSplitter(InstanceSplitter):
|
||||
observed_value_field: str = FieldName.OBSERVED_VALUES,
|
||||
lead_time: int = 0,
|
||||
output_NTC: bool = True,
|
||||
time_series_fields: Optional[List[str]] = None,
|
||||
past_time_series_fields: Optional[List[str]] = None,
|
||||
time_series_fields: List[str] = [],
|
||||
past_time_series_fields: List[str] = [],
|
||||
dummy_value: float = 0.0,
|
||||
) -> None:
|
||||
|
||||
super().__init__(
|
||||
target_field=target_field,
|
||||
is_pad_field=is_pad_field,
|
||||
start_field=start_field,
|
||||
forecast_start_field=forecast_start_field,
|
||||
instance_sampler=instance_sampler,
|
||||
past_length=past_length,
|
||||
future_length=future_length,
|
||||
lead_time=lead_time,
|
||||
output_NTC=output_NTC,
|
||||
time_series_fields=time_series_fields,
|
||||
dummy_value=dummy_value,
|
||||
)
|
||||
|
||||
assert past_length > 0, "The value of `past_length` should be > 0"
|
||||
assert future_length > 0, "The value of `future_length` should be > 0"
|
||||
|
||||
self.instance_sampler = instance_sampler
|
||||
self.past_length = past_length
|
||||
self.future_length = future_length
|
||||
self.lead_time = lead_time
|
||||
self.output_NTC = output_NTC
|
||||
self.dummy_value = dummy_value
|
||||
|
||||
self.target_field = target_field
|
||||
self.is_pad_field = is_pad_field
|
||||
self.start_field = start_field
|
||||
self.forecast_start_field = forecast_start_field
|
||||
self.observed_value_field = observed_value_field
|
||||
|
||||
self.ts_fields = time_series_fields or []
|
||||
self.past_ts_fields = past_time_series_fields or []
|
||||
self.past_ts_fields = past_time_series_fields
|
||||
|
||||
def flatmap_transform(self, data: DataEntry, is_train: bool) -> Iterator[DataEntry]:
|
||||
pl = self.future_length
|
||||
|
||||
Reference in New Issue
Block a user