use gluonts dev

This commit is contained in:
Kashif Rasul
2022-06-15 11:40:28 +02:00
parent 712ab88599
commit 296f9c360a
28 changed files with 4818 additions and 3909 deletions
View File
+1 -1
View File
@@ -8,7 +8,7 @@ from gluonts.itertools import Cyclic, IterableSlice, PseudoShuffled
from gluonts.time_feature import TimeFeature, time_features_from_frequency_str
from gluonts.torch.model.estimator import PyTorchLightningEstimator
from gluonts.torch.model.predictor import PyTorchPredictor
from gluonts.torch.modules.distribution_output import DistributionOutput, StudentTOutput
from gluonts.torch.distributions import DistributionOutput, StudentTOutput
from gluonts.torch.modules.loss import DistributionLoss, NegativeLogLikelihood
from gluonts.torch.util import IterableDataset
from gluonts.transform import (
+1 -1
View File
@@ -6,7 +6,7 @@ import torch.nn as nn
import torch.nn.functional as F
from gluonts.core.component import validated
from gluonts.time_feature import get_lags_for_frequency
from gluonts.torch.modules.distribution_output import DistributionOutput, StudentTOutput
from gluonts.torch.distributions import DistributionOutput, StudentTOutput
from gluonts.torch.modules.feature import FeatureEmbedder
from gluonts.torch.modules.scaler import MeanScaler, NOPScaler
+1 -1
View File
@@ -8,7 +8,7 @@ from gluonts.itertools import Cyclic, IterableSlice, PseudoShuffled
from gluonts.time_feature import TimeFeature, time_features_from_frequency_str
from gluonts.torch.model.estimator import PyTorchLightningEstimator
from gluonts.torch.model.predictor import PyTorchPredictor
from gluonts.torch.modules.distribution_output import DistributionOutput, StudentTOutput
from gluonts.torch.distributions import DistributionOutput, StudentTOutput
from gluonts.torch.modules.loss import DistributionLoss, NegativeLogLikelihood
from gluonts.torch.util import IterableDataset
from gluonts.transform import (
+1 -1
View File
@@ -5,7 +5,7 @@ import torch.nn as nn
from etsformer_pytorch import ETSFormer
from gluonts.core.component import validated
from gluonts.time_feature import get_lags_for_frequency
from gluonts.torch.modules.distribution_output import DistributionOutput, StudentTOutput
from gluonts.torch.distributions import DistributionOutput, StudentTOutput
from gluonts.torch.modules.feature import FeatureEmbedder
from gluonts.torch.modules.scaler import MeanScaler, NOPScaler
+1 -1
View File
@@ -8,7 +8,7 @@ from gluonts.itertools import Cyclic, IterableSlice, PseudoShuffled
from gluonts.time_feature import TimeFeature, time_features_from_frequency_str
from gluonts.torch.model.estimator import PyTorchLightningEstimator
from gluonts.torch.model.predictor import PyTorchPredictor
from gluonts.torch.modules.distribution_output import DistributionOutput, StudentTOutput
from gluonts.torch.distributions import DistributionOutput, StudentTOutput
from gluonts.torch.modules.loss import DistributionLoss, NegativeLogLikelihood
from gluonts.torch.util import IterableDataset
from gluonts.transform import (
+262
View File
@@ -0,0 +1,262 @@
from typing import List, Optional
import numpy as np
import torch
import torch.nn as nn
from gluonts.core.component import validated
from gluonts.dataset.field_names import FieldName
from gluonts.time_feature import TimeFeature
from gluonts.torch.distributions import DistributionOutput
from gluonts.torch.util import copy_parameters
from gluonts.torch.model.predictor import PyTorchPredictor
from gluonts.model.predictor import Predictor
from gluonts.transform import (
Transformation,
Chain,
InstanceSplitter,
InstanceSampler,
ValidationSplitSampler,
TestSplitSampler,
ExpectedNumInstanceSampler,
RemoveFields,
AddAgeFeature,
AsNumpyArray,
AddObservedValuesIndicator,
AddTimeFeatures,
VstackFeatures,
SetField,
)
from pts import Trainer
from pts.model.utils import get_module_forward_input_names
from pts.feature import (
fourier_time_features_from_frequency,
lags_for_fourier_time_features_from_frequency,
)
from pts.model import PyTorchEstimator
from pts.modules import StudentTOutput
from .hopfield_network import (
HopfieldTrainingNetwork,
HopfieldPredictionNetwork,
)
class HopfieldEstimator(PyTorchEstimator):
@validated()
def __init__(
self,
input_size: int,
freq: str,
prediction_length: int,
context_length: Optional[int] = None,
trainer: Trainer = Trainer(),
dropout_rate: float = 0.1,
cardinality: Optional[List[int]] = None,
embedding_dimension: List[int] = [20],
distr_output: DistributionOutput = StudentTOutput(),
d_model: int = 32,
dim_feedforward_scale: int = 4,
act_type: str = "gelu",
num_heads: int = 8,
num_encoder_layers: int = 3,
num_decoder_layers: int = 3,
scaling: bool = True,
lags_seq: Optional[List[int]] = None,
time_features: Optional[List[TimeFeature]] = None,
use_feat_dynamic_real: bool = False,
use_feat_static_cat: bool = False,
use_feat_static_real: bool = False,
num_parallel_samples: int = 100,
) -> None:
super().__init__(trainer=trainer)
self.input_size = input_size
self.freq = freq
self.prediction_length = prediction_length
self.context_length = (
context_length if context_length is not None else prediction_length
)
self.distr_output = distr_output
self.dropout_rate = dropout_rate
self.use_feat_dynamic_real = use_feat_dynamic_real
self.use_feat_static_cat = use_feat_static_cat
self.use_feat_static_real = use_feat_static_real
self.cardinality = cardinality if use_feat_static_cat else [1]
self.embedding_dimension = embedding_dimension
self.num_parallel_samples = num_parallel_samples
self.lags_seq = (
lags_seq
if lags_seq is not None
else lags_for_fourier_time_features_from_frequency(freq_str=freq)
)
self.time_features = (
time_features
if time_features is not None
else fourier_time_features_from_frequency(self.freq)
)
self.history_length = self.context_length + max(self.lags_seq)
self.scaling = scaling
self.d_model = d_model
self.num_heads = num_heads
self.act_type = act_type
self.dim_feedforward_scale = dim_feedforward_scale
self.num_encoder_layers = num_encoder_layers
self.num_decoder_layers = num_decoder_layers
self.train_sampler = ExpectedNumInstanceSampler(
num_instances=1.0, min_future=prediction_length
)
self.validation_sampler = ValidationSplitSampler(min_future=prediction_length)
def create_transformation(self) -> Transformation:
remove_field_names = [
FieldName.FEAT_DYNAMIC_CAT,
]
if not self.use_feat_dynamic_real:
remove_field_names.append(FieldName.FEAT_DYNAMIC_REAL)
if not self.use_feat_static_real:
remove_field_names.append(FieldName.FEAT_STATIC_REAL)
return Chain(
[RemoveFields(field_names=remove_field_names)]
+ (
[SetField(output_field=FieldName.FEAT_STATIC_CAT, value=[0])]
if not self.use_feat_static_cat
else []
)
+ (
[SetField(output_field=FieldName.FEAT_STATIC_REAL, value=[0.0])]
if not self.use_feat_static_real
else []
)
+ [
AsNumpyArray(
field=FieldName.FEAT_STATIC_CAT, expected_ndim=1, dtype=np.long
),
AsNumpyArray(
field=FieldName.FEAT_STATIC_REAL,
expected_ndim=1,
dtype=self.dtype,
),
AsNumpyArray(
field=FieldName.TARGET,
# in the following line, we add 1 for the time dimension
expected_ndim=1 + len(self.distr_output.event_shape),
),
AddObservedValuesIndicator(
target_field=FieldName.TARGET,
output_field=FieldName.OBSERVED_VALUES,
),
AddTimeFeatures(
start_field=FieldName.START,
target_field=FieldName.TARGET,
output_field=FieldName.FEAT_TIME,
time_features=self.time_features,
pred_length=self.prediction_length,
),
AddAgeFeature(
target_field=FieldName.TARGET,
output_field=FieldName.FEAT_AGE,
pred_length=self.prediction_length,
log_scale=True,
),
VstackFeatures(
output_field=FieldName.FEAT_TIME,
input_fields=[FieldName.FEAT_TIME, FieldName.FEAT_AGE]
+ (
[FieldName.FEAT_DYNAMIC_REAL]
if self.use_feat_dynamic_real
else []
),
),
]
)
def create_instance_splitter(self, mode: str):
assert mode in ["training", "validation", "test"]
instance_sampler = {
"training": self.train_sampler,
"validation": self.validation_sampler,
"test": TestSplitSampler(),
}[mode]
return InstanceSplitter(
target_field=FieldName.TARGET,
is_pad_field=FieldName.IS_PAD,
start_field=FieldName.START,
forecast_start_field=FieldName.FORECAST_START,
instance_sampler=instance_sampler,
past_length=self.history_length,
future_length=self.prediction_length,
time_series_fields=[
FieldName.FEAT_TIME,
FieldName.OBSERVED_VALUES,
],
)
def create_training_network(self, device: torch.device) -> HopfieldTrainingNetwork:
training_network = HopfieldTrainingNetwork(
input_size=self.input_size,
num_heads=self.num_heads,
act_type=self.act_type,
dropout_rate=self.dropout_rate,
d_model=self.d_model,
dim_feedforward_scale=self.dim_feedforward_scale,
num_encoder_layers=self.num_encoder_layers,
num_decoder_layers=self.num_decoder_layers,
history_length=self.history_length,
context_length=self.context_length,
prediction_length=self.prediction_length,
distr_output=self.distr_output,
cardinality=self.cardinality,
embedding_dimension=self.embedding_dimension,
lags_seq=self.lags_seq,
scaling=self.scaling,
).to(device)
return training_network
def create_predictor(
self,
transformation: Transformation,
trained_network: HopfieldTrainingNetwork,
device: torch.device,
) -> Predictor:
prediction_network = HopfieldPredictionNetwork(
input_size=self.input_size,
num_heads=self.num_heads,
act_type=self.act_type,
dropout_rate=self.dropout_rate,
d_model=self.d_model,
dim_feedforward_scale=self.dim_feedforward_scale,
num_encoder_layers=self.num_encoder_layers,
num_decoder_layers=self.num_decoder_layers,
history_length=self.history_length,
context_length=self.context_length,
prediction_length=self.prediction_length,
distr_output=self.distr_output,
cardinality=self.cardinality,
embedding_dimension=self.embedding_dimension,
lags_seq=self.lags_seq,
scaling=self.scaling,
num_parallel_samples=self.num_parallel_samples,
).to(device)
copy_parameters(trained_network, prediction_network)
input_names = get_module_forward_input_names(prediction_network)
prediction_splitter = self.create_instance_splitter("test")
return PyTorchPredictor(
input_transform=transformation + prediction_splitter,
input_names=input_names,
prediction_net=prediction_network,
batch_size=self.trainer.batch_size,
freq=self.freq,
prediction_length=self.prediction_length,
device=device,
)
+477
View File
@@ -0,0 +1,477 @@
from typing import List, Optional, Tuple
import torch
import torch.nn as nn
from gluonts.core.component import validated
from gluonts.torch.distributions import DistributionOutput
from pts.modules import MeanScaler, NOPScaler, FeatureEmbedder
from hflayers import Hopfield
from hflayers.transformer import HopfieldDecoderLayer, HopfieldEncoderLayer
def prod(xs):
p = 1
for x in xs:
p *= x
return p
class HopfieldNetwork(nn.Module):
@validated()
def __init__(
self,
input_size: int,
d_model: int,
num_heads: int,
act_type: str,
dropout_rate: float,
dim_feedforward_scale: int,
num_encoder_layers: int,
num_decoder_layers: int,
history_length: int,
context_length: int,
prediction_length: int,
distr_output: DistributionOutput,
cardinality: List[int],
embedding_dimension: List[int],
lags_seq: List[int],
scaling: bool = True,
**kwargs,
) -> None:
super().__init__(**kwargs)
self.history_length = history_length
self.context_length = context_length
self.prediction_length = prediction_length
self.scaling = scaling
self.cardinality = cardinality
self.embedding_dimension = embedding_dimension
self.distr_output = distr_output
assert len(set(lags_seq)) == len(lags_seq), "no duplicated lags allowed!"
lags_seq.sort()
self.lags_seq = lags_seq
self.target_shape = distr_output.event_shape
# [B, T, input_size] -> [B, T, d_model]
# self.encoder_input = nn.Linear(input_size, d_model)
# self.decoder_input = nn.Linear(input_size, d_model)
# [B, T, d_model] where d_model / num_heads is int
encoder_association = Hopfield(input_size=input_size, num_heads=num_heads)
encoder_layer = HopfieldEncoderLayer(
encoder_association,
dim_feedforward=dim_feedforward_scale * input_size,
dropout=dropout_rate,
activation=act_type,
)
transformer_encoder = nn.TransformerEncoder(
encoder_layer, num_layers=num_encoder_layers
)
decoder_association_self = Hopfield(input_size=input_size, num_heads=num_heads)
decoder_association_cross = Hopfield(input_size=input_size, num_heads=num_heads)
decoder_layer = HopfieldDecoderLayer(
hopfield_association_self=decoder_association_self,
hopfield_association_cross=decoder_association_cross,
dim_feedforward=dim_feedforward_scale * input_size,
dropout=dropout_rate,
activation=act_type,
)
transformer_decoder = nn.TransformerDecoder(
decoder_layer, num_layers=num_decoder_layers
)
self.transformer = nn.Transformer(
d_model=input_size,
nhead=num_heads,
custom_encoder=transformer_encoder,
custom_decoder=transformer_decoder,
batch_first=True,
)
self.proj_dist_args = distr_output.get_args_proj(input_size)
self.embedder = FeatureEmbedder(
cardinalities=cardinality,
embedding_dims=embedding_dimension,
)
if scaling:
self.scaler = MeanScaler(keepdim=True)
else:
self.scaler = NOPScaler(keepdim=True)
# mask
self.register_buffer(
"tgt_mask",
self.transformer.generate_square_subsequent_mask(prediction_length),
)
@staticmethod
def get_lagged_subsequences(
sequence: torch.Tensor,
sequence_length: int,
indices: List[int],
subsequences_length: int = 1,
) -> torch.Tensor:
"""
Returns lagged subsequences of a given sequence.
Parameters
----------
sequence : Tensor
the sequence from which lagged subsequences should be extracted.
Shape: (N, T, C).
sequence_length : int
length of sequence in the T (time) dimension (axis = 1).
indices : List[int]
list of lag indices to be used.
subsequences_length : int
length of the subsequences to be extracted.
Returns
--------
lagged : Tensor
a tensor of shape (N, S, C, I), where S = subsequences_length and
I = len(indices), containing lagged subsequences. Specifically,
lagged[i, j, :, k] = sequence[i, -indices[k]-S+j, :].
"""
assert max(indices) + subsequences_length <= sequence_length, (
f"lags cannot go further than history length, found lag {max(indices)} "
f"while history length is only {sequence_length}"
)
assert all(lag_index >= 0 for lag_index in indices)
lagged_values = []
for lag_index in indices:
begin_index = -lag_index - subsequences_length
end_index = -lag_index if lag_index > 0 else None
lagged_values.append(sequence[:, begin_index:end_index, ...])
return torch.stack(lagged_values, dim=-1)
def create_network_input(
self,
feat_static_cat: torch.Tensor, # (batch_size, num_features)
feat_static_real: torch.Tensor,
# (batch_size, num_features, history_length)
past_time_feat: torch.Tensor,
past_target: torch.Tensor, # (batch_size, history_length, 1)
past_observed_values: torch.Tensor, # (batch_size, history_length)
future_time_feat: Optional[
torch.Tensor
], # (batch_size, num_features, prediction_length)
# (batch_size, prediction_length)
future_target: Optional[torch.Tensor],
) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
"""
Creates inputs for the transformer network.
All tensor arguments should have NTC layout.
"""
if future_time_feat is None or future_target is None:
time_feat = past_time_feat[
:, self.history_length - self.context_length :, ...
]
sequence = past_target
sequence_length = self.history_length
subsequences_length = self.context_length
else:
time_feat = torch.cat(
(
past_time_feat[:, self.history_length - self.context_length :, ...],
future_time_feat,
),
dim=1,
)
sequence = torch.cat((past_target, future_target), dim=1)
sequence_length = self.history_length + self.prediction_length
subsequences_length = self.context_length + self.prediction_length
# (batch_size, sub_seq_len, *target_shape, num_lags)
lags = self.get_lagged_subsequences(
sequence=sequence,
sequence_length=sequence_length,
indices=self.lags_seq,
subsequences_length=subsequences_length,
)
# scale is computed on the context length last units of the past target
# scale shape is (batch_size, 1, *target_shape)
_, scale = self.scaler(
past_target[:, -self.context_length :, ...],
past_observed_values[:, -self.context_length :, ...],
)
embedded_cat = self.embedder(feat_static_cat)
# in addition to embedding features, use the log scale as it can help prediction too
# (batch_size, num_features + prod(target_shape))
static_feat = torch.cat(
(
embedded_cat,
feat_static_real,
torch.log(scale)
if len(self.target_shape) == 0
else torch.log(scale.squeeze(1)),
),
dim=1,
)
repeated_static_feat = static_feat.unsqueeze(1).expand(
-1, subsequences_length, -1
)
# (batch_size, sub_seq_len, *target_shape, num_lags)
lags_scaled = lags / scale.unsqueeze(-1)
# from (batch_size, sub_seq_len, *target_shape, num_lags)
# to (batch_size, sub_seq_len, prod(target_shape) * num_lags)
input_lags = lags_scaled.reshape(
(-1, subsequences_length, len(self.lags_seq) * prod(self.target_shape))
)
# (batch_size, sub_seq_len, input_dim)
inputs = torch.cat((input_lags, time_feat, repeated_static_feat), dim=-1)
return inputs, scale, static_feat
class HopfieldTrainingNetwork(HopfieldNetwork):
# noinspection PyMethodOverriding,PyPep8Naming
def forward(
self,
feat_static_cat: torch.Tensor,
feat_static_real: torch.Tensor,
past_time_feat: torch.Tensor,
past_target: torch.Tensor,
past_observed_values: torch.Tensor,
future_time_feat: torch.Tensor,
future_target: torch.Tensor,
) -> torch.Tensor:
"""
Computes the loss for training Hopfield Transformer, all inputs tensors representing time series have NTC layout.
Parameters
----------
feat_static_cat : (batch_size, num_features)
feat_static_real: torch.Tensor, # (batch_size, num_features)
past_time_feat : (batch_size, history_length, num_features)
past_target : (batch_size, history_length, *target_shape)
past_observed_values : (batch_size, history_length, *target_shape, seq_len)
future_time_feat : (batch_size, prediction_length, num_features)
future_target : (batch_size, prediction_length, *target_shape)
Returns
-------
Loss with shape (batch_size, context + prediction_length, 1)
"""
# create the inputs for the encoder
inputs, scale, _ = self.create_network_input(
feat_static_cat=feat_static_cat,
feat_static_real=feat_static_real,
past_time_feat=past_time_feat,
past_target=past_target,
past_observed_values=past_observed_values,
future_time_feat=future_time_feat,
future_target=future_target,
)
enc_input = inputs[:, : self.context_length, ...] # F.slice_axis(
# inputs, axis=1, begin=0, end=self.context_length
# )
dec_input = inputs[:, self.context_length :, ...] # F.slice_axis(
# inputs, axis=1, begin=self.context_length, end=None
# )
# pass through encoder [T, B, b_model]
enc_out = self.transformer.encoder(
# self.encoder_input(enc_input).permute(1, 0, 2)
enc_input
)
# input to decoder
dec_output = self.transformer.decoder(
# self.decoder_input(dec_input).permute(1, 0, 2),
dec_input,
enc_out, # memory
tgt_mask=self.tgt_mask,
)
# compute loss
# distr_args = self.proj_dist_args(dec_output.permute(1, 0, 2))
distr_args = self.proj_dist_args(dec_output)
distr = self.distr_output.distribution(distr_args, scale=scale)
loss = -distr.log_prob(future_target)
return loss.mean()
class HopfieldPredictionNetwork(HopfieldNetwork):
def __init__(self, num_parallel_samples: int = 100, **kwargs) -> None:
super().__init__(**kwargs)
self.num_parallel_samples = num_parallel_samples
# for decoding the lags are shifted by one,
# at the first time-step of the decoder a lag of one corresponds to the last target value
self.shifted_lags = [l - 1 for l in self.lags_seq]
def sampling_decoder(
self,
static_feat: torch.Tensor,
past_target: torch.Tensor,
time_feat: torch.Tensor,
scale: torch.Tensor,
enc_out: torch.Tensor,
) -> torch.Tensor:
"""
Computes sample paths by decoding from the transformer.
Parameters
----------
static_feat : Tensor
static features. Shape: (batch_size, num_static_features).
past_target : Tensor
target history. Shape: (batch_size, history_length, 1).
time_feat : Tensor
time features. Shape: (batch_size, prediction_length, num_time_features).
scale : Tensor
tensor containing the scale of each element in the batch. Shape: (batch_size, ).
enc_out: Tensor
output of the encoder. Shape: (batch_size, num_cells)
Returns
--------
sample_paths : Tensor
a tensor containing sampled paths. Shape: (batch_size, num_sample_paths, prediction_length).
"""
# blows-up the dimension of each tensor to batch_size * self.num_parallel_samples for increasing parallelism
repeated_past_target = past_target.repeat_interleave(
repeats=self.num_parallel_samples, dim=0
)
repeated_time_feat = time_feat.repeat_interleave(
repeats=self.num_parallel_samples, dim=0
)
repeated_static_feat = static_feat.repeat_interleave(
repeats=self.num_parallel_samples, dim=0
).unsqueeze(1)
repeated_enc_out = enc_out.repeat_interleave(
repeats=self.num_parallel_samples, dim=0 # 1
)
repeated_scale = scale.repeat_interleave(
repeats=self.num_parallel_samples, dim=0
)
future_samples = []
# for each future time-units we draw new samples for this time-unit and update the state
for k in range(self.prediction_length):
lags = self.get_lagged_subsequences(
sequence=repeated_past_target,
sequence_length=self.history_length + k,
indices=self.shifted_lags,
subsequences_length=1,
)
# (batch_size * num_samples, 1, *target_shape, num_lags)
lags_scaled = lags / repeated_scale.unsqueeze(1)
# lags_scaled = F.broadcast_div(
# lags, repeated_scale.expand_dims(axis=-1)
# )
# from (batch_size * num_samples, 1, *target_shape, num_lags)
# to (batch_size * num_samples, 1, prod(target_shape) * num_lags)
input_lags = lags_scaled.reshape(
shape=(-1, 1, prod(self.target_shape) * len(self.lags_seq))
)
# (batch_size * num_samples, 1, prod(target_shape) * num_lags + num_time_features + num_static_features)
dec_input = torch.cat(
(input_lags, repeated_time_feat[:, k : k + 1, :], repeated_static_feat),
dim=-1,
)
dec_output = self.transformer.decoder(
# self.decoder_input(dec_input).permute(1, 0, 2),
dec_input,
repeated_enc_out,
)
# distr_args = self.proj_dist_args(dec_output.permute(1, 0, 2))
distr_args = self.proj_dist_args(dec_output)
# compute likelihood of target given the predicted parameters
distr = self.distr_output.distribution(distr_args, scale=repeated_scale)
# (batch_size * num_samples, 1, *target_shape)
new_samples = distr.sample()
# (batch_size * num_samples, seq_len, *target_shape)
repeated_past_target = torch.cat((repeated_past_target, new_samples), dim=1)
future_samples.append(new_samples)
# reset cache of the decoder
# self.transformer.decoder.cache_reset()
# (batch_size * num_samples, prediction_length, *target_shape)
samples = torch.cat(future_samples, dim=1)
# (batch_size, num_samples, *target_shape, prediction_length)
return samples.reshape(
(
(-1, self.num_parallel_samples)
+ self.target_shape
+ (self.prediction_length,)
)
)
def forward(
self,
feat_static_cat: torch.Tensor,
feat_static_real: torch.Tensor,
past_time_feat: torch.Tensor,
past_target: torch.Tensor,
past_observed_values: torch.Tensor,
future_time_feat: torch.Tensor,
) -> torch.Tensor:
"""
Predicts samples, all tensors should have NTC layout.
Parameters
----------
feat_static_cat : (batch_size, num_features)
feat_static_real : (batch_size, num_features)
past_time_feat : (batch_size, history_length, num_features)
past_target : (batch_size, history_length, *target_shape)
past_observed_values : (batch_size, history_length, *target_shape)
future_time_feat : (batch_size, prediction_length, num_features)
Returns predicted samples
-------
"""
# create the inputs for the encoder
inputs, scale, static_feat = self.create_network_input(
feat_static_cat=feat_static_cat,
feat_static_real=feat_static_real,
past_time_feat=past_time_feat,
past_target=past_target,
past_observed_values=past_observed_values,
future_time_feat=None,
future_target=None,
)
# pass through encoder
enc_out = self.transformer.encoder(
# self.encoder_input(inputs).permute(1, 0, 2)
inputs
)
return self.sampling_decoder(
past_target=past_target,
time_feat=future_time_feat,
static_feat=static_feat,
scale=scale,
enc_out=enc_out,
)
+1 -1
View File
@@ -4,7 +4,7 @@ import torch
import torch.nn as nn
from gluonts.core.component import validated
from gluonts.time_feature import get_lags_for_frequency
from gluonts.torch.modules.distribution_output import DistributionOutput, StudentTOutput
from gluonts.torch.distributions import DistributionOutput, StudentTOutput
from gluonts.torch.modules.feature import FeatureEmbedder
from gluonts.torch.modules.scaler import MeanScaler, NOPScaler
from hflayers import Hopfield
+1 -1
View File
@@ -8,7 +8,7 @@ from gluonts.itertools import Cyclic, IterableSlice, PseudoShuffled
from gluonts.time_feature import TimeFeature, time_features_from_frequency_str
from gluonts.torch.model.estimator import PyTorchLightningEstimator
from gluonts.torch.model.predictor import PyTorchPredictor
from gluonts.torch.modules.distribution_output import DistributionOutput, StudentTOutput
from gluonts.torch.distributions import DistributionOutput, StudentTOutput
from gluonts.torch.modules.loss import DistributionLoss, NegativeLogLikelihood
from gluonts.torch.util import IterableDataset
from gluonts.transform import (
+159 -134
View File
File diff suppressed because one or more lines are too long
+1 -1
View File
@@ -7,7 +7,7 @@ import torch.nn as nn
import torch.nn.functional as F
from gluonts.core.component import validated
from gluonts.time_feature import get_lags_for_frequency
from gluonts.torch.modules.distribution_output import DistributionOutput, StudentTOutput
from gluonts.torch.distributions import DistributionOutput, StudentTOutput
from gluonts.torch.modules.feature import FeatureEmbedder
from gluonts.torch.modules.scaler import MeanScaler, NOPScaler
+5 -5
View File
@@ -8,7 +8,7 @@ from gluonts.itertools import Cyclic, IterableSlice, PseudoShuffled
from gluonts.time_feature import TimeFeature, time_features_from_frequency_str
from gluonts.torch.model.estimator import PyTorchLightningEstimator
from gluonts.torch.model.predictor import PyTorchPredictor
from gluonts.torch.modules.distribution_output import DistributionOutput, StudentTOutput
from gluonts.torch.distributions import DistributionOutput, StudentTOutput
from gluonts.torch.modules.loss import DistributionLoss, NegativeLogLikelihood
from gluonts.time_feature import get_lags_for_frequency
from gluonts.torch.util import IterableDataset
@@ -79,7 +79,7 @@ class PyraformerEstimator(PyTorchLightningEstimator):
n_layer: int = 4,
enc_in: int = 1, # depends on dataset used
CSCM: str = "Bottleneck_Construct", # [Bottleneck_Construct, Conv_Construct, MaxPooling_Construct, AvgPooling_Construct]
embed_type: str = "CustomEmbedding", #[DataEmbedding, CustomEmbedding]
embed_type: str = "CustomEmbedding", # [DataEmbedding, CustomEmbedding]
truncate: bool = False,
# loss: DistributionLoss = LossFactory,
ignore_zero: bool = True,
@@ -339,7 +339,7 @@ class PyraformerEstimator(PyTorchLightningEstimator):
num_seq=self.num_seq,
input_size=self.input_size,
dropout=self.dropout,
d_model = self.d_model,
d_model=self.d_model,
d_inner_hid=self.d_inner_hid,
d_k=self.d_k,
d_v=self.d_v,
@@ -392,8 +392,8 @@ class PyraformerEstimator(PyTorchLightningEstimator):
cardinality=self.cardinality,
embedding_dimension=self.embedding_dimension,
num_parallel_samples=self.num_parallel_samples,
embed_type = self.embed_type,
distr_output= self.distr_output,
embed_type=self.embed_type,
distr_output=self.distr_output,
device=device,
)
return PyraformerLightningModule(model=model, loss=self.loss)
+6 -7
View File
@@ -3,7 +3,7 @@ import torch
import torch.nn as nn
from gluonts.core.component import validated
from gluonts.time_feature import get_lags_for_frequency
from gluonts.torch.modules.distribution_output import DistributionOutput, StudentTOutput
from gluonts.torch.distributions import DistributionOutput, StudentTOutput
from gluonts.torch.modules.feature import FeatureEmbedder
from gluonts.torch.modules.scaler import MeanScaler, NOPScaler
@@ -151,7 +151,6 @@ class PyraformerSSModel(nn.Module):
scaling,
num_parallel_samples,
device,
):
super().__init__()
@@ -179,7 +178,7 @@ class PyraformerSSModel(nn.Module):
# convert hidden vectors into two scalar
self.mean_hidden = Predictor(4 * d_model, 1)
self.var_hidden = Predictor(4 * d_model, 1)
self.softplus = nn.Softplus()
self.distr_output = distr_output
@@ -512,7 +511,7 @@ class PyraformerLRModel(nn.Module):
num_parallel_samples,
embed_type,
distr_output,
device
device,
):
super().__init__()
@@ -524,7 +523,7 @@ class PyraformerLRModel(nn.Module):
self.distr_output = distr_output
self.context_length = context_length
self.lags_seq = lags_seq
self.encoder = Encoder(
# model,
window_size,
@@ -593,10 +592,10 @@ class PyraformerLRModel(nn.Module):
)
return pred
@property
def _past_length(self) -> int:
return self.predict_step #+ max(0,self.lags_seq)
return self.predict_step # + max(0,self.lags_seq)
@property
def _number_of_features(self) -> int:
+1 -1
View File
@@ -8,7 +8,7 @@ from gluonts.itertools import Cyclic, IterableSlice, PseudoShuffled
from gluonts.time_feature import TimeFeature, time_features_from_frequency_str
from gluonts.torch.model.estimator import PyTorchLightningEstimator
from gluonts.torch.model.predictor import PyTorchPredictor
from gluonts.torch.modules.distribution_output import DistributionOutput, StudentTOutput
from gluonts.torch.distributions import DistributionOutput, StudentTOutput
from gluonts.torch.modules.loss import DistributionLoss, NegativeLogLikelihood
from gluonts.torch.util import IterableDataset
from gluonts.transform import (
+1 -1
View File
@@ -5,7 +5,7 @@ import torch
import torch.nn as nn
from gluonts.core.component import validated
from gluonts.time_feature import get_lags_for_frequency
from gluonts.torch.modules.distribution_output import DistributionOutput, StudentTOutput
from gluonts.torch.distributions import DistributionOutput, StudentTOutput
from gluonts.torch.modules.feature import FeatureEmbedder
from gluonts.torch.modules.scaler import MeanScaler, NOPScaler
from reformer_pytorch.reformer_pytorch import Reformer
+288 -237
View File
File diff suppressed because one or more lines are too long
+1 -1
View File
@@ -1,6 +1,6 @@
orjson
torch
gluonts
https://github.com/awslabs/gluon-ts
pytorch-lightning
datasets
xformers
+1706 -1436
View File
File diff suppressed because one or more lines are too long
+31 -59
View File
@@ -51,7 +51,7 @@
},
{
"cell_type": "code",
"execution_count": 4,
"execution_count": 18,
"id": "e772234f",
"metadata": {},
"outputs": [],
@@ -59,6 +59,7 @@
"estimator = SwitchTransformerEstimator(\n",
" freq=dataset.metadata.freq,\n",
" prediction_length=dataset.metadata.prediction_length,\n",
" context_length=8*dataset.metadata.prediction_length,\n",
" num_feat_static_cat=1,\n",
" cardinality=[321],\n",
" embedding_dimension=[3],\n",
@@ -67,20 +68,20 @@
" num_encoder_layers=2,\n",
" num_decoder_layers=2,\n",
" nhead=2,\n",
" n_experts = 4,\n",
" capacity_factor = 0.2,\n",
" n_experts=4,\n",
" capacity_factor=1.0,\n",
" \n",
" activation=\"relu\",\n",
"\n",
" batch_size=128,\n",
" num_batches_per_epoch=100,\n",
" trainer_kwargs=dict(max_epochs=50, accelerator='gpu', gpus=1),\n",
" trainer_kwargs=dict(max_epochs=20, accelerator='gpu', gpus=1),\n",
")"
]
},
{
"cell_type": "code",
"execution_count": 5,
"execution_count": 19,
"id": "22d804e4",
"metadata": {},
"outputs": [
@@ -114,7 +115,7 @@
{
"data": {
"application/vnd.jupyter.widget-view+json": {
"model_id": "78ce8638ba8b45e59858e078aa08b6ba",
"model_id": "ea9100994cbc439b9a6c66b6beab4253",
"version_major": 2,
"version_minor": 0
},
@@ -139,56 +140,26 @@
" default for now) or if `full_state_update=False` can be used safely.\n",
" \n",
" warnings.warn(*args, **kwargs)\n",
"Epoch 0, global step 100: 'train_loss' reached 6.53932 (best 6.53932), saving model to '/mnt/scratch/kashif/pytorch-transformer-ts/switch/lightning_logs/version_1/checkpoints/epoch=0-step=100.ckpt' as top 1\n",
"Epoch 1, global step 200: 'train_loss' reached 6.06170 (best 6.06170), saving model to '/mnt/scratch/kashif/pytorch-transformer-ts/switch/lightning_logs/version_1/checkpoints/epoch=1-step=200.ckpt' as top 1\n",
"Epoch 2, global step 300: 'train_loss' reached 5.85441 (best 5.85441), saving model to '/mnt/scratch/kashif/pytorch-transformer-ts/switch/lightning_logs/version_1/checkpoints/epoch=2-step=300.ckpt' as top 1\n",
"Epoch 3, global step 400: 'train_loss' reached 5.71498 (best 5.71498), saving model to '/mnt/scratch/kashif/pytorch-transformer-ts/switch/lightning_logs/version_1/checkpoints/epoch=3-step=400.ckpt' as top 1\n",
"Epoch 4, global step 500: 'train_loss' reached 5.63095 (best 5.63095), saving model to '/mnt/scratch/kashif/pytorch-transformer-ts/switch/lightning_logs/version_1/checkpoints/epoch=4-step=500.ckpt' as top 1\n",
"Epoch 5, global step 600: 'train_loss' reached 5.59845 (best 5.59845), saving model to '/mnt/scratch/kashif/pytorch-transformer-ts/switch/lightning_logs/version_1/checkpoints/epoch=5-step=600.ckpt' as top 1\n",
"Epoch 6, global step 700: 'train_loss' reached 5.50794 (best 5.50794), saving model to '/mnt/scratch/kashif/pytorch-transformer-ts/switch/lightning_logs/version_1/checkpoints/epoch=6-step=700.ckpt' as top 1\n",
"Epoch 7, global step 800: 'train_loss' reached 5.50530 (best 5.50530), saving model to '/mnt/scratch/kashif/pytorch-transformer-ts/switch/lightning_logs/version_1/checkpoints/epoch=7-step=800.ckpt' as top 1\n",
"Epoch 8, global step 900: 'train_loss' reached 5.46209 (best 5.46209), saving model to '/mnt/scratch/kashif/pytorch-transformer-ts/switch/lightning_logs/version_1/checkpoints/epoch=8-step=900.ckpt' as top 1\n",
"Epoch 9, global step 1000: 'train_loss' reached 5.44430 (best 5.44430), saving model to '/mnt/scratch/kashif/pytorch-transformer-ts/switch/lightning_logs/version_1/checkpoints/epoch=9-step=1000.ckpt' as top 1\n",
"Epoch 10, global step 1100: 'train_loss' reached 5.44166 (best 5.44166), saving model to '/mnt/scratch/kashif/pytorch-transformer-ts/switch/lightning_logs/version_1/checkpoints/epoch=10-step=1100.ckpt' as top 1\n",
"Epoch 11, global step 1200: 'train_loss' reached 5.36817 (best 5.36817), saving model to '/mnt/scratch/kashif/pytorch-transformer-ts/switch/lightning_logs/version_1/checkpoints/epoch=11-step=1200.ckpt' as top 1\n",
"Epoch 12, global step 1300: 'train_loss' was not in top 1\n",
"Epoch 0, global step 100: 'train_loss' reached 6.89574 (best 6.89574), saving model to '/mnt/scratch/kashif/pytorch-transformer-ts/switch/lightning_logs/version_3/checkpoints/epoch=0-step=100.ckpt' as top 1\n",
"Epoch 1, global step 200: 'train_loss' reached 6.07414 (best 6.07414), saving model to '/mnt/scratch/kashif/pytorch-transformer-ts/switch/lightning_logs/version_3/checkpoints/epoch=1-step=200.ckpt' as top 1\n",
"Epoch 2, global step 300: 'train_loss' reached 5.79265 (best 5.79265), saving model to '/mnt/scratch/kashif/pytorch-transformer-ts/switch/lightning_logs/version_3/checkpoints/epoch=2-step=300.ckpt' as top 1\n",
"Epoch 3, global step 400: 'train_loss' reached 5.67099 (best 5.67099), saving model to '/mnt/scratch/kashif/pytorch-transformer-ts/switch/lightning_logs/version_3/checkpoints/epoch=3-step=400.ckpt' as top 1\n",
"Epoch 4, global step 500: 'train_loss' reached 5.56751 (best 5.56751), saving model to '/mnt/scratch/kashif/pytorch-transformer-ts/switch/lightning_logs/version_3/checkpoints/epoch=4-step=500.ckpt' as top 1\n",
"Epoch 5, global step 600: 'train_loss' reached 5.53956 (best 5.53956), saving model to '/mnt/scratch/kashif/pytorch-transformer-ts/switch/lightning_logs/version_3/checkpoints/epoch=5-step=600.ckpt' as top 1\n",
"Epoch 6, global step 700: 'train_loss' reached 5.46131 (best 5.46131), saving model to '/mnt/scratch/kashif/pytorch-transformer-ts/switch/lightning_logs/version_3/checkpoints/epoch=6-step=700.ckpt' as top 1\n",
"Epoch 7, global step 800: 'train_loss' reached 5.45841 (best 5.45841), saving model to '/mnt/scratch/kashif/pytorch-transformer-ts/switch/lightning_logs/version_3/checkpoints/epoch=7-step=800.ckpt' as top 1\n",
"Epoch 8, global step 900: 'train_loss' reached 5.42569 (best 5.42569), saving model to '/mnt/scratch/kashif/pytorch-transformer-ts/switch/lightning_logs/version_3/checkpoints/epoch=8-step=900.ckpt' as top 1\n",
"Epoch 9, global step 1000: 'train_loss' reached 5.38103 (best 5.38103), saving model to '/mnt/scratch/kashif/pytorch-transformer-ts/switch/lightning_logs/version_3/checkpoints/epoch=9-step=1000.ckpt' as top 1\n",
"Epoch 10, global step 1100: 'train_loss' reached 5.37174 (best 5.37174), saving model to '/mnt/scratch/kashif/pytorch-transformer-ts/switch/lightning_logs/version_3/checkpoints/epoch=10-step=1100.ckpt' as top 1\n",
"Epoch 11, global step 1200: 'train_loss' reached 5.35499 (best 5.35499), saving model to '/mnt/scratch/kashif/pytorch-transformer-ts/switch/lightning_logs/version_3/checkpoints/epoch=11-step=1200.ckpt' as top 1\n",
"Epoch 12, global step 1300: 'train_loss' reached 5.33469 (best 5.33469), saving model to '/mnt/scratch/kashif/pytorch-transformer-ts/switch/lightning_logs/version_3/checkpoints/epoch=12-step=1300.ckpt' as top 1\n",
"Epoch 13, global step 1400: 'train_loss' was not in top 1\n",
"Epoch 14, global step 1500: 'train_loss' reached 5.32576 (best 5.32576), saving model to '/mnt/scratch/kashif/pytorch-transformer-ts/switch/lightning_logs/version_1/checkpoints/epoch=14-step=1500.ckpt' as top 1\n",
"Epoch 15, global step 1600: 'train_loss' reached 5.29555 (best 5.29555), saving model to '/mnt/scratch/kashif/pytorch-transformer-ts/switch/lightning_logs/version_1/checkpoints/epoch=15-step=1600.ckpt' as top 1\n",
"Epoch 14, global step 1500: 'train_loss' reached 5.28044 (best 5.28044), saving model to '/mnt/scratch/kashif/pytorch-transformer-ts/switch/lightning_logs/version_3/checkpoints/epoch=14-step=1500.ckpt' as top 1\n",
"Epoch 15, global step 1600: 'train_loss' was not in top 1\n",
"Epoch 16, global step 1700: 'train_loss' was not in top 1\n",
"Epoch 17, global step 1800: 'train_loss' reached 5.29322 (best 5.29322), saving model to '/mnt/scratch/kashif/pytorch-transformer-ts/switch/lightning_logs/version_1/checkpoints/epoch=17-step=1800.ckpt' as top 1\n",
"Epoch 17, global step 1800: 'train_loss' reached 5.27935 (best 5.27935), saving model to '/mnt/scratch/kashif/pytorch-transformer-ts/switch/lightning_logs/version_3/checkpoints/epoch=17-step=1800.ckpt' as top 1\n",
"Epoch 18, global step 1900: 'train_loss' was not in top 1\n",
"Epoch 19, global step 2000: 'train_loss' reached 5.26062 (best 5.26062), saving model to '/mnt/scratch/kashif/pytorch-transformer-ts/switch/lightning_logs/version_1/checkpoints/epoch=19-step=2000.ckpt' as top 1\n",
"Epoch 20, global step 2100: 'train_loss' was not in top 1\n",
"Epoch 21, global step 2200: 'train_loss' reached 5.25029 (best 5.25029), saving model to '/mnt/scratch/kashif/pytorch-transformer-ts/switch/lightning_logs/version_1/checkpoints/epoch=21-step=2200.ckpt' as top 1\n",
"Epoch 22, global step 2300: 'train_loss' reached 5.23372 (best 5.23372), saving model to '/mnt/scratch/kashif/pytorch-transformer-ts/switch/lightning_logs/version_1/checkpoints/epoch=22-step=2300.ckpt' as top 1\n",
"Epoch 23, global step 2400: 'train_loss' was not in top 1\n",
"Epoch 24, global step 2500: 'train_loss' reached 5.21814 (best 5.21814), saving model to '/mnt/scratch/kashif/pytorch-transformer-ts/switch/lightning_logs/version_1/checkpoints/epoch=24-step=2500.ckpt' as top 1\n",
"Epoch 25, global step 2600: 'train_loss' was not in top 1\n",
"Epoch 26, global step 2700: 'train_loss' reached 5.21784 (best 5.21784), saving model to '/mnt/scratch/kashif/pytorch-transformer-ts/switch/lightning_logs/version_1/checkpoints/epoch=26-step=2700.ckpt' as top 1\n",
"Epoch 27, global step 2800: 'train_loss' reached 5.18084 (best 5.18084), saving model to '/mnt/scratch/kashif/pytorch-transformer-ts/switch/lightning_logs/version_1/checkpoints/epoch=27-step=2800.ckpt' as top 1\n",
"Epoch 28, global step 2900: 'train_loss' was not in top 1\n",
"Epoch 29, global step 3000: 'train_loss' was not in top 1\n",
"Epoch 30, global step 3100: 'train_loss' was not in top 1\n",
"Epoch 31, global step 3200: 'train_loss' was not in top 1\n",
"Epoch 32, global step 3300: 'train_loss' reached 5.16294 (best 5.16294), saving model to '/mnt/scratch/kashif/pytorch-transformer-ts/switch/lightning_logs/version_1/checkpoints/epoch=32-step=3300.ckpt' as top 1\n",
"Epoch 33, global step 3400: 'train_loss' was not in top 1\n",
"Epoch 34, global step 3500: 'train_loss' was not in top 1\n",
"Epoch 35, global step 3600: 'train_loss' was not in top 1\n",
"Epoch 36, global step 3700: 'train_loss' was not in top 1\n",
"Epoch 37, global step 3800: 'train_loss' was not in top 1\n",
"Epoch 38, global step 3900: 'train_loss' reached 5.16191 (best 5.16191), saving model to '/mnt/scratch/kashif/pytorch-transformer-ts/switch/lightning_logs/version_1/checkpoints/epoch=38-step=3900.ckpt' as top 1\n",
"Epoch 39, global step 4000: 'train_loss' was not in top 1\n",
"Epoch 40, global step 4100: 'train_loss' was not in top 1\n",
"Epoch 41, global step 4200: 'train_loss' reached 5.12905 (best 5.12905), saving model to '/mnt/scratch/kashif/pytorch-transformer-ts/switch/lightning_logs/version_1/checkpoints/epoch=41-step=4200.ckpt' as top 1\n",
"Epoch 42, global step 4300: 'train_loss' was not in top 1\n",
"Epoch 43, global step 4400: 'train_loss' reached 5.11128 (best 5.11128), saving model to '/mnt/scratch/kashif/pytorch-transformer-ts/switch/lightning_logs/version_1/checkpoints/epoch=43-step=4400.ckpt' as top 1\n",
"Epoch 44, global step 4500: 'train_loss' was not in top 1\n",
"Epoch 45, global step 4600: 'train_loss' was not in top 1\n",
"Epoch 46, global step 4700: 'train_loss' was not in top 1\n",
"Epoch 47, global step 4800: 'train_loss' was not in top 1\n",
"Epoch 48, global step 4900: 'train_loss' was not in top 1\n",
"Epoch 49, global step 5000: 'train_loss' was not in top 1\n"
"Epoch 19, global step 2000: 'train_loss' reached 5.24220 (best 5.24220), saving model to '/mnt/scratch/kashif/pytorch-transformer-ts/switch/lightning_logs/version_3/checkpoints/epoch=19-step=2000.ckpt' as top 1\n"
]
}
],
@@ -202,7 +173,7 @@
},
{
"cell_type": "code",
"execution_count": 6,
"execution_count": 20,
"id": "11a47d5a",
"metadata": {},
"outputs": [],
@@ -215,7 +186,7 @@
},
{
"cell_type": "code",
"execution_count": 7,
"execution_count": 21,
"id": "1492f7fb",
"metadata": {},
"outputs": [],
@@ -235,7 +206,7 @@
},
{
"cell_type": "code",
"execution_count": 9,
"execution_count": 22,
"id": "e00601c4",
"metadata": {},
"outputs": [],
@@ -245,7 +216,7 @@
},
{
"cell_type": "code",
"execution_count": 10,
"execution_count": 23,
"id": "9ed4c523",
"metadata": {},
"outputs": [
@@ -253,7 +224,8 @@
"name": "stderr",
"output_type": "stream",
"text": [
"Running evaluation: 2247it [00:00, 5481.81it/s]\n",
"\n",
"Running evaluation: 2247it [00:00, 4999.89it/s]\n",
"/home/kashif/.env/pytorch/lib/python3.8/site-packages/pandas/core/construction.py:781: UserWarning: Warning: converting a masked element to nan.\n",
" subarr = np.array(arr, dtype=dtype, copy=copy)\n"
]
@@ -266,7 +238,7 @@
{
"cell_type": "code",
"execution_count": 11,
"id": "2cb4abe2",
"id": "d3efc8ff",
"metadata": {},
"outputs": [
{
+14 -21
View File
@@ -38,7 +38,7 @@ from gluonts.torch.util import (
)
from gluonts.torch.model.estimator import PyTorchLightningEstimator
from gluonts.torch.model.predictor import PyTorchPredictor
from gluonts.torch.modules.distribution_output import (
from gluonts.torch.distributions import (
DistributionOutput,
StudentTOutput,
)
@@ -60,11 +60,11 @@ class TransformerModel(nn.Module):
cardinality: List[int],
embedding_dimension: Optional[List[int]] = None,
# Added transformer arguments
encoder = None,
decoder = None,
embeding = None,
target_embed = None ,
generator = None,
encoder=None,
decoder=None,
embeding=None,
target_embed=None,
generator=None,
#############################
dropout_rate: float = 0.1,
distr_output: DistributionOutput = StudentTOutput(),
@@ -73,7 +73,7 @@ class TransformerModel(nn.Module):
num_parallel_samples: int = 100,
) -> None:
super().__init__()
self.context_length = context_length
self.prediction_length = prediction_length
self.distr_output = distr_output
@@ -98,7 +98,7 @@ class TransformerModel(nn.Module):
self.scaler = MeanScaler(dim=1, keepdim=True)
else:
self.scaler = NOPScaler(dim=1, keepdim=True)
# Added transformer enc-decoder and mask initializer
self.encoder = encoder
self.decoder = decoder
@@ -106,16 +106,14 @@ class TransformerModel(nn.Module):
self.target_embed = target_embed
self.generator = generator
########################
# TODO
# add method that does the forward for training
"""
A build-in Encoder-Decoder architecture for TransformerModel class
"""
def forward(self, src, tgt, mask_source, mask_target):
"Take in and process masked sourcerc and target sequences."
memory = self.encoder(self.embeding(src), mask_source)
@@ -124,12 +122,10 @@ class TransformerModel(nn.Module):
def encode(self, src, mask_source):
return self.encoder(self.src_embed(src), mask_source)
def decode(self, memory, mask_source, tgt, mask_target):
return self.decoder(self.tgt_embed(tgt), memory, mask_source, mask_target)
@property
def _number_of_features(self) -> int:
return (
@@ -142,7 +138,7 @@ class TransformerModel(nn.Module):
@property
def _past_length(self) -> int:
return self.context_length + max(self.lags_seq)
# for prediction
def forward(
self,
@@ -156,8 +152,5 @@ class TransformerModel(nn.Module):
) -> torch.Tensor:
if num_parallel_samples is None:
num_parallel_samples = self.num_parallel_samples
# TODO
# TODO
+3 -3
View File
@@ -5,7 +5,7 @@
"id": "329d23e6",
"metadata": {},
"source": [
"# PyTorch Transformer Time Series Template\n",
"# PyTorch Transformer for Time Series Implementation\n",
"\n",
"The estimator consits of the:\n",
"\n",
@@ -106,7 +106,7 @@
")\n",
"from gluonts.torch.model.estimator import PyTorchLightningEstimator\n",
"from gluonts.torch.model.predictor import PyTorchPredictor\n",
"from gluonts.torch.modules.distribution_output import (\n",
"from gluonts.torch.distributions import (\n",
" DistributionOutput,\n",
" StudentTOutput,\n",
")\n",
@@ -597,7 +597,7 @@
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.9.12"
"version": "3.8.10"
}
},
"nbformat": 4,
+1 -1
View File
@@ -8,7 +8,7 @@ from gluonts.itertools import Cyclic, IterableSlice, PseudoShuffled
from gluonts.time_feature import TimeFeature, time_features_from_frequency_str
from gluonts.torch.model.estimator import PyTorchLightningEstimator
from gluonts.torch.model.predictor import PyTorchPredictor
from gluonts.torch.modules.distribution_output import DistributionOutput, StudentTOutput
from gluonts.torch.distributions import DistributionOutput, StudentTOutput
from gluonts.torch.modules.loss import DistributionLoss, NegativeLogLikelihood
from gluonts.torch.util import IterableDataset
from gluonts.transform import (
+1 -1
View File
@@ -4,7 +4,7 @@ import torch
import torch.nn as nn
from gluonts.core.component import validated
from gluonts.time_feature import get_lags_for_frequency
from gluonts.torch.modules.distribution_output import DistributionOutput, StudentTOutput
from gluonts.torch.distributions import DistributionOutput, StudentTOutput
from gluonts.torch.modules.feature import FeatureEmbedder as BaseFeatureEmbedder
from gluonts.torch.modules.scaler import MeanScaler, NOPScaler
+275 -261
View File
File diff suppressed because one or more lines are too long
+1 -1
View File
@@ -8,7 +8,7 @@ from gluonts.itertools import Cyclic, IterableSlice, PseudoShuffled
from gluonts.time_feature import TimeFeature, time_features_from_frequency_str
from gluonts.torch.model.estimator import PyTorchLightningEstimator
from gluonts.torch.model.predictor import PyTorchPredictor
from gluonts.torch.modules.distribution_output import DistributionOutput, StudentTOutput
from gluonts.torch.distributions import DistributionOutput, StudentTOutput
from gluonts.torch.modules.loss import DistributionLoss, NegativeLogLikelihood
from gluonts.torch.util import IterableDataset
from gluonts.transform import (
+1 -1
View File
@@ -4,7 +4,7 @@ import torch
import torch.nn as nn
from gluonts.core.component import validated
from gluonts.time_feature import get_lags_for_frequency
from gluonts.torch.modules.distribution_output import DistributionOutput, StudentTOutput
from gluonts.torch.distributions import DistributionOutput, StudentTOutput
from gluonts.torch.modules.feature import FeatureEmbedder
from gluonts.torch.modules.scaler import MeanScaler, NOPScaler
+1577 -1731
View File
File diff suppressed because one or more lines are too long