mirror of
https://github.com/wassname/pytorch-transformer-ts.git
synced 2026-06-27 16:46:28 +08:00
Clean up(still not working)
This commit is contained in:
Vendored
BIN
Binary file not shown.
@@ -1,10 +1,9 @@
|
||||
from .estimator import PyraformerEstimator
|
||||
from .lightning_module import PyraformerLightningModule
|
||||
from .module import PyraformerLRModel, PyraformerSSModel
|
||||
from .estimator import TransformerEstimator
|
||||
from .lightning_module import TransformerLightningModule
|
||||
from .module import TransformerModel
|
||||
|
||||
__all__ = [
|
||||
"PyraformerSSModel",
|
||||
"PyraformerLRModel",
|
||||
"PyraformerLightningModule",
|
||||
"PyraformerEstimator",
|
||||
"TransformerModel",
|
||||
"TransformerLightningModule",
|
||||
"TransformerEstimator",
|
||||
]
|
||||
|
||||
+32
-53
@@ -5,15 +5,12 @@ from gluonts.core.component import validated
|
||||
from gluonts.dataset.common import Dataset
|
||||
from gluonts.dataset.field_names import FieldName
|
||||
from gluonts.itertools import Cyclic, IterableSlice, PseudoShuffled
|
||||
from gluonts.time_feature import (
|
||||
TimeFeature,
|
||||
get_lags_for_frequency,
|
||||
time_features_from_frequency_str,
|
||||
)
|
||||
from gluonts.time_feature import TimeFeature, time_features_from_frequency_str
|
||||
from gluonts.torch.model.estimator import PyTorchLightningEstimator
|
||||
from gluonts.torch.model.predictor import PyTorchPredictor
|
||||
from gluonts.torch.modules.distribution_output import DistributionOutput, StudentTOutput
|
||||
from gluonts.torch.modules.loss import DistributionLoss, NegativeLogLikelihood
|
||||
from gluonts.time_feature import get_lags_for_frequency
|
||||
from gluonts.torch.util import IterableDataset
|
||||
from gluonts.transform import (
|
||||
AddAgeFeature,
|
||||
@@ -32,13 +29,12 @@ from gluonts.transform import (
|
||||
VstackFeatures,
|
||||
)
|
||||
from gluonts.transform.sampler import InstanceSampler
|
||||
from torch.utils.data import DataLoader
|
||||
from torch.utils.data.sampler import RandomSampler
|
||||
|
||||
from lightning_module import PyraformerLightningModule
|
||||
from module import PyraformerLRModel, PyraformerSSModel
|
||||
from module import PyraformerSSModel
|
||||
from module import PyraformerLRModel
|
||||
from torch.utils.data import DataLoader
|
||||
from tools import SingleStepLoss as LossFactory
|
||||
|
||||
from torch.utils.data.sampler import RandomSampler
|
||||
PREDICTION_INPUT_NAMES = [
|
||||
"feat_static_cat",
|
||||
"feat_static_real",
|
||||
@@ -60,28 +56,30 @@ class PyraformerEstimator(PyTorchLightningEstimator):
|
||||
self,
|
||||
freq: str,
|
||||
prediction_length: int,
|
||||
# Train parameters
|
||||
#Train parameters
|
||||
inner_batch: int = 8,
|
||||
lr: float = 1e-5,
|
||||
visualize_fre: int = 2000,
|
||||
pretrain: bool = True,
|
||||
hard_sample_mining: bool = True,
|
||||
hard_sample_mining:bool=True,
|
||||
covariate_size: int = 3,
|
||||
# Model parameters
|
||||
num_seq: int = 370, #
|
||||
decoder: str = "FC", # selection: [FC, attention]
|
||||
|
||||
# Model parameters
|
||||
num_seq: int = 370,#
|
||||
decoder: str = 'FC',# selection: [FC, attention]
|
||||
context_length: Optional[int] = None,
|
||||
input_size: int = 1,
|
||||
dropout: float = 0.1,
|
||||
d_model: int = 512,
|
||||
d_inner_hid: int = 512,
|
||||
d_k: int = 128,
|
||||
d_v: int = 128,
|
||||
d_v:int = 128,
|
||||
num_heads: int = 4,
|
||||
n_layer: int = 4,
|
||||
# loss: DistributionLoss = LossFactory,
|
||||
ignore_zero: bool = True,
|
||||
single_step: bool = True, # if False, Multistep=True
|
||||
single_step: bool = True,#if False, Multistep=True
|
||||
|
||||
inner_size: int = 3,
|
||||
use_tvm: bool = False,
|
||||
num_feat_dynamic_real: int = 0,
|
||||
@@ -100,7 +98,7 @@ class PyraformerEstimator(PyTorchLightningEstimator):
|
||||
trainer_kwargs: Optional[Dict[str, Any]] = dict(),
|
||||
train_sampler: Optional[InstanceSampler] = None,
|
||||
validation_sampler: Optional[InstanceSampler] = None,
|
||||
window_size: int = [4, 4, 4],
|
||||
window_size: int = [4, 4, 4]
|
||||
) -> None:
|
||||
trainer_kwargs = {
|
||||
"max_epochs": 10,
|
||||
@@ -127,15 +125,11 @@ class PyraformerEstimator(PyTorchLightningEstimator):
|
||||
self.n_layer = n_layer
|
||||
self.single_step = single_step
|
||||
self.ignore_zero = ignore_zero
|
||||
self.loss = (
|
||||
LossFactory(self.ignore_zero)
|
||||
if self.single_step == True
|
||||
else torch.nn.MSELoss(reduction="none")
|
||||
)
|
||||
self.loss = LossFactory(self.ignore_zero) if self.single_step==True else torch.nn.MSELoss(reduction='none')
|
||||
self.batch_size = batch_size
|
||||
self.distr_output = distr_output
|
||||
|
||||
self.window_size = window_size # [4,4,4]#window_size
|
||||
self.window_size = window_size#[4,4,4]#window_size
|
||||
self.inner_size = inner_size
|
||||
self.use_tvm = use_tvm
|
||||
self.prediction_length = prediction_length
|
||||
@@ -325,35 +319,18 @@ class PyraformerEstimator(PyTorchLightningEstimator):
|
||||
def create_lightning_module(self) -> PyraformerLightningModule:
|
||||
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
|
||||
if self.single_step:
|
||||
model = PyraformerSSModel(
|
||||
freq=self.freq,
|
||||
covariate_size=self.covariate_size,
|
||||
num_seq=self.num_seq,
|
||||
input_size=self.input_size,
|
||||
dropout=self.dropout,
|
||||
d_model=self.d_model,
|
||||
d_inner_hid=self.d_inner_hid,
|
||||
d_k=self.d_k,
|
||||
d_v=self.d_v,
|
||||
num_heads=self.num_heads,
|
||||
n_layer=self.n_layer,
|
||||
loss=self.loss,
|
||||
window_size=self.window_size,
|
||||
inner_size=self.inner_size,
|
||||
use_tvm=self.use_tvm,
|
||||
prediction_length=self.prediction_length,
|
||||
context_length=self.context_length,
|
||||
lags_seq=self.lags_seq,
|
||||
num_feat_dynamic_real=self.num_feat_dynamic_real,
|
||||
num_feat_static_cat=self.num_feat_static_cat,
|
||||
num_feat_static_real=self.num_feat_static_real,
|
||||
cardinality=self.cardinality,
|
||||
embedding_dimension=self.embedding_dimension,
|
||||
distr_output=self.distr_output,
|
||||
scaling=self.scaling,
|
||||
num_parallel_samples=self.num_parallel_samples,
|
||||
device=device,
|
||||
)
|
||||
model = PyraformerSSModel(freq= self.freq, covariate_size = self.covariate_size,
|
||||
num_seq=self.num_seq, input_size = self.input_size, dropout = self.dropout, d_model = self.d_model,
|
||||
d_inner_hid = self.d_inner_hid, d_k = self.d_k, d_v = self.d_v,
|
||||
num_heads = self.num_heads, n_layer = self.n_layer, loss = self.loss,
|
||||
window_size = self.window_size, inner_size = self.inner_size,
|
||||
use_tvm = self.use_tvm, prediction_length = self.prediction_length,context_length = self.context_length, lags_seq = self.lags_seq, num_feat_dynamic_real= self.num_feat_dynamic_real,
|
||||
num_feat_static_cat = self.num_feat_static_cat,
|
||||
num_feat_static_real = self.num_feat_static_real,
|
||||
cardinality = self.cardinality,
|
||||
embedding_dimension = self.embedding_dimension,
|
||||
distr_output=self.distr_output,
|
||||
scaling=self.scaling,num_parallel_samples=self.num_parallel_samples, device=device)
|
||||
# else:
|
||||
# model = PyraformerLRModel(freq= self.freq, covariate_size = self.covariate_size,
|
||||
# num_seq=self.num_seq, input_size = self.input_size, dropout = self.dropout, d_model = self.d_model,
|
||||
@@ -362,3 +339,5 @@ class PyraformerEstimator(PyTorchLightningEstimator):
|
||||
# window_size = self.window_size, inner_size = self.inner_size,
|
||||
# use_tvm = self.use_tvm, prediction_length = self.prediction_length,context_length = self.context_length, lags_seq = self.lags_seq, device=device)
|
||||
return PyraformerLightningModule(model=model, loss=self.loss)
|
||||
|
||||
|
||||
|
||||
@@ -2,84 +2,81 @@ import pytorch_lightning as pl
|
||||
import torch
|
||||
from gluonts.torch.modules.loss import DistributionLoss, NegativeLogLikelihood
|
||||
from gluonts.torch.util import weighted_average
|
||||
|
||||
from module import PyraformerLRModel, PyraformerSSModel
|
||||
from tools import AE_loss
|
||||
from module import PyraformerSSModel
|
||||
from module import PyraformerLRModel
|
||||
from tools import SingleStepLoss as LossFactory
|
||||
|
||||
from tools import AE_loss
|
||||
# from module import PyraformerModel
|
||||
|
||||
|
||||
class PyraformerLightningModule(pl.LightningModule):
|
||||
def __init__(
|
||||
self,
|
||||
model: PyraformerSSModel,
|
||||
loss: DistributionLoss = LossFactory,
|
||||
lr: float = 1e-5,
|
||||
weight_decay: float = 1e-8,
|
||||
) -> None:
|
||||
super().__init__()
|
||||
self.save_hyperparameters()
|
||||
self.model = model
|
||||
self.loss = loss
|
||||
self.lr = lr
|
||||
self.weight_decay = weight_decay
|
||||
def __init__(self, model: PyraformerSSModel, loss: DistributionLoss = LossFactory, lr: float = 1e-5, weight_decay: float = 1e-8,) -> None:
|
||||
super().__init__()
|
||||
self.save_hyperparameters()
|
||||
self.model = model
|
||||
self.loss = loss
|
||||
self.lr = lr
|
||||
self.weight_decay = weight_decay
|
||||
|
||||
def training_step(self, batch, batch_idx: int):
|
||||
def training_step(self, batch, batch_idx: int):
|
||||
|
||||
"""Execute training step"""
|
||||
train_loss = self(batch)
|
||||
self.log(
|
||||
"train_loss",
|
||||
train_loss,
|
||||
on_epoch=True,
|
||||
on_step=False,
|
||||
prog_bar=True,
|
||||
)
|
||||
return train_loss
|
||||
"""Execute training step"""
|
||||
train_loss = self(batch)
|
||||
self.log(
|
||||
"train_loss",
|
||||
train_loss,
|
||||
on_epoch=True,
|
||||
on_step=False,
|
||||
prog_bar=True,
|
||||
)
|
||||
return train_loss
|
||||
|
||||
def validation_step(self, batch, batch_idx: int):
|
||||
"""Execute validation step"""
|
||||
with torch.inference_mode():
|
||||
val_loss = self(batch)
|
||||
self.log("val_loss", val_loss, on_epoch=True, on_step=False, prog_bar=True)
|
||||
return val_loss
|
||||
def validation_step(self, batch, batch_idx: int):
|
||||
"""Execute validation step"""
|
||||
with torch.inference_mode():
|
||||
val_loss = self(batch)
|
||||
self.log("val_loss", val_loss, on_epoch=True, on_step=False, prog_bar=True)
|
||||
return val_loss
|
||||
|
||||
def configure_optimizers(self):
|
||||
"""Returns the optimizer to use"""
|
||||
return torch.optim.Adam(
|
||||
self.model.parameters(),
|
||||
lr=self.lr,
|
||||
weight_decay=self.weight_decay,
|
||||
)
|
||||
def configure_optimizers(self):
|
||||
"""Returns the optimizer to use"""
|
||||
return torch.optim.Adam(
|
||||
self.model.parameters(),
|
||||
lr=self.lr,
|
||||
weight_decay=self.weight_decay,
|
||||
)
|
||||
|
||||
def forward(self, batch):
|
||||
feat_static_cat = batch["feat_static_cat"]
|
||||
feat_static_real = batch["feat_static_real"]
|
||||
past_time_feat = batch["past_time_feat"]
|
||||
past_target = batch["past_target"]
|
||||
future_time_feat = batch["future_time_feat"]
|
||||
future_target = batch["future_target"]
|
||||
past_observed_values = batch["past_observed_values"]
|
||||
future_observed_values = batch["future_observed_values"]
|
||||
def forward(self, batch):
|
||||
feat_static_cat = batch["feat_static_cat"]
|
||||
feat_static_real = batch["feat_static_real"]
|
||||
past_time_feat = batch["past_time_feat"]
|
||||
past_target = batch["past_target"]
|
||||
future_time_feat = batch["future_time_feat"]
|
||||
future_target = batch["future_target"]
|
||||
past_observed_values = batch["past_observed_values"]
|
||||
future_observed_values = batch["future_observed_values"]
|
||||
|
||||
Pyraformer_inputs, scale, _ = self.model.create_network_inputs(
|
||||
feat_static_cat,
|
||||
feat_static_real,
|
||||
past_time_feat,
|
||||
past_target,
|
||||
past_observed_values,
|
||||
future_time_feat,
|
||||
future_target,
|
||||
)
|
||||
params = self.model.output_params(Pyraformer_inputs)
|
||||
distr = self.model.output_distribution(params, scale)
|
||||
|
||||
loss_values = self.loss(distr, future_target)
|
||||
|
||||
if len(self.model.target_shape) == 0:
|
||||
loss_weights = future_observed_values
|
||||
else:
|
||||
loss_weights = future_observed_values.min(dim=-1, keepdim=False)
|
||||
|
||||
return weighted_average(loss_values, weights=loss_weights)
|
||||
|
||||
Pyraformer_inputs, scale, _ = self.model.create_network_inputs(
|
||||
feat_static_cat,
|
||||
feat_static_real,
|
||||
past_time_feat,
|
||||
past_target,
|
||||
past_observed_values,
|
||||
future_time_feat,
|
||||
future_target,
|
||||
)
|
||||
params = self.model.output_params(Pyraformer_inputs)
|
||||
distr = self.model.output_distribution(params, scale)
|
||||
|
||||
loss_values = self.loss(distr, future_target)
|
||||
|
||||
if len(self.model.target_shape) == 0:
|
||||
loss_weights = future_observed_values
|
||||
else:
|
||||
loss_weights = future_observed_values.min(dim=-1, keepdim=False)
|
||||
|
||||
return weighted_average(loss_values, weights=loss_weights)
|
||||
|
||||
+93
-617
@@ -1,5 +1,4 @@
|
||||
from typing import List, Optional
|
||||
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
from gluonts.core.component import validated
|
||||
@@ -8,45 +7,20 @@ from gluonts.torch.modules.distribution_output import DistributionOutput, Studen
|
||||
from gluonts.torch.modules.feature import FeatureEmbedder
|
||||
from gluonts.torch.modules.scaler import MeanScaler, NOPScaler
|
||||
|
||||
from pyraformer.embed import CustomEmbedding, DataEmbedding, SingleStepEmbedding
|
||||
from pyraformer.Layers import (
|
||||
AvgPooling_Construct,
|
||||
Bottleneck_Construct,
|
||||
Conv_Construct,
|
||||
Decoder,
|
||||
EncoderLayer,
|
||||
MaxPooling_Construct,
|
||||
Predictor,
|
||||
get_k_q,
|
||||
get_mask,
|
||||
get_q_k,
|
||||
get_subsequent_mask,
|
||||
refer_points,
|
||||
)
|
||||
from pyraformer.Layers import EncoderLayer, Predictor, Decoder
|
||||
from pyraformer.Layers import Bottleneck_Construct, Conv_Construct, MaxPooling_Construct, AvgPooling_Construct
|
||||
from pyraformer.Layers import get_mask, refer_points, get_k_q, get_q_k, get_subsequent_mask
|
||||
from pyraformer.embed import SingleStepEmbedding, DataEmbedding, CustomEmbedding
|
||||
|
||||
|
||||
class EncoderSS(nn.Module):
|
||||
"""A encoder model with self attention mechanism."""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
covariate_size,
|
||||
num_seq,
|
||||
input_size,
|
||||
dropout,
|
||||
d_model,
|
||||
d_inner_hid,
|
||||
d_k,
|
||||
d_v,
|
||||
num_heads,
|
||||
n_layer,
|
||||
loss,
|
||||
window_size,
|
||||
inner_size,
|
||||
use_tvm,
|
||||
prediction_length,
|
||||
device,
|
||||
):
|
||||
@validated()
|
||||
def __init__(self, covariate_size,
|
||||
num_seq, input_size ,dropout , d_model,
|
||||
d_inner_hid, d_k, d_v ,
|
||||
num_heads , n_layer, loss ,
|
||||
window_size , inner_size,
|
||||
use_tvm, prediction_length, device):
|
||||
super().__init__()
|
||||
|
||||
self.d_model = d_model
|
||||
@@ -56,49 +30,21 @@ class EncoderSS(nn.Module):
|
||||
self.indexes = refer_points(self.all_size, window_size, device)
|
||||
|
||||
if use_tvm:
|
||||
|
||||
assert (
|
||||
len(set(self.window_size)) == 1
|
||||
), "Only constant window size is supported."
|
||||
|
||||
assert len(set(self.window_size)) == 1, "Only constant window size is supported."
|
||||
q_k_mask = get_q_k(input_size, inner_size, window_size[0], device)
|
||||
k_q_mask = get_k_q(q_k_mask)
|
||||
self.layers = nn.ModuleList(
|
||||
[
|
||||
EncoderLayer(
|
||||
d_model,
|
||||
d_inner_hid,
|
||||
num_heads,
|
||||
d_k,
|
||||
d_v,
|
||||
dropout=dropout,
|
||||
normalize_before=False,
|
||||
use_tvm=True,
|
||||
q_k_mask=q_k_mask,
|
||||
k_q_mask=k_q_mask,
|
||||
)
|
||||
for i in range(n_layer)
|
||||
]
|
||||
)
|
||||
self.layers = nn.ModuleList([
|
||||
EncoderLayer(d_model, d_inner_hid, num_heads, d_k, d_v, dropout=dropout, \
|
||||
normalize_before=False, use_tvm=True, q_k_mask=q_k_mask, k_q_mask=k_q_mask) for i in range(n_layer)
|
||||
])
|
||||
else:
|
||||
self.layers = nn.ModuleList(
|
||||
[
|
||||
EncoderLayer(
|
||||
d_model,
|
||||
d_inner_hid,
|
||||
num_heads,
|
||||
d_k,
|
||||
d_v,
|
||||
dropout=dropout,
|
||||
normalize_before=False,
|
||||
)
|
||||
for i in range(n_layer)
|
||||
]
|
||||
)
|
||||
|
||||
self.embedding = SingleStepEmbedding(
|
||||
covariate_size, num_seq, d_model, input_size, device
|
||||
)
|
||||
self.layers = nn.ModuleList([
|
||||
EncoderLayer(d_model, d_inner_hid, num_heads, d_k, d_v, dropout=dropout, \
|
||||
normalize_before=False) for i in range(n_layer)
|
||||
])
|
||||
|
||||
self.embedding = SingleStepEmbedding(covariate_size, num_seq, d_model, input_size, device)
|
||||
self.conv_layers = Bottleneck_Construct(d_model, window_size, d_k)
|
||||
|
||||
def forward(self, sequence):
|
||||
@@ -111,38 +57,21 @@ class EncoderSS(nn.Module):
|
||||
for i in range(len(self.layers)):
|
||||
seq_enc, _ = self.layers[i](seq_enc, mask)
|
||||
|
||||
indexes = self.indexes.repeat(seq_enc.size(0), 1, 1, seq_enc.size(2)).to(
|
||||
seq_enc.device
|
||||
)
|
||||
indexes = self.indexes.repeat(seq_enc.size(0), 1, 1, seq_enc.size(2)).to(seq_enc.device)
|
||||
indexes = indexes.view(seq_enc.size(0), -1, seq_enc.size(2))
|
||||
all_enc = torch.gather(seq_enc, 1, indexes)
|
||||
all_enc = all_enc.view(seq_enc.size(0), self.all_size[0], -1)
|
||||
|
||||
return all_enc
|
||||
|
||||
|
||||
class PyraformerSSModel(nn.Module):
|
||||
@validated()
|
||||
def __init__(
|
||||
self,
|
||||
freq,
|
||||
covariate_size,
|
||||
num_seq,
|
||||
input_size,
|
||||
dropout,
|
||||
d_model,
|
||||
d_inner_hid,
|
||||
d_k,
|
||||
d_v,
|
||||
num_heads,
|
||||
n_layer,
|
||||
loss,
|
||||
window_size,
|
||||
inner_size,
|
||||
use_tvm,
|
||||
prediction_length,
|
||||
context_length,
|
||||
lags_seq,
|
||||
def __init__(self, freq, covariate_size,
|
||||
num_seq, input_size ,dropout, d_model,
|
||||
d_inner_hid, d_k, d_v,
|
||||
num_heads , n_layer, loss ,
|
||||
window_size , inner_size,
|
||||
use_tvm, prediction_length,context_length,lags_seq,
|
||||
num_feat_dynamic_real,
|
||||
num_feat_static_cat,
|
||||
num_feat_static_real,
|
||||
@@ -150,39 +79,25 @@ class PyraformerSSModel(nn.Module):
|
||||
embedding_dimension,
|
||||
distr_output,
|
||||
# loss: DistributionLoss = NegativeLogLikelihood(),
|
||||
scaling,
|
||||
num_parallel_samples,
|
||||
device,
|
||||
):
|
||||
scaling,num_parallel_samples,device):
|
||||
|
||||
|
||||
super().__init__()
|
||||
self.context_length = context_length
|
||||
self.lags_seq = lags_seq or get_lags_for_frequency(freq_str=freq)
|
||||
self.encoder = EncoderSS(
|
||||
covariate_size,
|
||||
num_seq,
|
||||
input_size,
|
||||
dropout,
|
||||
d_model,
|
||||
d_inner_hid,
|
||||
d_k,
|
||||
d_v,
|
||||
num_heads,
|
||||
n_layer,
|
||||
loss,
|
||||
window_size,
|
||||
inner_size,
|
||||
use_tvm,
|
||||
prediction_length,
|
||||
device,
|
||||
)
|
||||
self.encoder = EncoderSS(covariate_size,
|
||||
num_seq, input_size ,dropout , d_model,
|
||||
d_inner_hid, d_k, d_v ,
|
||||
num_heads , n_layer, loss ,
|
||||
window_size , inner_size,
|
||||
use_tvm, prediction_length, device)
|
||||
|
||||
# convert hidden vectors into two scalar
|
||||
# convert hidden vectors into two scalar
|
||||
self.mean_hidden = Predictor(4 * d_model, 1)
|
||||
self.var_hidden = Predictor(4 * d_model, 1)
|
||||
|
||||
self.softplus = nn.Softplus()
|
||||
|
||||
|
||||
def forward(self, data):
|
||||
enc_output = self.encoder(data)
|
||||
|
||||
@@ -199,11 +114,10 @@ class PyraformerSSModel(nn.Module):
|
||||
sample_mu = mu[:, -1] * v
|
||||
sample_sigma = sigma[:, -1] * v
|
||||
return sample_mu, sample_sigma
|
||||
|
||||
|
||||
@property
|
||||
def _past_length(self) -> int:
|
||||
return self.context_length + max(self.lags_seq)
|
||||
|
||||
@property
|
||||
def _number_of_features(self) -> int:
|
||||
return (
|
||||
@@ -212,7 +126,6 @@ class PyraformerSSModel(nn.Module):
|
||||
+ self.num_feat_static_real
|
||||
+ 1 # the log(scale)
|
||||
)
|
||||
|
||||
def get_lagged_subsequences(
|
||||
self, sequence: torch.Tensor, subsequences_length: int, shift: int = 0
|
||||
) -> torch.Tensor:
|
||||
@@ -265,7 +178,7 @@ class PyraformerSSModel(nn.Module):
|
||||
assert (
|
||||
features is None or features.shape[2] == self._number_of_features
|
||||
), f"{features.shape[2]}, expected {self._number_of_features}"
|
||||
|
||||
|
||||
def create_network_inputs(
|
||||
self,
|
||||
feat_static_cat: torch.Tensor,
|
||||
@@ -361,149 +274,70 @@ class PyraformerSSModel(nn.Module):
|
||||
if trailing_n is not None:
|
||||
sliced_params = [p[:, -trailing_n:] for p in params]
|
||||
return self.distr_output.distribution(sliced_params, scale=scale)
|
||||
|
||||
|
||||
|
||||
class Encoder(nn.Module):
|
||||
"""A encoder model with self attention mechanism."""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
model,
|
||||
window_size,
|
||||
truncate,
|
||||
input_size,
|
||||
inner_size,
|
||||
decoder,
|
||||
num_head,
|
||||
d_model,
|
||||
d_k,
|
||||
d_v,
|
||||
d_inner_hid,
|
||||
dropout,
|
||||
n_layer,
|
||||
enc_in,
|
||||
covariate_size,
|
||||
seq_num,
|
||||
CSCM,
|
||||
d_bottleneck,
|
||||
num_head,
|
||||
use_tvm,
|
||||
device,
|
||||
):
|
||||
@validated()
|
||||
def __init__(self, model, window_size, truncate, input_size, inner_size,decoder,d_model, d_k,d_v,d_inner_hid, dropout,n_layer ,enc_in,covariate_size,seq_num,CSCM,d_bottleneck,num_head,use_tvm,device):
|
||||
super().__init__()
|
||||
|
||||
self.d_model = d_model
|
||||
self.model_type = model
|
||||
self.window_size = window_size
|
||||
self.truncate = truncate
|
||||
if decoder == "attention":
|
||||
self.mask, self.all_size = get_mask(
|
||||
input_size, window_size, inner_size, device
|
||||
)
|
||||
if decoder == 'attention':
|
||||
self.mask, self.all_size = get_mask(input_size, window_size, inner_size, device)
|
||||
else:
|
||||
self.mask, self.all_size = get_mask(
|
||||
input_size + 1, window_size, inner_size, device
|
||||
)
|
||||
self.mask, self.all_size = get_mask(input_size+1, window_size, inner_size, device)
|
||||
self.decoder_type = decoder
|
||||
if decoder == "FC":
|
||||
if decoder == 'FC':
|
||||
self.indexes = refer_points(self.all_size, window_size, device)
|
||||
|
||||
if use_tvm:
|
||||
assert (
|
||||
len(set(self.window_size)) == 1
|
||||
), "Only constant window size is supported."
|
||||
padding = 1 if decoder == "FC" else 0
|
||||
q_k_mask = get_q_k(input_size + padding, inner_size, window_size[0], device)
|
||||
assert len(set(self.window_size)) == 1, "Only constant window size is supported."
|
||||
padding = 1 if decoder == 'FC' else 0
|
||||
q_k_mask = get_q_k(input_size + padding, inner_size, window_size[0],device)
|
||||
k_q_mask = get_k_q(q_k_mask)
|
||||
self.layers = nn.ModuleList(
|
||||
[
|
||||
EncoderLayer(
|
||||
d_model,
|
||||
d_inner_hid,
|
||||
num_head,
|
||||
d_k,
|
||||
d_v,
|
||||
dropout=dropout,
|
||||
normalize_before=False,
|
||||
use_tvm=True,
|
||||
q_k_mask=q_k_mask,
|
||||
k_q_mask=k_q_mask,
|
||||
)
|
||||
for i in range(n_layer)
|
||||
]
|
||||
)
|
||||
self.layers = nn.ModuleList([
|
||||
EncoderLayer(d_model, d_inner_hid, num_head, d_k, d_v, dropout=dropout, \
|
||||
normalize_before=False, use_tvm=True, q_k_mask=q_k_mask, k_q_mask=k_q_mask) for i in range(n_layer)
|
||||
])
|
||||
else:
|
||||
self.layers = nn.ModuleList(
|
||||
[
|
||||
EncoderLayer(
|
||||
d_model,
|
||||
d_inner_hid,
|
||||
num_head,
|
||||
d_k,
|
||||
d_v,
|
||||
dropout=dropout,
|
||||
normalize_before=False,
|
||||
)
|
||||
for i in range(n_layer)
|
||||
]
|
||||
)
|
||||
|
||||
if opt.embed_type == "CustomEmbedding":
|
||||
self.enc_embedding = CustomEmbedding(
|
||||
enc_in, d_model, covariate_size, seq_num, dropout
|
||||
)
|
||||
self.layers = nn.ModuleList([
|
||||
EncoderLayer(d_model, d_inner_hid, num_head, d_k, d_v, dropout=dropout, \
|
||||
normalize_before=False) for i in range(n_layer)
|
||||
])
|
||||
|
||||
if opt.embed_type == 'CustomEmbedding':
|
||||
self.enc_embedding = CustomEmbedding(enc_in, d_model, covariate_size, seq_num, dropout)
|
||||
else:
|
||||
self.enc_embedding = DataEmbedding(enc_in, d_model, dropout)
|
||||
|
||||
|
||||
self.conv_layers = eval(CSCM)(d_model, window_size, d_bottleneck)
|
||||
|
||||
|
||||
def forward(self, x_enc, x_mark_enc):
|
||||
seq_enc = self.enc_embedding(x_enc, x_mark_enc)
|
||||
|
||||
mask = self.mask.repeat(len(seq_enc), 1, 1).to(x_enc.device)
|
||||
seq_enc = self.conv_layers(seq_enc)
|
||||
|
||||
def forward(self, x_enc, x_mark_enc):
|
||||
seq_enc = self.enc_embedding(x_enc, x_mark_enc)
|
||||
|
||||
mask = self.mask.repeat(len(seq_enc), 1, 1).to(x_enc.device)
|
||||
seq_enc = self.conv_layers(seq_enc)
|
||||
|
||||
for i in range(len(self.layers)):
|
||||
seq_enc, _ = self.layers[i](seq_enc, mask)
|
||||
|
||||
if self.decoder_type == "FC":
|
||||
indexes = self.indexes.repeat(
|
||||
seq_enc.size(0), 1, 1, seq_enc.size(2)
|
||||
).to(seq_enc.device)
|
||||
indexes = indexes.view(seq_enc.size(0), -1, seq_enc.size(2))
|
||||
all_enc = torch.gather(seq_enc, 1, indexes)
|
||||
seq_enc = all_enc.view(seq_enc.size(0), self.all_size[0], -1)
|
||||
elif self.decoder_type == "attention" and self.truncate:
|
||||
seq_enc = seq_enc[:, : self.all_size[0]]
|
||||
|
||||
return seq_enc
|
||||
|
||||
for i in range(len(self.layers)):
|
||||
seq_enc, _ = self.layers[i](seq_enc, mask)
|
||||
|
||||
if self.decoder_type == 'FC':
|
||||
indexes = self.indexes.repeat(seq_enc.size(0), 1, 1, seq_enc.size(2)).to(seq_enc.device)
|
||||
indexes = indexes.view(seq_enc.size(0), -1, seq_enc.size(2))
|
||||
all_enc = torch.gather(seq_enc, 1, indexes)
|
||||
seq_enc = all_enc.view(seq_enc.size(0), self.all_size[0], -1)
|
||||
elif self.decoder_type == 'attention' and self.truncate:
|
||||
seq_enc = seq_enc[:, :self.all_size[0]]
|
||||
|
||||
return seq_enc
|
||||
|
||||
|
||||
class PyraformerLRModel(nn.Module):
|
||||
@validated()
|
||||
def __init__(
|
||||
self,
|
||||
predict_step,
|
||||
d_model,
|
||||
input_size,
|
||||
decoder,
|
||||
window_size,
|
||||
truncate,
|
||||
model,
|
||||
d_inner_hid,
|
||||
num_head,
|
||||
d_k,
|
||||
d_v,
|
||||
dropout,
|
||||
enc_in,
|
||||
covariate_size,
|
||||
seq_num,
|
||||
CSCM,
|
||||
d_bottleneck,
|
||||
num_head,
|
||||
use_tvm,
|
||||
device,
|
||||
):
|
||||
def __init__(self, predict_step, d_model, input_size, decoder, window_size, truncate, model,d_inner_hid,d_k,d_v,dropout,enc_in,covariate_size,seq_num,CSCM,d_bottleneck,num_head,use_tvm,device):
|
||||
super().__init__()
|
||||
|
||||
self.predict_step = predict_step
|
||||
@@ -511,46 +345,13 @@ class PyraformerLRModel(nn.Module):
|
||||
self.input_size = input_size
|
||||
self.decoder_type = decoder
|
||||
self.channels = enc_in
|
||||
|
||||
self.encoder = Encoder(
|
||||
model,
|
||||
window_size,
|
||||
truncate,
|
||||
input_size,
|
||||
inner_size,
|
||||
decoder,
|
||||
d_model,
|
||||
d_k,
|
||||
d_v,
|
||||
d_inner_hid,
|
||||
dropout,
|
||||
n_layer,
|
||||
enc_in,
|
||||
covariate_size,
|
||||
seq_num,
|
||||
CSCM,
|
||||
d_bottleneck,
|
||||
num_head,
|
||||
use_tvm,
|
||||
device,
|
||||
)
|
||||
if decoder == "attention":
|
||||
|
||||
self.encoder = Encoder(model, window_size, truncate, input_size, inner_size,decoder,d_model, d_k,d_v,d_inner_hid, dropout,n_layer, enc_in,covariate_size,seq_num,CSCM,d_bottleneck,num_head,use_tvm,device)
|
||||
if decoder == 'attention':
|
||||
mask = get_subsequent_mask(input_size, window_size, predict_step, truncate)
|
||||
self.decoder = Decoder(
|
||||
model,
|
||||
d_model,
|
||||
d_inner_hid,
|
||||
num_head,
|
||||
d_k,
|
||||
d_v,
|
||||
dropout,
|
||||
enc_in,
|
||||
covariate_size,
|
||||
seq_num,
|
||||
mask,
|
||||
)
|
||||
self.decoder = Decoder(model,d_model,d_inner_hid,num_head,d_k,d_v,dropout,enc_in,covariate_size,seq_num, mask)
|
||||
self.predictor = Predictor(d_model, enc_in)
|
||||
elif opt.decoder == "FC":
|
||||
elif opt.decoder == 'FC':
|
||||
self.predictor = Predictor(4 * d_model, predict_step * enc_in)
|
||||
|
||||
def forward(self, x_enc, x_mark_enc, x_dec, x_mark_dec, pretrain):
|
||||
@@ -563,343 +364,18 @@ class PyraformerLRModel(nn.Module):
|
||||
type_prediction: batch*seq_len*num_classes (not normalized);
|
||||
time_prediction: batch*seq_len.
|
||||
"""
|
||||
if self.decoder_type == "attention":
|
||||
if self.decoder_type == 'attention':
|
||||
enc_output = self.encoder(x_enc, x_mark_enc)
|
||||
dec_enc = self.decoder(x_dec, x_mark_dec, enc_output)
|
||||
|
||||
if pretrain:
|
||||
dec_enc = torch.cat([enc_output[:, : self.input_size], dec_enc], dim=1)
|
||||
dec_enc = torch.cat([enc_output[:, :self.input_size], dec_enc], dim=1)
|
||||
pred = self.predictor(dec_enc)
|
||||
else:
|
||||
pred = self.predictor(dec_enc)
|
||||
elif self.decoder_type == "FC":
|
||||
elif self.decoder_type == 'FC':
|
||||
enc_output = self.encoder(x_enc, x_mark_enc)[:, -1, :]
|
||||
pred = self.predictor(enc_output).view(
|
||||
enc_output.size(0), self.predict_step, -1
|
||||
)
|
||||
pred = self.predictor(enc_output).view(enc_output.size(0), self.predict_step, -1)
|
||||
|
||||
return pred
|
||||
|
||||
|
||||
class TransformerModel(nn.Module):
|
||||
@validated()
|
||||
def __init__(
|
||||
self,
|
||||
freq: str,
|
||||
context_length: int,
|
||||
prediction_length: int,
|
||||
num_feat_dynamic_real: int,
|
||||
num_feat_static_real: int,
|
||||
num_feat_static_cat: int,
|
||||
cardinality: List[int],
|
||||
# transformer arguments
|
||||
nhead: int,
|
||||
num_encoder_layers: int,
|
||||
num_decoder_layers: int,
|
||||
dim_feedforward: int,
|
||||
activation: str = "gelu",
|
||||
dropout: float = 0.1,
|
||||
# univariate input
|
||||
input_size: int = 1,
|
||||
embedding_dimension: Optional[List[int]] = None,
|
||||
distr_output: DistributionOutput = StudentTOutput(),
|
||||
lags_seq: Optional[List[int]] = None,
|
||||
scaling: bool = True,
|
||||
num_parallel_samples: int = 100,
|
||||
) -> None:
|
||||
super().__init__()
|
||||
|
||||
self.input_size = input_size
|
||||
|
||||
self.target_shape = distr_output.event_shape
|
||||
self.num_feat_dynamic_real = num_feat_dynamic_real
|
||||
self.num_feat_static_cat = num_feat_static_cat
|
||||
self.num_feat_static_real = num_feat_static_real
|
||||
self.embedding_dimension = (
|
||||
embedding_dimension
|
||||
if embedding_dimension is not None or cardinality is None
|
||||
else [min(50, (cat + 1) // 2) for cat in cardinality]
|
||||
)
|
||||
self.lags_seq = lags_seq or get_lags_for_frequency(freq_str=freq)
|
||||
self.num_parallel_samples = num_parallel_samples
|
||||
self.history_length = context_length + max(self.lags_seq)
|
||||
self.embedder = FeatureEmbedder(
|
||||
cardinalities=cardinality,
|
||||
embedding_dims=self.embedding_dimension,
|
||||
)
|
||||
if scaling:
|
||||
self.scaler = MeanScaler(dim=1, keepdim=True)
|
||||
else:
|
||||
self.scaler = NOPScaler(dim=1, keepdim=True)
|
||||
|
||||
# total feature size
|
||||
d_model = self.input_size * len(self.lags_seq) + self._number_of_features
|
||||
|
||||
self.context_length = context_length
|
||||
self.prediction_length = prediction_length
|
||||
self.distr_output = distr_output
|
||||
self.param_proj = distr_output.get_args_proj(d_model)
|
||||
|
||||
# transformer enc-decoder and mask initializer
|
||||
self.transformer = nn.Transformer(
|
||||
d_model=d_model,
|
||||
nhead=nhead,
|
||||
num_encoder_layers=num_encoder_layers,
|
||||
num_decoder_layers=num_decoder_layers,
|
||||
dim_feedforward=dim_feedforward,
|
||||
dropout=dropout,
|
||||
activation=activation,
|
||||
batch_first=True,
|
||||
)
|
||||
|
||||
# causal decoder tgt mask
|
||||
self.register_buffer(
|
||||
"tgt_mask",
|
||||
self.transformer.generate_square_subsequent_mask(prediction_length),
|
||||
)
|
||||
|
||||
@property
|
||||
def _number_of_features(self) -> int:
|
||||
return (
|
||||
sum(self.embedding_dimension)
|
||||
+ self.num_feat_dynamic_real
|
||||
+ self.num_feat_static_real
|
||||
+ 1 # the log(scale)
|
||||
)
|
||||
|
||||
@property
|
||||
def _past_length(self) -> int:
|
||||
return self.context_length + max(self.lags_seq)
|
||||
|
||||
def get_lagged_subsequences(
|
||||
self, sequence: torch.Tensor, subsequences_length: int, shift: int = 0
|
||||
) -> torch.Tensor:
|
||||
"""
|
||||
Returns lagged subsequences of a given sequence.
|
||||
Parameters
|
||||
----------
|
||||
sequence : Tensor
|
||||
the sequence from which lagged subsequences should be extracted.
|
||||
Shape: (N, T, C).
|
||||
subsequences_length : int
|
||||
length of the subsequences to be extracted.
|
||||
shift: int
|
||||
shift the lags by this amount back.
|
||||
Returns
|
||||
--------
|
||||
lagged : Tensor
|
||||
a tensor of shape (N, S, C, I), where S = subsequences_length and
|
||||
I = len(indices), containing lagged subsequences. Specifically,
|
||||
lagged[i, j, :, k] = sequence[i, -indices[k]-S+j, :].
|
||||
"""
|
||||
sequence_length = sequence.shape[1]
|
||||
indices = [lag - shift for lag in self.lags_seq]
|
||||
|
||||
assert max(indices) + subsequences_length <= sequence_length, (
|
||||
f"lags cannot go further than history length, found lag {max(indices)} "
|
||||
f"while history length is only {sequence_length}"
|
||||
)
|
||||
|
||||
lagged_values = []
|
||||
for lag_index in indices:
|
||||
begin_index = -lag_index - subsequences_length
|
||||
end_index = -lag_index if lag_index > 0 else None
|
||||
lagged_values.append(sequence[:, begin_index:end_index, ...])
|
||||
return torch.stack(lagged_values, dim=-1)
|
||||
|
||||
def _check_shapes(
|
||||
self,
|
||||
prior_input: torch.Tensor,
|
||||
inputs: torch.Tensor,
|
||||
features: Optional[torch.Tensor],
|
||||
) -> None:
|
||||
assert len(prior_input.shape) == len(inputs.shape)
|
||||
assert (
|
||||
len(prior_input.shape) == 2 and self.input_size == 1
|
||||
) or prior_input.shape[2] == self.input_size
|
||||
assert (len(inputs.shape) == 2 and self.input_size == 1) or inputs.shape[
|
||||
-1
|
||||
] == self.input_size
|
||||
assert (
|
||||
features is None or features.shape[2] == self._number_of_features
|
||||
), f"{features.shape[2]}, expected {self._number_of_features}"
|
||||
|
||||
def create_network_inputs(
|
||||
self,
|
||||
feat_static_cat: torch.Tensor,
|
||||
feat_static_real: torch.Tensor,
|
||||
past_time_feat: torch.Tensor,
|
||||
past_target: torch.Tensor,
|
||||
past_observed_values: torch.Tensor,
|
||||
future_time_feat: Optional[torch.Tensor] = None,
|
||||
future_target: Optional[torch.Tensor] = None,
|
||||
):
|
||||
# time feature
|
||||
time_feat = (
|
||||
torch.cat(
|
||||
(
|
||||
past_time_feat[:, self._past_length - self.context_length :, ...],
|
||||
future_time_feat,
|
||||
),
|
||||
dim=1,
|
||||
)
|
||||
if future_target is not None
|
||||
else past_time_feat[:, self._past_length - self.context_length :, ...]
|
||||
)
|
||||
|
||||
# target
|
||||
context = past_target[:, -self.context_length :]
|
||||
observed_context = past_observed_values[:, -self.context_length :]
|
||||
_, scale = self.scaler(context, observed_context)
|
||||
|
||||
inputs = (
|
||||
torch.cat((past_target, future_target), dim=1) / scale
|
||||
if future_target is not None
|
||||
else past_target / scale
|
||||
)
|
||||
|
||||
inputs_length = (
|
||||
self._past_length + self.prediction_length
|
||||
if future_target is not None
|
||||
else self._past_length
|
||||
)
|
||||
assert inputs.shape[1] == inputs_length
|
||||
|
||||
subsequences_length = (
|
||||
self.context_length + self.prediction_length
|
||||
if future_target is not None
|
||||
else self.context_length
|
||||
)
|
||||
|
||||
# embeddings
|
||||
embedded_cat = self.embedder(feat_static_cat)
|
||||
static_feat = torch.cat(
|
||||
(embedded_cat, feat_static_real, scale.log()),
|
||||
dim=1,
|
||||
)
|
||||
expanded_static_feat = static_feat.unsqueeze(1).expand(
|
||||
-1, time_feat.shape[1], -1
|
||||
)
|
||||
|
||||
features = torch.cat((expanded_static_feat, time_feat), dim=-1)
|
||||
|
||||
# self._check_shapes(prior_input, inputs, features)
|
||||
|
||||
# sequence = torch.cat((prior_input, inputs), dim=1)
|
||||
lagged_sequence = self.get_lagged_subsequences(
|
||||
sequence=inputs,
|
||||
subsequences_length=subsequences_length,
|
||||
)
|
||||
|
||||
lags_shape = lagged_sequence.shape
|
||||
reshaped_lagged_sequence = lagged_sequence.reshape(
|
||||
lags_shape[0], lags_shape[1], -1
|
||||
)
|
||||
|
||||
transformer_inputs = torch.cat((reshaped_lagged_sequence, features), dim=-1)
|
||||
|
||||
return transformer_inputs, scale, static_feat
|
||||
|
||||
def output_params(self, transformer_inputs):
|
||||
enc_input = transformer_inputs[:, : self.context_length, ...]
|
||||
dec_input = transformer_inputs[:, self.context_length :, ...]
|
||||
|
||||
enc_out = self.transformer.encoder(enc_input)
|
||||
dec_output = self.transformer.decoder(
|
||||
dec_input, enc_out, tgt_mask=self.tgt_mask
|
||||
)
|
||||
|
||||
return self.param_proj(dec_output)
|
||||
|
||||
@torch.jit.ignore
|
||||
def output_distribution(
|
||||
self, params, scale=None, trailing_n=None
|
||||
) -> torch.distributions.Distribution:
|
||||
sliced_params = params
|
||||
if trailing_n is not None:
|
||||
sliced_params = [p[:, -trailing_n:] for p in params]
|
||||
return self.distr_output.distribution(sliced_params, scale=scale)
|
||||
|
||||
# for prediction
|
||||
def forward(
|
||||
self,
|
||||
feat_static_cat: torch.Tensor,
|
||||
feat_static_real: torch.Tensor,
|
||||
past_time_feat: torch.Tensor,
|
||||
past_target: torch.Tensor,
|
||||
past_observed_values: torch.Tensor,
|
||||
future_time_feat: torch.Tensor,
|
||||
num_parallel_samples: Optional[int] = None,
|
||||
) -> torch.Tensor:
|
||||
|
||||
if num_parallel_samples is None:
|
||||
num_parallel_samples = self.num_parallel_samples
|
||||
|
||||
encoder_inputs, scale, static_feat = self.create_network_inputs(
|
||||
feat_static_cat,
|
||||
feat_static_real,
|
||||
past_time_feat,
|
||||
past_target,
|
||||
past_observed_values,
|
||||
)
|
||||
|
||||
enc_out = self.transformer.encoder(encoder_inputs)
|
||||
|
||||
repeated_scale = scale.repeat_interleave(
|
||||
repeats=self.num_parallel_samples, dim=0
|
||||
)
|
||||
|
||||
repeated_past_target = (
|
||||
past_target.repeat_interleave(repeats=self.num_parallel_samples, dim=0)
|
||||
/ repeated_scale
|
||||
)
|
||||
|
||||
expanded_static_feat = static_feat.unsqueeze(1).expand(
|
||||
-1, future_time_feat.shape[1], -1
|
||||
)
|
||||
features = torch.cat((expanded_static_feat, future_time_feat), dim=-1)
|
||||
repeated_features = features.repeat_interleave(
|
||||
repeats=self.num_parallel_samples, dim=0
|
||||
)
|
||||
|
||||
repeated_enc_out = enc_out.repeat_interleave(
|
||||
repeats=self.num_parallel_samples, dim=0
|
||||
)
|
||||
|
||||
future_samples = []
|
||||
|
||||
# greedy decoding
|
||||
for k in range(self.prediction_length):
|
||||
# self._check_shapes(repeated_past_target, next_sample, next_features)
|
||||
# sequence = torch.cat((repeated_past_target, next_sample), dim=1)
|
||||
|
||||
lagged_sequence = self.get_lagged_subsequences(
|
||||
sequence=repeated_past_target,
|
||||
subsequences_length=1 + k,
|
||||
shift=1,
|
||||
)
|
||||
|
||||
lags_shape = lagged_sequence.shape
|
||||
reshaped_lagged_sequence = lagged_sequence.reshape(
|
||||
lags_shape[0], lags_shape[1], -1
|
||||
)
|
||||
|
||||
decoder_input = torch.cat(
|
||||
(reshaped_lagged_sequence, repeated_features[:, : k + 1]), dim=-1
|
||||
)
|
||||
|
||||
output = self.transformer.decoder(decoder_input, repeated_enc_out)
|
||||
|
||||
params = self.param_proj(output[:, -1:])
|
||||
distr = self.output_distribution(params, scale=repeated_scale)
|
||||
next_sample = distr.sample()
|
||||
|
||||
repeated_past_target = torch.cat(
|
||||
(repeated_past_target, next_sample / repeated_scale), dim=1
|
||||
)
|
||||
future_samples.append(next_sample)
|
||||
|
||||
concat_future_samples = torch.cat(future_samples, dim=1)
|
||||
return concat_future_samples.reshape(
|
||||
(-1, self.num_parallel_samples, self.prediction_length) + self.target_shape,
|
||||
)
|
||||
|
||||
+1
-272
File diff suppressed because one or more lines are too long
Vendored
BIN
Binary file not shown.
+96
-209
@@ -1,12 +1,12 @@
|
||||
from torch.functional import align_tensors
|
||||
import torch.nn as nn
|
||||
|
||||
from torch.nn.modules.linear import Linear
|
||||
from .SubLayers import MultiHeadAttention, PositionwiseFeedForward
|
||||
import torch
|
||||
from .embed import DataEmbedding, CustomEmbedding
|
||||
import math
|
||||
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
from torch.functional import align_tensors
|
||||
from torch.nn.modules.linear import Linear
|
||||
|
||||
from .embed import CustomEmbedding, DataEmbedding
|
||||
from .SubLayers import MultiHeadAttention, PositionwiseFeedForward
|
||||
|
||||
|
||||
def get_mask(input_size, window_size, inner_size, device):
|
||||
@@ -34,15 +34,11 @@ def get_mask(input_size, window_size, inner_size, device):
|
||||
for layer_idx in range(1, len(all_size)):
|
||||
start = sum(all_size[:layer_idx])
|
||||
for i in range(start, start + all_size[layer_idx]):
|
||||
left_side = (start - all_size[layer_idx - 1]) + (i - start) * window_size[
|
||||
layer_idx - 1
|
||||
]
|
||||
if i == (start + all_size[layer_idx] - 1):
|
||||
left_side = (start - all_size[layer_idx - 1]) + (i - start) * window_size[layer_idx - 1]
|
||||
if i == ( start + all_size[layer_idx] - 1):
|
||||
right_side = start
|
||||
else:
|
||||
right_side = (start - all_size[layer_idx - 1]) + (
|
||||
i - start + 1
|
||||
) * window_size[layer_idx - 1]
|
||||
right_side = (start - all_size[layer_idx - 1]) + (i - start + 1) * window_size[layer_idx - 1]
|
||||
mask[i, left_side:right_side] = 1
|
||||
mask[left_side:right_side, i] = 1
|
||||
|
||||
@@ -62,9 +58,7 @@ def refer_points(all_sizes, window_size, device):
|
||||
for j in range(1, len(all_sizes)):
|
||||
start = sum(all_sizes[:j])
|
||||
inner_layer_idx = former_index - (start - all_sizes[j - 1])
|
||||
former_index = start + min(
|
||||
inner_layer_idx // window_size[j - 1], all_sizes[j] - 1
|
||||
)
|
||||
former_index = start + min(inner_layer_idx // window_size[j - 1], all_sizes[j] - 1)
|
||||
indexes[i][j] = former_index
|
||||
|
||||
indexes = indexes.unsqueeze(0).unsqueeze(3)
|
||||
@@ -77,7 +71,7 @@ def get_subsequent_mask(input_size, window_size, predict_step, truncate):
|
||||
if truncate:
|
||||
mask = torch.zeros(predict_step, input_size + predict_step)
|
||||
for i in range(predict_step):
|
||||
mask[i][: input_size + i + 1] = 1
|
||||
mask[i][:input_size+i+1] = 1
|
||||
mask = (1 - mask).bool().unsqueeze(0)
|
||||
else:
|
||||
all_size = []
|
||||
@@ -88,7 +82,7 @@ def get_subsequent_mask(input_size, window_size, predict_step, truncate):
|
||||
all_size = sum(all_size)
|
||||
mask = torch.zeros(predict_step, all_size + predict_step)
|
||||
for i in range(predict_step):
|
||||
mask[i][: all_size + i + 1] = 1
|
||||
mask[i][:all_size+i+1] = 1
|
||||
mask = (1 - mask).bool().unsqueeze(0)
|
||||
|
||||
return mask
|
||||
@@ -120,56 +114,38 @@ def get_q_k(input_size, window_size, stride, device):
|
||||
mask[i, -1] = i // stride + input_size
|
||||
mask[i][mask[i] > third_start - 1] = third_start - 1
|
||||
for i in range(second_length):
|
||||
mask[input_size + i, 0:window_size] = (
|
||||
input_size + i + torch.arange(window_size) - window_size // 2
|
||||
)
|
||||
mask[input_size + i, mask[input_size + i] < input_size] = -1
|
||||
mask[input_size + i, mask[input_size + i] > third_start - 1] = -1
|
||||
mask[input_size+i, 0:window_size] = input_size + i + torch.arange(window_size) - window_size // 2
|
||||
mask[input_size+i, mask[input_size+i] < input_size] = -1
|
||||
mask[input_size+i, mask[input_size+i] > third_start - 1] = -1
|
||||
|
||||
if i < second_length - 1:
|
||||
mask[input_size + i, window_size : (window_size + stride)] = (
|
||||
torch.arange(stride) + i * stride
|
||||
)
|
||||
mask[input_size+i, window_size:(window_size+stride)] = torch.arange(stride) + i * stride
|
||||
else:
|
||||
mask[input_size + i, window_size : (window_size + second_last)] = (
|
||||
torch.arange(second_last) + i * stride
|
||||
)
|
||||
mask[input_size+i, window_size:(window_size+second_last)] = torch.arange(second_last) + i * stride
|
||||
|
||||
mask[input_size + i, -1] = i // stride + third_start
|
||||
mask[input_size + i, mask[input_size + i] > fourth_start - 1] = fourth_start - 1
|
||||
mask[input_size+i, -1] = i // stride + third_start
|
||||
mask[input_size+i, mask[input_size+i] > fourth_start - 1] = fourth_start - 1
|
||||
for i in range(third_length):
|
||||
mask[third_start + i, 0:window_size] = (
|
||||
third_start + i + torch.arange(window_size) - window_size // 2
|
||||
)
|
||||
mask[third_start + i, mask[third_start + i] < third_start] = -1
|
||||
mask[third_start + i, mask[third_start + i] > fourth_start - 1] = -1
|
||||
mask[third_start+i, 0:window_size] = third_start + i + torch.arange(window_size) - window_size // 2
|
||||
mask[third_start+i, mask[third_start+i] < third_start] = -1
|
||||
mask[third_start+i, mask[third_start+i] > fourth_start - 1] = -1
|
||||
|
||||
if i < third_length - 1:
|
||||
mask[third_start + i, window_size : (window_size + stride)] = (
|
||||
input_size + torch.arange(stride) + i * stride
|
||||
)
|
||||
mask[third_start+i, window_size:(window_size+stride)] = input_size + torch.arange(stride) + i * stride
|
||||
else:
|
||||
mask[third_start + i, window_size : (window_size + third_last)] = (
|
||||
input_size + torch.arange(third_last) + i * stride
|
||||
)
|
||||
mask[third_start+i, window_size:(window_size+third_last)] = input_size + torch.arange(third_last) + i * stride
|
||||
|
||||
mask[third_start + i, -1] = i // stride + fourth_start
|
||||
mask[third_start + i, mask[third_start + i] > full_length - 1] = full_length - 1
|
||||
mask[third_start+i, -1] = i // stride + fourth_start
|
||||
mask[third_start+i, mask[third_start+i] > full_length - 1] = full_length - 1
|
||||
for i in range(fourth_length):
|
||||
mask[fourth_start + i, 0:window_size] = (
|
||||
fourth_start + i + torch.arange(window_size) - window_size // 2
|
||||
)
|
||||
mask[fourth_start + i, mask[fourth_start + i] < fourth_start] = -1
|
||||
mask[fourth_start + i, mask[fourth_start + i] > full_length - 1] = -1
|
||||
mask[fourth_start+i, 0:window_size] = fourth_start + i + torch.arange(window_size) - window_size // 2
|
||||
mask[fourth_start+i, mask[fourth_start+i] < fourth_start] = -1
|
||||
mask[fourth_start+i, mask[fourth_start+i] > full_length - 1] = -1
|
||||
|
||||
if i < fourth_length - 1:
|
||||
mask[fourth_start + i, window_size : (window_size + stride)] = (
|
||||
third_start + torch.arange(stride) + i * stride
|
||||
)
|
||||
mask[fourth_start+i, window_size:(window_size+stride)] = third_start + torch.arange(stride) + i * stride
|
||||
else:
|
||||
mask[fourth_start + i, window_size : (window_size + fourth_last)] = (
|
||||
third_start + torch.arange(fourth_last) + i * stride
|
||||
)
|
||||
mask[fourth_start+i, window_size:(window_size+fourth_last)] = third_start + torch.arange(fourth_last) + i * stride
|
||||
|
||||
return mask
|
||||
|
||||
@@ -182,64 +158,32 @@ def get_k_q(q_k_mask):
|
||||
for i in range(len(q_k_mask)):
|
||||
for j in range(len(q_k_mask[0])):
|
||||
if q_k_mask[i, j] >= 0:
|
||||
k_q_mask[i, j] = torch.where(q_k_mask[q_k_mask[i, j]] == i)[0]
|
||||
k_q_mask[i, j] = torch.where(q_k_mask[q_k_mask[i, j]] ==i )[0]
|
||||
|
||||
return k_q_mask
|
||||
|
||||
|
||||
class EncoderLayer(nn.Module):
|
||||
"""Compose with two layers"""
|
||||
""" Compose with two layers """
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
d_model,
|
||||
d_inner,
|
||||
n_head,
|
||||
d_k,
|
||||
d_v,
|
||||
dropout=0.1,
|
||||
normalize_before=True,
|
||||
use_tvm=False,
|
||||
q_k_mask=None,
|
||||
k_q_mask=None,
|
||||
):
|
||||
def __init__(self, d_model, d_inner, n_head, d_k, d_v, dropout=0.1, normalize_before=True, use_tvm=False, q_k_mask=None, k_q_mask=None):
|
||||
super(EncoderLayer, self).__init__()
|
||||
self.use_tvm = use_tvm
|
||||
if use_tvm:
|
||||
from .PAM_TVM import PyramidalAttention
|
||||
|
||||
self.slf_attn = PyramidalAttention(
|
||||
n_head,
|
||||
d_model,
|
||||
d_k,
|
||||
d_v,
|
||||
dropout=dropout,
|
||||
normalize_before=normalize_before,
|
||||
q_k_mask=q_k_mask,
|
||||
k_q_mask=k_q_mask,
|
||||
)
|
||||
self.slf_attn = PyramidalAttention(n_head, d_model, d_k, d_v, dropout=dropout, normalize_before=normalize_before, q_k_mask=q_k_mask, k_q_mask=k_q_mask)
|
||||
else:
|
||||
self.slf_attn = MultiHeadAttention(
|
||||
n_head,
|
||||
d_model,
|
||||
d_k,
|
||||
d_v,
|
||||
dropout=dropout,
|
||||
normalize_before=normalize_before,
|
||||
)
|
||||
self.slf_attn = MultiHeadAttention(n_head, d_model, d_k, d_v, dropout=dropout, normalize_before=normalize_before)
|
||||
|
||||
self.pos_ffn = PositionwiseFeedForward(
|
||||
d_model, d_inner, dropout=dropout, normalize_before=normalize_before
|
||||
)
|
||||
d_model, d_inner, dropout=dropout, normalize_before=normalize_before)
|
||||
|
||||
def forward(self, enc_input, slf_attn_mask=None):
|
||||
if self.use_tvm:
|
||||
enc_output = self.slf_attn(enc_input)
|
||||
enc_slf_attn = None
|
||||
else:
|
||||
enc_output, enc_slf_attn = self.slf_attn(
|
||||
enc_input, enc_input, enc_input, mask=slf_attn_mask
|
||||
)
|
||||
enc_output, enc_slf_attn = self.slf_attn(enc_input, enc_input, enc_input, mask=slf_attn_mask)
|
||||
|
||||
enc_output = self.pos_ffn(enc_output)
|
||||
|
||||
@@ -247,26 +191,18 @@ class EncoderLayer(nn.Module):
|
||||
|
||||
|
||||
class DecoderLayer(nn.Module):
|
||||
"""Compose with two layers"""
|
||||
""" Compose with two layers """
|
||||
|
||||
def __init__(
|
||||
self, d_model, d_inner, n_head, d_k, d_v, dropout=0.1, normalize_before=True
|
||||
):
|
||||
def __init__(self, d_model, d_inner, n_head, d_k, d_v, dropout=0.1, normalize_before=True):
|
||||
super(DecoderLayer, self).__init__()
|
||||
self.slf_attn = MultiHeadAttention(
|
||||
n_head,
|
||||
d_model,
|
||||
d_k,
|
||||
d_v,
|
||||
dropout=dropout,
|
||||
normalize_before=normalize_before,
|
||||
)
|
||||
n_head, d_model, d_k, d_v, dropout=dropout, normalize_before=normalize_before)
|
||||
self.pos_ffn = PositionwiseFeedForward(
|
||||
d_model, d_inner, dropout=dropout, normalize_before=normalize_before
|
||||
)
|
||||
d_model, d_inner, dropout=dropout, normalize_before=normalize_before)
|
||||
|
||||
def forward(self, Q, K, V, slf_attn_mask=None):
|
||||
enc_output, enc_slf_attn = self.slf_attn(Q, K, V, mask=slf_attn_mask)
|
||||
enc_output, enc_slf_attn = self.slf_attn(
|
||||
Q, K, V, mask=slf_attn_mask)
|
||||
|
||||
enc_output = self.pos_ffn(enc_output)
|
||||
|
||||
@@ -276,12 +212,10 @@ class DecoderLayer(nn.Module):
|
||||
class ConvLayer(nn.Module):
|
||||
def __init__(self, c_in, window_size):
|
||||
super(ConvLayer, self).__init__()
|
||||
self.downConv = nn.Conv1d(
|
||||
in_channels=c_in,
|
||||
out_channels=c_in,
|
||||
kernel_size=window_size,
|
||||
stride=window_size,
|
||||
)
|
||||
self.downConv = nn.Conv1d(in_channels=c_in,
|
||||
out_channels=c_in,
|
||||
kernel_size=window_size,
|
||||
stride=window_size)
|
||||
self.norm = nn.BatchNorm1d(c_in)
|
||||
self.activation = nn.ELU()
|
||||
|
||||
@@ -294,25 +228,20 @@ class ConvLayer(nn.Module):
|
||||
|
||||
class Conv_Construct(nn.Module):
|
||||
"""Convolution CSCM"""
|
||||
|
||||
def __init__(self, d_model, window_size, d_inner):
|
||||
super(Conv_Construct, self).__init__()
|
||||
if not isinstance(window_size, list):
|
||||
self.conv_layers = nn.ModuleList(
|
||||
[
|
||||
ConvLayer(d_model, window_size),
|
||||
ConvLayer(d_model, window_size),
|
||||
ConvLayer(d_model, window_size),
|
||||
]
|
||||
)
|
||||
self.conv_layers = nn.ModuleList([
|
||||
ConvLayer(d_model, window_size),
|
||||
ConvLayer(d_model, window_size),
|
||||
ConvLayer(d_model, window_size)
|
||||
])
|
||||
else:
|
||||
self.conv_layers = nn.ModuleList(
|
||||
[
|
||||
ConvLayer(d_model, window_size[0]),
|
||||
ConvLayer(d_model, window_size[1]),
|
||||
ConvLayer(d_model, window_size[2]),
|
||||
]
|
||||
)
|
||||
self.conv_layers = nn.ModuleList([
|
||||
ConvLayer(d_model, window_size[0]),
|
||||
ConvLayer(d_model, window_size[1]),
|
||||
ConvLayer(d_model, window_size[2])
|
||||
])
|
||||
self.norm = nn.LayerNorm(d_model)
|
||||
|
||||
def forward(self, enc_input):
|
||||
@@ -332,17 +261,14 @@ class Conv_Construct(nn.Module):
|
||||
|
||||
class Bottleneck_Construct(nn.Module):
|
||||
"""Bottleneck convolution CSCM"""
|
||||
|
||||
def __init__(self, d_model, window_size, d_inner):
|
||||
super(Bottleneck_Construct, self).__init__()
|
||||
if not isinstance(window_size, list):
|
||||
self.conv_layers = nn.ModuleList(
|
||||
[
|
||||
ConvLayer(d_inner, window_size),
|
||||
ConvLayer(d_inner, window_size),
|
||||
ConvLayer(d_inner, window_size),
|
||||
]
|
||||
)
|
||||
self.conv_layers = nn.ModuleList([
|
||||
ConvLayer(d_inner, window_size),
|
||||
ConvLayer(d_inner, window_size),
|
||||
ConvLayer(d_inner, window_size)
|
||||
])
|
||||
else:
|
||||
self.conv_layers = []
|
||||
for i in range(len(window_size)):
|
||||
@@ -371,25 +297,20 @@ class Bottleneck_Construct(nn.Module):
|
||||
|
||||
class MaxPooling_Construct(nn.Module):
|
||||
"""Max pooling CSCM"""
|
||||
|
||||
def __init__(self, d_model, window_size, d_inner):
|
||||
super(MaxPooling_Construct, self).__init__()
|
||||
if not isinstance(window_size, list):
|
||||
self.pooling_layers = nn.ModuleList(
|
||||
[
|
||||
nn.MaxPool1d(kernel_size=window_size),
|
||||
nn.MaxPool1d(kernel_size=window_size),
|
||||
nn.MaxPool1d(kernel_size=window_size),
|
||||
]
|
||||
)
|
||||
self.pooling_layers = nn.ModuleList([
|
||||
nn.MaxPool1d(kernel_size=window_size),
|
||||
nn.MaxPool1d(kernel_size=window_size),
|
||||
nn.MaxPool1d(kernel_size=window_size)
|
||||
])
|
||||
else:
|
||||
self.pooling_layers = nn.ModuleList(
|
||||
[
|
||||
nn.MaxPool1d(kernel_size=window_size[0]),
|
||||
nn.MaxPool1d(kernel_size=window_size[1]),
|
||||
nn.MaxPool1d(kernel_size=window_size[2]),
|
||||
]
|
||||
)
|
||||
self.pooling_layers = nn.ModuleList([
|
||||
nn.MaxPool1d(kernel_size=window_size[0]),
|
||||
nn.MaxPool1d(kernel_size=window_size[1]),
|
||||
nn.MaxPool1d(kernel_size=window_size[2])
|
||||
])
|
||||
self.norm = nn.LayerNorm(d_model)
|
||||
|
||||
def forward(self, enc_input):
|
||||
@@ -409,25 +330,20 @@ class MaxPooling_Construct(nn.Module):
|
||||
|
||||
class AvgPooling_Construct(nn.Module):
|
||||
"""Average pooling CSCM"""
|
||||
|
||||
def __init__(self, d_model, window_size, d_inner):
|
||||
super(AvgPooling_Construct, self).__init__()
|
||||
if not isinstance(window_size, list):
|
||||
self.pooling_layers = nn.ModuleList(
|
||||
[
|
||||
nn.AvgPool1d(kernel_size=window_size),
|
||||
nn.AvgPool1d(kernel_size=window_size),
|
||||
nn.AvgPool1d(kernel_size=window_size),
|
||||
]
|
||||
)
|
||||
self.pooling_layers = nn.ModuleList([
|
||||
nn.AvgPool1d(kernel_size=window_size),
|
||||
nn.AvgPool1d(kernel_size=window_size),
|
||||
nn.AvgPool1d(kernel_size=window_size)
|
||||
])
|
||||
else:
|
||||
self.pooling_layers = nn.ModuleList(
|
||||
[
|
||||
nn.AvgPool1d(kernel_size=window_size[0]),
|
||||
nn.AvgPool1d(kernel_size=window_size[1]),
|
||||
nn.AvgPool1d(kernel_size=window_size[2]),
|
||||
]
|
||||
)
|
||||
self.pooling_layers = nn.ModuleList([
|
||||
nn.AvgPool1d(kernel_size=window_size[0]),
|
||||
nn.AvgPool1d(kernel_size=window_size[1]),
|
||||
nn.AvgPool1d(kernel_size=window_size[2])
|
||||
])
|
||||
self.norm = nn.LayerNorm(d_model)
|
||||
|
||||
def forward(self, enc_input):
|
||||
@@ -446,6 +362,7 @@ class AvgPooling_Construct(nn.Module):
|
||||
|
||||
|
||||
class Predictor(nn.Module):
|
||||
|
||||
def __init__(self, dim, num_types):
|
||||
super().__init__()
|
||||
|
||||
@@ -459,54 +376,23 @@ class Predictor(nn.Module):
|
||||
|
||||
|
||||
class Decoder(nn.Module):
|
||||
"""A encoder model with self attention mechanism."""
|
||||
""" A encoder model with self attention mechanism. """
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
model,
|
||||
d_model,
|
||||
d_inner_hid,
|
||||
num_head,
|
||||
d_k,
|
||||
d_v,
|
||||
dropout,
|
||||
enc_in,
|
||||
covariate_size,
|
||||
seq_num,
|
||||
mask,
|
||||
):
|
||||
def __init__(self, model,d_model,d_inner_hid,num_head,d_k,d_v,dropout,enc_in,covariate_size,seq_num, mask):
|
||||
super().__init__()
|
||||
|
||||
self.model_type = model
|
||||
self.mask = mask
|
||||
|
||||
self.layers = nn.ModuleList(
|
||||
[
|
||||
DecoderLayer(
|
||||
d_model,
|
||||
d_inner_hid,
|
||||
num_head,
|
||||
d_k,
|
||||
d_v,
|
||||
dropout=dropout,
|
||||
normalize_before=False,
|
||||
),
|
||||
DecoderLayer(
|
||||
d_model,
|
||||
d_inner_hid,
|
||||
num_head,
|
||||
d_k,
|
||||
d_v,
|
||||
dropout=dropout,
|
||||
normalize_before=False,
|
||||
),
|
||||
]
|
||||
)
|
||||
self.layers = nn.ModuleList([
|
||||
DecoderLayer(d_model, d_inner_hid, num_head, d_k, d_v, dropout=dropout, \
|
||||
normalize_before=False),
|
||||
DecoderLayer(d_model, d_inner_hid, num_head, d_k, d_v, dropout=dropout, \
|
||||
normalize_before=False)
|
||||
])
|
||||
|
||||
if opt.embed_type == "CustomEmbedding":
|
||||
self.dec_embedding = CustomEmbedding(
|
||||
enc_in, d_model, covariate_size, seq_num, dropout
|
||||
)
|
||||
if opt.embed_type == 'CustomEmbedding':
|
||||
self.dec_embedding = CustomEmbedding(enc_in, d_model, covariate_size, seq_num, dropout)
|
||||
else:
|
||||
self.dec_embedding = DataEmbedding(enc_in, d_model, dropout)
|
||||
|
||||
@@ -519,3 +405,4 @@ class Decoder(nn.Module):
|
||||
dec_enc, _ = self.layers[1](dec_enc, refer_enc, refer_enc, slf_attn_mask=mask)
|
||||
|
||||
return dec_enc
|
||||
|
||||
|
||||
@@ -4,7 +4,7 @@ import torch.nn.functional as F
|
||||
|
||||
|
||||
class ScaledDotProductAttention(nn.Module):
|
||||
"""Scaled Dot-Product Attention"""
|
||||
""" Scaled Dot-Product Attention """
|
||||
|
||||
def __init__(self, temperature, attn_dropout=0.2):
|
||||
super().__init__()
|
||||
@@ -22,3 +22,4 @@ class ScaledDotProductAttention(nn.Module):
|
||||
output = torch.matmul(attn, v)
|
||||
|
||||
return output, attn
|
||||
|
||||
|
||||
@@ -1,15 +1,11 @@
|
||||
import math
|
||||
|
||||
import torch.nn as nn
|
||||
import torch.nn.functional as F
|
||||
|
||||
import math
|
||||
from .hierarchical_mm_tvm import graph_mm as graph_mm_tvm
|
||||
|
||||
|
||||
class PyramidalAttention(nn.Module):
|
||||
def __init__(
|
||||
self, n_head, d_model, d_k, d_v, dropout, normalize_before, q_k_mask, k_q_mask
|
||||
):
|
||||
def __init__(self, n_head, d_model, d_k, d_v, dropout, normalize_before, q_k_mask, k_q_mask):
|
||||
super(PyramidalAttention, self).__init__()
|
||||
self.normalize_before = normalize_before
|
||||
self.n_head = n_head
|
||||
@@ -35,7 +31,7 @@ class PyramidalAttention(nn.Module):
|
||||
residual = hidden_states
|
||||
|
||||
hidden_states = hidden_states
|
||||
bsz, seq_len, _ = hidden_states.size()
|
||||
bsz, seq_len, _ = hidden_states.size()
|
||||
|
||||
q = hidden_states
|
||||
if self.normalize_before:
|
||||
@@ -66,3 +62,4 @@ class PyramidalAttention(nn.Module):
|
||||
context = self.layer_norm(context)
|
||||
|
||||
return context
|
||||
|
||||
|
||||
@@ -5,7 +5,7 @@ from .Modules import ScaledDotProductAttention
|
||||
|
||||
|
||||
class MultiHeadAttention(nn.Module):
|
||||
"""Multi-Head Attention module"""
|
||||
""" Multi-Head Attention module """
|
||||
|
||||
def __init__(self, n_head, d_model, d_k, d_v, dropout=0.1, normalize_before=True):
|
||||
super().__init__()
|
||||
@@ -25,9 +25,7 @@ class MultiHeadAttention(nn.Module):
|
||||
self.fc = nn.Linear(d_v * n_head, d_model)
|
||||
nn.init.xavier_uniform_(self.fc.weight)
|
||||
|
||||
self.attention = ScaledDotProductAttention(
|
||||
temperature=d_k**0.5, attn_dropout=dropout
|
||||
)
|
||||
self.attention = ScaledDotProductAttention(temperature=d_k ** 0.5, attn_dropout=dropout)
|
||||
|
||||
self.layer_norm = nn.LayerNorm(d_model, eps=1e-6)
|
||||
self.dropout = nn.Dropout(dropout)
|
||||
@@ -67,7 +65,7 @@ class MultiHeadAttention(nn.Module):
|
||||
|
||||
|
||||
class PositionwiseFeedForward(nn.Module):
|
||||
"""Two-layer position-wise feed-forward neural network."""
|
||||
""" Two-layer position-wise feed-forward neural network. """
|
||||
|
||||
def __init__(self, d_in, d_hid, dropout=0.1, normalize_before=True):
|
||||
super().__init__()
|
||||
@@ -78,7 +76,7 @@ class PositionwiseFeedForward(nn.Module):
|
||||
self.w_2 = nn.Linear(d_hid, d_in)
|
||||
|
||||
self.layer_norm = nn.LayerNorm(d_in, eps=1e-6)
|
||||
# self.layer_norm = GraphNorm(d_in)
|
||||
#self.layer_norm = GraphNorm(d_in)
|
||||
self.dropout = nn.Dropout(dropout)
|
||||
|
||||
def forward(self, x):
|
||||
@@ -95,3 +93,4 @@ class PositionwiseFeedForward(nn.Module):
|
||||
if not self.normalize_before:
|
||||
x = self.layer_norm(x)
|
||||
return x
|
||||
|
||||
|
||||
@@ -9,11 +9,11 @@ Modified based on Informer.
|
||||
}
|
||||
"""
|
||||
|
||||
import math
|
||||
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
|
||||
import math
|
||||
|
||||
|
||||
class PositionalEmbedding(nn.Module):
|
||||
def __init__(self, d_model, max_len=5000):
|
||||
@@ -23,42 +23,31 @@ class PositionalEmbedding(nn.Module):
|
||||
pe.require_grad = False
|
||||
|
||||
position = torch.arange(0, max_len).float().unsqueeze(1)
|
||||
div_term = (
|
||||
torch.arange(0, d_model, 2).float() * -(math.log(10000.0) / d_model)
|
||||
).exp()
|
||||
div_term = (torch.arange(0, d_model, 2).float() * -(math.log(10000.0) / d_model)).exp()
|
||||
|
||||
pe[:, 0::2] = torch.sin(position * div_term)
|
||||
pe[:, 1::2] = torch.cos(position * div_term)
|
||||
|
||||
pe = pe.unsqueeze(0)
|
||||
self.register_buffer("pe", pe)
|
||||
self.register_buffer('pe', pe)
|
||||
|
||||
def forward(self, x):
|
||||
return self.pe[:, : x.size(1)]
|
||||
|
||||
return self.pe[:, :x.size(1)]
|
||||
|
||||
class TokenEmbedding(nn.Module):
|
||||
def __init__(self, c_in, d_model):
|
||||
super(TokenEmbedding, self).__init__()
|
||||
padding = 1 if torch.__version__ >= "1.5.0" else 2
|
||||
self.tokenConv = nn.Conv1d(
|
||||
in_channels=c_in,
|
||||
out_channels=d_model,
|
||||
kernel_size=3,
|
||||
padding=padding,
|
||||
padding_mode="circular",
|
||||
)
|
||||
padding = 1 if torch.__version__>='1.5.0' else 2
|
||||
self.tokenConv = nn.Conv1d(in_channels=c_in, out_channels=d_model,
|
||||
kernel_size=3, padding=padding, padding_mode='circular')
|
||||
for m in self.modules():
|
||||
if isinstance(m, nn.Conv1d):
|
||||
nn.init.kaiming_normal_(
|
||||
m.weight, mode="fan_in", nonlinearity="leaky_relu"
|
||||
)
|
||||
nn.init.kaiming_normal_(m.weight,mode='fan_in',nonlinearity='leaky_relu')
|
||||
|
||||
def forward(self, x):
|
||||
x = self.tokenConv(x.permute(0, 2, 1)).transpose(1, 2)
|
||||
x = self.tokenConv(x.permute(0, 2, 1)).transpose(1,2)
|
||||
return x
|
||||
|
||||
|
||||
class FixedEmbedding(nn.Module):
|
||||
def __init__(self, c_in, d_model):
|
||||
super(FixedEmbedding, self).__init__()
|
||||
@@ -67,9 +56,7 @@ class FixedEmbedding(nn.Module):
|
||||
w.require_grad = False
|
||||
|
||||
position = torch.arange(0, c_in).float().unsqueeze(1)
|
||||
div_term = (
|
||||
torch.arange(0, d_model, 2).float() * -(math.log(10000.0) / d_model)
|
||||
).exp()
|
||||
div_term = (torch.arange(0, d_model, 2).float() * -(math.log(10000.0) / d_model)).exp()
|
||||
|
||||
w[:, 0::2] = torch.sin(position * div_term)
|
||||
w[:, 1::2] = torch.cos(position * div_term)
|
||||
@@ -80,21 +67,17 @@ class FixedEmbedding(nn.Module):
|
||||
def forward(self, x):
|
||||
return self.emb(x).detach()
|
||||
|
||||
|
||||
class TimeFeatureEmbedding(nn.Module):
|
||||
def __init__(self, d_model):
|
||||
super(TimeFeatureEmbedding, self).__init__()
|
||||
|
||||
d_inp = 4
|
||||
self.embed = nn.Linear(d_inp, d_model)
|
||||
|
||||
|
||||
def forward(self, x):
|
||||
return self.embed(x)
|
||||
|
||||
|
||||
"""Embedding modules. The DataEmbedding is used by the ETT dataset for long range forecasting."""
|
||||
|
||||
|
||||
class DataEmbedding(nn.Module):
|
||||
def __init__(self, c_in, d_model, dropout=0.1):
|
||||
super(DataEmbedding, self).__init__()
|
||||
@@ -106,18 +89,11 @@ class DataEmbedding(nn.Module):
|
||||
self.dropout = nn.Dropout(p=dropout)
|
||||
|
||||
def forward(self, x, x_mark):
|
||||
x = (
|
||||
self.value_embedding(x)
|
||||
+ self.position_embedding(x)
|
||||
+ self.temporal_embedding(x_mark)
|
||||
)
|
||||
x = self.value_embedding(x) + self.position_embedding(x) + self.temporal_embedding(x_mark)
|
||||
|
||||
return self.dropout(x)
|
||||
|
||||
|
||||
"""The CustomEmbedding is used by the electricity dataset and app flow dataset for long range forecasting."""
|
||||
|
||||
|
||||
class CustomEmbedding(nn.Module):
|
||||
def __init__(self, c_in, d_model, temporal_size, seq_num, dropout=0.1):
|
||||
super(CustomEmbedding, self).__init__()
|
||||
@@ -130,46 +106,28 @@ class CustomEmbedding(nn.Module):
|
||||
self.dropout = nn.Dropout(p=dropout)
|
||||
|
||||
def forward(self, x, x_mark):
|
||||
x = (
|
||||
self.value_embedding(x)
|
||||
+ self.position_embedding(x)
|
||||
+ self.temporal_embedding(x_mark[:, :, :-1])
|
||||
x = self.value_embedding(x) + self.position_embedding(x) + self.temporal_embedding(x_mark[:, :, :-1])\
|
||||
+ self.seqid_embedding(x_mark[:, :, -1].long())
|
||||
)
|
||||
|
||||
return self.dropout(x)
|
||||
|
||||
|
||||
"""The SingleStepEmbedding is used by all datasets for single step forecasting."""
|
||||
|
||||
|
||||
class SingleStepEmbedding(nn.Module):
|
||||
def __init__(self, cov_size, num_seq, d_model, input_size, device):
|
||||
super().__init__()
|
||||
|
||||
self.cov_size = cov_size
|
||||
self.num_class = num_seq
|
||||
self.cov_emb = nn.Linear(cov_size + 1, d_model)
|
||||
padding = 1 if torch.__version__ >= "1.5.0" else 2
|
||||
self.data_emb = nn.Conv1d(
|
||||
in_channels=1,
|
||||
out_channels=d_model,
|
||||
kernel_size=3,
|
||||
padding=padding,
|
||||
padding_mode="circular",
|
||||
)
|
||||
self.cov_emb = nn.Linear(cov_size+1, d_model)
|
||||
padding = 1 if torch.__version__>='1.5.0' else 2
|
||||
self.data_emb = nn.Conv1d(in_channels=1, out_channels=d_model, kernel_size=3, padding=padding, padding_mode='circular')
|
||||
|
||||
self.position = torch.arange(input_size, device=device).unsqueeze(0)
|
||||
self.position_vec = torch.tensor(
|
||||
[math.pow(10000.0, 2.0 * (i // 2) / d_model) for i in range(d_model)],
|
||||
device=device,
|
||||
)
|
||||
self.position_vec = torch.tensor([math.pow(10000.0, 2.0 * (i // 2) / d_model) for i in range(d_model)], device=device)
|
||||
|
||||
for m in self.modules():
|
||||
if isinstance(m, nn.Conv1d):
|
||||
nn.init.kaiming_normal_(
|
||||
m.weight, mode="fan_in", nonlinearity="leaky_relu"
|
||||
)
|
||||
nn.init.kaiming_normal_(m.weight,mode='fan_in',nonlinearity='leaky_relu')
|
||||
elif isinstance(m, nn.Linear):
|
||||
nn.init.xavier_normal_(m.weight)
|
||||
nn.init.constant_(m.bias, 0)
|
||||
@@ -185,19 +143,15 @@ class SingleStepEmbedding(nn.Module):
|
||||
return result
|
||||
|
||||
def forward(self, x):
|
||||
covs = x[:, :, 1 : (1 + self.cov_size)]
|
||||
covs = x[:, :, 1:(1+self.cov_size)]
|
||||
seq_ids = ((x[:, :, -1] / self.num_class) - 0.5).unsqueeze(2)
|
||||
covs = torch.cat([covs, seq_ids], dim=-1)
|
||||
cov_embedding = self.cov_emb(covs)
|
||||
data_embedding = self.data_emb(
|
||||
x[:, :, 0].unsqueeze(2).permute(0, 2, 1)
|
||||
).transpose(1, 2)
|
||||
data_embedding = self.data_emb(x[:, :, 0].unsqueeze(2).permute(0, 2, 1)).transpose(1,2)
|
||||
embedding = cov_embedding + data_embedding
|
||||
|
||||
position = self.position.repeat(len(x), 1).to(x.device)
|
||||
position_emb = self.transformer_embedding(
|
||||
position, self.position_vec.to(x.device)
|
||||
)
|
||||
position_emb = self.transformer_embedding(position, self.position_vec.to(x.device))
|
||||
|
||||
embedding += position_emb
|
||||
|
||||
|
||||
@@ -2,23 +2,21 @@
|
||||
Test the time and CUDA memory consumption of different attention mechanisms.
|
||||
"""
|
||||
|
||||
import argparse
|
||||
import math
|
||||
import time
|
||||
from math import sqrt
|
||||
from typing import List
|
||||
|
||||
import numpy as np
|
||||
import math
|
||||
import torch
|
||||
from torch import nn
|
||||
import torch.nn.functional as F
|
||||
import torch.optim as optim
|
||||
from hierarchical_mm_tvm import graph_mm as graph_mm_tvm
|
||||
from torch import nn
|
||||
import argparse
|
||||
import time
|
||||
import numpy as np
|
||||
from math import sqrt
|
||||
|
||||
torch.cuda.set_device(0)
|
||||
print("Using device: {}".format(torch.cuda.get_device_name()))
|
||||
print('Using device: {}'.format(torch.cuda.get_device_name()))
|
||||
import pynvml
|
||||
|
||||
pynvml.nvmlInit()
|
||||
|
||||
|
||||
@@ -50,64 +48,46 @@ def get_q_k(input_size, window_size, stride, device):
|
||||
mask[i][mask[i] > third_start - 1] = third_start - 1
|
||||
# 第二层
|
||||
for i in range(second_length):
|
||||
mask[input_size + i, 0:window_size] = (
|
||||
input_size + i + torch.arange(window_size) - window_size // 2
|
||||
)
|
||||
mask[input_size+i, 0:window_size] = input_size + i + torch.arange(window_size) - window_size // 2
|
||||
# 当window在序列左端时,置为-1
|
||||
mask[input_size + i, mask[input_size + i] < input_size] = -1
|
||||
mask[input_size+i, mask[input_size+i] < input_size] = -1
|
||||
# 当window在序列右端时,置为-1
|
||||
mask[input_size + i, mask[input_size + i] > third_start - 1] = -1
|
||||
mask[input_size+i, mask[input_size+i] > third_start - 1] = -1
|
||||
|
||||
if i < second_length - 1:
|
||||
mask[input_size + i, window_size : (window_size + stride)] = (
|
||||
torch.arange(stride) + i * stride
|
||||
)
|
||||
mask[input_size+i, window_size:(window_size+stride)] = torch.arange(stride) + i * stride
|
||||
else:
|
||||
mask[input_size + i, window_size : (window_size + second_last)] = (
|
||||
torch.arange(second_last) + i * stride
|
||||
)
|
||||
mask[input_size+i, window_size:(window_size+second_last)] = torch.arange(second_last) + i * stride
|
||||
|
||||
mask[input_size + i, -1] = i // stride + third_start
|
||||
mask[input_size + i, mask[input_size + i] > fourth_start - 1] = fourth_start - 1
|
||||
mask[input_size+i, -1] = i // stride + third_start
|
||||
mask[input_size+i, mask[input_size+i] > fourth_start - 1] = fourth_start - 1
|
||||
# 第三层
|
||||
for i in range(third_length):
|
||||
mask[third_start + i, 0:window_size] = (
|
||||
third_start + i + torch.arange(window_size) - window_size // 2
|
||||
)
|
||||
mask[third_start+i, 0:window_size] = third_start + i + torch.arange(window_size) - window_size // 2
|
||||
# 当window在序列左端时,置为-1
|
||||
mask[third_start + i, mask[third_start + i] < third_start] = -1
|
||||
mask[third_start+i, mask[third_start+i] < third_start] = -1
|
||||
# 当window在序列右端时,置为-1
|
||||
mask[third_start + i, mask[third_start + i] > fourth_start - 1] = -1
|
||||
mask[third_start+i, mask[third_start+i] > fourth_start - 1] = -1
|
||||
|
||||
if i < third_length - 1:
|
||||
mask[third_start + i, window_size : (window_size + stride)] = (
|
||||
input_size + torch.arange(stride) + i * stride
|
||||
)
|
||||
mask[third_start+i, window_size:(window_size+stride)] = input_size + torch.arange(stride) + i * stride
|
||||
else:
|
||||
mask[third_start + i, window_size : (window_size + third_last)] = (
|
||||
input_size + torch.arange(third_last) + i * stride
|
||||
)
|
||||
mask[third_start+i, window_size:(window_size+third_last)] = input_size + torch.arange(third_last) + i * stride
|
||||
|
||||
mask[third_start + i, -1] = i // stride + fourth_start
|
||||
mask[third_start + i, mask[third_start + i] > full_length - 1] = full_length - 1
|
||||
mask[third_start+i, -1] = i // stride + fourth_start
|
||||
mask[third_start+i, mask[third_start+i] > full_length - 1] = full_length - 1
|
||||
# 第四层
|
||||
for i in range(fourth_length):
|
||||
mask[fourth_start + i, 0:window_size] = (
|
||||
fourth_start + i + torch.arange(window_size) - window_size // 2
|
||||
)
|
||||
mask[fourth_start+i, 0:window_size] = fourth_start + i + torch.arange(window_size) - window_size // 2
|
||||
# 当window在序列左端时,置为-1
|
||||
mask[fourth_start + i, mask[fourth_start + i] < fourth_start] = -1
|
||||
mask[fourth_start+i, mask[fourth_start+i] < fourth_start] = -1
|
||||
# 当window在序列右端时,置为-1
|
||||
mask[fourth_start + i, mask[fourth_start + i] > full_length - 1] = -1
|
||||
mask[fourth_start+i, mask[fourth_start+i] > full_length - 1] = -1
|
||||
|
||||
if i < fourth_length - 1:
|
||||
mask[fourth_start + i, window_size : (window_size + stride)] = (
|
||||
third_start + torch.arange(stride) + i * stride
|
||||
)
|
||||
mask[fourth_start+i, window_size:(window_size+stride)] = third_start + torch.arange(stride) + i * stride
|
||||
else:
|
||||
mask[fourth_start + i, window_size : (window_size + fourth_last)] = (
|
||||
third_start + torch.arange(fourth_last) + i * stride
|
||||
)
|
||||
mask[fourth_start+i, window_size:(window_size+fourth_last)] = third_start + torch.arange(fourth_last) + i * stride
|
||||
|
||||
return mask
|
||||
|
||||
@@ -118,8 +98,8 @@ def get_k_q(q_k_mask):
|
||||
for i in range(len(q_k_mask)):
|
||||
for j in range(len(q_k_mask[0])):
|
||||
if q_k_mask[i, j] >= 0:
|
||||
k_q_mask[i, j] = torch.where(q_k_mask[q_k_mask[i, j]] == i)[0]
|
||||
|
||||
k_q_mask[i, j] = torch.where(q_k_mask[q_k_mask[i, j]] ==i )[0]
|
||||
|
||||
return k_q_mask
|
||||
|
||||
|
||||
@@ -201,8 +181,6 @@ def get_mask(input_size, window_size, inner_size, device):
|
||||
|
||||
|
||||
"""PAM"""
|
||||
|
||||
|
||||
class GraphSelfAttention(nn.Module):
|
||||
def __init__(self, opt):
|
||||
super(GraphSelfAttention, self).__init__()
|
||||
@@ -226,16 +204,15 @@ class GraphSelfAttention(nn.Module):
|
||||
self.seq_len = opt.seq_len
|
||||
self.window_size = opt.window_size
|
||||
self.stride_size = opt.stride_size
|
||||
self.q_k_mask = get_q_k(
|
||||
self.seq_len, self.window_size, self.stride_size, opt.device
|
||||
)
|
||||
self.q_k_mask = get_q_k(self.seq_len, self.window_size, self.stride_size, opt.device)
|
||||
self.k_q_mask = get_k_q(self.q_k_mask)
|
||||
|
||||
|
||||
def forward(self, hidden_states):
|
||||
residual = hidden_states
|
||||
|
||||
hidden_states = hidden_states
|
||||
bsz, seq_len, _ = hidden_states.size()
|
||||
bsz, seq_len, _ = hidden_states.size()
|
||||
|
||||
q = hidden_states
|
||||
if self.normalize_before:
|
||||
@@ -269,8 +246,6 @@ class GraphSelfAttention(nn.Module):
|
||||
|
||||
|
||||
"""Multi-head self attention"""
|
||||
|
||||
|
||||
class NormalSelfAttention(nn.Module):
|
||||
def __init__(self, opt):
|
||||
super(NormalSelfAttention, self).__init__()
|
||||
@@ -295,12 +270,11 @@ class NormalSelfAttention(nn.Module):
|
||||
self.window_size = opt.window_size
|
||||
self.stride_size = opt.stride_size
|
||||
if opt.mask:
|
||||
self.mask, _ = get_mask(
|
||||
self.seq_len, self.stride_size, self.window_size, opt.device
|
||||
)
|
||||
self.mask, _ = get_mask(self.seq_len, self.stride_size, self.window_size, opt.device)
|
||||
else:
|
||||
self.mask = None
|
||||
|
||||
|
||||
def forward(self, hidden_states):
|
||||
residual = hidden_states
|
||||
|
||||
@@ -342,8 +316,6 @@ class NormalSelfAttention(nn.Module):
|
||||
|
||||
|
||||
"""Prob-sparse attention"""
|
||||
|
||||
|
||||
class ProbSparseAttention(nn.Module):
|
||||
def __init__(self, opt):
|
||||
super(ProbSparseAttention, self).__init__()
|
||||
@@ -367,16 +339,14 @@ class ProbSparseAttention(nn.Module):
|
||||
self.seq_len = opt.seq_len
|
||||
self.factor = opt.factor
|
||||
|
||||
def _prob_QK(self, Q, K, sample_k, n_top): # n_top: c*ln(L_q)
|
||||
def _prob_QK(self, Q, K, sample_k, n_top): # n_top: c*ln(L_q)
|
||||
# Q [B, H, L, D]
|
||||
B, H, L_K, E = K.shape
|
||||
_, _, L_Q, _ = Q.shape
|
||||
|
||||
# calculate the sampled Q_K
|
||||
K_expand = K.unsqueeze(-3).expand(B, H, L_Q, L_K, E)
|
||||
index_sample = torch.randint(
|
||||
L_K, (L_Q, sample_k)
|
||||
) # real U = U_part(factor*ln(L_k))*L_q
|
||||
index_sample = torch.randint(L_K, (L_Q, sample_k)) # real U = U_part(factor*ln(L_k))*L_q
|
||||
K_sample = K_expand[:, :, torch.arange(L_Q).unsqueeze(1), index_sample, :]
|
||||
Q_K_sample = torch.matmul(Q.unsqueeze(-2), K_sample.transpose(-2, -1)).squeeze()
|
||||
|
||||
@@ -385,10 +355,10 @@ class ProbSparseAttention(nn.Module):
|
||||
M_top = M.topk(n_top, sorted=False)[1]
|
||||
|
||||
# use the reduced Q to calculate Q_K
|
||||
Q_reduce = Q[
|
||||
torch.arange(B)[:, None, None], torch.arange(H)[None, :, None], M_top, :
|
||||
] # factor*ln(L_q)
|
||||
Q_K = torch.matmul(Q_reduce, K.transpose(-2, -1)) # factor*ln(L_q)*L_k
|
||||
Q_reduce = Q[torch.arange(B)[:, None, None],
|
||||
torch.arange(H)[None, :, None],
|
||||
M_top, :] # factor*ln(L_q)
|
||||
Q_K = torch.matmul(Q_reduce, K.transpose(-2, -1)) # factor*ln(L_q)*L_k
|
||||
|
||||
return Q_K, M_top
|
||||
|
||||
@@ -402,11 +372,11 @@ class ProbSparseAttention(nn.Module):
|
||||
def _update_context(self, context_in, V, scores, index, L_Q):
|
||||
B, H, L_V, D = V.shape
|
||||
|
||||
attn = torch.softmax(scores, dim=-1) # nn.Softmax(dim=-1)(scores)
|
||||
attn = torch.softmax(scores, dim=-1) # nn.Softmax(dim=-1)(scores)
|
||||
|
||||
context_in[
|
||||
torch.arange(B)[:, None, None], torch.arange(H)[None, :, None], index, :
|
||||
] = torch.matmul(attn, V).type_as(context_in)
|
||||
context_in[torch.arange(B)[:, None, None],
|
||||
torch.arange(H)[None, :, None],
|
||||
index, :] = torch.matmul(attn, V).type_as(context_in)
|
||||
return context_in
|
||||
|
||||
def forward(self, hidden_states):
|
||||
@@ -431,23 +401,17 @@ class ProbSparseAttention(nn.Module):
|
||||
k = k.float().contiguous()
|
||||
v = v.float().contiguous()
|
||||
|
||||
u = U_part = (
|
||||
self.factor * np.ceil(np.log(seq_len)).astype("int").item()
|
||||
) # c*ln(L_k)
|
||||
u = U_part = self.factor * np.ceil(np.log(seq_len)).astype('int').item() # c*ln(L_k)
|
||||
|
||||
U_part = U_part if U_part < seq_len else seq_len
|
||||
U_part = U_part if U_part<seq_len else seq_len
|
||||
u = u if u < seq_len else seq_len
|
||||
|
||||
scores_top, index = self._prob_QK(q, k, sample_k=U_part, n_top=u)
|
||||
scores_top, index = self._prob_QK(q, k, sample_k=U_part, n_top=u)
|
||||
|
||||
# get the context
|
||||
context = self._get_initial_context(v, seq_len)
|
||||
# update the context with selected top_k queries
|
||||
context = (
|
||||
self._update_context(context, v, scores_top, index, seq_len)
|
||||
.transpose(1, 2)
|
||||
.contiguous()
|
||||
)
|
||||
context = self._update_context(context, v, scores_top, index, seq_len).transpose(1, 2).contiguous()
|
||||
|
||||
context = context.view(bsz, seq_len, self.n_head * self.d_k)
|
||||
|
||||
@@ -461,24 +425,24 @@ class ProbSparseAttention(nn.Module):
|
||||
|
||||
|
||||
def parsing():
|
||||
parser = argparse.ArgumentParser(description="Needed for graph self attention.")
|
||||
parser.add_argument("-d_model", type=int, default=256)
|
||||
parser.add_argument("-d_k", type=int, default=64)
|
||||
parser.add_argument("-normalize_before", type=bool, default=False)
|
||||
parser.add_argument("-n_head", type=int, default=4)
|
||||
parser.add_argument("-dropout", type=float, default=0.1)
|
||||
parser = argparse.ArgumentParser(description='Needed for graph self attention.')
|
||||
parser.add_argument('-d_model', type=int, default=256)
|
||||
parser.add_argument('-d_k', type=int, default=64)
|
||||
parser.add_argument('-normalize_before', type=bool, default=False)
|
||||
parser.add_argument('-n_head', type=int, default=4)
|
||||
parser.add_argument('-dropout', type=float, default=0.1)
|
||||
|
||||
# arguments for Multiformer
|
||||
parser.add_argument("-window_size", type=int, default=3)
|
||||
parser.add_argument("-stride_size", type=int, default=25)
|
||||
parser.add_argument('-window_size', type=int, default=3)
|
||||
parser.add_argument('-stride_size', type=int, default=25)
|
||||
|
||||
# arguments for ProbSparse
|
||||
parser.add_argument("-factor", type=int, default=5)
|
||||
parser.add_argument('-factor', type=int, default=5)
|
||||
|
||||
# arguments for full-attention
|
||||
parser.add_argument("-mask", type=int, default=0)
|
||||
parser.add_argument('-mask', type=int, default=0)
|
||||
|
||||
parser.add_argument("-seq_len", type=int, default=1000)
|
||||
parser.add_argument('-seq_len', type=int, default=1000)
|
||||
args = parser.parse_args()
|
||||
|
||||
return args
|
||||
@@ -493,14 +457,12 @@ def test_NSA(args, input_len):
|
||||
NSA_Layer = NormalSelfAttention(args).to(args.device)
|
||||
optimizer = optim.Adam(NSA_Layer.parameters(), 1e-4)
|
||||
optimizer.zero_grad()
|
||||
hidden_state = torch.ones(4, input_len, args.d_model, dtype=torch.float32).to(
|
||||
args.device
|
||||
)
|
||||
hidden_state = torch.ones(4, input_len, args.d_model, dtype=torch.float32).to(args.device)
|
||||
fake_gt = torch.zeros(4, input_len, args.d_model).to(args.device)
|
||||
|
||||
# Preload the layer
|
||||
result = NSA_Layer(hidden_state)
|
||||
loss = ((fake_gt - result) ** 2).mean()
|
||||
loss = ((fake_gt - result) ** 2).mean()
|
||||
loss.backward()
|
||||
optimizer.step()
|
||||
|
||||
@@ -511,17 +473,13 @@ def test_NSA(args, input_len):
|
||||
handle = pynvml.nvmlDeviceGetHandleByIndex(1)
|
||||
meminfo = pynvml.nvmlDeviceGetMemoryInfo(handle)
|
||||
used_memory += meminfo.used / 1024**3
|
||||
loss = ((fake_gt - result) ** 2).mean()
|
||||
loss = ((fake_gt - result) ** 2).mean()
|
||||
loss.backward()
|
||||
optimizer.step()
|
||||
|
||||
print(
|
||||
"NSA used average time: {} s".format(
|
||||
round((time.time() - start_time) / 1000, 4)
|
||||
)
|
||||
)
|
||||
print('NSA used average time: {} s'.format(round((time.time() - start_time) / 1000, 4)))
|
||||
used_memory = used_memory / 1000
|
||||
print("NSA used average memory: {} GB".format(round(used_memory - init_mem, 4)))
|
||||
print('NSA used average memory: {} GB'.format(round(used_memory-init_mem, 4)))
|
||||
|
||||
|
||||
def test_GSA(args, input_len):
|
||||
@@ -533,14 +491,12 @@ def test_GSA(args, input_len):
|
||||
GSA_Layer = GraphSelfAttention(args).to(args.device)
|
||||
optimizer = optim.Adam(GSA_Layer.parameters(), 1e-4)
|
||||
optimizer.zero_grad()
|
||||
hidden_state = torch.ones(
|
||||
4, input_len, args.d_model, dtype=torch.float32, device=args.device
|
||||
)
|
||||
hidden_state = torch.ones(4, input_len, args.d_model, dtype=torch.float32, device=args.device)
|
||||
fake_gt = torch.zeros(4, input_len, args.d_model, device=args.device)
|
||||
|
||||
# Preload the layer
|
||||
result = GSA_Layer(hidden_state)
|
||||
loss = ((fake_gt - result) ** 2).mean()
|
||||
loss = ((fake_gt - result) ** 2).mean()
|
||||
loss.backward()
|
||||
optimizer.step()
|
||||
|
||||
@@ -552,15 +508,13 @@ def test_GSA(args, input_len):
|
||||
handle = pynvml.nvmlDeviceGetHandleByIndex(1)
|
||||
meminfo = pynvml.nvmlDeviceGetMemoryInfo(handle)
|
||||
used_memory += meminfo.used / 1024**3
|
||||
loss = ((fake_gt - result) ** 2).mean()
|
||||
loss = ((fake_gt - result) ** 2).mean()
|
||||
loss.backward()
|
||||
optimizer.step()
|
||||
|
||||
print(
|
||||
"GSA used time:{} s".format(round((time.time() - start_time) / repeat_times, 4))
|
||||
)
|
||||
print('GSA used time:{} s'.format(round((time.time() - start_time) / repeat_times, 4)))
|
||||
used_memory = used_memory / repeat_times
|
||||
print("GSA used average memory: {} GB".format(round(used_memory - init_mem, 4)))
|
||||
print('GSA used average memory: {} GB'.format(round(used_memory-init_mem, 4)))
|
||||
|
||||
|
||||
def test_PSA(args, input_len):
|
||||
@@ -572,14 +526,12 @@ def test_PSA(args, input_len):
|
||||
LSA_Layer = ProbSparseAttention(args).to(args.device)
|
||||
optimizer = optim.Adam(LSA_Layer.parameters(), 1e-4)
|
||||
optimizer.zero_grad()
|
||||
hidden_state = torch.ones(
|
||||
4, input_len, args.d_model, dtype=torch.float32, device=args.device
|
||||
)
|
||||
hidden_state = torch.ones(4, input_len, args.d_model, dtype=torch.float32, device=args.device)
|
||||
fake_gt = torch.zeros(4, input_len, args.d_model, device=args.device)
|
||||
|
||||
# Preload the layer
|
||||
result = LSA_Layer(hidden_state)
|
||||
loss = ((fake_gt - result) ** 2).mean()
|
||||
loss = ((fake_gt - result) ** 2).mean()
|
||||
loss.backward()
|
||||
optimizer.step()
|
||||
|
||||
@@ -591,23 +543,21 @@ def test_PSA(args, input_len):
|
||||
handle = pynvml.nvmlDeviceGetHandleByIndex(1)
|
||||
meminfo = pynvml.nvmlDeviceGetMemoryInfo(handle)
|
||||
used_memory += meminfo.used / 1024**3
|
||||
loss = ((fake_gt - result) ** 2).mean()
|
||||
loss = ((fake_gt - result) ** 2).mean()
|
||||
loss.backward()
|
||||
optimizer.step()
|
||||
|
||||
print(
|
||||
"LSA used time:{} s".format(round((time.time() - start_time) / repeat_times, 4))
|
||||
)
|
||||
print('LSA used time:{} s'.format(round((time.time() - start_time) / repeat_times, 4)))
|
||||
used_memory = used_memory / repeat_times
|
||||
print("LSA used average memory: {} GB".format(round(used_memory - init_mem, 4)))
|
||||
print('LSA used average memory: {} GB'.format(round(used_memory-init_mem, 4)))
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
if __name__ == '__main__':
|
||||
args = parsing()
|
||||
if torch.cuda.is_available():
|
||||
args.device = torch.device("cuda")
|
||||
args.device = torch.device('cuda')
|
||||
else:
|
||||
args.device = torch.device("cpu")
|
||||
args.device = torch.device('cpu')
|
||||
|
||||
input_size = args.seq_len
|
||||
stride = args.stride_size
|
||||
@@ -617,13 +567,14 @@ if __name__ == "__main__":
|
||||
input_len = input_size + second_length + third_length + fourth_length
|
||||
|
||||
if args.mask:
|
||||
print("sequence length: {}".format(input_len))
|
||||
print('sequence length: {}'.format(input_len))
|
||||
test_NSA(args, input_len)
|
||||
else:
|
||||
print("sequence length: {}".format(input_size))
|
||||
print('sequence length: {}'.format(input_size))
|
||||
test_NSA(args, input_size)
|
||||
|
||||
print("sequence length: {}".format(input_len))
|
||||
print('sequence length: {}'.format(input_len))
|
||||
test_GSA(args, input_len)
|
||||
print("sequence length: {}".format(input_size))
|
||||
print('sequence length: {}'.format(input_size))
|
||||
test_PSA(args, input_size)
|
||||
|
||||
|
||||
@@ -8,112 +8,86 @@ Modified based on Longformer.
|
||||
}
|
||||
"""
|
||||
|
||||
import os.path
|
||||
import sys
|
||||
from functools import lru_cache
|
||||
from typing import Union
|
||||
from functools import lru_cache
|
||||
|
||||
import torch
|
||||
|
||||
sys.path.append("pyraformer/tvm/python")
|
||||
|
||||
import os.path
|
||||
import sys
|
||||
sys.path.append('pyraformer/tvm/python')
|
||||
|
||||
class GraphMM(torch.autograd.Function):
|
||||
"""Class to encapsulate tvm code for compiling a diagonal_mm function, in addition to calling
|
||||
'''Class to encapsulate tvm code for compiling a diagonal_mm function, in addition to calling
|
||||
this function from PyTorch
|
||||
"""
|
||||
'''
|
||||
|
||||
function_dict = (
|
||||
{}
|
||||
) # save a list of functions, each has a different set of parameters
|
||||
function_dict = {} # save a list of functions, each has a different set of parameters
|
||||
|
||||
@staticmethod
|
||||
def _compile_function(
|
||||
dtype: str, device: str, b0: int = 4, b1: int = 8, b2: int = 8
|
||||
):
|
||||
"""Compiles a tvm function that computes diagonal_mm
|
||||
def _compile_function(dtype: str, device: str, b0: int = 4, b1: int = 8, b2: int = 8):
|
||||
'''Compiles a tvm function that computes diagonal_mm
|
||||
args:
|
||||
dtype: str in ['float64', 'float32', 'float16']
|
||||
device: str in ['cpu' or 'cuda']
|
||||
b0, b1, b2: size of tensor tiles. Very important for good performance
|
||||
"""
|
||||
'''
|
||||
import tvm # import the full tvm library here for compilation. Don't import at the top of the file in case we don't need to compile
|
||||
from tvm.contrib import nvcc
|
||||
|
||||
@tvm.register_func
|
||||
def tvm_callback_cuda_compile(code):
|
||||
"""Use nvcc compiler for better perf."""
|
||||
ptx = nvcc.compile_cuda(
|
||||
code, target="ptx", arch="sm_52"
|
||||
) # use old arch for this to work on old GPUs
|
||||
ptx = nvcc.compile_cuda(code, target="ptx", arch='sm_52') # use old arch for this to work on old GPUs
|
||||
return ptx
|
||||
|
||||
assert dtype in ["float16", "float32", "float64"]
|
||||
assert device in ["cpu", "cuda"]
|
||||
device = None if device == "cpu" else device
|
||||
tgt_host = "llvm"
|
||||
assert dtype in ['float16', 'float32', 'float64']
|
||||
assert device in ['cpu', 'cuda']
|
||||
device = None if device == 'cpu' else device
|
||||
tgt_host="llvm"
|
||||
|
||||
b = tvm.te.var("b") # batch size
|
||||
n = tvm.te.var("n") # sequence length
|
||||
h = tvm.te.var("h") # number of heads
|
||||
m = tvm.te.var("m") # hidden dimension
|
||||
w = tvm.te.var("w") # window size
|
||||
padding = tvm.te.var("padding") # padding
|
||||
transpose_t1 = tvm.te.var("transpose_t1") # t1 should be transposed
|
||||
t1d3 = tvm.te.var("t1d3") # last dimension of t1
|
||||
t3d3 = tvm.te.var("t3d3") # last dimension of t3 (the result tensor)
|
||||
max_attn = tvm.te.var("max_attn")
|
||||
X = tvm.te.placeholder((b, n, h, t1d3), name="X", dtype=dtype) # first tensor
|
||||
Y = tvm.te.placeholder((b, n, h, m), name="Y", dtype=dtype) # second tensor
|
||||
k = tvm.te.reduce_axis((0, t1d3), name="k") # dimension to sum over
|
||||
q_k_mask = tvm.te.placeholder(
|
||||
(n, max_attn), name="q_k", dtype="int"
|
||||
) # dilation per head
|
||||
k_q_mask = tvm.te.placeholder((n, max_attn), name="k_q", dtype="int") #
|
||||
b = tvm.te.var('b') # batch size
|
||||
n = tvm.te.var('n') # sequence length
|
||||
h = tvm.te.var('h') # number of heads
|
||||
m = tvm.te.var('m') # hidden dimension
|
||||
w = tvm.te.var('w') # window size
|
||||
padding = tvm.te.var('padding') # padding
|
||||
transpose_t1 = tvm.te.var('transpose_t1') # t1 should be transposed
|
||||
t1d3 = tvm.te.var('t1d3') # last dimension of t1
|
||||
t3d3 = tvm.te.var('t3d3') # last dimension of t3 (the result tensor)
|
||||
max_attn = tvm.te.var('max_attn')
|
||||
X = tvm.te.placeholder((b, n, h, t1d3), name='X', dtype=dtype) # first tensor
|
||||
Y = tvm.te.placeholder((b, n, h, m), name='Y', dtype=dtype) # second tensor
|
||||
k = tvm.te.reduce_axis((0, t1d3), name='k') # dimension to sum over
|
||||
q_k_mask = tvm.te.placeholder((n, max_attn), name='q_k', dtype='int') # dilation per head
|
||||
k_q_mask = tvm.te.placeholder((n, max_attn), name='k_q', dtype='int') #
|
||||
output_shape = (b, n, h, t3d3) # shape of the result tensor
|
||||
|
||||
algorithm = lambda l, i, q, j: tvm.te.sum(
|
||||
tvm.te.if_then_else(
|
||||
t3d3
|
||||
== m, # if output dimension == m, then t1 is diagonaled (FIXME: This breaks if t3d3 == m == t1d3)
|
||||
t3d3 == m, # if output dimension == m, then t1 is diagonaled (FIXME: This breaks if t3d3 == m == t1d3)
|
||||
tvm.te.if_then_else(
|
||||
transpose_t1 == 0,
|
||||
tvm.te.if_then_else(
|
||||
q_k_mask[i, k] >= 0,
|
||||
q_k_mask[i, k]>=0,
|
||||
X[l, i, q, k] * Y[l, q_k_mask[i, k], q, j], # t1 is diagonaled
|
||||
padding,
|
||||
padding
|
||||
),
|
||||
tvm.te.if_then_else(
|
||||
q_k_mask[i, k] >= 0,
|
||||
X[l, q_k_mask[i, k], q, k_q_mask[i, k]]
|
||||
* Y[
|
||||
l, q_k_mask[i, k], q, j
|
||||
], # # t1 is diagonaled and should be transposed
|
||||
padding,
|
||||
q_k_mask[i, k]>=0,
|
||||
X[l, q_k_mask[i, k], q, k_q_mask[i, k]] * Y[l, q_k_mask[i, k], q, j], # # t1 is diagonaled and should be transposed
|
||||
padding
|
||||
),
|
||||
),
|
||||
tvm.te.if_then_else(
|
||||
q_k_mask[i, j] >= 0,
|
||||
X[l, i, q, k]
|
||||
* Y[
|
||||
l, q_k_mask[i, j], q, k
|
||||
], # t1 is not diagonaled, but the output tensor is going to be
|
||||
padding,
|
||||
),
|
||||
),
|
||||
axis=k,
|
||||
)
|
||||
q_k_mask[i, j]>=0,
|
||||
X[l, i, q, k] * Y[l, q_k_mask[i, j], q, k], # t1 is not diagonaled, but the output tensor is going to be
|
||||
padding
|
||||
)
|
||||
), axis=k)
|
||||
|
||||
Z = tvm.te.compute(
|
||||
output_shape, algorithm, name="Z"
|
||||
) # automatically generate cuda code
|
||||
Z = tvm.te.compute(output_shape, algorithm, name='Z') # automatically generate cuda code
|
||||
s = tvm.te.create_schedule(Z.op)
|
||||
|
||||
print(
|
||||
"Lowering: \n ===================== \n{}".format(
|
||||
tvm.lower(s, [X, Y, q_k_mask, k_q_mask], simple_mode=True)
|
||||
)
|
||||
)
|
||||
print('Lowering: \n ===================== \n{}'.format(tvm.lower(s, [X, Y, q_k_mask, k_q_mask], simple_mode=True)))
|
||||
|
||||
# split long axis into smaller chunks and assing each one to a separate GPU thread/block
|
||||
ko, ki = s[Z].split(Z.op.reduce_axis[0], factor=b0)
|
||||
@@ -133,31 +107,21 @@ class GraphMM(torch.autograd.Function):
|
||||
s[ZF].compute_at(s[Z], s[Z].op.reduce_axis[0])
|
||||
s[Z].set_store_predicate(tx.var.equal(0))
|
||||
|
||||
print(
|
||||
"Lowering with GPU splits: \n ===================== \n{}".format(
|
||||
tvm.lower(s, [X, Y, q_k_mask, k_q_mask], simple_mode=True)
|
||||
)
|
||||
)
|
||||
print('Lowering with GPU splits: \n ===================== \n{}'.format(tvm.lower(s, [X, Y, q_k_mask, k_q_mask], simple_mode=True)))
|
||||
|
||||
# compiling the automatically generated cuda code
|
||||
graph_mm = tvm.build(
|
||||
s,
|
||||
[X, Y, Z, q_k_mask, k_q_mask, max_attn, padding, transpose_t1, t3d3],
|
||||
target=device,
|
||||
target_host=tgt_host,
|
||||
name="graph_mm",
|
||||
)
|
||||
graph_mm = tvm.build(s, [X, Y, Z, q_k_mask, k_q_mask, max_attn, padding, transpose_t1, t3d3], target=device, target_host=tgt_host, name='graph_mm')
|
||||
return graph_mm
|
||||
|
||||
@staticmethod
|
||||
def _get_lib_filename(dtype: str, device: str):
|
||||
base_filename = "lib/lib_hierarchical_mm"
|
||||
return "{}_{}_{}.so".format(base_filename, dtype, device)
|
||||
base_filename = 'lib/lib_hierarchical_mm'
|
||||
return '{}_{}_{}.so'.format(base_filename, dtype, device)
|
||||
|
||||
@staticmethod
|
||||
def _save_compiled_function(f, dtype: str, device: str):
|
||||
if not os.path.exists("lib/"):
|
||||
os.makedirs("lib/")
|
||||
if not os.path.exists('lib/'):
|
||||
os.makedirs('lib/')
|
||||
f.export_library(GraphMM._get_lib_filename(dtype, device))
|
||||
|
||||
@staticmethod
|
||||
@@ -167,63 +131,43 @@ class GraphMM(torch.autograd.Function):
|
||||
|
||||
filename = GraphMM._get_lib_filename(dtype, device)
|
||||
current_dir = os.path.dirname(os.path.abspath(__file__))
|
||||
potential_dirs = [
|
||||
"../../",
|
||||
"../",
|
||||
"./",
|
||||
f"{current_dir}/",
|
||||
f"{current_dir}/../",
|
||||
]
|
||||
for potential_dir in potential_dirs:
|
||||
filepath = "{}{}".format(potential_dir, filename)
|
||||
potential_dirs = ['../../', '../', './', f'{current_dir}/', f'{current_dir}/../']
|
||||
for potential_dir in potential_dirs:
|
||||
filepath = '{}{}'.format(potential_dir, filename)
|
||||
if os.path.isfile(filepath):
|
||||
print("Loading tvm binary from: {}".format(filepath))
|
||||
print('Loading tvm binary from: {}'.format(filepath))
|
||||
return load(filepath)
|
||||
return None
|
||||
|
||||
@staticmethod
|
||||
def _get_function(dtype: str, device: str):
|
||||
"""Loads the function from the disk or compile it"""
|
||||
'''Loads the function from the disk or compile it'''
|
||||
# A list of arguments that define the function
|
||||
args = (dtype, device)
|
||||
if args not in GraphMM.function_dict:
|
||||
graph_mm = GraphMM._load_compiled_function(
|
||||
dtype, device
|
||||
) # try to load from disk
|
||||
graph_mm = GraphMM._load_compiled_function(dtype, device) # try to load from disk
|
||||
if not graph_mm:
|
||||
print("Tvm binary not found. Compiling ...")
|
||||
print('Tvm binary not found. Compiling ...')
|
||||
graph_mm = GraphMM._compile_function(dtype, device) # compile
|
||||
GraphMM._save_compiled_function(graph_mm, dtype, device) # save to disk
|
||||
# convert the tvm function into a pytorch function
|
||||
from tvm.contrib import dlpack
|
||||
|
||||
graph_mm_pytorch = dlpack.to_pytorch_func(
|
||||
graph_mm
|
||||
) # wrap it as a pytorch function
|
||||
graph_mm_pytorch = dlpack.to_pytorch_func(graph_mm) # wrap it as a pytorch function
|
||||
# save the function into a dictionary to be reused
|
||||
GraphMM.function_dict[
|
||||
args
|
||||
] = graph_mm_pytorch # save it in a dictionary for next time
|
||||
GraphMM.function_dict[args] = graph_mm_pytorch # save it in a dictionary for next time
|
||||
return GraphMM.function_dict[args]
|
||||
|
||||
@staticmethod
|
||||
def _graph_mm(
|
||||
t1: torch.Tensor,
|
||||
t2: torch.Tensor,
|
||||
q_k_mask: torch.Tensor,
|
||||
k_q_mask: torch.Tensor,
|
||||
is_t1_diagonaled: bool = False,
|
||||
transpose_t1: bool = False,
|
||||
padding: int = 0,
|
||||
autoregressive: bool = False,
|
||||
):
|
||||
"""Calls the compiled function after checking the input format. This function is called in three different modes.
|
||||
def _graph_mm(t1: torch.Tensor, t2: torch.Tensor, q_k_mask: torch.Tensor, k_q_mask: torch.Tensor,
|
||||
is_t1_diagonaled: bool = False, transpose_t1: bool = False, padding: int = 0,
|
||||
autoregressive: bool = False):
|
||||
'''Calls the compiled function after checking the input format. This function is called in three different modes.
|
||||
t1 x t2 = r ==> t1 and t2 are not diagonaled, but r is. Useful for query x key = attention_scores
|
||||
t1 x t2 = r ==> t1 is diagonaled, but t2 and r are not. Useful to compuate attantion_scores x value = context
|
||||
t1 x t2 = r ==> t1 is diagonaled and it should be transposed, but t2 and r are not diagonaled. Useful in some of
|
||||
the calculations in the backward pass.
|
||||
"""
|
||||
dtype = str(t1.dtype).split(".")[1]
|
||||
'''
|
||||
dtype = str(t1.dtype).split('.')[1]
|
||||
device = t1.device.type
|
||||
assert len(t1.shape) == 4
|
||||
assert len(t1.shape) == len(t2.shape)
|
||||
@@ -252,26 +196,14 @@ class GraphMM(torch.autograd.Function):
|
||||
# This functions computes diagonal_mm then saves the result in `r`
|
||||
if m == max_attn:
|
||||
# FIXME
|
||||
print(
|
||||
"Error: the hidden dimension {m} shouldn't match number of diagonals {c}"
|
||||
)
|
||||
print('Error: the hidden dimension {m} shouldn\'t match number of diagonals {c}')
|
||||
assert False
|
||||
_graph_mm_function(
|
||||
t1,
|
||||
t2,
|
||||
r,
|
||||
q_k_mask,
|
||||
k_q_mask,
|
||||
max_attn,
|
||||
padding,
|
||||
transpose_t1,
|
||||
m if is_t1_diagonaled else max_attn,
|
||||
)
|
||||
_graph_mm_function(t1, t2, r, q_k_mask, k_q_mask, max_attn, padding, transpose_t1, m if is_t1_diagonaled else max_attn)
|
||||
return r
|
||||
|
||||
@staticmethod
|
||||
def _prepare_tensors(t):
|
||||
"""Fix `stride()` information of input tensor. This addresses some inconsistency in stride information in PyTorch.
|
||||
'''Fix `stride()` information of input tensor. This addresses some inconsistency in stride information in PyTorch.
|
||||
For a tensor t, if t.size(0) == 1, then the value of t.stride()[0] doesn't matter.
|
||||
TVM expects this value to be the `product(t.size()[1:])` but PyTorch some times sets it to `t.stride()[1]`.
|
||||
Here's an example to reporduce this issue:
|
||||
@@ -282,7 +214,7 @@ class GraphMM(torch.autograd.Function):
|
||||
> (1, 1) # expected it to be (10, 1) as above
|
||||
print(torch.randn(10, 2).t().contiguous().stride())
|
||||
> (10, 1) # but gets the expected stride if the first dimension is > 1
|
||||
"""
|
||||
'''
|
||||
assert t.is_contiguous()
|
||||
t_stride = list(t.stride())
|
||||
t_size = list(t.size())
|
||||
@@ -297,17 +229,9 @@ class GraphMM(torch.autograd.Function):
|
||||
min_seq_len = 16 # unexpected output if seq_len < 16
|
||||
|
||||
@staticmethod
|
||||
def forward(
|
||||
ctx,
|
||||
t1: torch.Tensor,
|
||||
t2: torch.Tensor,
|
||||
q_k_mask,
|
||||
k_q_mask,
|
||||
is_t1_diagonaled: bool = False,
|
||||
padding: int = 0,
|
||||
) -> torch.Tensor:
|
||||
"""Compuates diagonal_mm of t1 and t2.
|
||||
args:
|
||||
def forward(ctx, t1: torch.Tensor, t2: torch.Tensor, q_k_mask, k_q_mask, is_t1_diagonaled: bool = False, padding: int = 0) -> torch.Tensor:
|
||||
'''Compuates diagonal_mm of t1 and t2.
|
||||
args:
|
||||
t1: torch.Tensor = (batch_size, seq_len, num_attention_heads, hidden_size|number_of_diagonals).
|
||||
t1 can be a regular tensor (e.g. `query_layer`) or a diagonaled one (e.g. `attention_scores`)
|
||||
t2: torch.Tensor = (batch_size, seq_len, num_attention_heads, hidden_size). This is always a non-diagonaled
|
||||
@@ -322,13 +246,9 @@ class GraphMM(torch.autograd.Function):
|
||||
autoregressive: if true, return only the lower triangle
|
||||
returns: torch.Tensor = (batch_size, seq_len, num_attention_heads, hidden_size|number_of_diagonals)
|
||||
if t1 is diagonaed, result is non-diagonaled, and vice versa
|
||||
"""
|
||||
'''
|
||||
seq_len = t1.size(1)
|
||||
assert (
|
||||
seq_len >= GraphMM.min_seq_len
|
||||
), "avoid splitting errors by using seq_len >= {}".format(
|
||||
GraphMM.min_seq_len
|
||||
) # FIXME
|
||||
assert seq_len >= GraphMM.min_seq_len, 'avoid splitting errors by using seq_len >= {}'.format(GraphMM.min_seq_len) # FIXME
|
||||
|
||||
t1 = GraphMM._prepare_tensors(t1)
|
||||
t2 = GraphMM._prepare_tensors(t2)
|
||||
@@ -337,14 +257,7 @@ class GraphMM(torch.autograd.Function):
|
||||
ctx.save_for_backward(t1, t2, q_k_mask, k_q_mask)
|
||||
ctx.is_t1_diagonaled = is_t1_diagonaled
|
||||
# output = t1.mm(t2) # what would have been called if this was a regular matmul
|
||||
output = GraphMM._graph_mm(
|
||||
t1,
|
||||
t2,
|
||||
q_k_mask,
|
||||
k_q_mask,
|
||||
is_t1_diagonaled=is_t1_diagonaled,
|
||||
padding=padding,
|
||||
)
|
||||
output = GraphMM._graph_mm(t1, t2, q_k_mask, k_q_mask, is_t1_diagonaled=is_t1_diagonaled, padding=padding)
|
||||
return output
|
||||
|
||||
@staticmethod
|
||||
@@ -352,35 +265,17 @@ class GraphMM(torch.autograd.Function):
|
||||
t1, t2, q_k_mask, k_q_mask = ctx.saved_tensors
|
||||
is_t1_diagonaled = ctx.is_t1_diagonaled
|
||||
if not grad_output.is_contiguous():
|
||||
grad_output = (
|
||||
grad_output.contiguous()
|
||||
) # tvm requires all input tensors to be contiguous
|
||||
grad_output = grad_output.contiguous() # tvm requires all input tensors to be contiguous
|
||||
grad_output = GraphMM._prepare_tensors(grad_output)
|
||||
# http://cs231n.github.io/optimization-2/
|
||||
# https://pytorch.org/docs/master/notes/extending.html
|
||||
# grad_t1 = grad_output.mm(t2) # what would have been called if this was a regular matmul
|
||||
grad_t1 = GraphMM._graph_mm(
|
||||
grad_output, t2, q_k_mask, k_q_mask, is_t1_diagonaled=not is_t1_diagonaled
|
||||
)
|
||||
grad_t1 = GraphMM._graph_mm(grad_output, t2, q_k_mask, k_q_mask, is_t1_diagonaled=not is_t1_diagonaled)
|
||||
# grad_t2 = grad_output.t().mm(t1) # or `grad_t2 = t1.t().mm(grad_output).t()` because `(AB)^T = B^TA^T`
|
||||
if is_t1_diagonaled:
|
||||
grad_t2 = GraphMM._graph_mm(
|
||||
t1,
|
||||
grad_output,
|
||||
q_k_mask,
|
||||
k_q_mask,
|
||||
is_t1_diagonaled=True,
|
||||
transpose_t1=True,
|
||||
)
|
||||
grad_t2 = GraphMM._graph_mm(t1, grad_output, q_k_mask, k_q_mask, is_t1_diagonaled=True, transpose_t1=True)
|
||||
else:
|
||||
grad_t2 = GraphMM._graph_mm(
|
||||
grad_output,
|
||||
t1,
|
||||
q_k_mask,
|
||||
k_q_mask,
|
||||
is_t1_diagonaled=True,
|
||||
transpose_t1=True,
|
||||
)
|
||||
grad_t2 = GraphMM._graph_mm(grad_output, t1, q_k_mask, k_q_mask, is_t1_diagonaled=True, transpose_t1=True)
|
||||
return grad_t1, grad_t2, None, None, None, None, None
|
||||
|
||||
|
||||
|
||||
+18
-43
@@ -1,28 +1,23 @@
|
||||
import numpy as np
|
||||
import torch
|
||||
from torch.nn.modules import loss
|
||||
import torch
|
||||
import numpy as np
|
||||
|
||||
|
||||
def MAE(pred, true):
|
||||
return np.mean(np.abs(pred - true))
|
||||
|
||||
return np.mean(np.abs(pred-true))
|
||||
|
||||
def MSE(pred, true):
|
||||
return np.mean((pred - true) ** 2)
|
||||
|
||||
return np.mean((pred-true)**2)
|
||||
|
||||
def RMSE(pred, true):
|
||||
return np.sqrt(MSE(pred, true))
|
||||
|
||||
|
||||
def MAPE(pred, true):
|
||||
return np.mean(np.abs((pred - true) / true))
|
||||
|
||||
|
||||
def MSPE(pred, true):
|
||||
return np.mean(np.square((pred - true) / true))
|
||||
|
||||
|
||||
def metric(pred, true):
|
||||
mae = MAE(pred, true)
|
||||
mse = MSE(pred, true)
|
||||
@@ -30,50 +25,32 @@ def metric(pred, true):
|
||||
mape = MAPE(pred, true)
|
||||
mspe = MSPE(pred, true)
|
||||
|
||||
return mae, mse, rmse, mape, mspe
|
||||
return mae,mse,rmse,mape,mspe
|
||||
|
||||
|
||||
class StandardScaler:
|
||||
class StandardScaler():
|
||||
def __init__(self):
|
||||
self.mean = 0.0
|
||||
self.std = 1.0
|
||||
self.mean = 0.
|
||||
self.std = 1.
|
||||
|
||||
def fit(self, data):
|
||||
self.mean = data.mean(0)
|
||||
self.std = data.std(0)
|
||||
|
||||
def transform(self, data):
|
||||
mean = (
|
||||
torch.from_numpy(self.mean).type_as(data).to(data.device)
|
||||
if torch.is_tensor(data)
|
||||
else self.mean
|
||||
)
|
||||
std = (
|
||||
torch.from_numpy(self.std).type_as(data).to(data.device)
|
||||
if torch.is_tensor(data)
|
||||
else self.std
|
||||
)
|
||||
mean = torch.from_numpy(self.mean).type_as(data).to(data.device) if torch.is_tensor(data) else self.mean
|
||||
std = torch.from_numpy(self.std).type_as(data).to(data.device) if torch.is_tensor(data) else self.std
|
||||
return (data - mean) / std
|
||||
|
||||
def inverse_transform(self, data):
|
||||
mean = (
|
||||
torch.from_numpy(self.mean).type_as(data).to(data.device)
|
||||
if torch.is_tensor(data)
|
||||
else self.mean
|
||||
)
|
||||
std = (
|
||||
torch.from_numpy(self.std).type_as(data).to(data.device)
|
||||
if torch.is_tensor(data)
|
||||
else self.std
|
||||
)
|
||||
mean = torch.from_numpy(self.mean).type_as(data).to(data.device) if torch.is_tensor(data) else self.mean
|
||||
std = torch.from_numpy(self.std).type_as(data).to(data.device) if torch.is_tensor(data) else self.std
|
||||
return (data * std) + mean
|
||||
|
||||
|
||||
class TopkMSELoss(torch.nn.Module):
|
||||
def __init__(self, topk) -> None:
|
||||
super().__init__()
|
||||
self.topk = topk
|
||||
self.criterion = torch.nn.MSELoss(reduction="none")
|
||||
self.criterion = torch.nn.MSELoss(reduction='none')
|
||||
|
||||
def forward(self, output, label):
|
||||
losses = self.criterion(output, label).mean(2).mean(1)
|
||||
@@ -81,9 +58,8 @@ class TopkMSELoss(torch.nn.Module):
|
||||
|
||||
return losses
|
||||
|
||||
|
||||
class SingleStepLoss(torch.nn.Module):
|
||||
"""Compute top-k log-likelihood and mse."""
|
||||
""" Compute top-k log-likelihood and mse. """
|
||||
|
||||
def __init__(self, ignore_zero):
|
||||
super().__init__()
|
||||
@@ -91,9 +67,9 @@ class SingleStepLoss(torch.nn.Module):
|
||||
|
||||
def forward(self, mu, sigma, labels, topk=0):
|
||||
if self.ignore_zero:
|
||||
indexes = labels != 0
|
||||
indexes = (labels != 0)
|
||||
else:
|
||||
indexes = labels >= 0
|
||||
indexes = (labels >= 0)
|
||||
|
||||
distribution = torch.distributions.normal.Normal(mu[indexes], sigma[indexes])
|
||||
likelihood = -distribution.log_prob(labels[indexes])
|
||||
@@ -107,12 +83,11 @@ class SingleStepLoss(torch.nn.Module):
|
||||
|
||||
return likelihood, se
|
||||
|
||||
|
||||
def AE_loss(mu, labels, ignore_zero):
|
||||
if ignore_zero:
|
||||
indexes = labels != 0
|
||||
indexes = (labels != 0)
|
||||
else:
|
||||
indexes = labels >= 0
|
||||
indexes = (labels >= 0)
|
||||
|
||||
ae = torch.abs(labels[indexes] - mu[indexes])
|
||||
return ae
|
||||
|
||||
Reference in New Issue
Block a user