mirror of
https://github.com/wassname/pytorch-transformer-ts.git
synced 2026-06-27 16:46:28 +08:00
use gluonts dev
This commit is contained in:
@@ -8,7 +8,7 @@ from gluonts.itertools import Cyclic, IterableSlice, PseudoShuffled
|
||||
from gluonts.time_feature import TimeFeature, time_features_from_frequency_str
|
||||
from gluonts.torch.model.estimator import PyTorchLightningEstimator
|
||||
from gluonts.torch.model.predictor import PyTorchPredictor
|
||||
from gluonts.torch.modules.distribution_output import DistributionOutput, StudentTOutput
|
||||
from gluonts.torch.distributions import DistributionOutput, StudentTOutput
|
||||
from gluonts.torch.modules.loss import DistributionLoss, NegativeLogLikelihood
|
||||
from gluonts.torch.util import IterableDataset
|
||||
from gluonts.transform import (
|
||||
|
||||
@@ -6,7 +6,7 @@ import torch.nn as nn
|
||||
import torch.nn.functional as F
|
||||
from gluonts.core.component import validated
|
||||
from gluonts.time_feature import get_lags_for_frequency
|
||||
from gluonts.torch.modules.distribution_output import DistributionOutput, StudentTOutput
|
||||
from gluonts.torch.distributions import DistributionOutput, StudentTOutput
|
||||
from gluonts.torch.modules.feature import FeatureEmbedder
|
||||
from gluonts.torch.modules.scaler import MeanScaler, NOPScaler
|
||||
|
||||
|
||||
@@ -8,7 +8,7 @@ from gluonts.itertools import Cyclic, IterableSlice, PseudoShuffled
|
||||
from gluonts.time_feature import TimeFeature, time_features_from_frequency_str
|
||||
from gluonts.torch.model.estimator import PyTorchLightningEstimator
|
||||
from gluonts.torch.model.predictor import PyTorchPredictor
|
||||
from gluonts.torch.modules.distribution_output import DistributionOutput, StudentTOutput
|
||||
from gluonts.torch.distributions import DistributionOutput, StudentTOutput
|
||||
from gluonts.torch.modules.loss import DistributionLoss, NegativeLogLikelihood
|
||||
from gluonts.torch.util import IterableDataset
|
||||
from gluonts.transform import (
|
||||
|
||||
+1
-1
@@ -5,7 +5,7 @@ import torch.nn as nn
|
||||
from etsformer_pytorch import ETSFormer
|
||||
from gluonts.core.component import validated
|
||||
from gluonts.time_feature import get_lags_for_frequency
|
||||
from gluonts.torch.modules.distribution_output import DistributionOutput, StudentTOutput
|
||||
from gluonts.torch.distributions import DistributionOutput, StudentTOutput
|
||||
from gluonts.torch.modules.feature import FeatureEmbedder
|
||||
from gluonts.torch.modules.scaler import MeanScaler, NOPScaler
|
||||
|
||||
|
||||
@@ -8,7 +8,7 @@ from gluonts.itertools import Cyclic, IterableSlice, PseudoShuffled
|
||||
from gluonts.time_feature import TimeFeature, time_features_from_frequency_str
|
||||
from gluonts.torch.model.estimator import PyTorchLightningEstimator
|
||||
from gluonts.torch.model.predictor import PyTorchPredictor
|
||||
from gluonts.torch.modules.distribution_output import DistributionOutput, StudentTOutput
|
||||
from gluonts.torch.distributions import DistributionOutput, StudentTOutput
|
||||
from gluonts.torch.modules.loss import DistributionLoss, NegativeLogLikelihood
|
||||
from gluonts.torch.util import IterableDataset
|
||||
from gluonts.transform import (
|
||||
|
||||
@@ -0,0 +1,262 @@
|
||||
from typing import List, Optional
|
||||
|
||||
import numpy as np
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
|
||||
from gluonts.core.component import validated
|
||||
from gluonts.dataset.field_names import FieldName
|
||||
from gluonts.time_feature import TimeFeature
|
||||
from gluonts.torch.distributions import DistributionOutput
|
||||
from gluonts.torch.util import copy_parameters
|
||||
from gluonts.torch.model.predictor import PyTorchPredictor
|
||||
from gluonts.model.predictor import Predictor
|
||||
from gluonts.transform import (
|
||||
Transformation,
|
||||
Chain,
|
||||
InstanceSplitter,
|
||||
InstanceSampler,
|
||||
ValidationSplitSampler,
|
||||
TestSplitSampler,
|
||||
ExpectedNumInstanceSampler,
|
||||
RemoveFields,
|
||||
AddAgeFeature,
|
||||
AsNumpyArray,
|
||||
AddObservedValuesIndicator,
|
||||
AddTimeFeatures,
|
||||
VstackFeatures,
|
||||
SetField,
|
||||
)
|
||||
|
||||
from pts import Trainer
|
||||
from pts.model.utils import get_module_forward_input_names
|
||||
from pts.feature import (
|
||||
fourier_time_features_from_frequency,
|
||||
lags_for_fourier_time_features_from_frequency,
|
||||
)
|
||||
from pts.model import PyTorchEstimator
|
||||
from pts.modules import StudentTOutput
|
||||
|
||||
from .hopfield_network import (
|
||||
HopfieldTrainingNetwork,
|
||||
HopfieldPredictionNetwork,
|
||||
)
|
||||
|
||||
|
||||
class HopfieldEstimator(PyTorchEstimator):
|
||||
@validated()
|
||||
def __init__(
|
||||
self,
|
||||
input_size: int,
|
||||
freq: str,
|
||||
prediction_length: int,
|
||||
context_length: Optional[int] = None,
|
||||
trainer: Trainer = Trainer(),
|
||||
dropout_rate: float = 0.1,
|
||||
cardinality: Optional[List[int]] = None,
|
||||
embedding_dimension: List[int] = [20],
|
||||
distr_output: DistributionOutput = StudentTOutput(),
|
||||
d_model: int = 32,
|
||||
dim_feedforward_scale: int = 4,
|
||||
act_type: str = "gelu",
|
||||
num_heads: int = 8,
|
||||
num_encoder_layers: int = 3,
|
||||
num_decoder_layers: int = 3,
|
||||
scaling: bool = True,
|
||||
lags_seq: Optional[List[int]] = None,
|
||||
time_features: Optional[List[TimeFeature]] = None,
|
||||
use_feat_dynamic_real: bool = False,
|
||||
use_feat_static_cat: bool = False,
|
||||
use_feat_static_real: bool = False,
|
||||
num_parallel_samples: int = 100,
|
||||
) -> None:
|
||||
super().__init__(trainer=trainer)
|
||||
|
||||
self.input_size = input_size
|
||||
self.freq = freq
|
||||
self.prediction_length = prediction_length
|
||||
self.context_length = (
|
||||
context_length if context_length is not None else prediction_length
|
||||
)
|
||||
self.distr_output = distr_output
|
||||
self.dropout_rate = dropout_rate
|
||||
self.use_feat_dynamic_real = use_feat_dynamic_real
|
||||
self.use_feat_static_cat = use_feat_static_cat
|
||||
self.use_feat_static_real = use_feat_static_real
|
||||
self.cardinality = cardinality if use_feat_static_cat else [1]
|
||||
self.embedding_dimension = embedding_dimension
|
||||
self.num_parallel_samples = num_parallel_samples
|
||||
self.lags_seq = (
|
||||
lags_seq
|
||||
if lags_seq is not None
|
||||
else lags_for_fourier_time_features_from_frequency(freq_str=freq)
|
||||
)
|
||||
self.time_features = (
|
||||
time_features
|
||||
if time_features is not None
|
||||
else fourier_time_features_from_frequency(self.freq)
|
||||
)
|
||||
self.history_length = self.context_length + max(self.lags_seq)
|
||||
self.scaling = scaling
|
||||
|
||||
self.d_model = d_model
|
||||
self.num_heads = num_heads
|
||||
self.act_type = act_type
|
||||
self.dim_feedforward_scale = dim_feedforward_scale
|
||||
self.num_encoder_layers = num_encoder_layers
|
||||
self.num_decoder_layers = num_decoder_layers
|
||||
|
||||
self.train_sampler = ExpectedNumInstanceSampler(
|
||||
num_instances=1.0, min_future=prediction_length
|
||||
)
|
||||
self.validation_sampler = ValidationSplitSampler(min_future=prediction_length)
|
||||
|
||||
def create_transformation(self) -> Transformation:
|
||||
remove_field_names = [
|
||||
FieldName.FEAT_DYNAMIC_CAT,
|
||||
]
|
||||
if not self.use_feat_dynamic_real:
|
||||
remove_field_names.append(FieldName.FEAT_DYNAMIC_REAL)
|
||||
if not self.use_feat_static_real:
|
||||
remove_field_names.append(FieldName.FEAT_STATIC_REAL)
|
||||
return Chain(
|
||||
[RemoveFields(field_names=remove_field_names)]
|
||||
+ (
|
||||
[SetField(output_field=FieldName.FEAT_STATIC_CAT, value=[0])]
|
||||
if not self.use_feat_static_cat
|
||||
else []
|
||||
)
|
||||
+ (
|
||||
[SetField(output_field=FieldName.FEAT_STATIC_REAL, value=[0.0])]
|
||||
if not self.use_feat_static_real
|
||||
else []
|
||||
)
|
||||
+ [
|
||||
AsNumpyArray(
|
||||
field=FieldName.FEAT_STATIC_CAT, expected_ndim=1, dtype=np.long
|
||||
),
|
||||
AsNumpyArray(
|
||||
field=FieldName.FEAT_STATIC_REAL,
|
||||
expected_ndim=1,
|
||||
dtype=self.dtype,
|
||||
),
|
||||
AsNumpyArray(
|
||||
field=FieldName.TARGET,
|
||||
# in the following line, we add 1 for the time dimension
|
||||
expected_ndim=1 + len(self.distr_output.event_shape),
|
||||
),
|
||||
AddObservedValuesIndicator(
|
||||
target_field=FieldName.TARGET,
|
||||
output_field=FieldName.OBSERVED_VALUES,
|
||||
),
|
||||
AddTimeFeatures(
|
||||
start_field=FieldName.START,
|
||||
target_field=FieldName.TARGET,
|
||||
output_field=FieldName.FEAT_TIME,
|
||||
time_features=self.time_features,
|
||||
pred_length=self.prediction_length,
|
||||
),
|
||||
AddAgeFeature(
|
||||
target_field=FieldName.TARGET,
|
||||
output_field=FieldName.FEAT_AGE,
|
||||
pred_length=self.prediction_length,
|
||||
log_scale=True,
|
||||
),
|
||||
VstackFeatures(
|
||||
output_field=FieldName.FEAT_TIME,
|
||||
input_fields=[FieldName.FEAT_TIME, FieldName.FEAT_AGE]
|
||||
+ (
|
||||
[FieldName.FEAT_DYNAMIC_REAL]
|
||||
if self.use_feat_dynamic_real
|
||||
else []
|
||||
),
|
||||
),
|
||||
]
|
||||
)
|
||||
|
||||
def create_instance_splitter(self, mode: str):
|
||||
assert mode in ["training", "validation", "test"]
|
||||
|
||||
instance_sampler = {
|
||||
"training": self.train_sampler,
|
||||
"validation": self.validation_sampler,
|
||||
"test": TestSplitSampler(),
|
||||
}[mode]
|
||||
|
||||
return InstanceSplitter(
|
||||
target_field=FieldName.TARGET,
|
||||
is_pad_field=FieldName.IS_PAD,
|
||||
start_field=FieldName.START,
|
||||
forecast_start_field=FieldName.FORECAST_START,
|
||||
instance_sampler=instance_sampler,
|
||||
past_length=self.history_length,
|
||||
future_length=self.prediction_length,
|
||||
time_series_fields=[
|
||||
FieldName.FEAT_TIME,
|
||||
FieldName.OBSERVED_VALUES,
|
||||
],
|
||||
)
|
||||
|
||||
def create_training_network(self, device: torch.device) -> HopfieldTrainingNetwork:
|
||||
|
||||
training_network = HopfieldTrainingNetwork(
|
||||
input_size=self.input_size,
|
||||
num_heads=self.num_heads,
|
||||
act_type=self.act_type,
|
||||
dropout_rate=self.dropout_rate,
|
||||
d_model=self.d_model,
|
||||
dim_feedforward_scale=self.dim_feedforward_scale,
|
||||
num_encoder_layers=self.num_encoder_layers,
|
||||
num_decoder_layers=self.num_decoder_layers,
|
||||
history_length=self.history_length,
|
||||
context_length=self.context_length,
|
||||
prediction_length=self.prediction_length,
|
||||
distr_output=self.distr_output,
|
||||
cardinality=self.cardinality,
|
||||
embedding_dimension=self.embedding_dimension,
|
||||
lags_seq=self.lags_seq,
|
||||
scaling=self.scaling,
|
||||
).to(device)
|
||||
|
||||
return training_network
|
||||
|
||||
def create_predictor(
|
||||
self,
|
||||
transformation: Transformation,
|
||||
trained_network: HopfieldTrainingNetwork,
|
||||
device: torch.device,
|
||||
) -> Predictor:
|
||||
|
||||
prediction_network = HopfieldPredictionNetwork(
|
||||
input_size=self.input_size,
|
||||
num_heads=self.num_heads,
|
||||
act_type=self.act_type,
|
||||
dropout_rate=self.dropout_rate,
|
||||
d_model=self.d_model,
|
||||
dim_feedforward_scale=self.dim_feedforward_scale,
|
||||
num_encoder_layers=self.num_encoder_layers,
|
||||
num_decoder_layers=self.num_decoder_layers,
|
||||
history_length=self.history_length,
|
||||
context_length=self.context_length,
|
||||
prediction_length=self.prediction_length,
|
||||
distr_output=self.distr_output,
|
||||
cardinality=self.cardinality,
|
||||
embedding_dimension=self.embedding_dimension,
|
||||
lags_seq=self.lags_seq,
|
||||
scaling=self.scaling,
|
||||
num_parallel_samples=self.num_parallel_samples,
|
||||
).to(device)
|
||||
|
||||
copy_parameters(trained_network, prediction_network)
|
||||
input_names = get_module_forward_input_names(prediction_network)
|
||||
prediction_splitter = self.create_instance_splitter("test")
|
||||
|
||||
return PyTorchPredictor(
|
||||
input_transform=transformation + prediction_splitter,
|
||||
input_names=input_names,
|
||||
prediction_net=prediction_network,
|
||||
batch_size=self.trainer.batch_size,
|
||||
freq=self.freq,
|
||||
prediction_length=self.prediction_length,
|
||||
device=device,
|
||||
)
|
||||
@@ -0,0 +1,477 @@
|
||||
from typing import List, Optional, Tuple
|
||||
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
|
||||
from gluonts.core.component import validated
|
||||
from gluonts.torch.distributions import DistributionOutput
|
||||
from pts.modules import MeanScaler, NOPScaler, FeatureEmbedder
|
||||
|
||||
from hflayers import Hopfield
|
||||
from hflayers.transformer import HopfieldDecoderLayer, HopfieldEncoderLayer
|
||||
|
||||
|
||||
def prod(xs):
|
||||
p = 1
|
||||
for x in xs:
|
||||
p *= x
|
||||
return p
|
||||
|
||||
|
||||
class HopfieldNetwork(nn.Module):
|
||||
@validated()
|
||||
def __init__(
|
||||
self,
|
||||
input_size: int,
|
||||
d_model: int,
|
||||
num_heads: int,
|
||||
act_type: str,
|
||||
dropout_rate: float,
|
||||
dim_feedforward_scale: int,
|
||||
num_encoder_layers: int,
|
||||
num_decoder_layers: int,
|
||||
history_length: int,
|
||||
context_length: int,
|
||||
prediction_length: int,
|
||||
distr_output: DistributionOutput,
|
||||
cardinality: List[int],
|
||||
embedding_dimension: List[int],
|
||||
lags_seq: List[int],
|
||||
scaling: bool = True,
|
||||
**kwargs,
|
||||
) -> None:
|
||||
super().__init__(**kwargs)
|
||||
|
||||
self.history_length = history_length
|
||||
self.context_length = context_length
|
||||
self.prediction_length = prediction_length
|
||||
self.scaling = scaling
|
||||
self.cardinality = cardinality
|
||||
self.embedding_dimension = embedding_dimension
|
||||
self.distr_output = distr_output
|
||||
|
||||
assert len(set(lags_seq)) == len(lags_seq), "no duplicated lags allowed!"
|
||||
lags_seq.sort()
|
||||
|
||||
self.lags_seq = lags_seq
|
||||
|
||||
self.target_shape = distr_output.event_shape
|
||||
|
||||
# [B, T, input_size] -> [B, T, d_model]
|
||||
# self.encoder_input = nn.Linear(input_size, d_model)
|
||||
# self.decoder_input = nn.Linear(input_size, d_model)
|
||||
|
||||
# [B, T, d_model] where d_model / num_heads is int
|
||||
encoder_association = Hopfield(input_size=input_size, num_heads=num_heads)
|
||||
encoder_layer = HopfieldEncoderLayer(
|
||||
encoder_association,
|
||||
dim_feedforward=dim_feedforward_scale * input_size,
|
||||
dropout=dropout_rate,
|
||||
activation=act_type,
|
||||
)
|
||||
transformer_encoder = nn.TransformerEncoder(
|
||||
encoder_layer, num_layers=num_encoder_layers
|
||||
)
|
||||
|
||||
decoder_association_self = Hopfield(input_size=input_size, num_heads=num_heads)
|
||||
decoder_association_cross = Hopfield(input_size=input_size, num_heads=num_heads)
|
||||
decoder_layer = HopfieldDecoderLayer(
|
||||
hopfield_association_self=decoder_association_self,
|
||||
hopfield_association_cross=decoder_association_cross,
|
||||
dim_feedforward=dim_feedforward_scale * input_size,
|
||||
dropout=dropout_rate,
|
||||
activation=act_type,
|
||||
)
|
||||
transformer_decoder = nn.TransformerDecoder(
|
||||
decoder_layer, num_layers=num_decoder_layers
|
||||
)
|
||||
|
||||
self.transformer = nn.Transformer(
|
||||
d_model=input_size,
|
||||
nhead=num_heads,
|
||||
custom_encoder=transformer_encoder,
|
||||
custom_decoder=transformer_decoder,
|
||||
batch_first=True,
|
||||
)
|
||||
|
||||
self.proj_dist_args = distr_output.get_args_proj(input_size)
|
||||
|
||||
self.embedder = FeatureEmbedder(
|
||||
cardinalities=cardinality,
|
||||
embedding_dims=embedding_dimension,
|
||||
)
|
||||
|
||||
if scaling:
|
||||
self.scaler = MeanScaler(keepdim=True)
|
||||
else:
|
||||
self.scaler = NOPScaler(keepdim=True)
|
||||
|
||||
# mask
|
||||
self.register_buffer(
|
||||
"tgt_mask",
|
||||
self.transformer.generate_square_subsequent_mask(prediction_length),
|
||||
)
|
||||
|
||||
@staticmethod
|
||||
def get_lagged_subsequences(
|
||||
sequence: torch.Tensor,
|
||||
sequence_length: int,
|
||||
indices: List[int],
|
||||
subsequences_length: int = 1,
|
||||
) -> torch.Tensor:
|
||||
"""
|
||||
Returns lagged subsequences of a given sequence.
|
||||
Parameters
|
||||
----------
|
||||
sequence : Tensor
|
||||
the sequence from which lagged subsequences should be extracted.
|
||||
Shape: (N, T, C).
|
||||
sequence_length : int
|
||||
length of sequence in the T (time) dimension (axis = 1).
|
||||
indices : List[int]
|
||||
list of lag indices to be used.
|
||||
subsequences_length : int
|
||||
length of the subsequences to be extracted.
|
||||
Returns
|
||||
--------
|
||||
lagged : Tensor
|
||||
a tensor of shape (N, S, C, I), where S = subsequences_length and
|
||||
I = len(indices), containing lagged subsequences. Specifically,
|
||||
lagged[i, j, :, k] = sequence[i, -indices[k]-S+j, :].
|
||||
"""
|
||||
assert max(indices) + subsequences_length <= sequence_length, (
|
||||
f"lags cannot go further than history length, found lag {max(indices)} "
|
||||
f"while history length is only {sequence_length}"
|
||||
)
|
||||
assert all(lag_index >= 0 for lag_index in indices)
|
||||
|
||||
lagged_values = []
|
||||
for lag_index in indices:
|
||||
begin_index = -lag_index - subsequences_length
|
||||
end_index = -lag_index if lag_index > 0 else None
|
||||
lagged_values.append(sequence[:, begin_index:end_index, ...])
|
||||
return torch.stack(lagged_values, dim=-1)
|
||||
|
||||
def create_network_input(
|
||||
self,
|
||||
feat_static_cat: torch.Tensor, # (batch_size, num_features)
|
||||
feat_static_real: torch.Tensor,
|
||||
# (batch_size, num_features, history_length)
|
||||
past_time_feat: torch.Tensor,
|
||||
past_target: torch.Tensor, # (batch_size, history_length, 1)
|
||||
past_observed_values: torch.Tensor, # (batch_size, history_length)
|
||||
future_time_feat: Optional[
|
||||
torch.Tensor
|
||||
], # (batch_size, num_features, prediction_length)
|
||||
# (batch_size, prediction_length)
|
||||
future_target: Optional[torch.Tensor],
|
||||
) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
|
||||
"""
|
||||
Creates inputs for the transformer network.
|
||||
All tensor arguments should have NTC layout.
|
||||
"""
|
||||
|
||||
if future_time_feat is None or future_target is None:
|
||||
time_feat = past_time_feat[
|
||||
:, self.history_length - self.context_length :, ...
|
||||
]
|
||||
sequence = past_target
|
||||
sequence_length = self.history_length
|
||||
subsequences_length = self.context_length
|
||||
else:
|
||||
time_feat = torch.cat(
|
||||
(
|
||||
past_time_feat[:, self.history_length - self.context_length :, ...],
|
||||
future_time_feat,
|
||||
),
|
||||
dim=1,
|
||||
)
|
||||
sequence = torch.cat((past_target, future_target), dim=1)
|
||||
sequence_length = self.history_length + self.prediction_length
|
||||
subsequences_length = self.context_length + self.prediction_length
|
||||
|
||||
# (batch_size, sub_seq_len, *target_shape, num_lags)
|
||||
lags = self.get_lagged_subsequences(
|
||||
sequence=sequence,
|
||||
sequence_length=sequence_length,
|
||||
indices=self.lags_seq,
|
||||
subsequences_length=subsequences_length,
|
||||
)
|
||||
|
||||
# scale is computed on the context length last units of the past target
|
||||
# scale shape is (batch_size, 1, *target_shape)
|
||||
_, scale = self.scaler(
|
||||
past_target[:, -self.context_length :, ...],
|
||||
past_observed_values[:, -self.context_length :, ...],
|
||||
)
|
||||
embedded_cat = self.embedder(feat_static_cat)
|
||||
|
||||
# in addition to embedding features, use the log scale as it can help prediction too
|
||||
# (batch_size, num_features + prod(target_shape))
|
||||
static_feat = torch.cat(
|
||||
(
|
||||
embedded_cat,
|
||||
feat_static_real,
|
||||
torch.log(scale)
|
||||
if len(self.target_shape) == 0
|
||||
else torch.log(scale.squeeze(1)),
|
||||
),
|
||||
dim=1,
|
||||
)
|
||||
|
||||
repeated_static_feat = static_feat.unsqueeze(1).expand(
|
||||
-1, subsequences_length, -1
|
||||
)
|
||||
|
||||
# (batch_size, sub_seq_len, *target_shape, num_lags)
|
||||
lags_scaled = lags / scale.unsqueeze(-1)
|
||||
|
||||
# from (batch_size, sub_seq_len, *target_shape, num_lags)
|
||||
# to (batch_size, sub_seq_len, prod(target_shape) * num_lags)
|
||||
input_lags = lags_scaled.reshape(
|
||||
(-1, subsequences_length, len(self.lags_seq) * prod(self.target_shape))
|
||||
)
|
||||
|
||||
# (batch_size, sub_seq_len, input_dim)
|
||||
inputs = torch.cat((input_lags, time_feat, repeated_static_feat), dim=-1)
|
||||
|
||||
return inputs, scale, static_feat
|
||||
|
||||
|
||||
class HopfieldTrainingNetwork(HopfieldNetwork):
|
||||
# noinspection PyMethodOverriding,PyPep8Naming
|
||||
def forward(
|
||||
self,
|
||||
feat_static_cat: torch.Tensor,
|
||||
feat_static_real: torch.Tensor,
|
||||
past_time_feat: torch.Tensor,
|
||||
past_target: torch.Tensor,
|
||||
past_observed_values: torch.Tensor,
|
||||
future_time_feat: torch.Tensor,
|
||||
future_target: torch.Tensor,
|
||||
) -> torch.Tensor:
|
||||
"""
|
||||
Computes the loss for training Hopfield Transformer, all inputs tensors representing time series have NTC layout.
|
||||
Parameters
|
||||
----------
|
||||
feat_static_cat : (batch_size, num_features)
|
||||
feat_static_real: torch.Tensor, # (batch_size, num_features)
|
||||
past_time_feat : (batch_size, history_length, num_features)
|
||||
past_target : (batch_size, history_length, *target_shape)
|
||||
past_observed_values : (batch_size, history_length, *target_shape, seq_len)
|
||||
future_time_feat : (batch_size, prediction_length, num_features)
|
||||
future_target : (batch_size, prediction_length, *target_shape)
|
||||
Returns
|
||||
-------
|
||||
Loss with shape (batch_size, context + prediction_length, 1)
|
||||
"""
|
||||
|
||||
# create the inputs for the encoder
|
||||
inputs, scale, _ = self.create_network_input(
|
||||
feat_static_cat=feat_static_cat,
|
||||
feat_static_real=feat_static_real,
|
||||
past_time_feat=past_time_feat,
|
||||
past_target=past_target,
|
||||
past_observed_values=past_observed_values,
|
||||
future_time_feat=future_time_feat,
|
||||
future_target=future_target,
|
||||
)
|
||||
|
||||
enc_input = inputs[:, : self.context_length, ...] # F.slice_axis(
|
||||
# inputs, axis=1, begin=0, end=self.context_length
|
||||
# )
|
||||
dec_input = inputs[:, self.context_length :, ...] # F.slice_axis(
|
||||
# inputs, axis=1, begin=self.context_length, end=None
|
||||
# )
|
||||
|
||||
# pass through encoder [T, B, b_model]
|
||||
enc_out = self.transformer.encoder(
|
||||
# self.encoder_input(enc_input).permute(1, 0, 2)
|
||||
enc_input
|
||||
)
|
||||
|
||||
# input to decoder
|
||||
dec_output = self.transformer.decoder(
|
||||
# self.decoder_input(dec_input).permute(1, 0, 2),
|
||||
dec_input,
|
||||
enc_out, # memory
|
||||
tgt_mask=self.tgt_mask,
|
||||
)
|
||||
|
||||
# compute loss
|
||||
# distr_args = self.proj_dist_args(dec_output.permute(1, 0, 2))
|
||||
distr_args = self.proj_dist_args(dec_output)
|
||||
distr = self.distr_output.distribution(distr_args, scale=scale)
|
||||
loss = -distr.log_prob(future_target)
|
||||
|
||||
return loss.mean()
|
||||
|
||||
|
||||
class HopfieldPredictionNetwork(HopfieldNetwork):
|
||||
def __init__(self, num_parallel_samples: int = 100, **kwargs) -> None:
|
||||
super().__init__(**kwargs)
|
||||
self.num_parallel_samples = num_parallel_samples
|
||||
|
||||
# for decoding the lags are shifted by one,
|
||||
# at the first time-step of the decoder a lag of one corresponds to the last target value
|
||||
self.shifted_lags = [l - 1 for l in self.lags_seq]
|
||||
|
||||
def sampling_decoder(
|
||||
self,
|
||||
static_feat: torch.Tensor,
|
||||
past_target: torch.Tensor,
|
||||
time_feat: torch.Tensor,
|
||||
scale: torch.Tensor,
|
||||
enc_out: torch.Tensor,
|
||||
) -> torch.Tensor:
|
||||
"""
|
||||
Computes sample paths by decoding from the transformer.
|
||||
Parameters
|
||||
----------
|
||||
static_feat : Tensor
|
||||
static features. Shape: (batch_size, num_static_features).
|
||||
past_target : Tensor
|
||||
target history. Shape: (batch_size, history_length, 1).
|
||||
time_feat : Tensor
|
||||
time features. Shape: (batch_size, prediction_length, num_time_features).
|
||||
scale : Tensor
|
||||
tensor containing the scale of each element in the batch. Shape: (batch_size, ).
|
||||
enc_out: Tensor
|
||||
output of the encoder. Shape: (batch_size, num_cells)
|
||||
Returns
|
||||
--------
|
||||
sample_paths : Tensor
|
||||
a tensor containing sampled paths. Shape: (batch_size, num_sample_paths, prediction_length).
|
||||
"""
|
||||
|
||||
# blows-up the dimension of each tensor to batch_size * self.num_parallel_samples for increasing parallelism
|
||||
repeated_past_target = past_target.repeat_interleave(
|
||||
repeats=self.num_parallel_samples, dim=0
|
||||
)
|
||||
|
||||
repeated_time_feat = time_feat.repeat_interleave(
|
||||
repeats=self.num_parallel_samples, dim=0
|
||||
)
|
||||
|
||||
repeated_static_feat = static_feat.repeat_interleave(
|
||||
repeats=self.num_parallel_samples, dim=0
|
||||
).unsqueeze(1)
|
||||
|
||||
repeated_enc_out = enc_out.repeat_interleave(
|
||||
repeats=self.num_parallel_samples, dim=0 # 1
|
||||
)
|
||||
|
||||
repeated_scale = scale.repeat_interleave(
|
||||
repeats=self.num_parallel_samples, dim=0
|
||||
)
|
||||
|
||||
future_samples = []
|
||||
|
||||
# for each future time-units we draw new samples for this time-unit and update the state
|
||||
for k in range(self.prediction_length):
|
||||
lags = self.get_lagged_subsequences(
|
||||
sequence=repeated_past_target,
|
||||
sequence_length=self.history_length + k,
|
||||
indices=self.shifted_lags,
|
||||
subsequences_length=1,
|
||||
)
|
||||
|
||||
# (batch_size * num_samples, 1, *target_shape, num_lags)
|
||||
lags_scaled = lags / repeated_scale.unsqueeze(1)
|
||||
# lags_scaled = F.broadcast_div(
|
||||
# lags, repeated_scale.expand_dims(axis=-1)
|
||||
# )
|
||||
|
||||
# from (batch_size * num_samples, 1, *target_shape, num_lags)
|
||||
# to (batch_size * num_samples, 1, prod(target_shape) * num_lags)
|
||||
input_lags = lags_scaled.reshape(
|
||||
shape=(-1, 1, prod(self.target_shape) * len(self.lags_seq))
|
||||
)
|
||||
|
||||
# (batch_size * num_samples, 1, prod(target_shape) * num_lags + num_time_features + num_static_features)
|
||||
dec_input = torch.cat(
|
||||
(input_lags, repeated_time_feat[:, k : k + 1, :], repeated_static_feat),
|
||||
dim=-1,
|
||||
)
|
||||
|
||||
dec_output = self.transformer.decoder(
|
||||
# self.decoder_input(dec_input).permute(1, 0, 2),
|
||||
dec_input,
|
||||
repeated_enc_out,
|
||||
)
|
||||
|
||||
# distr_args = self.proj_dist_args(dec_output.permute(1, 0, 2))
|
||||
distr_args = self.proj_dist_args(dec_output)
|
||||
|
||||
# compute likelihood of target given the predicted parameters
|
||||
distr = self.distr_output.distribution(distr_args, scale=repeated_scale)
|
||||
|
||||
# (batch_size * num_samples, 1, *target_shape)
|
||||
new_samples = distr.sample()
|
||||
|
||||
# (batch_size * num_samples, seq_len, *target_shape)
|
||||
repeated_past_target = torch.cat((repeated_past_target, new_samples), dim=1)
|
||||
future_samples.append(new_samples)
|
||||
|
||||
# reset cache of the decoder
|
||||
# self.transformer.decoder.cache_reset()
|
||||
|
||||
# (batch_size * num_samples, prediction_length, *target_shape)
|
||||
samples = torch.cat(future_samples, dim=1)
|
||||
|
||||
# (batch_size, num_samples, *target_shape, prediction_length)
|
||||
return samples.reshape(
|
||||
(
|
||||
(-1, self.num_parallel_samples)
|
||||
+ self.target_shape
|
||||
+ (self.prediction_length,)
|
||||
)
|
||||
)
|
||||
|
||||
def forward(
|
||||
self,
|
||||
feat_static_cat: torch.Tensor,
|
||||
feat_static_real: torch.Tensor,
|
||||
past_time_feat: torch.Tensor,
|
||||
past_target: torch.Tensor,
|
||||
past_observed_values: torch.Tensor,
|
||||
future_time_feat: torch.Tensor,
|
||||
) -> torch.Tensor:
|
||||
"""
|
||||
Predicts samples, all tensors should have NTC layout.
|
||||
Parameters
|
||||
----------
|
||||
feat_static_cat : (batch_size, num_features)
|
||||
feat_static_real : (batch_size, num_features)
|
||||
past_time_feat : (batch_size, history_length, num_features)
|
||||
past_target : (batch_size, history_length, *target_shape)
|
||||
past_observed_values : (batch_size, history_length, *target_shape)
|
||||
future_time_feat : (batch_size, prediction_length, num_features)
|
||||
Returns predicted samples
|
||||
-------
|
||||
"""
|
||||
|
||||
# create the inputs for the encoder
|
||||
inputs, scale, static_feat = self.create_network_input(
|
||||
feat_static_cat=feat_static_cat,
|
||||
feat_static_real=feat_static_real,
|
||||
past_time_feat=past_time_feat,
|
||||
past_target=past_target,
|
||||
past_observed_values=past_observed_values,
|
||||
future_time_feat=None,
|
||||
future_target=None,
|
||||
)
|
||||
|
||||
# pass through encoder
|
||||
enc_out = self.transformer.encoder(
|
||||
# self.encoder_input(inputs).permute(1, 0, 2)
|
||||
inputs
|
||||
)
|
||||
|
||||
return self.sampling_decoder(
|
||||
past_target=past_target,
|
||||
time_feat=future_time_feat,
|
||||
static_feat=static_feat,
|
||||
scale=scale,
|
||||
enc_out=enc_out,
|
||||
)
|
||||
+1
-1
@@ -4,7 +4,7 @@ import torch
|
||||
import torch.nn as nn
|
||||
from gluonts.core.component import validated
|
||||
from gluonts.time_feature import get_lags_for_frequency
|
||||
from gluonts.torch.modules.distribution_output import DistributionOutput, StudentTOutput
|
||||
from gluonts.torch.distributions import DistributionOutput, StudentTOutput
|
||||
from gluonts.torch.modules.feature import FeatureEmbedder
|
||||
from gluonts.torch.modules.scaler import MeanScaler, NOPScaler
|
||||
from hflayers import Hopfield
|
||||
|
||||
@@ -8,7 +8,7 @@ from gluonts.itertools import Cyclic, IterableSlice, PseudoShuffled
|
||||
from gluonts.time_feature import TimeFeature, time_features_from_frequency_str
|
||||
from gluonts.torch.model.estimator import PyTorchLightningEstimator
|
||||
from gluonts.torch.model.predictor import PyTorchPredictor
|
||||
from gluonts.torch.modules.distribution_output import DistributionOutput, StudentTOutput
|
||||
from gluonts.torch.distributions import DistributionOutput, StudentTOutput
|
||||
from gluonts.torch.modules.loss import DistributionLoss, NegativeLogLikelihood
|
||||
from gluonts.torch.util import IterableDataset
|
||||
from gluonts.transform import (
|
||||
|
||||
+159
-134
File diff suppressed because one or more lines are too long
+1
-1
@@ -7,7 +7,7 @@ import torch.nn as nn
|
||||
import torch.nn.functional as F
|
||||
from gluonts.core.component import validated
|
||||
from gluonts.time_feature import get_lags_for_frequency
|
||||
from gluonts.torch.modules.distribution_output import DistributionOutput, StudentTOutput
|
||||
from gluonts.torch.distributions import DistributionOutput, StudentTOutput
|
||||
from gluonts.torch.modules.feature import FeatureEmbedder
|
||||
from gluonts.torch.modules.scaler import MeanScaler, NOPScaler
|
||||
|
||||
|
||||
@@ -8,7 +8,7 @@ from gluonts.itertools import Cyclic, IterableSlice, PseudoShuffled
|
||||
from gluonts.time_feature import TimeFeature, time_features_from_frequency_str
|
||||
from gluonts.torch.model.estimator import PyTorchLightningEstimator
|
||||
from gluonts.torch.model.predictor import PyTorchPredictor
|
||||
from gluonts.torch.modules.distribution_output import DistributionOutput, StudentTOutput
|
||||
from gluonts.torch.distributions import DistributionOutput, StudentTOutput
|
||||
from gluonts.torch.modules.loss import DistributionLoss, NegativeLogLikelihood
|
||||
from gluonts.time_feature import get_lags_for_frequency
|
||||
from gluonts.torch.util import IterableDataset
|
||||
@@ -79,7 +79,7 @@ class PyraformerEstimator(PyTorchLightningEstimator):
|
||||
n_layer: int = 4,
|
||||
enc_in: int = 1, # depends on dataset used
|
||||
CSCM: str = "Bottleneck_Construct", # [Bottleneck_Construct, Conv_Construct, MaxPooling_Construct, AvgPooling_Construct]
|
||||
embed_type: str = "CustomEmbedding", #[DataEmbedding, CustomEmbedding]
|
||||
embed_type: str = "CustomEmbedding", # [DataEmbedding, CustomEmbedding]
|
||||
truncate: bool = False,
|
||||
# loss: DistributionLoss = LossFactory,
|
||||
ignore_zero: bool = True,
|
||||
@@ -339,7 +339,7 @@ class PyraformerEstimator(PyTorchLightningEstimator):
|
||||
num_seq=self.num_seq,
|
||||
input_size=self.input_size,
|
||||
dropout=self.dropout,
|
||||
d_model = self.d_model,
|
||||
d_model=self.d_model,
|
||||
d_inner_hid=self.d_inner_hid,
|
||||
d_k=self.d_k,
|
||||
d_v=self.d_v,
|
||||
@@ -392,8 +392,8 @@ class PyraformerEstimator(PyTorchLightningEstimator):
|
||||
cardinality=self.cardinality,
|
||||
embedding_dimension=self.embedding_dimension,
|
||||
num_parallel_samples=self.num_parallel_samples,
|
||||
embed_type = self.embed_type,
|
||||
distr_output= self.distr_output,
|
||||
embed_type=self.embed_type,
|
||||
distr_output=self.distr_output,
|
||||
device=device,
|
||||
)
|
||||
return PyraformerLightningModule(model=model, loss=self.loss)
|
||||
|
||||
@@ -3,7 +3,7 @@ import torch
|
||||
import torch.nn as nn
|
||||
from gluonts.core.component import validated
|
||||
from gluonts.time_feature import get_lags_for_frequency
|
||||
from gluonts.torch.modules.distribution_output import DistributionOutput, StudentTOutput
|
||||
from gluonts.torch.distributions import DistributionOutput, StudentTOutput
|
||||
from gluonts.torch.modules.feature import FeatureEmbedder
|
||||
from gluonts.torch.modules.scaler import MeanScaler, NOPScaler
|
||||
|
||||
@@ -151,7 +151,6 @@ class PyraformerSSModel(nn.Module):
|
||||
scaling,
|
||||
num_parallel_samples,
|
||||
device,
|
||||
|
||||
):
|
||||
|
||||
super().__init__()
|
||||
@@ -179,7 +178,7 @@ class PyraformerSSModel(nn.Module):
|
||||
# convert hidden vectors into two scalar
|
||||
self.mean_hidden = Predictor(4 * d_model, 1)
|
||||
self.var_hidden = Predictor(4 * d_model, 1)
|
||||
|
||||
|
||||
self.softplus = nn.Softplus()
|
||||
self.distr_output = distr_output
|
||||
|
||||
@@ -512,7 +511,7 @@ class PyraformerLRModel(nn.Module):
|
||||
num_parallel_samples,
|
||||
embed_type,
|
||||
distr_output,
|
||||
device
|
||||
device,
|
||||
):
|
||||
super().__init__()
|
||||
|
||||
@@ -524,7 +523,7 @@ class PyraformerLRModel(nn.Module):
|
||||
self.distr_output = distr_output
|
||||
self.context_length = context_length
|
||||
self.lags_seq = lags_seq
|
||||
|
||||
|
||||
self.encoder = Encoder(
|
||||
# model,
|
||||
window_size,
|
||||
@@ -593,10 +592,10 @@ class PyraformerLRModel(nn.Module):
|
||||
)
|
||||
|
||||
return pred
|
||||
|
||||
|
||||
@property
|
||||
def _past_length(self) -> int:
|
||||
return self.predict_step #+ max(0,self.lags_seq)
|
||||
return self.predict_step # + max(0,self.lags_seq)
|
||||
|
||||
@property
|
||||
def _number_of_features(self) -> int:
|
||||
|
||||
@@ -8,7 +8,7 @@ from gluonts.itertools import Cyclic, IterableSlice, PseudoShuffled
|
||||
from gluonts.time_feature import TimeFeature, time_features_from_frequency_str
|
||||
from gluonts.torch.model.estimator import PyTorchLightningEstimator
|
||||
from gluonts.torch.model.predictor import PyTorchPredictor
|
||||
from gluonts.torch.modules.distribution_output import DistributionOutput, StudentTOutput
|
||||
from gluonts.torch.distributions import DistributionOutput, StudentTOutput
|
||||
from gluonts.torch.modules.loss import DistributionLoss, NegativeLogLikelihood
|
||||
from gluonts.torch.util import IterableDataset
|
||||
from gluonts.transform import (
|
||||
|
||||
+1
-1
@@ -5,7 +5,7 @@ import torch
|
||||
import torch.nn as nn
|
||||
from gluonts.core.component import validated
|
||||
from gluonts.time_feature import get_lags_for_frequency
|
||||
from gluonts.torch.modules.distribution_output import DistributionOutput, StudentTOutput
|
||||
from gluonts.torch.distributions import DistributionOutput, StudentTOutput
|
||||
from gluonts.torch.modules.feature import FeatureEmbedder
|
||||
from gluonts.torch.modules.scaler import MeanScaler, NOPScaler
|
||||
from reformer_pytorch.reformer_pytorch import Reformer
|
||||
|
||||
+288
-237
File diff suppressed because one or more lines are too long
+1
-1
@@ -1,6 +1,6 @@
|
||||
orjson
|
||||
torch
|
||||
gluonts
|
||||
https://github.com/awslabs/gluon-ts
|
||||
pytorch-lightning
|
||||
datasets
|
||||
xformers
|
||||
|
||||
+1706
-1436
File diff suppressed because one or more lines are too long
+31
-59
@@ -51,7 +51,7 @@
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 4,
|
||||
"execution_count": 18,
|
||||
"id": "e772234f",
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
@@ -59,6 +59,7 @@
|
||||
"estimator = SwitchTransformerEstimator(\n",
|
||||
" freq=dataset.metadata.freq,\n",
|
||||
" prediction_length=dataset.metadata.prediction_length,\n",
|
||||
" context_length=8*dataset.metadata.prediction_length,\n",
|
||||
" num_feat_static_cat=1,\n",
|
||||
" cardinality=[321],\n",
|
||||
" embedding_dimension=[3],\n",
|
||||
@@ -67,20 +68,20 @@
|
||||
" num_encoder_layers=2,\n",
|
||||
" num_decoder_layers=2,\n",
|
||||
" nhead=2,\n",
|
||||
" n_experts = 4,\n",
|
||||
" capacity_factor = 0.2,\n",
|
||||
" n_experts=4,\n",
|
||||
" capacity_factor=1.0,\n",
|
||||
" \n",
|
||||
" activation=\"relu\",\n",
|
||||
"\n",
|
||||
" batch_size=128,\n",
|
||||
" num_batches_per_epoch=100,\n",
|
||||
" trainer_kwargs=dict(max_epochs=50, accelerator='gpu', gpus=1),\n",
|
||||
" trainer_kwargs=dict(max_epochs=20, accelerator='gpu', gpus=1),\n",
|
||||
")"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 5,
|
||||
"execution_count": 19,
|
||||
"id": "22d804e4",
|
||||
"metadata": {},
|
||||
"outputs": [
|
||||
@@ -114,7 +115,7 @@
|
||||
{
|
||||
"data": {
|
||||
"application/vnd.jupyter.widget-view+json": {
|
||||
"model_id": "78ce8638ba8b45e59858e078aa08b6ba",
|
||||
"model_id": "ea9100994cbc439b9a6c66b6beab4253",
|
||||
"version_major": 2,
|
||||
"version_minor": 0
|
||||
},
|
||||
@@ -139,56 +140,26 @@
|
||||
" default for now) or if `full_state_update=False` can be used safely.\n",
|
||||
" \n",
|
||||
" warnings.warn(*args, **kwargs)\n",
|
||||
"Epoch 0, global step 100: 'train_loss' reached 6.53932 (best 6.53932), saving model to '/mnt/scratch/kashif/pytorch-transformer-ts/switch/lightning_logs/version_1/checkpoints/epoch=0-step=100.ckpt' as top 1\n",
|
||||
"Epoch 1, global step 200: 'train_loss' reached 6.06170 (best 6.06170), saving model to '/mnt/scratch/kashif/pytorch-transformer-ts/switch/lightning_logs/version_1/checkpoints/epoch=1-step=200.ckpt' as top 1\n",
|
||||
"Epoch 2, global step 300: 'train_loss' reached 5.85441 (best 5.85441), saving model to '/mnt/scratch/kashif/pytorch-transformer-ts/switch/lightning_logs/version_1/checkpoints/epoch=2-step=300.ckpt' as top 1\n",
|
||||
"Epoch 3, global step 400: 'train_loss' reached 5.71498 (best 5.71498), saving model to '/mnt/scratch/kashif/pytorch-transformer-ts/switch/lightning_logs/version_1/checkpoints/epoch=3-step=400.ckpt' as top 1\n",
|
||||
"Epoch 4, global step 500: 'train_loss' reached 5.63095 (best 5.63095), saving model to '/mnt/scratch/kashif/pytorch-transformer-ts/switch/lightning_logs/version_1/checkpoints/epoch=4-step=500.ckpt' as top 1\n",
|
||||
"Epoch 5, global step 600: 'train_loss' reached 5.59845 (best 5.59845), saving model to '/mnt/scratch/kashif/pytorch-transformer-ts/switch/lightning_logs/version_1/checkpoints/epoch=5-step=600.ckpt' as top 1\n",
|
||||
"Epoch 6, global step 700: 'train_loss' reached 5.50794 (best 5.50794), saving model to '/mnt/scratch/kashif/pytorch-transformer-ts/switch/lightning_logs/version_1/checkpoints/epoch=6-step=700.ckpt' as top 1\n",
|
||||
"Epoch 7, global step 800: 'train_loss' reached 5.50530 (best 5.50530), saving model to '/mnt/scratch/kashif/pytorch-transformer-ts/switch/lightning_logs/version_1/checkpoints/epoch=7-step=800.ckpt' as top 1\n",
|
||||
"Epoch 8, global step 900: 'train_loss' reached 5.46209 (best 5.46209), saving model to '/mnt/scratch/kashif/pytorch-transformer-ts/switch/lightning_logs/version_1/checkpoints/epoch=8-step=900.ckpt' as top 1\n",
|
||||
"Epoch 9, global step 1000: 'train_loss' reached 5.44430 (best 5.44430), saving model to '/mnt/scratch/kashif/pytorch-transformer-ts/switch/lightning_logs/version_1/checkpoints/epoch=9-step=1000.ckpt' as top 1\n",
|
||||
"Epoch 10, global step 1100: 'train_loss' reached 5.44166 (best 5.44166), saving model to '/mnt/scratch/kashif/pytorch-transformer-ts/switch/lightning_logs/version_1/checkpoints/epoch=10-step=1100.ckpt' as top 1\n",
|
||||
"Epoch 11, global step 1200: 'train_loss' reached 5.36817 (best 5.36817), saving model to '/mnt/scratch/kashif/pytorch-transformer-ts/switch/lightning_logs/version_1/checkpoints/epoch=11-step=1200.ckpt' as top 1\n",
|
||||
"Epoch 12, global step 1300: 'train_loss' was not in top 1\n",
|
||||
"Epoch 0, global step 100: 'train_loss' reached 6.89574 (best 6.89574), saving model to '/mnt/scratch/kashif/pytorch-transformer-ts/switch/lightning_logs/version_3/checkpoints/epoch=0-step=100.ckpt' as top 1\n",
|
||||
"Epoch 1, global step 200: 'train_loss' reached 6.07414 (best 6.07414), saving model to '/mnt/scratch/kashif/pytorch-transformer-ts/switch/lightning_logs/version_3/checkpoints/epoch=1-step=200.ckpt' as top 1\n",
|
||||
"Epoch 2, global step 300: 'train_loss' reached 5.79265 (best 5.79265), saving model to '/mnt/scratch/kashif/pytorch-transformer-ts/switch/lightning_logs/version_3/checkpoints/epoch=2-step=300.ckpt' as top 1\n",
|
||||
"Epoch 3, global step 400: 'train_loss' reached 5.67099 (best 5.67099), saving model to '/mnt/scratch/kashif/pytorch-transformer-ts/switch/lightning_logs/version_3/checkpoints/epoch=3-step=400.ckpt' as top 1\n",
|
||||
"Epoch 4, global step 500: 'train_loss' reached 5.56751 (best 5.56751), saving model to '/mnt/scratch/kashif/pytorch-transformer-ts/switch/lightning_logs/version_3/checkpoints/epoch=4-step=500.ckpt' as top 1\n",
|
||||
"Epoch 5, global step 600: 'train_loss' reached 5.53956 (best 5.53956), saving model to '/mnt/scratch/kashif/pytorch-transformer-ts/switch/lightning_logs/version_3/checkpoints/epoch=5-step=600.ckpt' as top 1\n",
|
||||
"Epoch 6, global step 700: 'train_loss' reached 5.46131 (best 5.46131), saving model to '/mnt/scratch/kashif/pytorch-transformer-ts/switch/lightning_logs/version_3/checkpoints/epoch=6-step=700.ckpt' as top 1\n",
|
||||
"Epoch 7, global step 800: 'train_loss' reached 5.45841 (best 5.45841), saving model to '/mnt/scratch/kashif/pytorch-transformer-ts/switch/lightning_logs/version_3/checkpoints/epoch=7-step=800.ckpt' as top 1\n",
|
||||
"Epoch 8, global step 900: 'train_loss' reached 5.42569 (best 5.42569), saving model to '/mnt/scratch/kashif/pytorch-transformer-ts/switch/lightning_logs/version_3/checkpoints/epoch=8-step=900.ckpt' as top 1\n",
|
||||
"Epoch 9, global step 1000: 'train_loss' reached 5.38103 (best 5.38103), saving model to '/mnt/scratch/kashif/pytorch-transformer-ts/switch/lightning_logs/version_3/checkpoints/epoch=9-step=1000.ckpt' as top 1\n",
|
||||
"Epoch 10, global step 1100: 'train_loss' reached 5.37174 (best 5.37174), saving model to '/mnt/scratch/kashif/pytorch-transformer-ts/switch/lightning_logs/version_3/checkpoints/epoch=10-step=1100.ckpt' as top 1\n",
|
||||
"Epoch 11, global step 1200: 'train_loss' reached 5.35499 (best 5.35499), saving model to '/mnt/scratch/kashif/pytorch-transformer-ts/switch/lightning_logs/version_3/checkpoints/epoch=11-step=1200.ckpt' as top 1\n",
|
||||
"Epoch 12, global step 1300: 'train_loss' reached 5.33469 (best 5.33469), saving model to '/mnt/scratch/kashif/pytorch-transformer-ts/switch/lightning_logs/version_3/checkpoints/epoch=12-step=1300.ckpt' as top 1\n",
|
||||
"Epoch 13, global step 1400: 'train_loss' was not in top 1\n",
|
||||
"Epoch 14, global step 1500: 'train_loss' reached 5.32576 (best 5.32576), saving model to '/mnt/scratch/kashif/pytorch-transformer-ts/switch/lightning_logs/version_1/checkpoints/epoch=14-step=1500.ckpt' as top 1\n",
|
||||
"Epoch 15, global step 1600: 'train_loss' reached 5.29555 (best 5.29555), saving model to '/mnt/scratch/kashif/pytorch-transformer-ts/switch/lightning_logs/version_1/checkpoints/epoch=15-step=1600.ckpt' as top 1\n",
|
||||
"Epoch 14, global step 1500: 'train_loss' reached 5.28044 (best 5.28044), saving model to '/mnt/scratch/kashif/pytorch-transformer-ts/switch/lightning_logs/version_3/checkpoints/epoch=14-step=1500.ckpt' as top 1\n",
|
||||
"Epoch 15, global step 1600: 'train_loss' was not in top 1\n",
|
||||
"Epoch 16, global step 1700: 'train_loss' was not in top 1\n",
|
||||
"Epoch 17, global step 1800: 'train_loss' reached 5.29322 (best 5.29322), saving model to '/mnt/scratch/kashif/pytorch-transformer-ts/switch/lightning_logs/version_1/checkpoints/epoch=17-step=1800.ckpt' as top 1\n",
|
||||
"Epoch 17, global step 1800: 'train_loss' reached 5.27935 (best 5.27935), saving model to '/mnt/scratch/kashif/pytorch-transformer-ts/switch/lightning_logs/version_3/checkpoints/epoch=17-step=1800.ckpt' as top 1\n",
|
||||
"Epoch 18, global step 1900: 'train_loss' was not in top 1\n",
|
||||
"Epoch 19, global step 2000: 'train_loss' reached 5.26062 (best 5.26062), saving model to '/mnt/scratch/kashif/pytorch-transformer-ts/switch/lightning_logs/version_1/checkpoints/epoch=19-step=2000.ckpt' as top 1\n",
|
||||
"Epoch 20, global step 2100: 'train_loss' was not in top 1\n",
|
||||
"Epoch 21, global step 2200: 'train_loss' reached 5.25029 (best 5.25029), saving model to '/mnt/scratch/kashif/pytorch-transformer-ts/switch/lightning_logs/version_1/checkpoints/epoch=21-step=2200.ckpt' as top 1\n",
|
||||
"Epoch 22, global step 2300: 'train_loss' reached 5.23372 (best 5.23372), saving model to '/mnt/scratch/kashif/pytorch-transformer-ts/switch/lightning_logs/version_1/checkpoints/epoch=22-step=2300.ckpt' as top 1\n",
|
||||
"Epoch 23, global step 2400: 'train_loss' was not in top 1\n",
|
||||
"Epoch 24, global step 2500: 'train_loss' reached 5.21814 (best 5.21814), saving model to '/mnt/scratch/kashif/pytorch-transformer-ts/switch/lightning_logs/version_1/checkpoints/epoch=24-step=2500.ckpt' as top 1\n",
|
||||
"Epoch 25, global step 2600: 'train_loss' was not in top 1\n",
|
||||
"Epoch 26, global step 2700: 'train_loss' reached 5.21784 (best 5.21784), saving model to '/mnt/scratch/kashif/pytorch-transformer-ts/switch/lightning_logs/version_1/checkpoints/epoch=26-step=2700.ckpt' as top 1\n",
|
||||
"Epoch 27, global step 2800: 'train_loss' reached 5.18084 (best 5.18084), saving model to '/mnt/scratch/kashif/pytorch-transformer-ts/switch/lightning_logs/version_1/checkpoints/epoch=27-step=2800.ckpt' as top 1\n",
|
||||
"Epoch 28, global step 2900: 'train_loss' was not in top 1\n",
|
||||
"Epoch 29, global step 3000: 'train_loss' was not in top 1\n",
|
||||
"Epoch 30, global step 3100: 'train_loss' was not in top 1\n",
|
||||
"Epoch 31, global step 3200: 'train_loss' was not in top 1\n",
|
||||
"Epoch 32, global step 3300: 'train_loss' reached 5.16294 (best 5.16294), saving model to '/mnt/scratch/kashif/pytorch-transformer-ts/switch/lightning_logs/version_1/checkpoints/epoch=32-step=3300.ckpt' as top 1\n",
|
||||
"Epoch 33, global step 3400: 'train_loss' was not in top 1\n",
|
||||
"Epoch 34, global step 3500: 'train_loss' was not in top 1\n",
|
||||
"Epoch 35, global step 3600: 'train_loss' was not in top 1\n",
|
||||
"Epoch 36, global step 3700: 'train_loss' was not in top 1\n",
|
||||
"Epoch 37, global step 3800: 'train_loss' was not in top 1\n",
|
||||
"Epoch 38, global step 3900: 'train_loss' reached 5.16191 (best 5.16191), saving model to '/mnt/scratch/kashif/pytorch-transformer-ts/switch/lightning_logs/version_1/checkpoints/epoch=38-step=3900.ckpt' as top 1\n",
|
||||
"Epoch 39, global step 4000: 'train_loss' was not in top 1\n",
|
||||
"Epoch 40, global step 4100: 'train_loss' was not in top 1\n",
|
||||
"Epoch 41, global step 4200: 'train_loss' reached 5.12905 (best 5.12905), saving model to '/mnt/scratch/kashif/pytorch-transformer-ts/switch/lightning_logs/version_1/checkpoints/epoch=41-step=4200.ckpt' as top 1\n",
|
||||
"Epoch 42, global step 4300: 'train_loss' was not in top 1\n",
|
||||
"Epoch 43, global step 4400: 'train_loss' reached 5.11128 (best 5.11128), saving model to '/mnt/scratch/kashif/pytorch-transformer-ts/switch/lightning_logs/version_1/checkpoints/epoch=43-step=4400.ckpt' as top 1\n",
|
||||
"Epoch 44, global step 4500: 'train_loss' was not in top 1\n",
|
||||
"Epoch 45, global step 4600: 'train_loss' was not in top 1\n",
|
||||
"Epoch 46, global step 4700: 'train_loss' was not in top 1\n",
|
||||
"Epoch 47, global step 4800: 'train_loss' was not in top 1\n",
|
||||
"Epoch 48, global step 4900: 'train_loss' was not in top 1\n",
|
||||
"Epoch 49, global step 5000: 'train_loss' was not in top 1\n"
|
||||
"Epoch 19, global step 2000: 'train_loss' reached 5.24220 (best 5.24220), saving model to '/mnt/scratch/kashif/pytorch-transformer-ts/switch/lightning_logs/version_3/checkpoints/epoch=19-step=2000.ckpt' as top 1\n"
|
||||
]
|
||||
}
|
||||
],
|
||||
@@ -202,7 +173,7 @@
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 6,
|
||||
"execution_count": 20,
|
||||
"id": "11a47d5a",
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
@@ -215,7 +186,7 @@
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 7,
|
||||
"execution_count": 21,
|
||||
"id": "1492f7fb",
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
@@ -235,7 +206,7 @@
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 9,
|
||||
"execution_count": 22,
|
||||
"id": "e00601c4",
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
@@ -245,7 +216,7 @@
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 10,
|
||||
"execution_count": 23,
|
||||
"id": "9ed4c523",
|
||||
"metadata": {},
|
||||
"outputs": [
|
||||
@@ -253,7 +224,8 @@
|
||||
"name": "stderr",
|
||||
"output_type": "stream",
|
||||
"text": [
|
||||
"Running evaluation: 2247it [00:00, 5481.81it/s]\n",
|
||||
"\n",
|
||||
"Running evaluation: 2247it [00:00, 4999.89it/s]\n",
|
||||
"/home/kashif/.env/pytorch/lib/python3.8/site-packages/pandas/core/construction.py:781: UserWarning: Warning: converting a masked element to nan.\n",
|
||||
" subarr = np.array(arr, dtype=dtype, copy=copy)\n"
|
||||
]
|
||||
@@ -266,7 +238,7 @@
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 11,
|
||||
"id": "2cb4abe2",
|
||||
"id": "d3efc8ff",
|
||||
"metadata": {},
|
||||
"outputs": [
|
||||
{
|
||||
|
||||
+14
-21
@@ -38,7 +38,7 @@ from gluonts.torch.util import (
|
||||
)
|
||||
from gluonts.torch.model.estimator import PyTorchLightningEstimator
|
||||
from gluonts.torch.model.predictor import PyTorchPredictor
|
||||
from gluonts.torch.modules.distribution_output import (
|
||||
from gluonts.torch.distributions import (
|
||||
DistributionOutput,
|
||||
StudentTOutput,
|
||||
)
|
||||
@@ -60,11 +60,11 @@ class TransformerModel(nn.Module):
|
||||
cardinality: List[int],
|
||||
embedding_dimension: Optional[List[int]] = None,
|
||||
# Added transformer arguments
|
||||
encoder = None,
|
||||
decoder = None,
|
||||
embeding = None,
|
||||
target_embed = None ,
|
||||
generator = None,
|
||||
encoder=None,
|
||||
decoder=None,
|
||||
embeding=None,
|
||||
target_embed=None,
|
||||
generator=None,
|
||||
#############################
|
||||
dropout_rate: float = 0.1,
|
||||
distr_output: DistributionOutput = StudentTOutput(),
|
||||
@@ -73,7 +73,7 @@ class TransformerModel(nn.Module):
|
||||
num_parallel_samples: int = 100,
|
||||
) -> None:
|
||||
super().__init__()
|
||||
|
||||
|
||||
self.context_length = context_length
|
||||
self.prediction_length = prediction_length
|
||||
self.distr_output = distr_output
|
||||
@@ -98,7 +98,7 @@ class TransformerModel(nn.Module):
|
||||
self.scaler = MeanScaler(dim=1, keepdim=True)
|
||||
else:
|
||||
self.scaler = NOPScaler(dim=1, keepdim=True)
|
||||
|
||||
|
||||
# Added transformer enc-decoder and mask initializer
|
||||
self.encoder = encoder
|
||||
self.decoder = decoder
|
||||
@@ -106,16 +106,14 @@ class TransformerModel(nn.Module):
|
||||
self.target_embed = target_embed
|
||||
self.generator = generator
|
||||
########################
|
||||
|
||||
|
||||
|
||||
# TODO
|
||||
# add method that does the forward for training
|
||||
|
||||
|
||||
"""
|
||||
A build-in Encoder-Decoder architecture for TransformerModel class
|
||||
"""
|
||||
|
||||
|
||||
def forward(self, src, tgt, mask_source, mask_target):
|
||||
"Take in and process masked sourcerc and target sequences."
|
||||
memory = self.encoder(self.embeding(src), mask_source)
|
||||
@@ -124,12 +122,10 @@ class TransformerModel(nn.Module):
|
||||
|
||||
def encode(self, src, mask_source):
|
||||
return self.encoder(self.src_embed(src), mask_source)
|
||||
|
||||
|
||||
def decode(self, memory, mask_source, tgt, mask_target):
|
||||
return self.decoder(self.tgt_embed(tgt), memory, mask_source, mask_target)
|
||||
|
||||
|
||||
|
||||
|
||||
@property
|
||||
def _number_of_features(self) -> int:
|
||||
return (
|
||||
@@ -142,7 +138,7 @@ class TransformerModel(nn.Module):
|
||||
@property
|
||||
def _past_length(self) -> int:
|
||||
return self.context_length + max(self.lags_seq)
|
||||
|
||||
|
||||
# for prediction
|
||||
def forward(
|
||||
self,
|
||||
@@ -156,8 +152,5 @@ class TransformerModel(nn.Module):
|
||||
) -> torch.Tensor:
|
||||
if num_parallel_samples is None:
|
||||
num_parallel_samples = self.num_parallel_samples
|
||||
|
||||
# TODO
|
||||
|
||||
|
||||
|
||||
# TODO
|
||||
|
||||
@@ -5,7 +5,7 @@
|
||||
"id": "329d23e6",
|
||||
"metadata": {},
|
||||
"source": [
|
||||
"# PyTorch Transformer Time Series Template\n",
|
||||
"# PyTorch Transformer for Time Series Implementation\n",
|
||||
"\n",
|
||||
"The estimator consits of the:\n",
|
||||
"\n",
|
||||
@@ -106,7 +106,7 @@
|
||||
")\n",
|
||||
"from gluonts.torch.model.estimator import PyTorchLightningEstimator\n",
|
||||
"from gluonts.torch.model.predictor import PyTorchPredictor\n",
|
||||
"from gluonts.torch.modules.distribution_output import (\n",
|
||||
"from gluonts.torch.distributions import (\n",
|
||||
" DistributionOutput,\n",
|
||||
" StudentTOutput,\n",
|
||||
")\n",
|
||||
@@ -597,7 +597,7 @@
|
||||
"name": "python",
|
||||
"nbconvert_exporter": "python",
|
||||
"pygments_lexer": "ipython3",
|
||||
"version": "3.9.12"
|
||||
"version": "3.8.10"
|
||||
}
|
||||
},
|
||||
"nbformat": 4,
|
||||
|
||||
+1
-1
@@ -8,7 +8,7 @@ from gluonts.itertools import Cyclic, IterableSlice, PseudoShuffled
|
||||
from gluonts.time_feature import TimeFeature, time_features_from_frequency_str
|
||||
from gluonts.torch.model.estimator import PyTorchLightningEstimator
|
||||
from gluonts.torch.model.predictor import PyTorchPredictor
|
||||
from gluonts.torch.modules.distribution_output import DistributionOutput, StudentTOutput
|
||||
from gluonts.torch.distributions import DistributionOutput, StudentTOutput
|
||||
from gluonts.torch.modules.loss import DistributionLoss, NegativeLogLikelihood
|
||||
from gluonts.torch.util import IterableDataset
|
||||
from gluonts.transform import (
|
||||
|
||||
+1
-1
@@ -4,7 +4,7 @@ import torch
|
||||
import torch.nn as nn
|
||||
from gluonts.core.component import validated
|
||||
from gluonts.time_feature import get_lags_for_frequency
|
||||
from gluonts.torch.modules.distribution_output import DistributionOutput, StudentTOutput
|
||||
from gluonts.torch.distributions import DistributionOutput, StudentTOutput
|
||||
from gluonts.torch.modules.feature import FeatureEmbedder as BaseFeatureEmbedder
|
||||
from gluonts.torch.modules.scaler import MeanScaler, NOPScaler
|
||||
|
||||
|
||||
+275
-261
File diff suppressed because one or more lines are too long
@@ -8,7 +8,7 @@ from gluonts.itertools import Cyclic, IterableSlice, PseudoShuffled
|
||||
from gluonts.time_feature import TimeFeature, time_features_from_frequency_str
|
||||
from gluonts.torch.model.estimator import PyTorchLightningEstimator
|
||||
from gluonts.torch.model.predictor import PyTorchPredictor
|
||||
from gluonts.torch.modules.distribution_output import DistributionOutput, StudentTOutput
|
||||
from gluonts.torch.distributions import DistributionOutput, StudentTOutput
|
||||
from gluonts.torch.modules.loss import DistributionLoss, NegativeLogLikelihood
|
||||
from gluonts.torch.util import IterableDataset
|
||||
from gluonts.transform import (
|
||||
|
||||
@@ -4,7 +4,7 @@ import torch
|
||||
import torch.nn as nn
|
||||
from gluonts.core.component import validated
|
||||
from gluonts.time_feature import get_lags_for_frequency
|
||||
from gluonts.torch.modules.distribution_output import DistributionOutput, StudentTOutput
|
||||
from gluonts.torch.distributions import DistributionOutput, StudentTOutput
|
||||
from gluonts.torch.modules.feature import FeatureEmbedder
|
||||
from gluonts.torch.modules.scaler import MeanScaler, NOPScaler
|
||||
|
||||
|
||||
+1577
-1731
File diff suppressed because one or more lines are too long
Reference in New Issue
Block a user