mirror of
https://github.com/wassname/pytorch-transformer-ts.git
synced 2026-06-27 16:31:19 +08:00
157 lines
5.1 KiB
Python
157 lines
5.1 KiB
Python
from typing import List, Optional, Iterable, Dict, Any
|
|
|
|
import numpy as np
|
|
import torch
|
|
import torch.nn as nn
|
|
from torch.utils.data import DataLoader
|
|
|
|
import pytorch_lightning as pl
|
|
|
|
from gluonts.core.component import validated
|
|
from gluonts.dataset.common import Dataset
|
|
from gluonts.dataset.field_names import FieldName
|
|
from gluonts.itertools import Cyclic, PseudoShuffled, IterableSlice
|
|
from gluonts.time_feature import (
|
|
TimeFeature,
|
|
time_features_from_frequency_str,
|
|
)
|
|
from gluonts.torch.modules.loss import DistributionLoss, NegativeLogLikelihood
|
|
from gluonts.transform import (
|
|
Transformation,
|
|
Chain,
|
|
RemoveFields,
|
|
SetField,
|
|
AsNumpyArray,
|
|
AddObservedValuesIndicator,
|
|
AddTimeFeatures,
|
|
AddAgeFeature,
|
|
VstackFeatures,
|
|
InstanceSplitter,
|
|
ValidationSplitSampler,
|
|
TestSplitSampler,
|
|
ExpectedNumInstanceSampler,
|
|
SelectFields,
|
|
InstanceSampler,
|
|
)
|
|
from gluonts.torch.util import (
|
|
IterableDataset,
|
|
)
|
|
from gluonts.torch.model.estimator import PyTorchLightningEstimator
|
|
from gluonts.torch.model.predictor import PyTorchPredictor
|
|
from gluonts.torch.distributions import (
|
|
DistributionOutput,
|
|
StudentTOutput,
|
|
)
|
|
from gluonts.torch.modules.scaler import MeanScaler, NOPScaler
|
|
from gluonts.torch.modules.feature import FeatureEmbedder
|
|
from gluonts.time_feature import get_lags_for_frequency
|
|
|
|
|
|
class TransformerModel(nn.Module):
|
|
@validated()
|
|
def __init__(
|
|
self,
|
|
freq: str,
|
|
context_length: int,
|
|
prediction_length: int,
|
|
num_feat_dynamic_real: int,
|
|
num_feat_static_real: int,
|
|
num_feat_static_cat: int,
|
|
cardinality: List[int],
|
|
embedding_dimension: Optional[List[int]] = None,
|
|
# Added transformer arguments
|
|
encoder=None,
|
|
decoder=None,
|
|
embeding=None,
|
|
target_embed=None,
|
|
generator=None,
|
|
#############################
|
|
dropout_rate: float = 0.1,
|
|
distr_output: DistributionOutput = StudentTOutput(),
|
|
lags_seq: Optional[List[int]] = None,
|
|
scaling: bool = True,
|
|
num_parallel_samples: int = 100,
|
|
) -> None:
|
|
super().__init__()
|
|
|
|
self.context_length = context_length
|
|
self.prediction_length = prediction_length
|
|
self.distr_output = distr_output
|
|
self.param_proj = distr_output.get_args_proj(hidden_size)
|
|
self.target_shape = distr_output.event_shape
|
|
self.num_feat_dynamic_real = num_feat_dynamic_real
|
|
self.num_feat_static_cat = num_feat_static_cat
|
|
self.num_feat_static_real = num_feat_static_real
|
|
self.embedding_dimension = (
|
|
embedding_dimension
|
|
if embedding_dimension is not None or cardinality is None
|
|
else [min(50, (cat + 1) // 2) for cat in cardinality]
|
|
)
|
|
self.lags_seq = lags_seq or get_lags_for_frequency(freq_str=freq)
|
|
self.num_parallel_samples = num_parallel_samples
|
|
self.history_length = self.context_length + max(self.lags_seq)
|
|
self.embedder = FeatureEmbedder(
|
|
cardinalities=cardinality,
|
|
embedding_dims=self.embedding_dimension,
|
|
)
|
|
if scaling:
|
|
self.scaler = MeanScaler(dim=1, keepdim=True)
|
|
else:
|
|
self.scaler = NOPScaler(dim=1, keepdim=True)
|
|
|
|
# Added transformer enc-decoder and mask initializer
|
|
self.encoder = encoder
|
|
self.decoder = decoder
|
|
self.embeding = embeding
|
|
self.target_embed = target_embed
|
|
self.generator = generator
|
|
########################
|
|
|
|
# TODO
|
|
# add method that does the forward for training
|
|
|
|
"""
|
|
A build-in Encoder-Decoder architecture for TransformerModel class
|
|
"""
|
|
|
|
def forward(self, src, tgt, mask_source, mask_target):
|
|
"Take in and process masked sourcerc and target sequences."
|
|
memory = self.encoder(self.embeding(src), mask_source)
|
|
output = self.decoder(self.target_embed(tgt), memory, mask_source, mask_target)
|
|
return output
|
|
|
|
def encode(self, src, mask_source):
|
|
return self.encoder(self.src_embed(src), mask_source)
|
|
|
|
def decode(self, memory, mask_source, tgt, mask_target):
|
|
return self.decoder(self.tgt_embed(tgt), memory, mask_source, mask_target)
|
|
|
|
@property
|
|
def _number_of_features(self) -> int:
|
|
return (
|
|
sum(self.embedding_dimension)
|
|
+ self.num_feat_dynamic_real
|
|
+ self.num_feat_static_real
|
|
+ 1 # the log(scale)
|
|
)
|
|
|
|
@property
|
|
def _past_length(self) -> int:
|
|
return self.context_length + max(self.lags_seq)
|
|
|
|
# for prediction
|
|
def forward(
|
|
self,
|
|
feat_static_cat: torch.Tensor,
|
|
feat_static_real: torch.Tensor,
|
|
past_time_feat: torch.Tensor,
|
|
past_target: torch.Tensor,
|
|
past_observed_values: torch.Tensor,
|
|
future_time_feat: torch.Tensor,
|
|
num_parallel_samples: Optional[int] = None,
|
|
) -> torch.Tensor:
|
|
if num_parallel_samples is None:
|
|
num_parallel_samples = self.num_parallel_samples
|
|
|
|
# TODO
|