Clean up(still not working)

This commit is contained in:
Hstellar
2022-04-19 16:29:30 -04:00
parent d72782f689
commit c7afec7ac5
15 changed files with 502 additions and 1663 deletions
BIN
View File
Binary file not shown.
+6 -7
View File
@@ -1,10 +1,9 @@
from .estimator import PyraformerEstimator
from .lightning_module import PyraformerLightningModule
from .module import PyraformerLRModel, PyraformerSSModel
from .estimator import TransformerEstimator
from .lightning_module import TransformerLightningModule
from .module import TransformerModel
__all__ = [
"PyraformerSSModel",
"PyraformerLRModel",
"PyraformerLightningModule",
"PyraformerEstimator",
"TransformerModel",
"TransformerLightningModule",
"TransformerEstimator",
]
+32 -53
View File
@@ -5,15 +5,12 @@ from gluonts.core.component import validated
from gluonts.dataset.common import Dataset
from gluonts.dataset.field_names import FieldName
from gluonts.itertools import Cyclic, IterableSlice, PseudoShuffled
from gluonts.time_feature import (
TimeFeature,
get_lags_for_frequency,
time_features_from_frequency_str,
)
from gluonts.time_feature import TimeFeature, time_features_from_frequency_str
from gluonts.torch.model.estimator import PyTorchLightningEstimator
from gluonts.torch.model.predictor import PyTorchPredictor
from gluonts.torch.modules.distribution_output import DistributionOutput, StudentTOutput
from gluonts.torch.modules.loss import DistributionLoss, NegativeLogLikelihood
from gluonts.time_feature import get_lags_for_frequency
from gluonts.torch.util import IterableDataset
from gluonts.transform import (
AddAgeFeature,
@@ -32,13 +29,12 @@ from gluonts.transform import (
VstackFeatures,
)
from gluonts.transform.sampler import InstanceSampler
from torch.utils.data import DataLoader
from torch.utils.data.sampler import RandomSampler
from lightning_module import PyraformerLightningModule
from module import PyraformerLRModel, PyraformerSSModel
from module import PyraformerSSModel
from module import PyraformerLRModel
from torch.utils.data import DataLoader
from tools import SingleStepLoss as LossFactory
from torch.utils.data.sampler import RandomSampler
PREDICTION_INPUT_NAMES = [
"feat_static_cat",
"feat_static_real",
@@ -60,28 +56,30 @@ class PyraformerEstimator(PyTorchLightningEstimator):
self,
freq: str,
prediction_length: int,
# Train parameters
#Train parameters
inner_batch: int = 8,
lr: float = 1e-5,
visualize_fre: int = 2000,
pretrain: bool = True,
hard_sample_mining: bool = True,
hard_sample_mining:bool=True,
covariate_size: int = 3,
# Model parameters
num_seq: int = 370, #
decoder: str = "FC", # selection: [FC, attention]
# Model parameters
num_seq: int = 370,#
decoder: str = 'FC',# selection: [FC, attention]
context_length: Optional[int] = None,
input_size: int = 1,
dropout: float = 0.1,
d_model: int = 512,
d_inner_hid: int = 512,
d_k: int = 128,
d_v: int = 128,
d_v:int = 128,
num_heads: int = 4,
n_layer: int = 4,
# loss: DistributionLoss = LossFactory,
ignore_zero: bool = True,
single_step: bool = True, # if False, Multistep=True
single_step: bool = True,#if False, Multistep=True
inner_size: int = 3,
use_tvm: bool = False,
num_feat_dynamic_real: int = 0,
@@ -100,7 +98,7 @@ class PyraformerEstimator(PyTorchLightningEstimator):
trainer_kwargs: Optional[Dict[str, Any]] = dict(),
train_sampler: Optional[InstanceSampler] = None,
validation_sampler: Optional[InstanceSampler] = None,
window_size: int = [4, 4, 4],
window_size: int = [4, 4, 4]
) -> None:
trainer_kwargs = {
"max_epochs": 10,
@@ -127,15 +125,11 @@ class PyraformerEstimator(PyTorchLightningEstimator):
self.n_layer = n_layer
self.single_step = single_step
self.ignore_zero = ignore_zero
self.loss = (
LossFactory(self.ignore_zero)
if self.single_step == True
else torch.nn.MSELoss(reduction="none")
)
self.loss = LossFactory(self.ignore_zero) if self.single_step==True else torch.nn.MSELoss(reduction='none')
self.batch_size = batch_size
self.distr_output = distr_output
self.window_size = window_size # [4,4,4]#window_size
self.window_size = window_size#[4,4,4]#window_size
self.inner_size = inner_size
self.use_tvm = use_tvm
self.prediction_length = prediction_length
@@ -325,35 +319,18 @@ class PyraformerEstimator(PyTorchLightningEstimator):
def create_lightning_module(self) -> PyraformerLightningModule:
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
if self.single_step:
model = PyraformerSSModel(
freq=self.freq,
covariate_size=self.covariate_size,
num_seq=self.num_seq,
input_size=self.input_size,
dropout=self.dropout,
d_model=self.d_model,
d_inner_hid=self.d_inner_hid,
d_k=self.d_k,
d_v=self.d_v,
num_heads=self.num_heads,
n_layer=self.n_layer,
loss=self.loss,
window_size=self.window_size,
inner_size=self.inner_size,
use_tvm=self.use_tvm,
prediction_length=self.prediction_length,
context_length=self.context_length,
lags_seq=self.lags_seq,
num_feat_dynamic_real=self.num_feat_dynamic_real,
num_feat_static_cat=self.num_feat_static_cat,
num_feat_static_real=self.num_feat_static_real,
cardinality=self.cardinality,
embedding_dimension=self.embedding_dimension,
distr_output=self.distr_output,
scaling=self.scaling,
num_parallel_samples=self.num_parallel_samples,
device=device,
)
model = PyraformerSSModel(freq= self.freq, covariate_size = self.covariate_size,
num_seq=self.num_seq, input_size = self.input_size, dropout = self.dropout, d_model = self.d_model,
d_inner_hid = self.d_inner_hid, d_k = self.d_k, d_v = self.d_v,
num_heads = self.num_heads, n_layer = self.n_layer, loss = self.loss,
window_size = self.window_size, inner_size = self.inner_size,
use_tvm = self.use_tvm, prediction_length = self.prediction_length,context_length = self.context_length, lags_seq = self.lags_seq, num_feat_dynamic_real= self.num_feat_dynamic_real,
num_feat_static_cat = self.num_feat_static_cat,
num_feat_static_real = self.num_feat_static_real,
cardinality = self.cardinality,
embedding_dimension = self.embedding_dimension,
distr_output=self.distr_output,
scaling=self.scaling,num_parallel_samples=self.num_parallel_samples, device=device)
# else:
# model = PyraformerLRModel(freq= self.freq, covariate_size = self.covariate_size,
# num_seq=self.num_seq, input_size = self.input_size, dropout = self.dropout, d_model = self.d_model,
@@ -362,3 +339,5 @@ class PyraformerEstimator(PyTorchLightningEstimator):
# window_size = self.window_size, inner_size = self.inner_size,
# use_tvm = self.use_tvm, prediction_length = self.prediction_length,context_length = self.context_length, lags_seq = self.lags_seq, device=device)
return PyraformerLightningModule(model=model, loss=self.loss)
+64 -67
View File
@@ -2,84 +2,81 @@ import pytorch_lightning as pl
import torch
from gluonts.torch.modules.loss import DistributionLoss, NegativeLogLikelihood
from gluonts.torch.util import weighted_average
from module import PyraformerLRModel, PyraformerSSModel
from tools import AE_loss
from module import PyraformerSSModel
from module import PyraformerLRModel
from tools import SingleStepLoss as LossFactory
from tools import AE_loss
# from module import PyraformerModel
class PyraformerLightningModule(pl.LightningModule):
def __init__(
self,
model: PyraformerSSModel,
loss: DistributionLoss = LossFactory,
lr: float = 1e-5,
weight_decay: float = 1e-8,
) -> None:
super().__init__()
self.save_hyperparameters()
self.model = model
self.loss = loss
self.lr = lr
self.weight_decay = weight_decay
def __init__(self, model: PyraformerSSModel, loss: DistributionLoss = LossFactory, lr: float = 1e-5, weight_decay: float = 1e-8,) -> None:
super().__init__()
self.save_hyperparameters()
self.model = model
self.loss = loss
self.lr = lr
self.weight_decay = weight_decay
def training_step(self, batch, batch_idx: int):
def training_step(self, batch, batch_idx: int):
"""Execute training step"""
train_loss = self(batch)
self.log(
"train_loss",
train_loss,
on_epoch=True,
on_step=False,
prog_bar=True,
)
return train_loss
"""Execute training step"""
train_loss = self(batch)
self.log(
"train_loss",
train_loss,
on_epoch=True,
on_step=False,
prog_bar=True,
)
return train_loss
def validation_step(self, batch, batch_idx: int):
"""Execute validation step"""
with torch.inference_mode():
val_loss = self(batch)
self.log("val_loss", val_loss, on_epoch=True, on_step=False, prog_bar=True)
return val_loss
def validation_step(self, batch, batch_idx: int):
"""Execute validation step"""
with torch.inference_mode():
val_loss = self(batch)
self.log("val_loss", val_loss, on_epoch=True, on_step=False, prog_bar=True)
return val_loss
def configure_optimizers(self):
"""Returns the optimizer to use"""
return torch.optim.Adam(
self.model.parameters(),
lr=self.lr,
weight_decay=self.weight_decay,
)
def configure_optimizers(self):
"""Returns the optimizer to use"""
return torch.optim.Adam(
self.model.parameters(),
lr=self.lr,
weight_decay=self.weight_decay,
)
def forward(self, batch):
feat_static_cat = batch["feat_static_cat"]
feat_static_real = batch["feat_static_real"]
past_time_feat = batch["past_time_feat"]
past_target = batch["past_target"]
future_time_feat = batch["future_time_feat"]
future_target = batch["future_target"]
past_observed_values = batch["past_observed_values"]
future_observed_values = batch["future_observed_values"]
def forward(self, batch):
feat_static_cat = batch["feat_static_cat"]
feat_static_real = batch["feat_static_real"]
past_time_feat = batch["past_time_feat"]
past_target = batch["past_target"]
future_time_feat = batch["future_time_feat"]
future_target = batch["future_target"]
past_observed_values = batch["past_observed_values"]
future_observed_values = batch["future_observed_values"]
Pyraformer_inputs, scale, _ = self.model.create_network_inputs(
feat_static_cat,
feat_static_real,
past_time_feat,
past_target,
past_observed_values,
future_time_feat,
future_target,
)
params = self.model.output_params(Pyraformer_inputs)
distr = self.model.output_distribution(params, scale)
loss_values = self.loss(distr, future_target)
if len(self.model.target_shape) == 0:
loss_weights = future_observed_values
else:
loss_weights = future_observed_values.min(dim=-1, keepdim=False)
return weighted_average(loss_values, weights=loss_weights)
Pyraformer_inputs, scale, _ = self.model.create_network_inputs(
feat_static_cat,
feat_static_real,
past_time_feat,
past_target,
past_observed_values,
future_time_feat,
future_target,
)
params = self.model.output_params(Pyraformer_inputs)
distr = self.model.output_distribution(params, scale)
loss_values = self.loss(distr, future_target)
if len(self.model.target_shape) == 0:
loss_weights = future_observed_values
else:
loss_weights = future_observed_values.min(dim=-1, keepdim=False)
return weighted_average(loss_values, weights=loss_weights)
+93 -617
View File
@@ -1,5 +1,4 @@
from typing import List, Optional
import torch
import torch.nn as nn
from gluonts.core.component import validated
@@ -8,45 +7,20 @@ from gluonts.torch.modules.distribution_output import DistributionOutput, Studen
from gluonts.torch.modules.feature import FeatureEmbedder
from gluonts.torch.modules.scaler import MeanScaler, NOPScaler
from pyraformer.embed import CustomEmbedding, DataEmbedding, SingleStepEmbedding
from pyraformer.Layers import (
AvgPooling_Construct,
Bottleneck_Construct,
Conv_Construct,
Decoder,
EncoderLayer,
MaxPooling_Construct,
Predictor,
get_k_q,
get_mask,
get_q_k,
get_subsequent_mask,
refer_points,
)
from pyraformer.Layers import EncoderLayer, Predictor, Decoder
from pyraformer.Layers import Bottleneck_Construct, Conv_Construct, MaxPooling_Construct, AvgPooling_Construct
from pyraformer.Layers import get_mask, refer_points, get_k_q, get_q_k, get_subsequent_mask
from pyraformer.embed import SingleStepEmbedding, DataEmbedding, CustomEmbedding
class EncoderSS(nn.Module):
"""A encoder model with self attention mechanism."""
def __init__(
self,
covariate_size,
num_seq,
input_size,
dropout,
d_model,
d_inner_hid,
d_k,
d_v,
num_heads,
n_layer,
loss,
window_size,
inner_size,
use_tvm,
prediction_length,
device,
):
@validated()
def __init__(self, covariate_size,
num_seq, input_size ,dropout , d_model,
d_inner_hid, d_k, d_v ,
num_heads , n_layer, loss ,
window_size , inner_size,
use_tvm, prediction_length, device):
super().__init__()
self.d_model = d_model
@@ -56,49 +30,21 @@ class EncoderSS(nn.Module):
self.indexes = refer_points(self.all_size, window_size, device)
if use_tvm:
assert (
len(set(self.window_size)) == 1
), "Only constant window size is supported."
assert len(set(self.window_size)) == 1, "Only constant window size is supported."
q_k_mask = get_q_k(input_size, inner_size, window_size[0], device)
k_q_mask = get_k_q(q_k_mask)
self.layers = nn.ModuleList(
[
EncoderLayer(
d_model,
d_inner_hid,
num_heads,
d_k,
d_v,
dropout=dropout,
normalize_before=False,
use_tvm=True,
q_k_mask=q_k_mask,
k_q_mask=k_q_mask,
)
for i in range(n_layer)
]
)
self.layers = nn.ModuleList([
EncoderLayer(d_model, d_inner_hid, num_heads, d_k, d_v, dropout=dropout, \
normalize_before=False, use_tvm=True, q_k_mask=q_k_mask, k_q_mask=k_q_mask) for i in range(n_layer)
])
else:
self.layers = nn.ModuleList(
[
EncoderLayer(
d_model,
d_inner_hid,
num_heads,
d_k,
d_v,
dropout=dropout,
normalize_before=False,
)
for i in range(n_layer)
]
)
self.embedding = SingleStepEmbedding(
covariate_size, num_seq, d_model, input_size, device
)
self.layers = nn.ModuleList([
EncoderLayer(d_model, d_inner_hid, num_heads, d_k, d_v, dropout=dropout, \
normalize_before=False) for i in range(n_layer)
])
self.embedding = SingleStepEmbedding(covariate_size, num_seq, d_model, input_size, device)
self.conv_layers = Bottleneck_Construct(d_model, window_size, d_k)
def forward(self, sequence):
@@ -111,38 +57,21 @@ class EncoderSS(nn.Module):
for i in range(len(self.layers)):
seq_enc, _ = self.layers[i](seq_enc, mask)
indexes = self.indexes.repeat(seq_enc.size(0), 1, 1, seq_enc.size(2)).to(
seq_enc.device
)
indexes = self.indexes.repeat(seq_enc.size(0), 1, 1, seq_enc.size(2)).to(seq_enc.device)
indexes = indexes.view(seq_enc.size(0), -1, seq_enc.size(2))
all_enc = torch.gather(seq_enc, 1, indexes)
all_enc = all_enc.view(seq_enc.size(0), self.all_size[0], -1)
return all_enc
class PyraformerSSModel(nn.Module):
@validated()
def __init__(
self,
freq,
covariate_size,
num_seq,
input_size,
dropout,
d_model,
d_inner_hid,
d_k,
d_v,
num_heads,
n_layer,
loss,
window_size,
inner_size,
use_tvm,
prediction_length,
context_length,
lags_seq,
def __init__(self, freq, covariate_size,
num_seq, input_size ,dropout, d_model,
d_inner_hid, d_k, d_v,
num_heads , n_layer, loss ,
window_size , inner_size,
use_tvm, prediction_length,context_length,lags_seq,
num_feat_dynamic_real,
num_feat_static_cat,
num_feat_static_real,
@@ -150,39 +79,25 @@ class PyraformerSSModel(nn.Module):
embedding_dimension,
distr_output,
# loss: DistributionLoss = NegativeLogLikelihood(),
scaling,
num_parallel_samples,
device,
):
scaling,num_parallel_samples,device):
super().__init__()
self.context_length = context_length
self.lags_seq = lags_seq or get_lags_for_frequency(freq_str=freq)
self.encoder = EncoderSS(
covariate_size,
num_seq,
input_size,
dropout,
d_model,
d_inner_hid,
d_k,
d_v,
num_heads,
n_layer,
loss,
window_size,
inner_size,
use_tvm,
prediction_length,
device,
)
self.encoder = EncoderSS(covariate_size,
num_seq, input_size ,dropout , d_model,
d_inner_hid, d_k, d_v ,
num_heads , n_layer, loss ,
window_size , inner_size,
use_tvm, prediction_length, device)
# convert hidden vectors into two scalar
# convert hidden vectors into two scalar
self.mean_hidden = Predictor(4 * d_model, 1)
self.var_hidden = Predictor(4 * d_model, 1)
self.softplus = nn.Softplus()
def forward(self, data):
enc_output = self.encoder(data)
@@ -199,11 +114,10 @@ class PyraformerSSModel(nn.Module):
sample_mu = mu[:, -1] * v
sample_sigma = sigma[:, -1] * v
return sample_mu, sample_sigma
@property
def _past_length(self) -> int:
return self.context_length + max(self.lags_seq)
@property
def _number_of_features(self) -> int:
return (
@@ -212,7 +126,6 @@ class PyraformerSSModel(nn.Module):
+ self.num_feat_static_real
+ 1 # the log(scale)
)
def get_lagged_subsequences(
self, sequence: torch.Tensor, subsequences_length: int, shift: int = 0
) -> torch.Tensor:
@@ -265,7 +178,7 @@ class PyraformerSSModel(nn.Module):
assert (
features is None or features.shape[2] == self._number_of_features
), f"{features.shape[2]}, expected {self._number_of_features}"
def create_network_inputs(
self,
feat_static_cat: torch.Tensor,
@@ -361,149 +274,70 @@ class PyraformerSSModel(nn.Module):
if trailing_n is not None:
sliced_params = [p[:, -trailing_n:] for p in params]
return self.distr_output.distribution(sliced_params, scale=scale)
class Encoder(nn.Module):
"""A encoder model with self attention mechanism."""
def __init__(
self,
model,
window_size,
truncate,
input_size,
inner_size,
decoder,
num_head,
d_model,
d_k,
d_v,
d_inner_hid,
dropout,
n_layer,
enc_in,
covariate_size,
seq_num,
CSCM,
d_bottleneck,
num_head,
use_tvm,
device,
):
@validated()
def __init__(self, model, window_size, truncate, input_size, inner_size,decoder,d_model, d_k,d_v,d_inner_hid, dropout,n_layer ,enc_in,covariate_size,seq_num,CSCM,d_bottleneck,num_head,use_tvm,device):
super().__init__()
self.d_model = d_model
self.model_type = model
self.window_size = window_size
self.truncate = truncate
if decoder == "attention":
self.mask, self.all_size = get_mask(
input_size, window_size, inner_size, device
)
if decoder == 'attention':
self.mask, self.all_size = get_mask(input_size, window_size, inner_size, device)
else:
self.mask, self.all_size = get_mask(
input_size + 1, window_size, inner_size, device
)
self.mask, self.all_size = get_mask(input_size+1, window_size, inner_size, device)
self.decoder_type = decoder
if decoder == "FC":
if decoder == 'FC':
self.indexes = refer_points(self.all_size, window_size, device)
if use_tvm:
assert (
len(set(self.window_size)) == 1
), "Only constant window size is supported."
padding = 1 if decoder == "FC" else 0
q_k_mask = get_q_k(input_size + padding, inner_size, window_size[0], device)
assert len(set(self.window_size)) == 1, "Only constant window size is supported."
padding = 1 if decoder == 'FC' else 0
q_k_mask = get_q_k(input_size + padding, inner_size, window_size[0],device)
k_q_mask = get_k_q(q_k_mask)
self.layers = nn.ModuleList(
[
EncoderLayer(
d_model,
d_inner_hid,
num_head,
d_k,
d_v,
dropout=dropout,
normalize_before=False,
use_tvm=True,
q_k_mask=q_k_mask,
k_q_mask=k_q_mask,
)
for i in range(n_layer)
]
)
self.layers = nn.ModuleList([
EncoderLayer(d_model, d_inner_hid, num_head, d_k, d_v, dropout=dropout, \
normalize_before=False, use_tvm=True, q_k_mask=q_k_mask, k_q_mask=k_q_mask) for i in range(n_layer)
])
else:
self.layers = nn.ModuleList(
[
EncoderLayer(
d_model,
d_inner_hid,
num_head,
d_k,
d_v,
dropout=dropout,
normalize_before=False,
)
for i in range(n_layer)
]
)
if opt.embed_type == "CustomEmbedding":
self.enc_embedding = CustomEmbedding(
enc_in, d_model, covariate_size, seq_num, dropout
)
self.layers = nn.ModuleList([
EncoderLayer(d_model, d_inner_hid, num_head, d_k, d_v, dropout=dropout, \
normalize_before=False) for i in range(n_layer)
])
if opt.embed_type == 'CustomEmbedding':
self.enc_embedding = CustomEmbedding(enc_in, d_model, covariate_size, seq_num, dropout)
else:
self.enc_embedding = DataEmbedding(enc_in, d_model, dropout)
self.conv_layers = eval(CSCM)(d_model, window_size, d_bottleneck)
def forward(self, x_enc, x_mark_enc):
seq_enc = self.enc_embedding(x_enc, x_mark_enc)
mask = self.mask.repeat(len(seq_enc), 1, 1).to(x_enc.device)
seq_enc = self.conv_layers(seq_enc)
def forward(self, x_enc, x_mark_enc):
seq_enc = self.enc_embedding(x_enc, x_mark_enc)
mask = self.mask.repeat(len(seq_enc), 1, 1).to(x_enc.device)
seq_enc = self.conv_layers(seq_enc)
for i in range(len(self.layers)):
seq_enc, _ = self.layers[i](seq_enc, mask)
if self.decoder_type == "FC":
indexes = self.indexes.repeat(
seq_enc.size(0), 1, 1, seq_enc.size(2)
).to(seq_enc.device)
indexes = indexes.view(seq_enc.size(0), -1, seq_enc.size(2))
all_enc = torch.gather(seq_enc, 1, indexes)
seq_enc = all_enc.view(seq_enc.size(0), self.all_size[0], -1)
elif self.decoder_type == "attention" and self.truncate:
seq_enc = seq_enc[:, : self.all_size[0]]
return seq_enc
for i in range(len(self.layers)):
seq_enc, _ = self.layers[i](seq_enc, mask)
if self.decoder_type == 'FC':
indexes = self.indexes.repeat(seq_enc.size(0), 1, 1, seq_enc.size(2)).to(seq_enc.device)
indexes = indexes.view(seq_enc.size(0), -1, seq_enc.size(2))
all_enc = torch.gather(seq_enc, 1, indexes)
seq_enc = all_enc.view(seq_enc.size(0), self.all_size[0], -1)
elif self.decoder_type == 'attention' and self.truncate:
seq_enc = seq_enc[:, :self.all_size[0]]
return seq_enc
class PyraformerLRModel(nn.Module):
@validated()
def __init__(
self,
predict_step,
d_model,
input_size,
decoder,
window_size,
truncate,
model,
d_inner_hid,
num_head,
d_k,
d_v,
dropout,
enc_in,
covariate_size,
seq_num,
CSCM,
d_bottleneck,
num_head,
use_tvm,
device,
):
def __init__(self, predict_step, d_model, input_size, decoder, window_size, truncate, model,d_inner_hid,d_k,d_v,dropout,enc_in,covariate_size,seq_num,CSCM,d_bottleneck,num_head,use_tvm,device):
super().__init__()
self.predict_step = predict_step
@@ -511,46 +345,13 @@ class PyraformerLRModel(nn.Module):
self.input_size = input_size
self.decoder_type = decoder
self.channels = enc_in
self.encoder = Encoder(
model,
window_size,
truncate,
input_size,
inner_size,
decoder,
d_model,
d_k,
d_v,
d_inner_hid,
dropout,
n_layer,
enc_in,
covariate_size,
seq_num,
CSCM,
d_bottleneck,
num_head,
use_tvm,
device,
)
if decoder == "attention":
self.encoder = Encoder(model, window_size, truncate, input_size, inner_size,decoder,d_model, d_k,d_v,d_inner_hid, dropout,n_layer, enc_in,covariate_size,seq_num,CSCM,d_bottleneck,num_head,use_tvm,device)
if decoder == 'attention':
mask = get_subsequent_mask(input_size, window_size, predict_step, truncate)
self.decoder = Decoder(
model,
d_model,
d_inner_hid,
num_head,
d_k,
d_v,
dropout,
enc_in,
covariate_size,
seq_num,
mask,
)
self.decoder = Decoder(model,d_model,d_inner_hid,num_head,d_k,d_v,dropout,enc_in,covariate_size,seq_num, mask)
self.predictor = Predictor(d_model, enc_in)
elif opt.decoder == "FC":
elif opt.decoder == 'FC':
self.predictor = Predictor(4 * d_model, predict_step * enc_in)
def forward(self, x_enc, x_mark_enc, x_dec, x_mark_dec, pretrain):
@@ -563,343 +364,18 @@ class PyraformerLRModel(nn.Module):
type_prediction: batch*seq_len*num_classes (not normalized);
time_prediction: batch*seq_len.
"""
if self.decoder_type == "attention":
if self.decoder_type == 'attention':
enc_output = self.encoder(x_enc, x_mark_enc)
dec_enc = self.decoder(x_dec, x_mark_dec, enc_output)
if pretrain:
dec_enc = torch.cat([enc_output[:, : self.input_size], dec_enc], dim=1)
dec_enc = torch.cat([enc_output[:, :self.input_size], dec_enc], dim=1)
pred = self.predictor(dec_enc)
else:
pred = self.predictor(dec_enc)
elif self.decoder_type == "FC":
elif self.decoder_type == 'FC':
enc_output = self.encoder(x_enc, x_mark_enc)[:, -1, :]
pred = self.predictor(enc_output).view(
enc_output.size(0), self.predict_step, -1
)
pred = self.predictor(enc_output).view(enc_output.size(0), self.predict_step, -1)
return pred
class TransformerModel(nn.Module):
@validated()
def __init__(
self,
freq: str,
context_length: int,
prediction_length: int,
num_feat_dynamic_real: int,
num_feat_static_real: int,
num_feat_static_cat: int,
cardinality: List[int],
# transformer arguments
nhead: int,
num_encoder_layers: int,
num_decoder_layers: int,
dim_feedforward: int,
activation: str = "gelu",
dropout: float = 0.1,
# univariate input
input_size: int = 1,
embedding_dimension: Optional[List[int]] = None,
distr_output: DistributionOutput = StudentTOutput(),
lags_seq: Optional[List[int]] = None,
scaling: bool = True,
num_parallel_samples: int = 100,
) -> None:
super().__init__()
self.input_size = input_size
self.target_shape = distr_output.event_shape
self.num_feat_dynamic_real = num_feat_dynamic_real
self.num_feat_static_cat = num_feat_static_cat
self.num_feat_static_real = num_feat_static_real
self.embedding_dimension = (
embedding_dimension
if embedding_dimension is not None or cardinality is None
else [min(50, (cat + 1) // 2) for cat in cardinality]
)
self.lags_seq = lags_seq or get_lags_for_frequency(freq_str=freq)
self.num_parallel_samples = num_parallel_samples
self.history_length = context_length + max(self.lags_seq)
self.embedder = FeatureEmbedder(
cardinalities=cardinality,
embedding_dims=self.embedding_dimension,
)
if scaling:
self.scaler = MeanScaler(dim=1, keepdim=True)
else:
self.scaler = NOPScaler(dim=1, keepdim=True)
# total feature size
d_model = self.input_size * len(self.lags_seq) + self._number_of_features
self.context_length = context_length
self.prediction_length = prediction_length
self.distr_output = distr_output
self.param_proj = distr_output.get_args_proj(d_model)
# transformer enc-decoder and mask initializer
self.transformer = nn.Transformer(
d_model=d_model,
nhead=nhead,
num_encoder_layers=num_encoder_layers,
num_decoder_layers=num_decoder_layers,
dim_feedforward=dim_feedforward,
dropout=dropout,
activation=activation,
batch_first=True,
)
# causal decoder tgt mask
self.register_buffer(
"tgt_mask",
self.transformer.generate_square_subsequent_mask(prediction_length),
)
@property
def _number_of_features(self) -> int:
return (
sum(self.embedding_dimension)
+ self.num_feat_dynamic_real
+ self.num_feat_static_real
+ 1 # the log(scale)
)
@property
def _past_length(self) -> int:
return self.context_length + max(self.lags_seq)
def get_lagged_subsequences(
self, sequence: torch.Tensor, subsequences_length: int, shift: int = 0
) -> torch.Tensor:
"""
Returns lagged subsequences of a given sequence.
Parameters
----------
sequence : Tensor
the sequence from which lagged subsequences should be extracted.
Shape: (N, T, C).
subsequences_length : int
length of the subsequences to be extracted.
shift: int
shift the lags by this amount back.
Returns
--------
lagged : Tensor
a tensor of shape (N, S, C, I), where S = subsequences_length and
I = len(indices), containing lagged subsequences. Specifically,
lagged[i, j, :, k] = sequence[i, -indices[k]-S+j, :].
"""
sequence_length = sequence.shape[1]
indices = [lag - shift for lag in self.lags_seq]
assert max(indices) + subsequences_length <= sequence_length, (
f"lags cannot go further than history length, found lag {max(indices)} "
f"while history length is only {sequence_length}"
)
lagged_values = []
for lag_index in indices:
begin_index = -lag_index - subsequences_length
end_index = -lag_index if lag_index > 0 else None
lagged_values.append(sequence[:, begin_index:end_index, ...])
return torch.stack(lagged_values, dim=-1)
def _check_shapes(
self,
prior_input: torch.Tensor,
inputs: torch.Tensor,
features: Optional[torch.Tensor],
) -> None:
assert len(prior_input.shape) == len(inputs.shape)
assert (
len(prior_input.shape) == 2 and self.input_size == 1
) or prior_input.shape[2] == self.input_size
assert (len(inputs.shape) == 2 and self.input_size == 1) or inputs.shape[
-1
] == self.input_size
assert (
features is None or features.shape[2] == self._number_of_features
), f"{features.shape[2]}, expected {self._number_of_features}"
def create_network_inputs(
self,
feat_static_cat: torch.Tensor,
feat_static_real: torch.Tensor,
past_time_feat: torch.Tensor,
past_target: torch.Tensor,
past_observed_values: torch.Tensor,
future_time_feat: Optional[torch.Tensor] = None,
future_target: Optional[torch.Tensor] = None,
):
# time feature
time_feat = (
torch.cat(
(
past_time_feat[:, self._past_length - self.context_length :, ...],
future_time_feat,
),
dim=1,
)
if future_target is not None
else past_time_feat[:, self._past_length - self.context_length :, ...]
)
# target
context = past_target[:, -self.context_length :]
observed_context = past_observed_values[:, -self.context_length :]
_, scale = self.scaler(context, observed_context)
inputs = (
torch.cat((past_target, future_target), dim=1) / scale
if future_target is not None
else past_target / scale
)
inputs_length = (
self._past_length + self.prediction_length
if future_target is not None
else self._past_length
)
assert inputs.shape[1] == inputs_length
subsequences_length = (
self.context_length + self.prediction_length
if future_target is not None
else self.context_length
)
# embeddings
embedded_cat = self.embedder(feat_static_cat)
static_feat = torch.cat(
(embedded_cat, feat_static_real, scale.log()),
dim=1,
)
expanded_static_feat = static_feat.unsqueeze(1).expand(
-1, time_feat.shape[1], -1
)
features = torch.cat((expanded_static_feat, time_feat), dim=-1)
# self._check_shapes(prior_input, inputs, features)
# sequence = torch.cat((prior_input, inputs), dim=1)
lagged_sequence = self.get_lagged_subsequences(
sequence=inputs,
subsequences_length=subsequences_length,
)
lags_shape = lagged_sequence.shape
reshaped_lagged_sequence = lagged_sequence.reshape(
lags_shape[0], lags_shape[1], -1
)
transformer_inputs = torch.cat((reshaped_lagged_sequence, features), dim=-1)
return transformer_inputs, scale, static_feat
def output_params(self, transformer_inputs):
enc_input = transformer_inputs[:, : self.context_length, ...]
dec_input = transformer_inputs[:, self.context_length :, ...]
enc_out = self.transformer.encoder(enc_input)
dec_output = self.transformer.decoder(
dec_input, enc_out, tgt_mask=self.tgt_mask
)
return self.param_proj(dec_output)
@torch.jit.ignore
def output_distribution(
self, params, scale=None, trailing_n=None
) -> torch.distributions.Distribution:
sliced_params = params
if trailing_n is not None:
sliced_params = [p[:, -trailing_n:] for p in params]
return self.distr_output.distribution(sliced_params, scale=scale)
# for prediction
def forward(
self,
feat_static_cat: torch.Tensor,
feat_static_real: torch.Tensor,
past_time_feat: torch.Tensor,
past_target: torch.Tensor,
past_observed_values: torch.Tensor,
future_time_feat: torch.Tensor,
num_parallel_samples: Optional[int] = None,
) -> torch.Tensor:
if num_parallel_samples is None:
num_parallel_samples = self.num_parallel_samples
encoder_inputs, scale, static_feat = self.create_network_inputs(
feat_static_cat,
feat_static_real,
past_time_feat,
past_target,
past_observed_values,
)
enc_out = self.transformer.encoder(encoder_inputs)
repeated_scale = scale.repeat_interleave(
repeats=self.num_parallel_samples, dim=0
)
repeated_past_target = (
past_target.repeat_interleave(repeats=self.num_parallel_samples, dim=0)
/ repeated_scale
)
expanded_static_feat = static_feat.unsqueeze(1).expand(
-1, future_time_feat.shape[1], -1
)
features = torch.cat((expanded_static_feat, future_time_feat), dim=-1)
repeated_features = features.repeat_interleave(
repeats=self.num_parallel_samples, dim=0
)
repeated_enc_out = enc_out.repeat_interleave(
repeats=self.num_parallel_samples, dim=0
)
future_samples = []
# greedy decoding
for k in range(self.prediction_length):
# self._check_shapes(repeated_past_target, next_sample, next_features)
# sequence = torch.cat((repeated_past_target, next_sample), dim=1)
lagged_sequence = self.get_lagged_subsequences(
sequence=repeated_past_target,
subsequences_length=1 + k,
shift=1,
)
lags_shape = lagged_sequence.shape
reshaped_lagged_sequence = lagged_sequence.reshape(
lags_shape[0], lags_shape[1], -1
)
decoder_input = torch.cat(
(reshaped_lagged_sequence, repeated_features[:, : k + 1]), dim=-1
)
output = self.transformer.decoder(decoder_input, repeated_enc_out)
params = self.param_proj(output[:, -1:])
distr = self.output_distribution(params, scale=repeated_scale)
next_sample = distr.sample()
repeated_past_target = torch.cat(
(repeated_past_target, next_sample / repeated_scale), dim=1
)
future_samples.append(next_sample)
concat_future_samples = torch.cat(future_samples, dim=1)
return concat_future_samples.reshape(
(-1, self.num_parallel_samples, self.prediction_length) + self.target_shape,
)
File diff suppressed because one or more lines are too long
BIN
View File
Binary file not shown.
+96 -209
View File
@@ -1,12 +1,12 @@
from torch.functional import align_tensors
import torch.nn as nn
from torch.nn.modules.linear import Linear
from .SubLayers import MultiHeadAttention, PositionwiseFeedForward
import torch
from .embed import DataEmbedding, CustomEmbedding
import math
import torch
import torch.nn as nn
from torch.functional import align_tensors
from torch.nn.modules.linear import Linear
from .embed import CustomEmbedding, DataEmbedding
from .SubLayers import MultiHeadAttention, PositionwiseFeedForward
def get_mask(input_size, window_size, inner_size, device):
@@ -34,15 +34,11 @@ def get_mask(input_size, window_size, inner_size, device):
for layer_idx in range(1, len(all_size)):
start = sum(all_size[:layer_idx])
for i in range(start, start + all_size[layer_idx]):
left_side = (start - all_size[layer_idx - 1]) + (i - start) * window_size[
layer_idx - 1
]
if i == (start + all_size[layer_idx] - 1):
left_side = (start - all_size[layer_idx - 1]) + (i - start) * window_size[layer_idx - 1]
if i == ( start + all_size[layer_idx] - 1):
right_side = start
else:
right_side = (start - all_size[layer_idx - 1]) + (
i - start + 1
) * window_size[layer_idx - 1]
right_side = (start - all_size[layer_idx - 1]) + (i - start + 1) * window_size[layer_idx - 1]
mask[i, left_side:right_side] = 1
mask[left_side:right_side, i] = 1
@@ -62,9 +58,7 @@ def refer_points(all_sizes, window_size, device):
for j in range(1, len(all_sizes)):
start = sum(all_sizes[:j])
inner_layer_idx = former_index - (start - all_sizes[j - 1])
former_index = start + min(
inner_layer_idx // window_size[j - 1], all_sizes[j] - 1
)
former_index = start + min(inner_layer_idx // window_size[j - 1], all_sizes[j] - 1)
indexes[i][j] = former_index
indexes = indexes.unsqueeze(0).unsqueeze(3)
@@ -77,7 +71,7 @@ def get_subsequent_mask(input_size, window_size, predict_step, truncate):
if truncate:
mask = torch.zeros(predict_step, input_size + predict_step)
for i in range(predict_step):
mask[i][: input_size + i + 1] = 1
mask[i][:input_size+i+1] = 1
mask = (1 - mask).bool().unsqueeze(0)
else:
all_size = []
@@ -88,7 +82,7 @@ def get_subsequent_mask(input_size, window_size, predict_step, truncate):
all_size = sum(all_size)
mask = torch.zeros(predict_step, all_size + predict_step)
for i in range(predict_step):
mask[i][: all_size + i + 1] = 1
mask[i][:all_size+i+1] = 1
mask = (1 - mask).bool().unsqueeze(0)
return mask
@@ -120,56 +114,38 @@ def get_q_k(input_size, window_size, stride, device):
mask[i, -1] = i // stride + input_size
mask[i][mask[i] > third_start - 1] = third_start - 1
for i in range(second_length):
mask[input_size + i, 0:window_size] = (
input_size + i + torch.arange(window_size) - window_size // 2
)
mask[input_size + i, mask[input_size + i] < input_size] = -1
mask[input_size + i, mask[input_size + i] > third_start - 1] = -1
mask[input_size+i, 0:window_size] = input_size + i + torch.arange(window_size) - window_size // 2
mask[input_size+i, mask[input_size+i] < input_size] = -1
mask[input_size+i, mask[input_size+i] > third_start - 1] = -1
if i < second_length - 1:
mask[input_size + i, window_size : (window_size + stride)] = (
torch.arange(stride) + i * stride
)
mask[input_size+i, window_size:(window_size+stride)] = torch.arange(stride) + i * stride
else:
mask[input_size + i, window_size : (window_size + second_last)] = (
torch.arange(second_last) + i * stride
)
mask[input_size+i, window_size:(window_size+second_last)] = torch.arange(second_last) + i * stride
mask[input_size + i, -1] = i // stride + third_start
mask[input_size + i, mask[input_size + i] > fourth_start - 1] = fourth_start - 1
mask[input_size+i, -1] = i // stride + third_start
mask[input_size+i, mask[input_size+i] > fourth_start - 1] = fourth_start - 1
for i in range(third_length):
mask[third_start + i, 0:window_size] = (
third_start + i + torch.arange(window_size) - window_size // 2
)
mask[third_start + i, mask[third_start + i] < third_start] = -1
mask[third_start + i, mask[third_start + i] > fourth_start - 1] = -1
mask[third_start+i, 0:window_size] = third_start + i + torch.arange(window_size) - window_size // 2
mask[third_start+i, mask[third_start+i] < third_start] = -1
mask[third_start+i, mask[third_start+i] > fourth_start - 1] = -1
if i < third_length - 1:
mask[third_start + i, window_size : (window_size + stride)] = (
input_size + torch.arange(stride) + i * stride
)
mask[third_start+i, window_size:(window_size+stride)] = input_size + torch.arange(stride) + i * stride
else:
mask[third_start + i, window_size : (window_size + third_last)] = (
input_size + torch.arange(third_last) + i * stride
)
mask[third_start+i, window_size:(window_size+third_last)] = input_size + torch.arange(third_last) + i * stride
mask[third_start + i, -1] = i // stride + fourth_start
mask[third_start + i, mask[third_start + i] > full_length - 1] = full_length - 1
mask[third_start+i, -1] = i // stride + fourth_start
mask[third_start+i, mask[third_start+i] > full_length - 1] = full_length - 1
for i in range(fourth_length):
mask[fourth_start + i, 0:window_size] = (
fourth_start + i + torch.arange(window_size) - window_size // 2
)
mask[fourth_start + i, mask[fourth_start + i] < fourth_start] = -1
mask[fourth_start + i, mask[fourth_start + i] > full_length - 1] = -1
mask[fourth_start+i, 0:window_size] = fourth_start + i + torch.arange(window_size) - window_size // 2
mask[fourth_start+i, mask[fourth_start+i] < fourth_start] = -1
mask[fourth_start+i, mask[fourth_start+i] > full_length - 1] = -1
if i < fourth_length - 1:
mask[fourth_start + i, window_size : (window_size + stride)] = (
third_start + torch.arange(stride) + i * stride
)
mask[fourth_start+i, window_size:(window_size+stride)] = third_start + torch.arange(stride) + i * stride
else:
mask[fourth_start + i, window_size : (window_size + fourth_last)] = (
third_start + torch.arange(fourth_last) + i * stride
)
mask[fourth_start+i, window_size:(window_size+fourth_last)] = third_start + torch.arange(fourth_last) + i * stride
return mask
@@ -182,64 +158,32 @@ def get_k_q(q_k_mask):
for i in range(len(q_k_mask)):
for j in range(len(q_k_mask[0])):
if q_k_mask[i, j] >= 0:
k_q_mask[i, j] = torch.where(q_k_mask[q_k_mask[i, j]] == i)[0]
k_q_mask[i, j] = torch.where(q_k_mask[q_k_mask[i, j]] ==i )[0]
return k_q_mask
class EncoderLayer(nn.Module):
"""Compose with two layers"""
""" Compose with two layers """
def __init__(
self,
d_model,
d_inner,
n_head,
d_k,
d_v,
dropout=0.1,
normalize_before=True,
use_tvm=False,
q_k_mask=None,
k_q_mask=None,
):
def __init__(self, d_model, d_inner, n_head, d_k, d_v, dropout=0.1, normalize_before=True, use_tvm=False, q_k_mask=None, k_q_mask=None):
super(EncoderLayer, self).__init__()
self.use_tvm = use_tvm
if use_tvm:
from .PAM_TVM import PyramidalAttention
self.slf_attn = PyramidalAttention(
n_head,
d_model,
d_k,
d_v,
dropout=dropout,
normalize_before=normalize_before,
q_k_mask=q_k_mask,
k_q_mask=k_q_mask,
)
self.slf_attn = PyramidalAttention(n_head, d_model, d_k, d_v, dropout=dropout, normalize_before=normalize_before, q_k_mask=q_k_mask, k_q_mask=k_q_mask)
else:
self.slf_attn = MultiHeadAttention(
n_head,
d_model,
d_k,
d_v,
dropout=dropout,
normalize_before=normalize_before,
)
self.slf_attn = MultiHeadAttention(n_head, d_model, d_k, d_v, dropout=dropout, normalize_before=normalize_before)
self.pos_ffn = PositionwiseFeedForward(
d_model, d_inner, dropout=dropout, normalize_before=normalize_before
)
d_model, d_inner, dropout=dropout, normalize_before=normalize_before)
def forward(self, enc_input, slf_attn_mask=None):
if self.use_tvm:
enc_output = self.slf_attn(enc_input)
enc_slf_attn = None
else:
enc_output, enc_slf_attn = self.slf_attn(
enc_input, enc_input, enc_input, mask=slf_attn_mask
)
enc_output, enc_slf_attn = self.slf_attn(enc_input, enc_input, enc_input, mask=slf_attn_mask)
enc_output = self.pos_ffn(enc_output)
@@ -247,26 +191,18 @@ class EncoderLayer(nn.Module):
class DecoderLayer(nn.Module):
"""Compose with two layers"""
""" Compose with two layers """
def __init__(
self, d_model, d_inner, n_head, d_k, d_v, dropout=0.1, normalize_before=True
):
def __init__(self, d_model, d_inner, n_head, d_k, d_v, dropout=0.1, normalize_before=True):
super(DecoderLayer, self).__init__()
self.slf_attn = MultiHeadAttention(
n_head,
d_model,
d_k,
d_v,
dropout=dropout,
normalize_before=normalize_before,
)
n_head, d_model, d_k, d_v, dropout=dropout, normalize_before=normalize_before)
self.pos_ffn = PositionwiseFeedForward(
d_model, d_inner, dropout=dropout, normalize_before=normalize_before
)
d_model, d_inner, dropout=dropout, normalize_before=normalize_before)
def forward(self, Q, K, V, slf_attn_mask=None):
enc_output, enc_slf_attn = self.slf_attn(Q, K, V, mask=slf_attn_mask)
enc_output, enc_slf_attn = self.slf_attn(
Q, K, V, mask=slf_attn_mask)
enc_output = self.pos_ffn(enc_output)
@@ -276,12 +212,10 @@ class DecoderLayer(nn.Module):
class ConvLayer(nn.Module):
def __init__(self, c_in, window_size):
super(ConvLayer, self).__init__()
self.downConv = nn.Conv1d(
in_channels=c_in,
out_channels=c_in,
kernel_size=window_size,
stride=window_size,
)
self.downConv = nn.Conv1d(in_channels=c_in,
out_channels=c_in,
kernel_size=window_size,
stride=window_size)
self.norm = nn.BatchNorm1d(c_in)
self.activation = nn.ELU()
@@ -294,25 +228,20 @@ class ConvLayer(nn.Module):
class Conv_Construct(nn.Module):
"""Convolution CSCM"""
def __init__(self, d_model, window_size, d_inner):
super(Conv_Construct, self).__init__()
if not isinstance(window_size, list):
self.conv_layers = nn.ModuleList(
[
ConvLayer(d_model, window_size),
ConvLayer(d_model, window_size),
ConvLayer(d_model, window_size),
]
)
self.conv_layers = nn.ModuleList([
ConvLayer(d_model, window_size),
ConvLayer(d_model, window_size),
ConvLayer(d_model, window_size)
])
else:
self.conv_layers = nn.ModuleList(
[
ConvLayer(d_model, window_size[0]),
ConvLayer(d_model, window_size[1]),
ConvLayer(d_model, window_size[2]),
]
)
self.conv_layers = nn.ModuleList([
ConvLayer(d_model, window_size[0]),
ConvLayer(d_model, window_size[1]),
ConvLayer(d_model, window_size[2])
])
self.norm = nn.LayerNorm(d_model)
def forward(self, enc_input):
@@ -332,17 +261,14 @@ class Conv_Construct(nn.Module):
class Bottleneck_Construct(nn.Module):
"""Bottleneck convolution CSCM"""
def __init__(self, d_model, window_size, d_inner):
super(Bottleneck_Construct, self).__init__()
if not isinstance(window_size, list):
self.conv_layers = nn.ModuleList(
[
ConvLayer(d_inner, window_size),
ConvLayer(d_inner, window_size),
ConvLayer(d_inner, window_size),
]
)
self.conv_layers = nn.ModuleList([
ConvLayer(d_inner, window_size),
ConvLayer(d_inner, window_size),
ConvLayer(d_inner, window_size)
])
else:
self.conv_layers = []
for i in range(len(window_size)):
@@ -371,25 +297,20 @@ class Bottleneck_Construct(nn.Module):
class MaxPooling_Construct(nn.Module):
"""Max pooling CSCM"""
def __init__(self, d_model, window_size, d_inner):
super(MaxPooling_Construct, self).__init__()
if not isinstance(window_size, list):
self.pooling_layers = nn.ModuleList(
[
nn.MaxPool1d(kernel_size=window_size),
nn.MaxPool1d(kernel_size=window_size),
nn.MaxPool1d(kernel_size=window_size),
]
)
self.pooling_layers = nn.ModuleList([
nn.MaxPool1d(kernel_size=window_size),
nn.MaxPool1d(kernel_size=window_size),
nn.MaxPool1d(kernel_size=window_size)
])
else:
self.pooling_layers = nn.ModuleList(
[
nn.MaxPool1d(kernel_size=window_size[0]),
nn.MaxPool1d(kernel_size=window_size[1]),
nn.MaxPool1d(kernel_size=window_size[2]),
]
)
self.pooling_layers = nn.ModuleList([
nn.MaxPool1d(kernel_size=window_size[0]),
nn.MaxPool1d(kernel_size=window_size[1]),
nn.MaxPool1d(kernel_size=window_size[2])
])
self.norm = nn.LayerNorm(d_model)
def forward(self, enc_input):
@@ -409,25 +330,20 @@ class MaxPooling_Construct(nn.Module):
class AvgPooling_Construct(nn.Module):
"""Average pooling CSCM"""
def __init__(self, d_model, window_size, d_inner):
super(AvgPooling_Construct, self).__init__()
if not isinstance(window_size, list):
self.pooling_layers = nn.ModuleList(
[
nn.AvgPool1d(kernel_size=window_size),
nn.AvgPool1d(kernel_size=window_size),
nn.AvgPool1d(kernel_size=window_size),
]
)
self.pooling_layers = nn.ModuleList([
nn.AvgPool1d(kernel_size=window_size),
nn.AvgPool1d(kernel_size=window_size),
nn.AvgPool1d(kernel_size=window_size)
])
else:
self.pooling_layers = nn.ModuleList(
[
nn.AvgPool1d(kernel_size=window_size[0]),
nn.AvgPool1d(kernel_size=window_size[1]),
nn.AvgPool1d(kernel_size=window_size[2]),
]
)
self.pooling_layers = nn.ModuleList([
nn.AvgPool1d(kernel_size=window_size[0]),
nn.AvgPool1d(kernel_size=window_size[1]),
nn.AvgPool1d(kernel_size=window_size[2])
])
self.norm = nn.LayerNorm(d_model)
def forward(self, enc_input):
@@ -446,6 +362,7 @@ class AvgPooling_Construct(nn.Module):
class Predictor(nn.Module):
def __init__(self, dim, num_types):
super().__init__()
@@ -459,54 +376,23 @@ class Predictor(nn.Module):
class Decoder(nn.Module):
"""A encoder model with self attention mechanism."""
""" A encoder model with self attention mechanism. """
def __init__(
self,
model,
d_model,
d_inner_hid,
num_head,
d_k,
d_v,
dropout,
enc_in,
covariate_size,
seq_num,
mask,
):
def __init__(self, model,d_model,d_inner_hid,num_head,d_k,d_v,dropout,enc_in,covariate_size,seq_num, mask):
super().__init__()
self.model_type = model
self.mask = mask
self.layers = nn.ModuleList(
[
DecoderLayer(
d_model,
d_inner_hid,
num_head,
d_k,
d_v,
dropout=dropout,
normalize_before=False,
),
DecoderLayer(
d_model,
d_inner_hid,
num_head,
d_k,
d_v,
dropout=dropout,
normalize_before=False,
),
]
)
self.layers = nn.ModuleList([
DecoderLayer(d_model, d_inner_hid, num_head, d_k, d_v, dropout=dropout, \
normalize_before=False),
DecoderLayer(d_model, d_inner_hid, num_head, d_k, d_v, dropout=dropout, \
normalize_before=False)
])
if opt.embed_type == "CustomEmbedding":
self.dec_embedding = CustomEmbedding(
enc_in, d_model, covariate_size, seq_num, dropout
)
if opt.embed_type == 'CustomEmbedding':
self.dec_embedding = CustomEmbedding(enc_in, d_model, covariate_size, seq_num, dropout)
else:
self.dec_embedding = DataEmbedding(enc_in, d_model, dropout)
@@ -519,3 +405,4 @@ class Decoder(nn.Module):
dec_enc, _ = self.layers[1](dec_enc, refer_enc, refer_enc, slf_attn_mask=mask)
return dec_enc
+2 -1
View File
@@ -4,7 +4,7 @@ import torch.nn.functional as F
class ScaledDotProductAttention(nn.Module):
"""Scaled Dot-Product Attention"""
""" Scaled Dot-Product Attention """
def __init__(self, temperature, attn_dropout=0.2):
super().__init__()
@@ -22,3 +22,4 @@ class ScaledDotProductAttention(nn.Module):
output = torch.matmul(attn, v)
return output, attn
+4 -7
View File
@@ -1,15 +1,11 @@
import math
import torch.nn as nn
import torch.nn.functional as F
import math
from .hierarchical_mm_tvm import graph_mm as graph_mm_tvm
class PyramidalAttention(nn.Module):
def __init__(
self, n_head, d_model, d_k, d_v, dropout, normalize_before, q_k_mask, k_q_mask
):
def __init__(self, n_head, d_model, d_k, d_v, dropout, normalize_before, q_k_mask, k_q_mask):
super(PyramidalAttention, self).__init__()
self.normalize_before = normalize_before
self.n_head = n_head
@@ -35,7 +31,7 @@ class PyramidalAttention(nn.Module):
residual = hidden_states
hidden_states = hidden_states
bsz, seq_len, _ = hidden_states.size()
bsz, seq_len, _ = hidden_states.size()
q = hidden_states
if self.normalize_before:
@@ -66,3 +62,4 @@ class PyramidalAttention(nn.Module):
context = self.layer_norm(context)
return context
+5 -6
View File
@@ -5,7 +5,7 @@ from .Modules import ScaledDotProductAttention
class MultiHeadAttention(nn.Module):
"""Multi-Head Attention module"""
""" Multi-Head Attention module """
def __init__(self, n_head, d_model, d_k, d_v, dropout=0.1, normalize_before=True):
super().__init__()
@@ -25,9 +25,7 @@ class MultiHeadAttention(nn.Module):
self.fc = nn.Linear(d_v * n_head, d_model)
nn.init.xavier_uniform_(self.fc.weight)
self.attention = ScaledDotProductAttention(
temperature=d_k**0.5, attn_dropout=dropout
)
self.attention = ScaledDotProductAttention(temperature=d_k ** 0.5, attn_dropout=dropout)
self.layer_norm = nn.LayerNorm(d_model, eps=1e-6)
self.dropout = nn.Dropout(dropout)
@@ -67,7 +65,7 @@ class MultiHeadAttention(nn.Module):
class PositionwiseFeedForward(nn.Module):
"""Two-layer position-wise feed-forward neural network."""
""" Two-layer position-wise feed-forward neural network. """
def __init__(self, d_in, d_hid, dropout=0.1, normalize_before=True):
super().__init__()
@@ -78,7 +76,7 @@ class PositionwiseFeedForward(nn.Module):
self.w_2 = nn.Linear(d_hid, d_in)
self.layer_norm = nn.LayerNorm(d_in, eps=1e-6)
# self.layer_norm = GraphNorm(d_in)
#self.layer_norm = GraphNorm(d_in)
self.dropout = nn.Dropout(dropout)
def forward(self, x):
@@ -95,3 +93,4 @@ class PositionwiseFeedForward(nn.Module):
if not self.normalize_before:
x = self.layer_norm(x)
return x
+22 -68
View File
@@ -9,11 +9,11 @@ Modified based on Informer.
}
"""
import math
import torch
import torch.nn as nn
import math
class PositionalEmbedding(nn.Module):
def __init__(self, d_model, max_len=5000):
@@ -23,42 +23,31 @@ class PositionalEmbedding(nn.Module):
pe.require_grad = False
position = torch.arange(0, max_len).float().unsqueeze(1)
div_term = (
torch.arange(0, d_model, 2).float() * -(math.log(10000.0) / d_model)
).exp()
div_term = (torch.arange(0, d_model, 2).float() * -(math.log(10000.0) / d_model)).exp()
pe[:, 0::2] = torch.sin(position * div_term)
pe[:, 1::2] = torch.cos(position * div_term)
pe = pe.unsqueeze(0)
self.register_buffer("pe", pe)
self.register_buffer('pe', pe)
def forward(self, x):
return self.pe[:, : x.size(1)]
return self.pe[:, :x.size(1)]
class TokenEmbedding(nn.Module):
def __init__(self, c_in, d_model):
super(TokenEmbedding, self).__init__()
padding = 1 if torch.__version__ >= "1.5.0" else 2
self.tokenConv = nn.Conv1d(
in_channels=c_in,
out_channels=d_model,
kernel_size=3,
padding=padding,
padding_mode="circular",
)
padding = 1 if torch.__version__>='1.5.0' else 2
self.tokenConv = nn.Conv1d(in_channels=c_in, out_channels=d_model,
kernel_size=3, padding=padding, padding_mode='circular')
for m in self.modules():
if isinstance(m, nn.Conv1d):
nn.init.kaiming_normal_(
m.weight, mode="fan_in", nonlinearity="leaky_relu"
)
nn.init.kaiming_normal_(m.weight,mode='fan_in',nonlinearity='leaky_relu')
def forward(self, x):
x = self.tokenConv(x.permute(0, 2, 1)).transpose(1, 2)
x = self.tokenConv(x.permute(0, 2, 1)).transpose(1,2)
return x
class FixedEmbedding(nn.Module):
def __init__(self, c_in, d_model):
super(FixedEmbedding, self).__init__()
@@ -67,9 +56,7 @@ class FixedEmbedding(nn.Module):
w.require_grad = False
position = torch.arange(0, c_in).float().unsqueeze(1)
div_term = (
torch.arange(0, d_model, 2).float() * -(math.log(10000.0) / d_model)
).exp()
div_term = (torch.arange(0, d_model, 2).float() * -(math.log(10000.0) / d_model)).exp()
w[:, 0::2] = torch.sin(position * div_term)
w[:, 1::2] = torch.cos(position * div_term)
@@ -80,21 +67,17 @@ class FixedEmbedding(nn.Module):
def forward(self, x):
return self.emb(x).detach()
class TimeFeatureEmbedding(nn.Module):
def __init__(self, d_model):
super(TimeFeatureEmbedding, self).__init__()
d_inp = 4
self.embed = nn.Linear(d_inp, d_model)
def forward(self, x):
return self.embed(x)
"""Embedding modules. The DataEmbedding is used by the ETT dataset for long range forecasting."""
class DataEmbedding(nn.Module):
def __init__(self, c_in, d_model, dropout=0.1):
super(DataEmbedding, self).__init__()
@@ -106,18 +89,11 @@ class DataEmbedding(nn.Module):
self.dropout = nn.Dropout(p=dropout)
def forward(self, x, x_mark):
x = (
self.value_embedding(x)
+ self.position_embedding(x)
+ self.temporal_embedding(x_mark)
)
x = self.value_embedding(x) + self.position_embedding(x) + self.temporal_embedding(x_mark)
return self.dropout(x)
"""The CustomEmbedding is used by the electricity dataset and app flow dataset for long range forecasting."""
class CustomEmbedding(nn.Module):
def __init__(self, c_in, d_model, temporal_size, seq_num, dropout=0.1):
super(CustomEmbedding, self).__init__()
@@ -130,46 +106,28 @@ class CustomEmbedding(nn.Module):
self.dropout = nn.Dropout(p=dropout)
def forward(self, x, x_mark):
x = (
self.value_embedding(x)
+ self.position_embedding(x)
+ self.temporal_embedding(x_mark[:, :, :-1])
x = self.value_embedding(x) + self.position_embedding(x) + self.temporal_embedding(x_mark[:, :, :-1])\
+ self.seqid_embedding(x_mark[:, :, -1].long())
)
return self.dropout(x)
"""The SingleStepEmbedding is used by all datasets for single step forecasting."""
class SingleStepEmbedding(nn.Module):
def __init__(self, cov_size, num_seq, d_model, input_size, device):
super().__init__()
self.cov_size = cov_size
self.num_class = num_seq
self.cov_emb = nn.Linear(cov_size + 1, d_model)
padding = 1 if torch.__version__ >= "1.5.0" else 2
self.data_emb = nn.Conv1d(
in_channels=1,
out_channels=d_model,
kernel_size=3,
padding=padding,
padding_mode="circular",
)
self.cov_emb = nn.Linear(cov_size+1, d_model)
padding = 1 if torch.__version__>='1.5.0' else 2
self.data_emb = nn.Conv1d(in_channels=1, out_channels=d_model, kernel_size=3, padding=padding, padding_mode='circular')
self.position = torch.arange(input_size, device=device).unsqueeze(0)
self.position_vec = torch.tensor(
[math.pow(10000.0, 2.0 * (i // 2) / d_model) for i in range(d_model)],
device=device,
)
self.position_vec = torch.tensor([math.pow(10000.0, 2.0 * (i // 2) / d_model) for i in range(d_model)], device=device)
for m in self.modules():
if isinstance(m, nn.Conv1d):
nn.init.kaiming_normal_(
m.weight, mode="fan_in", nonlinearity="leaky_relu"
)
nn.init.kaiming_normal_(m.weight,mode='fan_in',nonlinearity='leaky_relu')
elif isinstance(m, nn.Linear):
nn.init.xavier_normal_(m.weight)
nn.init.constant_(m.bias, 0)
@@ -185,19 +143,15 @@ class SingleStepEmbedding(nn.Module):
return result
def forward(self, x):
covs = x[:, :, 1 : (1 + self.cov_size)]
covs = x[:, :, 1:(1+self.cov_size)]
seq_ids = ((x[:, :, -1] / self.num_class) - 0.5).unsqueeze(2)
covs = torch.cat([covs, seq_ids], dim=-1)
cov_embedding = self.cov_emb(covs)
data_embedding = self.data_emb(
x[:, :, 0].unsqueeze(2).permute(0, 2, 1)
).transpose(1, 2)
data_embedding = self.data_emb(x[:, :, 0].unsqueeze(2).permute(0, 2, 1)).transpose(1,2)
embedding = cov_embedding + data_embedding
position = self.position.repeat(len(x), 1).to(x.device)
position_emb = self.transformer_embedding(
position, self.position_vec.to(x.device)
)
position_emb = self.transformer_embedding(position, self.position_vec.to(x.device))
embedding += position_emb
+81 -130
View File
@@ -2,23 +2,21 @@
Test the time and CUDA memory consumption of different attention mechanisms.
"""
import argparse
import math
import time
from math import sqrt
from typing import List
import numpy as np
import math
import torch
from torch import nn
import torch.nn.functional as F
import torch.optim as optim
from hierarchical_mm_tvm import graph_mm as graph_mm_tvm
from torch import nn
import argparse
import time
import numpy as np
from math import sqrt
torch.cuda.set_device(0)
print("Using device: {}".format(torch.cuda.get_device_name()))
print('Using device: {}'.format(torch.cuda.get_device_name()))
import pynvml
pynvml.nvmlInit()
@@ -50,64 +48,46 @@ def get_q_k(input_size, window_size, stride, device):
mask[i][mask[i] > third_start - 1] = third_start - 1
# 第二层
for i in range(second_length):
mask[input_size + i, 0:window_size] = (
input_size + i + torch.arange(window_size) - window_size // 2
)
mask[input_size+i, 0:window_size] = input_size + i + torch.arange(window_size) - window_size // 2
# 当window在序列左端时,置为-1
mask[input_size + i, mask[input_size + i] < input_size] = -1
mask[input_size+i, mask[input_size+i] < input_size] = -1
# 当window在序列右端时,置为-1
mask[input_size + i, mask[input_size + i] > third_start - 1] = -1
mask[input_size+i, mask[input_size+i] > third_start - 1] = -1
if i < second_length - 1:
mask[input_size + i, window_size : (window_size + stride)] = (
torch.arange(stride) + i * stride
)
mask[input_size+i, window_size:(window_size+stride)] = torch.arange(stride) + i * stride
else:
mask[input_size + i, window_size : (window_size + second_last)] = (
torch.arange(second_last) + i * stride
)
mask[input_size+i, window_size:(window_size+second_last)] = torch.arange(second_last) + i * stride
mask[input_size + i, -1] = i // stride + third_start
mask[input_size + i, mask[input_size + i] > fourth_start - 1] = fourth_start - 1
mask[input_size+i, -1] = i // stride + third_start
mask[input_size+i, mask[input_size+i] > fourth_start - 1] = fourth_start - 1
# 第三层
for i in range(third_length):
mask[third_start + i, 0:window_size] = (
third_start + i + torch.arange(window_size) - window_size // 2
)
mask[third_start+i, 0:window_size] = third_start + i + torch.arange(window_size) - window_size // 2
# 当window在序列左端时,置为-1
mask[third_start + i, mask[third_start + i] < third_start] = -1
mask[third_start+i, mask[third_start+i] < third_start] = -1
# 当window在序列右端时,置为-1
mask[third_start + i, mask[third_start + i] > fourth_start - 1] = -1
mask[third_start+i, mask[third_start+i] > fourth_start - 1] = -1
if i < third_length - 1:
mask[third_start + i, window_size : (window_size + stride)] = (
input_size + torch.arange(stride) + i * stride
)
mask[third_start+i, window_size:(window_size+stride)] = input_size + torch.arange(stride) + i * stride
else:
mask[third_start + i, window_size : (window_size + third_last)] = (
input_size + torch.arange(third_last) + i * stride
)
mask[third_start+i, window_size:(window_size+third_last)] = input_size + torch.arange(third_last) + i * stride
mask[third_start + i, -1] = i // stride + fourth_start
mask[third_start + i, mask[third_start + i] > full_length - 1] = full_length - 1
mask[third_start+i, -1] = i // stride + fourth_start
mask[third_start+i, mask[third_start+i] > full_length - 1] = full_length - 1
# 第四层
for i in range(fourth_length):
mask[fourth_start + i, 0:window_size] = (
fourth_start + i + torch.arange(window_size) - window_size // 2
)
mask[fourth_start+i, 0:window_size] = fourth_start + i + torch.arange(window_size) - window_size // 2
# 当window在序列左端时,置为-1
mask[fourth_start + i, mask[fourth_start + i] < fourth_start] = -1
mask[fourth_start+i, mask[fourth_start+i] < fourth_start] = -1
# 当window在序列右端时,置为-1
mask[fourth_start + i, mask[fourth_start + i] > full_length - 1] = -1
mask[fourth_start+i, mask[fourth_start+i] > full_length - 1] = -1
if i < fourth_length - 1:
mask[fourth_start + i, window_size : (window_size + stride)] = (
third_start + torch.arange(stride) + i * stride
)
mask[fourth_start+i, window_size:(window_size+stride)] = third_start + torch.arange(stride) + i * stride
else:
mask[fourth_start + i, window_size : (window_size + fourth_last)] = (
third_start + torch.arange(fourth_last) + i * stride
)
mask[fourth_start+i, window_size:(window_size+fourth_last)] = third_start + torch.arange(fourth_last) + i * stride
return mask
@@ -118,8 +98,8 @@ def get_k_q(q_k_mask):
for i in range(len(q_k_mask)):
for j in range(len(q_k_mask[0])):
if q_k_mask[i, j] >= 0:
k_q_mask[i, j] = torch.where(q_k_mask[q_k_mask[i, j]] == i)[0]
k_q_mask[i, j] = torch.where(q_k_mask[q_k_mask[i, j]] ==i )[0]
return k_q_mask
@@ -201,8 +181,6 @@ def get_mask(input_size, window_size, inner_size, device):
"""PAM"""
class GraphSelfAttention(nn.Module):
def __init__(self, opt):
super(GraphSelfAttention, self).__init__()
@@ -226,16 +204,15 @@ class GraphSelfAttention(nn.Module):
self.seq_len = opt.seq_len
self.window_size = opt.window_size
self.stride_size = opt.stride_size
self.q_k_mask = get_q_k(
self.seq_len, self.window_size, self.stride_size, opt.device
)
self.q_k_mask = get_q_k(self.seq_len, self.window_size, self.stride_size, opt.device)
self.k_q_mask = get_k_q(self.q_k_mask)
def forward(self, hidden_states):
residual = hidden_states
hidden_states = hidden_states
bsz, seq_len, _ = hidden_states.size()
bsz, seq_len, _ = hidden_states.size()
q = hidden_states
if self.normalize_before:
@@ -269,8 +246,6 @@ class GraphSelfAttention(nn.Module):
"""Multi-head self attention"""
class NormalSelfAttention(nn.Module):
def __init__(self, opt):
super(NormalSelfAttention, self).__init__()
@@ -295,12 +270,11 @@ class NormalSelfAttention(nn.Module):
self.window_size = opt.window_size
self.stride_size = opt.stride_size
if opt.mask:
self.mask, _ = get_mask(
self.seq_len, self.stride_size, self.window_size, opt.device
)
self.mask, _ = get_mask(self.seq_len, self.stride_size, self.window_size, opt.device)
else:
self.mask = None
def forward(self, hidden_states):
residual = hidden_states
@@ -342,8 +316,6 @@ class NormalSelfAttention(nn.Module):
"""Prob-sparse attention"""
class ProbSparseAttention(nn.Module):
def __init__(self, opt):
super(ProbSparseAttention, self).__init__()
@@ -367,16 +339,14 @@ class ProbSparseAttention(nn.Module):
self.seq_len = opt.seq_len
self.factor = opt.factor
def _prob_QK(self, Q, K, sample_k, n_top): # n_top: c*ln(L_q)
def _prob_QK(self, Q, K, sample_k, n_top): # n_top: c*ln(L_q)
# Q [B, H, L, D]
B, H, L_K, E = K.shape
_, _, L_Q, _ = Q.shape
# calculate the sampled Q_K
K_expand = K.unsqueeze(-3).expand(B, H, L_Q, L_K, E)
index_sample = torch.randint(
L_K, (L_Q, sample_k)
) # real U = U_part(factor*ln(L_k))*L_q
index_sample = torch.randint(L_K, (L_Q, sample_k)) # real U = U_part(factor*ln(L_k))*L_q
K_sample = K_expand[:, :, torch.arange(L_Q).unsqueeze(1), index_sample, :]
Q_K_sample = torch.matmul(Q.unsqueeze(-2), K_sample.transpose(-2, -1)).squeeze()
@@ -385,10 +355,10 @@ class ProbSparseAttention(nn.Module):
M_top = M.topk(n_top, sorted=False)[1]
# use the reduced Q to calculate Q_K
Q_reduce = Q[
torch.arange(B)[:, None, None], torch.arange(H)[None, :, None], M_top, :
] # factor*ln(L_q)
Q_K = torch.matmul(Q_reduce, K.transpose(-2, -1)) # factor*ln(L_q)*L_k
Q_reduce = Q[torch.arange(B)[:, None, None],
torch.arange(H)[None, :, None],
M_top, :] # factor*ln(L_q)
Q_K = torch.matmul(Q_reduce, K.transpose(-2, -1)) # factor*ln(L_q)*L_k
return Q_K, M_top
@@ -402,11 +372,11 @@ class ProbSparseAttention(nn.Module):
def _update_context(self, context_in, V, scores, index, L_Q):
B, H, L_V, D = V.shape
attn = torch.softmax(scores, dim=-1) # nn.Softmax(dim=-1)(scores)
attn = torch.softmax(scores, dim=-1) # nn.Softmax(dim=-1)(scores)
context_in[
torch.arange(B)[:, None, None], torch.arange(H)[None, :, None], index, :
] = torch.matmul(attn, V).type_as(context_in)
context_in[torch.arange(B)[:, None, None],
torch.arange(H)[None, :, None],
index, :] = torch.matmul(attn, V).type_as(context_in)
return context_in
def forward(self, hidden_states):
@@ -431,23 +401,17 @@ class ProbSparseAttention(nn.Module):
k = k.float().contiguous()
v = v.float().contiguous()
u = U_part = (
self.factor * np.ceil(np.log(seq_len)).astype("int").item()
) # c*ln(L_k)
u = U_part = self.factor * np.ceil(np.log(seq_len)).astype('int').item() # c*ln(L_k)
U_part = U_part if U_part < seq_len else seq_len
U_part = U_part if U_part<seq_len else seq_len
u = u if u < seq_len else seq_len
scores_top, index = self._prob_QK(q, k, sample_k=U_part, n_top=u)
scores_top, index = self._prob_QK(q, k, sample_k=U_part, n_top=u)
# get the context
context = self._get_initial_context(v, seq_len)
# update the context with selected top_k queries
context = (
self._update_context(context, v, scores_top, index, seq_len)
.transpose(1, 2)
.contiguous()
)
context = self._update_context(context, v, scores_top, index, seq_len).transpose(1, 2).contiguous()
context = context.view(bsz, seq_len, self.n_head * self.d_k)
@@ -461,24 +425,24 @@ class ProbSparseAttention(nn.Module):
def parsing():
parser = argparse.ArgumentParser(description="Needed for graph self attention.")
parser.add_argument("-d_model", type=int, default=256)
parser.add_argument("-d_k", type=int, default=64)
parser.add_argument("-normalize_before", type=bool, default=False)
parser.add_argument("-n_head", type=int, default=4)
parser.add_argument("-dropout", type=float, default=0.1)
parser = argparse.ArgumentParser(description='Needed for graph self attention.')
parser.add_argument('-d_model', type=int, default=256)
parser.add_argument('-d_k', type=int, default=64)
parser.add_argument('-normalize_before', type=bool, default=False)
parser.add_argument('-n_head', type=int, default=4)
parser.add_argument('-dropout', type=float, default=0.1)
# arguments for Multiformer
parser.add_argument("-window_size", type=int, default=3)
parser.add_argument("-stride_size", type=int, default=25)
parser.add_argument('-window_size', type=int, default=3)
parser.add_argument('-stride_size', type=int, default=25)
# arguments for ProbSparse
parser.add_argument("-factor", type=int, default=5)
parser.add_argument('-factor', type=int, default=5)
# arguments for full-attention
parser.add_argument("-mask", type=int, default=0)
parser.add_argument('-mask', type=int, default=0)
parser.add_argument("-seq_len", type=int, default=1000)
parser.add_argument('-seq_len', type=int, default=1000)
args = parser.parse_args()
return args
@@ -493,14 +457,12 @@ def test_NSA(args, input_len):
NSA_Layer = NormalSelfAttention(args).to(args.device)
optimizer = optim.Adam(NSA_Layer.parameters(), 1e-4)
optimizer.zero_grad()
hidden_state = torch.ones(4, input_len, args.d_model, dtype=torch.float32).to(
args.device
)
hidden_state = torch.ones(4, input_len, args.d_model, dtype=torch.float32).to(args.device)
fake_gt = torch.zeros(4, input_len, args.d_model).to(args.device)
# Preload the layer
result = NSA_Layer(hidden_state)
loss = ((fake_gt - result) ** 2).mean()
loss = ((fake_gt - result) ** 2).mean()
loss.backward()
optimizer.step()
@@ -511,17 +473,13 @@ def test_NSA(args, input_len):
handle = pynvml.nvmlDeviceGetHandleByIndex(1)
meminfo = pynvml.nvmlDeviceGetMemoryInfo(handle)
used_memory += meminfo.used / 1024**3
loss = ((fake_gt - result) ** 2).mean()
loss = ((fake_gt - result) ** 2).mean()
loss.backward()
optimizer.step()
print(
"NSA used average time: {} s".format(
round((time.time() - start_time) / 1000, 4)
)
)
print('NSA used average time: {} s'.format(round((time.time() - start_time) / 1000, 4)))
used_memory = used_memory / 1000
print("NSA used average memory: {} GB".format(round(used_memory - init_mem, 4)))
print('NSA used average memory: {} GB'.format(round(used_memory-init_mem, 4)))
def test_GSA(args, input_len):
@@ -533,14 +491,12 @@ def test_GSA(args, input_len):
GSA_Layer = GraphSelfAttention(args).to(args.device)
optimizer = optim.Adam(GSA_Layer.parameters(), 1e-4)
optimizer.zero_grad()
hidden_state = torch.ones(
4, input_len, args.d_model, dtype=torch.float32, device=args.device
)
hidden_state = torch.ones(4, input_len, args.d_model, dtype=torch.float32, device=args.device)
fake_gt = torch.zeros(4, input_len, args.d_model, device=args.device)
# Preload the layer
result = GSA_Layer(hidden_state)
loss = ((fake_gt - result) ** 2).mean()
loss = ((fake_gt - result) ** 2).mean()
loss.backward()
optimizer.step()
@@ -552,15 +508,13 @@ def test_GSA(args, input_len):
handle = pynvml.nvmlDeviceGetHandleByIndex(1)
meminfo = pynvml.nvmlDeviceGetMemoryInfo(handle)
used_memory += meminfo.used / 1024**3
loss = ((fake_gt - result) ** 2).mean()
loss = ((fake_gt - result) ** 2).mean()
loss.backward()
optimizer.step()
print(
"GSA used time:{} s".format(round((time.time() - start_time) / repeat_times, 4))
)
print('GSA used time:{} s'.format(round((time.time() - start_time) / repeat_times, 4)))
used_memory = used_memory / repeat_times
print("GSA used average memory: {} GB".format(round(used_memory - init_mem, 4)))
print('GSA used average memory: {} GB'.format(round(used_memory-init_mem, 4)))
def test_PSA(args, input_len):
@@ -572,14 +526,12 @@ def test_PSA(args, input_len):
LSA_Layer = ProbSparseAttention(args).to(args.device)
optimizer = optim.Adam(LSA_Layer.parameters(), 1e-4)
optimizer.zero_grad()
hidden_state = torch.ones(
4, input_len, args.d_model, dtype=torch.float32, device=args.device
)
hidden_state = torch.ones(4, input_len, args.d_model, dtype=torch.float32, device=args.device)
fake_gt = torch.zeros(4, input_len, args.d_model, device=args.device)
# Preload the layer
result = LSA_Layer(hidden_state)
loss = ((fake_gt - result) ** 2).mean()
loss = ((fake_gt - result) ** 2).mean()
loss.backward()
optimizer.step()
@@ -591,23 +543,21 @@ def test_PSA(args, input_len):
handle = pynvml.nvmlDeviceGetHandleByIndex(1)
meminfo = pynvml.nvmlDeviceGetMemoryInfo(handle)
used_memory += meminfo.used / 1024**3
loss = ((fake_gt - result) ** 2).mean()
loss = ((fake_gt - result) ** 2).mean()
loss.backward()
optimizer.step()
print(
"LSA used time:{} s".format(round((time.time() - start_time) / repeat_times, 4))
)
print('LSA used time:{} s'.format(round((time.time() - start_time) / repeat_times, 4)))
used_memory = used_memory / repeat_times
print("LSA used average memory: {} GB".format(round(used_memory - init_mem, 4)))
print('LSA used average memory: {} GB'.format(round(used_memory-init_mem, 4)))
if __name__ == "__main__":
if __name__ == '__main__':
args = parsing()
if torch.cuda.is_available():
args.device = torch.device("cuda")
args.device = torch.device('cuda')
else:
args.device = torch.device("cpu")
args.device = torch.device('cpu')
input_size = args.seq_len
stride = args.stride_size
@@ -617,13 +567,14 @@ if __name__ == "__main__":
input_len = input_size + second_length + third_length + fourth_length
if args.mask:
print("sequence length: {}".format(input_len))
print('sequence length: {}'.format(input_len))
test_NSA(args, input_len)
else:
print("sequence length: {}".format(input_size))
print('sequence length: {}'.format(input_size))
test_NSA(args, input_size)
print("sequence length: {}".format(input_len))
print('sequence length: {}'.format(input_len))
test_GSA(args, input_len)
print("sequence length: {}".format(input_size))
print('sequence length: {}'.format(input_size))
test_PSA(args, input_size)
+78 -183
View File
@@ -8,112 +8,86 @@ Modified based on Longformer.
}
"""
import os.path
import sys
from functools import lru_cache
from typing import Union
from functools import lru_cache
import torch
sys.path.append("pyraformer/tvm/python")
import os.path
import sys
sys.path.append('pyraformer/tvm/python')
class GraphMM(torch.autograd.Function):
"""Class to encapsulate tvm code for compiling a diagonal_mm function, in addition to calling
'''Class to encapsulate tvm code for compiling a diagonal_mm function, in addition to calling
this function from PyTorch
"""
'''
function_dict = (
{}
) # save a list of functions, each has a different set of parameters
function_dict = {} # save a list of functions, each has a different set of parameters
@staticmethod
def _compile_function(
dtype: str, device: str, b0: int = 4, b1: int = 8, b2: int = 8
):
"""Compiles a tvm function that computes diagonal_mm
def _compile_function(dtype: str, device: str, b0: int = 4, b1: int = 8, b2: int = 8):
'''Compiles a tvm function that computes diagonal_mm
args:
dtype: str in ['float64', 'float32', 'float16']
device: str in ['cpu' or 'cuda']
b0, b1, b2: size of tensor tiles. Very important for good performance
"""
'''
import tvm # import the full tvm library here for compilation. Don't import at the top of the file in case we don't need to compile
from tvm.contrib import nvcc
@tvm.register_func
def tvm_callback_cuda_compile(code):
"""Use nvcc compiler for better perf."""
ptx = nvcc.compile_cuda(
code, target="ptx", arch="sm_52"
) # use old arch for this to work on old GPUs
ptx = nvcc.compile_cuda(code, target="ptx", arch='sm_52') # use old arch for this to work on old GPUs
return ptx
assert dtype in ["float16", "float32", "float64"]
assert device in ["cpu", "cuda"]
device = None if device == "cpu" else device
tgt_host = "llvm"
assert dtype in ['float16', 'float32', 'float64']
assert device in ['cpu', 'cuda']
device = None if device == 'cpu' else device
tgt_host="llvm"
b = tvm.te.var("b") # batch size
n = tvm.te.var("n") # sequence length
h = tvm.te.var("h") # number of heads
m = tvm.te.var("m") # hidden dimension
w = tvm.te.var("w") # window size
padding = tvm.te.var("padding") # padding
transpose_t1 = tvm.te.var("transpose_t1") # t1 should be transposed
t1d3 = tvm.te.var("t1d3") # last dimension of t1
t3d3 = tvm.te.var("t3d3") # last dimension of t3 (the result tensor)
max_attn = tvm.te.var("max_attn")
X = tvm.te.placeholder((b, n, h, t1d3), name="X", dtype=dtype) # first tensor
Y = tvm.te.placeholder((b, n, h, m), name="Y", dtype=dtype) # second tensor
k = tvm.te.reduce_axis((0, t1d3), name="k") # dimension to sum over
q_k_mask = tvm.te.placeholder(
(n, max_attn), name="q_k", dtype="int"
) # dilation per head
k_q_mask = tvm.te.placeholder((n, max_attn), name="k_q", dtype="int") #
b = tvm.te.var('b') # batch size
n = tvm.te.var('n') # sequence length
h = tvm.te.var('h') # number of heads
m = tvm.te.var('m') # hidden dimension
w = tvm.te.var('w') # window size
padding = tvm.te.var('padding') # padding
transpose_t1 = tvm.te.var('transpose_t1') # t1 should be transposed
t1d3 = tvm.te.var('t1d3') # last dimension of t1
t3d3 = tvm.te.var('t3d3') # last dimension of t3 (the result tensor)
max_attn = tvm.te.var('max_attn')
X = tvm.te.placeholder((b, n, h, t1d3), name='X', dtype=dtype) # first tensor
Y = tvm.te.placeholder((b, n, h, m), name='Y', dtype=dtype) # second tensor
k = tvm.te.reduce_axis((0, t1d3), name='k') # dimension to sum over
q_k_mask = tvm.te.placeholder((n, max_attn), name='q_k', dtype='int') # dilation per head
k_q_mask = tvm.te.placeholder((n, max_attn), name='k_q', dtype='int') #
output_shape = (b, n, h, t3d3) # shape of the result tensor
algorithm = lambda l, i, q, j: tvm.te.sum(
tvm.te.if_then_else(
t3d3
== m, # if output dimension == m, then t1 is diagonaled (FIXME: This breaks if t3d3 == m == t1d3)
t3d3 == m, # if output dimension == m, then t1 is diagonaled (FIXME: This breaks if t3d3 == m == t1d3)
tvm.te.if_then_else(
transpose_t1 == 0,
tvm.te.if_then_else(
q_k_mask[i, k] >= 0,
q_k_mask[i, k]>=0,
X[l, i, q, k] * Y[l, q_k_mask[i, k], q, j], # t1 is diagonaled
padding,
padding
),
tvm.te.if_then_else(
q_k_mask[i, k] >= 0,
X[l, q_k_mask[i, k], q, k_q_mask[i, k]]
* Y[
l, q_k_mask[i, k], q, j
], # # t1 is diagonaled and should be transposed
padding,
q_k_mask[i, k]>=0,
X[l, q_k_mask[i, k], q, k_q_mask[i, k]] * Y[l, q_k_mask[i, k], q, j], # # t1 is diagonaled and should be transposed
padding
),
),
tvm.te.if_then_else(
q_k_mask[i, j] >= 0,
X[l, i, q, k]
* Y[
l, q_k_mask[i, j], q, k
], # t1 is not diagonaled, but the output tensor is going to be
padding,
),
),
axis=k,
)
q_k_mask[i, j]>=0,
X[l, i, q, k] * Y[l, q_k_mask[i, j], q, k], # t1 is not diagonaled, but the output tensor is going to be
padding
)
), axis=k)
Z = tvm.te.compute(
output_shape, algorithm, name="Z"
) # automatically generate cuda code
Z = tvm.te.compute(output_shape, algorithm, name='Z') # automatically generate cuda code
s = tvm.te.create_schedule(Z.op)
print(
"Lowering: \n ===================== \n{}".format(
tvm.lower(s, [X, Y, q_k_mask, k_q_mask], simple_mode=True)
)
)
print('Lowering: \n ===================== \n{}'.format(tvm.lower(s, [X, Y, q_k_mask, k_q_mask], simple_mode=True)))
# split long axis into smaller chunks and assing each one to a separate GPU thread/block
ko, ki = s[Z].split(Z.op.reduce_axis[0], factor=b0)
@@ -133,31 +107,21 @@ class GraphMM(torch.autograd.Function):
s[ZF].compute_at(s[Z], s[Z].op.reduce_axis[0])
s[Z].set_store_predicate(tx.var.equal(0))
print(
"Lowering with GPU splits: \n ===================== \n{}".format(
tvm.lower(s, [X, Y, q_k_mask, k_q_mask], simple_mode=True)
)
)
print('Lowering with GPU splits: \n ===================== \n{}'.format(tvm.lower(s, [X, Y, q_k_mask, k_q_mask], simple_mode=True)))
# compiling the automatically generated cuda code
graph_mm = tvm.build(
s,
[X, Y, Z, q_k_mask, k_q_mask, max_attn, padding, transpose_t1, t3d3],
target=device,
target_host=tgt_host,
name="graph_mm",
)
graph_mm = tvm.build(s, [X, Y, Z, q_k_mask, k_q_mask, max_attn, padding, transpose_t1, t3d3], target=device, target_host=tgt_host, name='graph_mm')
return graph_mm
@staticmethod
def _get_lib_filename(dtype: str, device: str):
base_filename = "lib/lib_hierarchical_mm"
return "{}_{}_{}.so".format(base_filename, dtype, device)
base_filename = 'lib/lib_hierarchical_mm'
return '{}_{}_{}.so'.format(base_filename, dtype, device)
@staticmethod
def _save_compiled_function(f, dtype: str, device: str):
if not os.path.exists("lib/"):
os.makedirs("lib/")
if not os.path.exists('lib/'):
os.makedirs('lib/')
f.export_library(GraphMM._get_lib_filename(dtype, device))
@staticmethod
@@ -167,63 +131,43 @@ class GraphMM(torch.autograd.Function):
filename = GraphMM._get_lib_filename(dtype, device)
current_dir = os.path.dirname(os.path.abspath(__file__))
potential_dirs = [
"../../",
"../",
"./",
f"{current_dir}/",
f"{current_dir}/../",
]
for potential_dir in potential_dirs:
filepath = "{}{}".format(potential_dir, filename)
potential_dirs = ['../../', '../', './', f'{current_dir}/', f'{current_dir}/../']
for potential_dir in potential_dirs:
filepath = '{}{}'.format(potential_dir, filename)
if os.path.isfile(filepath):
print("Loading tvm binary from: {}".format(filepath))
print('Loading tvm binary from: {}'.format(filepath))
return load(filepath)
return None
@staticmethod
def _get_function(dtype: str, device: str):
"""Loads the function from the disk or compile it"""
'''Loads the function from the disk or compile it'''
# A list of arguments that define the function
args = (dtype, device)
if args not in GraphMM.function_dict:
graph_mm = GraphMM._load_compiled_function(
dtype, device
) # try to load from disk
graph_mm = GraphMM._load_compiled_function(dtype, device) # try to load from disk
if not graph_mm:
print("Tvm binary not found. Compiling ...")
print('Tvm binary not found. Compiling ...')
graph_mm = GraphMM._compile_function(dtype, device) # compile
GraphMM._save_compiled_function(graph_mm, dtype, device) # save to disk
# convert the tvm function into a pytorch function
from tvm.contrib import dlpack
graph_mm_pytorch = dlpack.to_pytorch_func(
graph_mm
) # wrap it as a pytorch function
graph_mm_pytorch = dlpack.to_pytorch_func(graph_mm) # wrap it as a pytorch function
# save the function into a dictionary to be reused
GraphMM.function_dict[
args
] = graph_mm_pytorch # save it in a dictionary for next time
GraphMM.function_dict[args] = graph_mm_pytorch # save it in a dictionary for next time
return GraphMM.function_dict[args]
@staticmethod
def _graph_mm(
t1: torch.Tensor,
t2: torch.Tensor,
q_k_mask: torch.Tensor,
k_q_mask: torch.Tensor,
is_t1_diagonaled: bool = False,
transpose_t1: bool = False,
padding: int = 0,
autoregressive: bool = False,
):
"""Calls the compiled function after checking the input format. This function is called in three different modes.
def _graph_mm(t1: torch.Tensor, t2: torch.Tensor, q_k_mask: torch.Tensor, k_q_mask: torch.Tensor,
is_t1_diagonaled: bool = False, transpose_t1: bool = False, padding: int = 0,
autoregressive: bool = False):
'''Calls the compiled function after checking the input format. This function is called in three different modes.
t1 x t2 = r ==> t1 and t2 are not diagonaled, but r is. Useful for query x key = attention_scores
t1 x t2 = r ==> t1 is diagonaled, but t2 and r are not. Useful to compuate attantion_scores x value = context
t1 x t2 = r ==> t1 is diagonaled and it should be transposed, but t2 and r are not diagonaled. Useful in some of
the calculations in the backward pass.
"""
dtype = str(t1.dtype).split(".")[1]
'''
dtype = str(t1.dtype).split('.')[1]
device = t1.device.type
assert len(t1.shape) == 4
assert len(t1.shape) == len(t2.shape)
@@ -252,26 +196,14 @@ class GraphMM(torch.autograd.Function):
# This functions computes diagonal_mm then saves the result in `r`
if m == max_attn:
# FIXME
print(
"Error: the hidden dimension {m} shouldn't match number of diagonals {c}"
)
print('Error: the hidden dimension {m} shouldn\'t match number of diagonals {c}')
assert False
_graph_mm_function(
t1,
t2,
r,
q_k_mask,
k_q_mask,
max_attn,
padding,
transpose_t1,
m if is_t1_diagonaled else max_attn,
)
_graph_mm_function(t1, t2, r, q_k_mask, k_q_mask, max_attn, padding, transpose_t1, m if is_t1_diagonaled else max_attn)
return r
@staticmethod
def _prepare_tensors(t):
"""Fix `stride()` information of input tensor. This addresses some inconsistency in stride information in PyTorch.
'''Fix `stride()` information of input tensor. This addresses some inconsistency in stride information in PyTorch.
For a tensor t, if t.size(0) == 1, then the value of t.stride()[0] doesn't matter.
TVM expects this value to be the `product(t.size()[1:])` but PyTorch some times sets it to `t.stride()[1]`.
Here's an example to reporduce this issue:
@@ -282,7 +214,7 @@ class GraphMM(torch.autograd.Function):
> (1, 1) # expected it to be (10, 1) as above
print(torch.randn(10, 2).t().contiguous().stride())
> (10, 1) # but gets the expected stride if the first dimension is > 1
"""
'''
assert t.is_contiguous()
t_stride = list(t.stride())
t_size = list(t.size())
@@ -297,17 +229,9 @@ class GraphMM(torch.autograd.Function):
min_seq_len = 16 # unexpected output if seq_len < 16
@staticmethod
def forward(
ctx,
t1: torch.Tensor,
t2: torch.Tensor,
q_k_mask,
k_q_mask,
is_t1_diagonaled: bool = False,
padding: int = 0,
) -> torch.Tensor:
"""Compuates diagonal_mm of t1 and t2.
args:
def forward(ctx, t1: torch.Tensor, t2: torch.Tensor, q_k_mask, k_q_mask, is_t1_diagonaled: bool = False, padding: int = 0) -> torch.Tensor:
'''Compuates diagonal_mm of t1 and t2.
args:
t1: torch.Tensor = (batch_size, seq_len, num_attention_heads, hidden_size|number_of_diagonals).
t1 can be a regular tensor (e.g. `query_layer`) or a diagonaled one (e.g. `attention_scores`)
t2: torch.Tensor = (batch_size, seq_len, num_attention_heads, hidden_size). This is always a non-diagonaled
@@ -322,13 +246,9 @@ class GraphMM(torch.autograd.Function):
autoregressive: if true, return only the lower triangle
returns: torch.Tensor = (batch_size, seq_len, num_attention_heads, hidden_size|number_of_diagonals)
if t1 is diagonaed, result is non-diagonaled, and vice versa
"""
'''
seq_len = t1.size(1)
assert (
seq_len >= GraphMM.min_seq_len
), "avoid splitting errors by using seq_len >= {}".format(
GraphMM.min_seq_len
) # FIXME
assert seq_len >= GraphMM.min_seq_len, 'avoid splitting errors by using seq_len >= {}'.format(GraphMM.min_seq_len) # FIXME
t1 = GraphMM._prepare_tensors(t1)
t2 = GraphMM._prepare_tensors(t2)
@@ -337,14 +257,7 @@ class GraphMM(torch.autograd.Function):
ctx.save_for_backward(t1, t2, q_k_mask, k_q_mask)
ctx.is_t1_diagonaled = is_t1_diagonaled
# output = t1.mm(t2) # what would have been called if this was a regular matmul
output = GraphMM._graph_mm(
t1,
t2,
q_k_mask,
k_q_mask,
is_t1_diagonaled=is_t1_diagonaled,
padding=padding,
)
output = GraphMM._graph_mm(t1, t2, q_k_mask, k_q_mask, is_t1_diagonaled=is_t1_diagonaled, padding=padding)
return output
@staticmethod
@@ -352,35 +265,17 @@ class GraphMM(torch.autograd.Function):
t1, t2, q_k_mask, k_q_mask = ctx.saved_tensors
is_t1_diagonaled = ctx.is_t1_diagonaled
if not grad_output.is_contiguous():
grad_output = (
grad_output.contiguous()
) # tvm requires all input tensors to be contiguous
grad_output = grad_output.contiguous() # tvm requires all input tensors to be contiguous
grad_output = GraphMM._prepare_tensors(grad_output)
# http://cs231n.github.io/optimization-2/
# https://pytorch.org/docs/master/notes/extending.html
# grad_t1 = grad_output.mm(t2) # what would have been called if this was a regular matmul
grad_t1 = GraphMM._graph_mm(
grad_output, t2, q_k_mask, k_q_mask, is_t1_diagonaled=not is_t1_diagonaled
)
grad_t1 = GraphMM._graph_mm(grad_output, t2, q_k_mask, k_q_mask, is_t1_diagonaled=not is_t1_diagonaled)
# grad_t2 = grad_output.t().mm(t1) # or `grad_t2 = t1.t().mm(grad_output).t()` because `(AB)^T = B^TA^T`
if is_t1_diagonaled:
grad_t2 = GraphMM._graph_mm(
t1,
grad_output,
q_k_mask,
k_q_mask,
is_t1_diagonaled=True,
transpose_t1=True,
)
grad_t2 = GraphMM._graph_mm(t1, grad_output, q_k_mask, k_q_mask, is_t1_diagonaled=True, transpose_t1=True)
else:
grad_t2 = GraphMM._graph_mm(
grad_output,
t1,
q_k_mask,
k_q_mask,
is_t1_diagonaled=True,
transpose_t1=True,
)
grad_t2 = GraphMM._graph_mm(grad_output, t1, q_k_mask, k_q_mask, is_t1_diagonaled=True, transpose_t1=True)
return grad_t1, grad_t2, None, None, None, None, None
+18 -43
View File
@@ -1,28 +1,23 @@
import numpy as np
import torch
from torch.nn.modules import loss
import torch
import numpy as np
def MAE(pred, true):
return np.mean(np.abs(pred - true))
return np.mean(np.abs(pred-true))
def MSE(pred, true):
return np.mean((pred - true) ** 2)
return np.mean((pred-true)**2)
def RMSE(pred, true):
return np.sqrt(MSE(pred, true))
def MAPE(pred, true):
return np.mean(np.abs((pred - true) / true))
def MSPE(pred, true):
return np.mean(np.square((pred - true) / true))
def metric(pred, true):
mae = MAE(pred, true)
mse = MSE(pred, true)
@@ -30,50 +25,32 @@ def metric(pred, true):
mape = MAPE(pred, true)
mspe = MSPE(pred, true)
return mae, mse, rmse, mape, mspe
return mae,mse,rmse,mape,mspe
class StandardScaler:
class StandardScaler():
def __init__(self):
self.mean = 0.0
self.std = 1.0
self.mean = 0.
self.std = 1.
def fit(self, data):
self.mean = data.mean(0)
self.std = data.std(0)
def transform(self, data):
mean = (
torch.from_numpy(self.mean).type_as(data).to(data.device)
if torch.is_tensor(data)
else self.mean
)
std = (
torch.from_numpy(self.std).type_as(data).to(data.device)
if torch.is_tensor(data)
else self.std
)
mean = torch.from_numpy(self.mean).type_as(data).to(data.device) if torch.is_tensor(data) else self.mean
std = torch.from_numpy(self.std).type_as(data).to(data.device) if torch.is_tensor(data) else self.std
return (data - mean) / std
def inverse_transform(self, data):
mean = (
torch.from_numpy(self.mean).type_as(data).to(data.device)
if torch.is_tensor(data)
else self.mean
)
std = (
torch.from_numpy(self.std).type_as(data).to(data.device)
if torch.is_tensor(data)
else self.std
)
mean = torch.from_numpy(self.mean).type_as(data).to(data.device) if torch.is_tensor(data) else self.mean
std = torch.from_numpy(self.std).type_as(data).to(data.device) if torch.is_tensor(data) else self.std
return (data * std) + mean
class TopkMSELoss(torch.nn.Module):
def __init__(self, topk) -> None:
super().__init__()
self.topk = topk
self.criterion = torch.nn.MSELoss(reduction="none")
self.criterion = torch.nn.MSELoss(reduction='none')
def forward(self, output, label):
losses = self.criterion(output, label).mean(2).mean(1)
@@ -81,9 +58,8 @@ class TopkMSELoss(torch.nn.Module):
return losses
class SingleStepLoss(torch.nn.Module):
"""Compute top-k log-likelihood and mse."""
""" Compute top-k log-likelihood and mse. """
def __init__(self, ignore_zero):
super().__init__()
@@ -91,9 +67,9 @@ class SingleStepLoss(torch.nn.Module):
def forward(self, mu, sigma, labels, topk=0):
if self.ignore_zero:
indexes = labels != 0
indexes = (labels != 0)
else:
indexes = labels >= 0
indexes = (labels >= 0)
distribution = torch.distributions.normal.Normal(mu[indexes], sigma[indexes])
likelihood = -distribution.log_prob(labels[indexes])
@@ -107,12 +83,11 @@ class SingleStepLoss(torch.nn.Module):
return likelihood, se
def AE_loss(mu, labels, ignore_zero):
if ignore_zero:
indexes = labels != 0
indexes = (labels != 0)
else:
indexes = labels >= 0
indexes = (labels >= 0)
ae = torch.abs(labels[indexes] - mu[indexes])
return ae