mirror of
https://github.com/wassname/pytorch-transformer-ts.git
synced 2026-06-27 16:46:28 +08:00
cleaned up(still not working)
This commit is contained in:
BIN
Binary file not shown.
Vendored
BIN
Binary file not shown.
+89
-33
@@ -35,6 +35,7 @@ from module import PyraformerLRModel
|
||||
from torch.utils.data import DataLoader
|
||||
from tools import SingleStepLoss as LossFactory
|
||||
from torch.utils.data.sampler import RandomSampler
|
||||
|
||||
PREDICTION_INPUT_NAMES = [
|
||||
"feat_static_cat",
|
||||
"feat_static_real",
|
||||
@@ -56,30 +57,33 @@ class PyraformerEstimator(PyTorchLightningEstimator):
|
||||
self,
|
||||
freq: str,
|
||||
prediction_length: int,
|
||||
#Train parameters
|
||||
# Train parameters
|
||||
inner_batch: int = 8,
|
||||
lr: float = 1e-5,
|
||||
visualize_fre: int = 2000,
|
||||
pretrain: bool = True,
|
||||
hard_sample_mining:bool=True,
|
||||
hard_sample_mining: bool = True,
|
||||
covariate_size: int = 3,
|
||||
|
||||
# Model parameters
|
||||
num_seq: int = 370,#
|
||||
decoder: str = 'FC',# selection: [FC, attention]
|
||||
# Model parameters
|
||||
num_seq: int = 370, #
|
||||
decoder: str = "FC", # selection: [FC, attention]
|
||||
context_length: Optional[int] = None,
|
||||
input_size: int = 1,
|
||||
dropout: float = 0.1,
|
||||
d_model: int = 512,
|
||||
d_inner_hid: int = 512,
|
||||
d_k: int = 128,
|
||||
d_v:int = 128,
|
||||
d_v: int = 128,
|
||||
d_bottleneck: int = 128,
|
||||
num_heads: int = 4,
|
||||
n_layer: int = 4,
|
||||
enc_in: int = 1, # depends on dataset used
|
||||
CSCM: str = "Bottleneck_Construct", # [Bottleneck_Construct, Conv_Construct, MaxPooling_Construct, AvgPooling_Construct]
|
||||
embed_type: str = "CustomEmbedding", #[DataEmbedding, CustomEmbedding]
|
||||
truncate: bool = False,
|
||||
# loss: DistributionLoss = LossFactory,
|
||||
ignore_zero: bool = True,
|
||||
single_step: bool = True,#if False, Multistep=True
|
||||
|
||||
single_step: bool = True, # if False, Multistep=True
|
||||
inner_size: int = 3,
|
||||
use_tvm: bool = False,
|
||||
num_feat_dynamic_real: int = 0,
|
||||
@@ -98,7 +102,7 @@ class PyraformerEstimator(PyTorchLightningEstimator):
|
||||
trainer_kwargs: Optional[Dict[str, Any]] = dict(),
|
||||
train_sampler: Optional[InstanceSampler] = None,
|
||||
validation_sampler: Optional[InstanceSampler] = None,
|
||||
window_size: int = [4, 4, 4]
|
||||
window_size: int = [4, 4, 4],
|
||||
) -> None:
|
||||
trainer_kwargs = {
|
||||
"max_epochs": 10,
|
||||
@@ -121,15 +125,25 @@ class PyraformerEstimator(PyTorchLightningEstimator):
|
||||
self.d_inner_hid = d_inner_hid
|
||||
self.d_k = d_k
|
||||
self.d_v = d_v
|
||||
self.d_bottleneck = d_bottleneck
|
||||
self.num_heads = num_heads
|
||||
self.n_layer = n_layer
|
||||
self.single_step = single_step
|
||||
self.ignore_zero = ignore_zero
|
||||
self.loss = LossFactory(self.ignore_zero) if self.single_step==True else torch.nn.MSELoss(reduction='none')
|
||||
self.decoder = decoder
|
||||
self.enc_in = enc_in
|
||||
self.CSCM = CSCM
|
||||
self.embed_type = embed_type
|
||||
self.truncate = truncate
|
||||
self.loss = (
|
||||
LossFactory(self.ignore_zero)
|
||||
if self.single_step == True
|
||||
else torch.nn.MSELoss(reduction="none")
|
||||
)
|
||||
self.batch_size = batch_size
|
||||
self.distr_output = distr_output
|
||||
|
||||
self.window_size = window_size#[4,4,4]#window_size
|
||||
self.window_size = window_size # [4,4,4]#window_size
|
||||
self.inner_size = inner_size
|
||||
self.use_tvm = use_tvm
|
||||
self.prediction_length = prediction_length
|
||||
@@ -319,25 +333,67 @@ class PyraformerEstimator(PyTorchLightningEstimator):
|
||||
def create_lightning_module(self) -> PyraformerLightningModule:
|
||||
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
|
||||
if self.single_step:
|
||||
model = PyraformerSSModel(freq= self.freq, covariate_size = self.covariate_size,
|
||||
num_seq=self.num_seq, input_size = self.input_size, dropout = self.dropout, d_model = self.d_model,
|
||||
d_inner_hid = self.d_inner_hid, d_k = self.d_k, d_v = self.d_v,
|
||||
num_heads = self.num_heads, n_layer = self.n_layer, loss = self.loss,
|
||||
window_size = self.window_size, inner_size = self.inner_size,
|
||||
use_tvm = self.use_tvm, prediction_length = self.prediction_length,context_length = self.context_length, lags_seq = self.lags_seq, num_feat_dynamic_real= self.num_feat_dynamic_real,
|
||||
num_feat_static_cat = self.num_feat_static_cat,
|
||||
num_feat_static_real = self.num_feat_static_real,
|
||||
cardinality = self.cardinality,
|
||||
embedding_dimension = self.embedding_dimension,
|
||||
distr_output=self.distr_output,
|
||||
scaling=self.scaling,num_parallel_samples=self.num_parallel_samples, device=device)
|
||||
# else:
|
||||
# model = PyraformerLRModel(freq= self.freq, covariate_size = self.covariate_size,
|
||||
# num_seq=self.num_seq, input_size = self.input_size, dropout = self.dropout, d_model = self.d_model,
|
||||
# d_inner_hid = self.d_inner_hid, d_k = self.d_k, d_v = self.d_v,
|
||||
# num_heads = self.num_heads, n_layer = self.n_layer, loss = self.loss,
|
||||
# window_size = self.window_size, inner_size = self.inner_size,
|
||||
# use_tvm = self.use_tvm, prediction_length = self.prediction_length,context_length = self.context_length, lags_seq = self.lags_seq, device=device)
|
||||
model = PyraformerSSModel(
|
||||
freq=self.freq,
|
||||
covariate_size=self.covariate_size,
|
||||
num_seq=self.num_seq,
|
||||
input_size=self.input_size,
|
||||
dropout=self.dropout,
|
||||
d_model = self.d_model,
|
||||
d_inner_hid=self.d_inner_hid,
|
||||
d_k=self.d_k,
|
||||
d_v=self.d_v,
|
||||
num_heads=self.num_heads,
|
||||
n_layer=self.n_layer,
|
||||
loss=self.loss,
|
||||
window_size=self.window_size,
|
||||
inner_size=self.inner_size,
|
||||
use_tvm=self.use_tvm,
|
||||
prediction_length=self.prediction_length,
|
||||
context_length=self.context_length,
|
||||
lags_seq=self.lags_seq,
|
||||
num_feat_dynamic_real=self.num_feat_dynamic_real,
|
||||
num_feat_static_cat=self.num_feat_static_cat,
|
||||
num_feat_static_real=self.num_feat_static_real,
|
||||
cardinality=self.cardinality,
|
||||
embedding_dimension=self.embedding_dimension,
|
||||
distr_output=self.distr_output,
|
||||
scaling=self.scaling,
|
||||
num_parallel_samples=self.num_parallel_samples,
|
||||
device=device,
|
||||
)
|
||||
else:
|
||||
model = PyraformerLRModel(
|
||||
predict_step=self.prediction_length,
|
||||
d_model=self.d_model,
|
||||
input_size=self.input_size,
|
||||
decoder=self.decoder,
|
||||
window_size=self.window_size,
|
||||
truncate=self.truncate,
|
||||
d_inner_hid=self.d_inner_hid,
|
||||
d_k=self.d_k,
|
||||
d_v=self.d_v,
|
||||
dropout=self.dropout,
|
||||
enc_in=self.enc_in,
|
||||
covariate_size=self.covariate_size,
|
||||
seq_num=self.num_seq,
|
||||
CSCM=self.CSCM,
|
||||
d_bottleneck=self.d_bottleneck,
|
||||
num_head=self.num_heads,
|
||||
n_layer=self.n_layer,
|
||||
inner_size=self.inner_size,
|
||||
use_tvm=self.use_tvm,
|
||||
prediction_length=self.prediction_length,
|
||||
context_length=self.context_length,
|
||||
lags_seq=self.lags_seq,
|
||||
num_feat_dynamic_real=self.num_feat_dynamic_real,
|
||||
num_feat_static_cat=self.num_feat_static_cat,
|
||||
num_feat_static_real=self.num_feat_static_real,
|
||||
cardinality=self.cardinality,
|
||||
embedding_dimension=self.embedding_dimension,
|
||||
num_parallel_samples=self.num_parallel_samples,
|
||||
embed_type = self.embed_type,
|
||||
distr_output= self.distr_output,
|
||||
device=device,
|
||||
)
|
||||
return PyraformerLightningModule(model=model, loss=self.loss)
|
||||
|
||||
|
||||
|
||||
@@ -6,77 +6,80 @@ from module import PyraformerSSModel
|
||||
from module import PyraformerLRModel
|
||||
from tools import SingleStepLoss as LossFactory
|
||||
from tools import AE_loss
|
||||
|
||||
# from module import PyraformerModel
|
||||
|
||||
|
||||
class PyraformerLightningModule(pl.LightningModule):
|
||||
def __init__(self, model: PyraformerSSModel, loss: DistributionLoss = LossFactory, lr: float = 1e-5, weight_decay: float = 1e-8,) -> None:
|
||||
super().__init__()
|
||||
self.save_hyperparameters()
|
||||
self.model = model
|
||||
self.loss = loss
|
||||
self.lr = lr
|
||||
self.weight_decay = weight_decay
|
||||
def __init__(
|
||||
self,
|
||||
model: PyraformerSSModel,
|
||||
loss: DistributionLoss = LossFactory,
|
||||
lr: float = 1e-5,
|
||||
weight_decay: float = 1e-8,
|
||||
) -> None:
|
||||
super().__init__()
|
||||
self.save_hyperparameters()
|
||||
self.model = model
|
||||
self.loss = loss
|
||||
self.lr = lr
|
||||
self.weight_decay = weight_decay
|
||||
|
||||
def training_step(self, batch, batch_idx: int):
|
||||
def training_step(self, batch, batch_idx: int):
|
||||
|
||||
"""Execute training step"""
|
||||
train_loss = self(batch)
|
||||
self.log(
|
||||
"train_loss",
|
||||
train_loss,
|
||||
on_epoch=True,
|
||||
on_step=False,
|
||||
prog_bar=True,
|
||||
)
|
||||
return train_loss
|
||||
"""Execute training step"""
|
||||
train_loss = self(batch)
|
||||
self.log(
|
||||
"train_loss",
|
||||
train_loss,
|
||||
on_epoch=True,
|
||||
on_step=False,
|
||||
prog_bar=True,
|
||||
)
|
||||
return train_loss
|
||||
|
||||
def validation_step(self, batch, batch_idx: int):
|
||||
"""Execute validation step"""
|
||||
with torch.inference_mode():
|
||||
val_loss = self(batch)
|
||||
self.log("val_loss", val_loss, on_epoch=True, on_step=False, prog_bar=True)
|
||||
return val_loss
|
||||
def validation_step(self, batch, batch_idx: int):
|
||||
"""Execute validation step"""
|
||||
with torch.inference_mode():
|
||||
val_loss = self(batch)
|
||||
self.log("val_loss", val_loss, on_epoch=True, on_step=False, prog_bar=True)
|
||||
return val_loss
|
||||
|
||||
def configure_optimizers(self):
|
||||
"""Returns the optimizer to use"""
|
||||
return torch.optim.Adam(
|
||||
self.model.parameters(),
|
||||
lr=self.lr,
|
||||
weight_decay=self.weight_decay,
|
||||
)
|
||||
def configure_optimizers(self):
|
||||
"""Returns the optimizer to use"""
|
||||
return torch.optim.Adam(
|
||||
self.model.parameters(),
|
||||
lr=self.lr,
|
||||
weight_decay=self.weight_decay,
|
||||
)
|
||||
|
||||
def forward(self, batch):
|
||||
feat_static_cat = batch["feat_static_cat"]
|
||||
feat_static_real = batch["feat_static_real"]
|
||||
past_time_feat = batch["past_time_feat"]
|
||||
past_target = batch["past_target"]
|
||||
future_time_feat = batch["future_time_feat"]
|
||||
future_target = batch["future_target"]
|
||||
past_observed_values = batch["past_observed_values"]
|
||||
future_observed_values = batch["future_observed_values"]
|
||||
|
||||
Pyraformer_inputs, scale, _ = self.model.create_network_inputs(
|
||||
feat_static_cat,
|
||||
feat_static_real,
|
||||
past_time_feat,
|
||||
past_target,
|
||||
past_observed_values,
|
||||
future_time_feat,
|
||||
future_target,
|
||||
)
|
||||
params = self.model.output_params(Pyraformer_inputs)
|
||||
distr = self.model.output_distribution(params, scale)
|
||||
|
||||
loss_values = self.loss(distr, future_target)
|
||||
|
||||
if len(self.model.target_shape) == 0:
|
||||
loss_weights = future_observed_values
|
||||
else:
|
||||
loss_weights = future_observed_values.min(dim=-1, keepdim=False)
|
||||
|
||||
return weighted_average(loss_values, weights=loss_weights)
|
||||
def forward(self, batch):
|
||||
feat_static_cat = batch["feat_static_cat"]
|
||||
feat_static_real = batch["feat_static_real"]
|
||||
past_time_feat = batch["past_time_feat"]
|
||||
past_target = batch["past_target"]
|
||||
future_time_feat = batch["future_time_feat"]
|
||||
future_target = batch["future_target"]
|
||||
past_observed_values = batch["past_observed_values"]
|
||||
future_observed_values = batch["future_observed_values"]
|
||||
|
||||
Pyraformer_inputs, scale, _ = self.model.create_network_inputs(
|
||||
feat_static_cat,
|
||||
feat_static_real,
|
||||
past_time_feat,
|
||||
past_target,
|
||||
past_observed_values,
|
||||
future_time_feat,
|
||||
future_target,
|
||||
)
|
||||
params = self.model.output_params(Pyraformer_inputs)
|
||||
distr = self.model.output_distribution(params, scale)
|
||||
|
||||
loss_values = self.loss(distr, future_target)
|
||||
|
||||
if len(self.model.target_shape) == 0:
|
||||
loss_weights = future_observed_values
|
||||
else:
|
||||
loss_weights = future_observed_values.min(dim=-1, keepdim=False)
|
||||
|
||||
return weighted_average(loss_values, weights=loss_weights)
|
||||
|
||||
+457
-81
@@ -8,19 +8,43 @@ from gluonts.torch.modules.feature import FeatureEmbedder
|
||||
from gluonts.torch.modules.scaler import MeanScaler, NOPScaler
|
||||
|
||||
from pyraformer.Layers import EncoderLayer, Predictor, Decoder
|
||||
from pyraformer.Layers import Bottleneck_Construct, Conv_Construct, MaxPooling_Construct, AvgPooling_Construct
|
||||
from pyraformer.Layers import get_mask, refer_points, get_k_q, get_q_k, get_subsequent_mask
|
||||
from pyraformer.Layers import (
|
||||
Bottleneck_Construct,
|
||||
Conv_Construct,
|
||||
MaxPooling_Construct,
|
||||
AvgPooling_Construct,
|
||||
)
|
||||
from pyraformer.Layers import (
|
||||
get_mask,
|
||||
refer_points,
|
||||
get_k_q,
|
||||
get_q_k,
|
||||
get_subsequent_mask,
|
||||
)
|
||||
from pyraformer.embed import SingleStepEmbedding, DataEmbedding, CustomEmbedding
|
||||
|
||||
|
||||
class EncoderSS(nn.Module):
|
||||
@validated()
|
||||
def __init__(self, covariate_size,
|
||||
num_seq, input_size ,dropout , d_model,
|
||||
d_inner_hid, d_k, d_v ,
|
||||
num_heads , n_layer, loss ,
|
||||
window_size , inner_size,
|
||||
use_tvm, prediction_length, device):
|
||||
def __init__(
|
||||
self,
|
||||
covariate_size,
|
||||
num_seq,
|
||||
input_size,
|
||||
dropout,
|
||||
d_model,
|
||||
d_inner_hid,
|
||||
d_k,
|
||||
d_v,
|
||||
num_heads,
|
||||
n_layer,
|
||||
loss,
|
||||
window_size,
|
||||
inner_size,
|
||||
use_tvm,
|
||||
prediction_length,
|
||||
device,
|
||||
):
|
||||
super().__init__()
|
||||
|
||||
self.d_model = d_model
|
||||
@@ -30,21 +54,49 @@ class EncoderSS(nn.Module):
|
||||
self.indexes = refer_points(self.all_size, window_size, device)
|
||||
|
||||
if use_tvm:
|
||||
|
||||
assert len(set(self.window_size)) == 1, "Only constant window size is supported."
|
||||
|
||||
assert (
|
||||
len(set(self.window_size)) == 1
|
||||
), "Only constant window size is supported."
|
||||
q_k_mask = get_q_k(input_size, inner_size, window_size[0], device)
|
||||
k_q_mask = get_k_q(q_k_mask)
|
||||
self.layers = nn.ModuleList([
|
||||
EncoderLayer(d_model, d_inner_hid, num_heads, d_k, d_v, dropout=dropout, \
|
||||
normalize_before=False, use_tvm=True, q_k_mask=q_k_mask, k_q_mask=k_q_mask) for i in range(n_layer)
|
||||
])
|
||||
self.layers = nn.ModuleList(
|
||||
[
|
||||
EncoderLayer(
|
||||
d_model,
|
||||
d_inner_hid,
|
||||
num_heads,
|
||||
d_k,
|
||||
d_v,
|
||||
dropout=dropout,
|
||||
normalize_before=False,
|
||||
use_tvm=True,
|
||||
q_k_mask=q_k_mask,
|
||||
k_q_mask=k_q_mask,
|
||||
)
|
||||
for i in range(n_layer)
|
||||
]
|
||||
)
|
||||
else:
|
||||
self.layers = nn.ModuleList([
|
||||
EncoderLayer(d_model, d_inner_hid, num_heads, d_k, d_v, dropout=dropout, \
|
||||
normalize_before=False) for i in range(n_layer)
|
||||
])
|
||||
self.layers = nn.ModuleList(
|
||||
[
|
||||
EncoderLayer(
|
||||
d_model,
|
||||
d_inner_hid,
|
||||
num_heads,
|
||||
d_k,
|
||||
d_v,
|
||||
dropout=dropout,
|
||||
normalize_before=False,
|
||||
)
|
||||
for i in range(n_layer)
|
||||
]
|
||||
)
|
||||
|
||||
self.embedding = SingleStepEmbedding(
|
||||
covariate_size, num_seq, d_model, input_size, device
|
||||
)
|
||||
|
||||
self.embedding = SingleStepEmbedding(covariate_size, num_seq, d_model, input_size, device)
|
||||
self.conv_layers = Bottleneck_Construct(d_model, window_size, d_k)
|
||||
|
||||
def forward(self, sequence):
|
||||
@@ -57,21 +109,38 @@ class EncoderSS(nn.Module):
|
||||
for i in range(len(self.layers)):
|
||||
seq_enc, _ = self.layers[i](seq_enc, mask)
|
||||
|
||||
indexes = self.indexes.repeat(seq_enc.size(0), 1, 1, seq_enc.size(2)).to(seq_enc.device)
|
||||
indexes = self.indexes.repeat(seq_enc.size(0), 1, 1, seq_enc.size(2)).to(
|
||||
seq_enc.device
|
||||
)
|
||||
indexes = indexes.view(seq_enc.size(0), -1, seq_enc.size(2))
|
||||
all_enc = torch.gather(seq_enc, 1, indexes)
|
||||
all_enc = all_enc.view(seq_enc.size(0), self.all_size[0], -1)
|
||||
|
||||
return all_enc
|
||||
|
||||
|
||||
class PyraformerSSModel(nn.Module):
|
||||
@validated()
|
||||
def __init__(self, freq, covariate_size,
|
||||
num_seq, input_size ,dropout, d_model,
|
||||
d_inner_hid, d_k, d_v,
|
||||
num_heads , n_layer, loss ,
|
||||
window_size , inner_size,
|
||||
use_tvm, prediction_length,context_length,lags_seq,
|
||||
def __init__(
|
||||
self,
|
||||
freq,
|
||||
covariate_size,
|
||||
num_seq,
|
||||
input_size,
|
||||
dropout,
|
||||
d_model,
|
||||
d_inner_hid,
|
||||
d_k,
|
||||
d_v,
|
||||
num_heads,
|
||||
n_layer,
|
||||
loss,
|
||||
window_size,
|
||||
inner_size,
|
||||
use_tvm,
|
||||
prediction_length,
|
||||
context_length,
|
||||
lags_seq,
|
||||
num_feat_dynamic_real,
|
||||
num_feat_static_cat,
|
||||
num_feat_static_real,
|
||||
@@ -79,25 +148,41 @@ class PyraformerSSModel(nn.Module):
|
||||
embedding_dimension,
|
||||
distr_output,
|
||||
# loss: DistributionLoss = NegativeLogLikelihood(),
|
||||
scaling,num_parallel_samples,device):
|
||||
|
||||
scaling,
|
||||
num_parallel_samples,
|
||||
device,
|
||||
|
||||
):
|
||||
|
||||
super().__init__()
|
||||
self.context_length = context_length
|
||||
self.lags_seq = lags_seq or get_lags_for_frequency(freq_str=freq)
|
||||
self.encoder = EncoderSS(covariate_size,
|
||||
num_seq, input_size ,dropout , d_model,
|
||||
d_inner_hid, d_k, d_v ,
|
||||
num_heads , n_layer, loss ,
|
||||
window_size , inner_size,
|
||||
use_tvm, prediction_length, device)
|
||||
self.encoder = EncoderSS(
|
||||
covariate_size,
|
||||
num_seq,
|
||||
input_size,
|
||||
dropout,
|
||||
d_model,
|
||||
d_inner_hid,
|
||||
d_k,
|
||||
d_v,
|
||||
num_heads,
|
||||
n_layer,
|
||||
loss,
|
||||
window_size,
|
||||
inner_size,
|
||||
use_tvm,
|
||||
prediction_length,
|
||||
device,
|
||||
)
|
||||
|
||||
# convert hidden vectors into two scalar
|
||||
# convert hidden vectors into two scalar
|
||||
self.mean_hidden = Predictor(4 * d_model, 1)
|
||||
self.var_hidden = Predictor(4 * d_model, 1)
|
||||
|
||||
|
||||
self.softplus = nn.Softplus()
|
||||
|
||||
self.distr_output = distr_output
|
||||
|
||||
def forward(self, data):
|
||||
enc_output = self.encoder(data)
|
||||
|
||||
@@ -114,10 +199,11 @@ class PyraformerSSModel(nn.Module):
|
||||
sample_mu = mu[:, -1] * v
|
||||
sample_sigma = sigma[:, -1] * v
|
||||
return sample_mu, sample_sigma
|
||||
|
||||
|
||||
@property
|
||||
def _past_length(self) -> int:
|
||||
return self.context_length + max(self.lags_seq)
|
||||
|
||||
@property
|
||||
def _number_of_features(self) -> int:
|
||||
return (
|
||||
@@ -126,6 +212,7 @@ class PyraformerSSModel(nn.Module):
|
||||
+ self.num_feat_static_real
|
||||
+ 1 # the log(scale)
|
||||
)
|
||||
|
||||
def get_lagged_subsequences(
|
||||
self, sequence: torch.Tensor, subsequences_length: int, shift: int = 0
|
||||
) -> torch.Tensor:
|
||||
@@ -178,7 +265,7 @@ class PyraformerSSModel(nn.Module):
|
||||
assert (
|
||||
features is None or features.shape[2] == self._number_of_features
|
||||
), f"{features.shape[2]}, expected {self._number_of_features}"
|
||||
|
||||
|
||||
def create_network_inputs(
|
||||
self,
|
||||
feat_static_cat: torch.Tensor,
|
||||
@@ -274,70 +361,159 @@ class PyraformerSSModel(nn.Module):
|
||||
if trailing_n is not None:
|
||||
sliced_params = [p[:, -trailing_n:] for p in params]
|
||||
return self.distr_output.distribution(sliced_params, scale=scale)
|
||||
|
||||
|
||||
|
||||
class Encoder(nn.Module):
|
||||
@validated()
|
||||
def __init__(self, model, window_size, truncate, input_size, inner_size,decoder,d_model, d_k,d_v,d_inner_hid, dropout,n_layer ,enc_in,covariate_size,seq_num,CSCM,d_bottleneck,num_head,use_tvm,device):
|
||||
def __init__(
|
||||
self,
|
||||
# model,
|
||||
window_size,
|
||||
truncate,
|
||||
input_size,
|
||||
inner_size,
|
||||
decoder,
|
||||
d_model,
|
||||
d_k,
|
||||
d_v,
|
||||
d_inner_hid,
|
||||
dropout,
|
||||
n_layer,
|
||||
enc_in,
|
||||
covariate_size,
|
||||
seq_num,
|
||||
CSCM,
|
||||
d_bottleneck,
|
||||
num_head,
|
||||
use_tvm,
|
||||
embed_type,
|
||||
device,
|
||||
):
|
||||
super().__init__()
|
||||
|
||||
self.d_model = d_model
|
||||
self.model_type = model
|
||||
# self.model_type = model
|
||||
self.window_size = window_size
|
||||
self.truncate = truncate
|
||||
if decoder == 'attention':
|
||||
self.mask, self.all_size = get_mask(input_size, window_size, inner_size, device)
|
||||
if decoder == "attention":
|
||||
self.mask, self.all_size = get_mask(
|
||||
input_size, window_size, inner_size, device
|
||||
)
|
||||
else:
|
||||
self.mask, self.all_size = get_mask(input_size+1, window_size, inner_size, device)
|
||||
self.mask, self.all_size = get_mask(
|
||||
input_size + 1, window_size, inner_size, device
|
||||
)
|
||||
self.decoder_type = decoder
|
||||
if decoder == 'FC':
|
||||
if decoder == "FC":
|
||||
self.indexes = refer_points(self.all_size, window_size, device)
|
||||
|
||||
if use_tvm:
|
||||
assert len(set(self.window_size)) == 1, "Only constant window size is supported."
|
||||
padding = 1 if decoder == 'FC' else 0
|
||||
q_k_mask = get_q_k(input_size + padding, inner_size, window_size[0],device)
|
||||
assert (
|
||||
len(set(self.window_size)) == 1
|
||||
), "Only constant window size is supported."
|
||||
padding = 1 if decoder == "FC" else 0
|
||||
q_k_mask = get_q_k(input_size + padding, inner_size, window_size[0], device)
|
||||
k_q_mask = get_k_q(q_k_mask)
|
||||
self.layers = nn.ModuleList([
|
||||
EncoderLayer(d_model, d_inner_hid, num_head, d_k, d_v, dropout=dropout, \
|
||||
normalize_before=False, use_tvm=True, q_k_mask=q_k_mask, k_q_mask=k_q_mask) for i in range(n_layer)
|
||||
])
|
||||
self.layers = nn.ModuleList(
|
||||
[
|
||||
EncoderLayer(
|
||||
d_model,
|
||||
d_inner_hid,
|
||||
num_head,
|
||||
d_k,
|
||||
d_v,
|
||||
dropout=dropout,
|
||||
normalize_before=False,
|
||||
use_tvm=True,
|
||||
q_k_mask=q_k_mask,
|
||||
k_q_mask=k_q_mask,
|
||||
)
|
||||
for i in range(n_layer)
|
||||
]
|
||||
)
|
||||
else:
|
||||
self.layers = nn.ModuleList([
|
||||
EncoderLayer(d_model, d_inner_hid, num_head, d_k, d_v, dropout=dropout, \
|
||||
normalize_before=False) for i in range(n_layer)
|
||||
])
|
||||
|
||||
if opt.embed_type == 'CustomEmbedding':
|
||||
self.enc_embedding = CustomEmbedding(enc_in, d_model, covariate_size, seq_num, dropout)
|
||||
self.layers = nn.ModuleList(
|
||||
[
|
||||
EncoderLayer(
|
||||
d_model,
|
||||
d_inner_hid,
|
||||
num_head,
|
||||
d_k,
|
||||
d_v,
|
||||
dropout=dropout,
|
||||
normalize_before=False,
|
||||
)
|
||||
for i in range(n_layer)
|
||||
]
|
||||
)
|
||||
|
||||
if embed_type == "CustomEmbedding":
|
||||
self.enc_embedding = CustomEmbedding(
|
||||
enc_in, d_model, covariate_size, seq_num, dropout
|
||||
)
|
||||
else:
|
||||
self.enc_embedding = DataEmbedding(enc_in, d_model, dropout)
|
||||
|
||||
|
||||
self.conv_layers = eval(CSCM)(d_model, window_size, d_bottleneck)
|
||||
|
||||
|
||||
|
||||
def forward(self, x_enc, x_mark_enc):
|
||||
seq_enc = self.enc_embedding(x_enc, x_mark_enc)
|
||||
|
||||
|
||||
mask = self.mask.repeat(len(seq_enc), 1, 1).to(x_enc.device)
|
||||
seq_enc = self.conv_layers(seq_enc)
|
||||
|
||||
for i in range(len(self.layers)):
|
||||
seq_enc, _ = self.layers[i](seq_enc, mask)
|
||||
|
||||
if self.decoder_type == 'FC':
|
||||
indexes = self.indexes.repeat(seq_enc.size(0), 1, 1, seq_enc.size(2)).to(seq_enc.device)
|
||||
|
||||
if self.decoder_type == "FC":
|
||||
indexes = self.indexes.repeat(seq_enc.size(0), 1, 1, seq_enc.size(2)).to(
|
||||
seq_enc.device
|
||||
)
|
||||
indexes = indexes.view(seq_enc.size(0), -1, seq_enc.size(2))
|
||||
all_enc = torch.gather(seq_enc, 1, indexes)
|
||||
seq_enc = all_enc.view(seq_enc.size(0), self.all_size[0], -1)
|
||||
elif self.decoder_type == 'attention' and self.truncate:
|
||||
seq_enc = seq_enc[:, :self.all_size[0]]
|
||||
|
||||
elif self.decoder_type == "attention" and self.truncate:
|
||||
seq_enc = seq_enc[:, : self.all_size[0]]
|
||||
|
||||
return seq_enc
|
||||
|
||||
|
||||
|
||||
class PyraformerLRModel(nn.Module):
|
||||
@validated()
|
||||
def __init__(self, predict_step, d_model, input_size, decoder, window_size, truncate, model,d_inner_hid,d_k,d_v,dropout,enc_in,covariate_size,seq_num,CSCM,d_bottleneck,num_head,use_tvm,device):
|
||||
def __init__(
|
||||
self,
|
||||
predict_step,
|
||||
d_model,
|
||||
input_size,
|
||||
decoder,
|
||||
window_size,
|
||||
truncate,
|
||||
d_inner_hid,
|
||||
d_k,
|
||||
d_v,
|
||||
dropout,
|
||||
enc_in,
|
||||
covariate_size,
|
||||
seq_num,
|
||||
CSCM,
|
||||
d_bottleneck,
|
||||
num_head,
|
||||
n_layer,
|
||||
inner_size,
|
||||
use_tvm,
|
||||
prediction_length,
|
||||
context_length,
|
||||
lags_seq,
|
||||
num_feat_dynamic_real,
|
||||
num_feat_static_cat,
|
||||
num_feat_static_real,
|
||||
cardinality,
|
||||
embedding_dimension,
|
||||
num_parallel_samples,
|
||||
embed_type,
|
||||
distr_output,
|
||||
device
|
||||
):
|
||||
super().__init__()
|
||||
|
||||
self.predict_step = predict_step
|
||||
@@ -345,13 +521,50 @@ class PyraformerLRModel(nn.Module):
|
||||
self.input_size = input_size
|
||||
self.decoder_type = decoder
|
||||
self.channels = enc_in
|
||||
|
||||
self.encoder = Encoder(model, window_size, truncate, input_size, inner_size,decoder,d_model, d_k,d_v,d_inner_hid, dropout,n_layer, enc_in,covariate_size,seq_num,CSCM,d_bottleneck,num_head,use_tvm,device)
|
||||
if decoder == 'attention':
|
||||
self.distr_output = distr_output
|
||||
self.context_length = context_length
|
||||
self.lags_seq = lags_seq
|
||||
|
||||
self.encoder = Encoder(
|
||||
# model,
|
||||
window_size,
|
||||
truncate,
|
||||
input_size,
|
||||
inner_size,
|
||||
decoder,
|
||||
d_model,
|
||||
d_k,
|
||||
d_v,
|
||||
d_inner_hid,
|
||||
dropout,
|
||||
n_layer,
|
||||
enc_in,
|
||||
covariate_size,
|
||||
seq_num,
|
||||
CSCM,
|
||||
d_bottleneck,
|
||||
num_head,
|
||||
use_tvm,
|
||||
embed_type,
|
||||
device,
|
||||
)
|
||||
if decoder == "attention":
|
||||
mask = get_subsequent_mask(input_size, window_size, predict_step, truncate)
|
||||
self.decoder = Decoder(model,d_model,d_inner_hid,num_head,d_k,d_v,dropout,enc_in,covariate_size,seq_num, mask)
|
||||
self.decoder = Decoder(
|
||||
# model,
|
||||
d_model,
|
||||
d_inner_hid,
|
||||
num_head,
|
||||
d_k,
|
||||
d_v,
|
||||
dropout,
|
||||
enc_in,
|
||||
covariate_size,
|
||||
seq_num,
|
||||
mask,
|
||||
)
|
||||
self.predictor = Predictor(d_model, enc_in)
|
||||
elif opt.decoder == 'FC':
|
||||
elif decoder == "FC":
|
||||
self.predictor = Predictor(4 * d_model, predict_step * enc_in)
|
||||
|
||||
def forward(self, x_enc, x_mark_enc, x_dec, x_mark_dec, pretrain):
|
||||
@@ -364,18 +577,181 @@ class PyraformerLRModel(nn.Module):
|
||||
type_prediction: batch*seq_len*num_classes (not normalized);
|
||||
time_prediction: batch*seq_len.
|
||||
"""
|
||||
if self.decoder_type == 'attention':
|
||||
if self.decoder_type == "attention":
|
||||
enc_output = self.encoder(x_enc, x_mark_enc)
|
||||
dec_enc = self.decoder(x_dec, x_mark_dec, enc_output)
|
||||
|
||||
if pretrain:
|
||||
dec_enc = torch.cat([enc_output[:, :self.input_size], dec_enc], dim=1)
|
||||
dec_enc = torch.cat([enc_output[:, : self.input_size], dec_enc], dim=1)
|
||||
pred = self.predictor(dec_enc)
|
||||
else:
|
||||
pred = self.predictor(dec_enc)
|
||||
elif self.decoder_type == 'FC':
|
||||
elif self.decoder_type == "FC":
|
||||
enc_output = self.encoder(x_enc, x_mark_enc)[:, -1, :]
|
||||
pred = self.predictor(enc_output).view(enc_output.size(0), self.predict_step, -1)
|
||||
pred = self.predictor(enc_output).view(
|
||||
enc_output.size(0), self.predict_step, -1
|
||||
)
|
||||
|
||||
return pred
|
||||
|
||||
@property
|
||||
def _past_length(self) -> int:
|
||||
return self.predict_step #+ max(0,self.lags_seq)
|
||||
|
||||
@property
|
||||
def _number_of_features(self) -> int:
|
||||
return (
|
||||
sum(self.embedding_dimension)
|
||||
+ self.num_feat_dynamic_real
|
||||
+ self.num_feat_static_real
|
||||
+ 1 # the log(scale)
|
||||
)
|
||||
|
||||
def get_lagged_subsequences(
|
||||
self, sequence: torch.Tensor, subsequences_length: int, shift: int = 0
|
||||
) -> torch.Tensor:
|
||||
"""
|
||||
Returns lagged subsequences of a given sequence.
|
||||
Parameters
|
||||
----------
|
||||
sequence : Tensor
|
||||
the sequence from which lagged subsequences should be extracted.
|
||||
Shape: (N, T, C).
|
||||
subsequences_length : int
|
||||
length of the subsequences to be extracted.
|
||||
shift: int
|
||||
shift the lags by this amount back.
|
||||
Returns
|
||||
--------
|
||||
lagged : Tensor
|
||||
a tensor of shape (N, S, C, I), where S = subsequences_length and
|
||||
I = len(indices), containing lagged subsequences. Specifically,
|
||||
lagged[i, j, :, k] = sequence[i, -indices[k]-S+j, :].
|
||||
"""
|
||||
sequence_length = sequence.shape[1]
|
||||
indices = [lag - shift for lag in self.lags_seq]
|
||||
|
||||
assert max(indices) + subsequences_length <= sequence_length, (
|
||||
f"lags cannot go further than history length, found lag {max(indices)} "
|
||||
f"while history length is only {sequence_length}"
|
||||
)
|
||||
|
||||
lagged_values = []
|
||||
for lag_index in indices:
|
||||
begin_index = -lag_index - subsequences_length
|
||||
end_index = -lag_index if lag_index > 0 else None
|
||||
lagged_values.append(sequence[:, begin_index:end_index, ...])
|
||||
return torch.stack(lagged_values, dim=-1)
|
||||
|
||||
def _check_shapes(
|
||||
self,
|
||||
prior_input: torch.Tensor,
|
||||
inputs: torch.Tensor,
|
||||
features: Optional[torch.Tensor],
|
||||
) -> None:
|
||||
assert len(prior_input.shape) == len(inputs.shape)
|
||||
assert (
|
||||
len(prior_input.shape) == 2 and self.input_size == 1
|
||||
) or prior_input.shape[2] == self.input_size
|
||||
assert (len(inputs.shape) == 2 and self.input_size == 1) or inputs.shape[
|
||||
-1
|
||||
] == self.input_size
|
||||
assert (
|
||||
features is None or features.shape[2] == self._number_of_features
|
||||
), f"{features.shape[2]}, expected {self._number_of_features}"
|
||||
|
||||
def create_network_inputs(
|
||||
self,
|
||||
feat_static_cat: torch.Tensor,
|
||||
feat_static_real: torch.Tensor,
|
||||
past_time_feat: torch.Tensor,
|
||||
past_target: torch.Tensor,
|
||||
past_observed_values: torch.Tensor,
|
||||
future_time_feat: Optional[torch.Tensor] = None,
|
||||
future_target: Optional[torch.Tensor] = None,
|
||||
):
|
||||
# time feature
|
||||
time_feat = (
|
||||
torch.cat(
|
||||
(
|
||||
past_time_feat[:, self._past_length - self.context_length :, ...],
|
||||
future_time_feat,
|
||||
),
|
||||
dim=1,
|
||||
)
|
||||
if future_target is not None
|
||||
else past_time_feat[:, self._past_length - self.context_length :, ...]
|
||||
)
|
||||
|
||||
# target
|
||||
context = past_target[:, -self.context_length :]
|
||||
observed_context = past_observed_values[:, -self.context_length :]
|
||||
_, scale = self.scaler(context, observed_context)
|
||||
|
||||
inputs = (
|
||||
torch.cat((past_target, future_target), dim=1) / scale
|
||||
if future_target is not None
|
||||
else past_target / scale
|
||||
)
|
||||
|
||||
inputs_length = (
|
||||
self._past_length + self.prediction_length
|
||||
if future_target is not None
|
||||
else self._past_length
|
||||
)
|
||||
assert inputs.shape[1] == inputs_length
|
||||
|
||||
subsequences_length = (
|
||||
self.context_length + self.prediction_length
|
||||
if future_target is not None
|
||||
else self.context_length
|
||||
)
|
||||
|
||||
# embeddings
|
||||
embedded_cat = self.embedder(feat_static_cat)
|
||||
static_feat = torch.cat(
|
||||
(embedded_cat, feat_static_real, scale.log()),
|
||||
dim=1,
|
||||
)
|
||||
expanded_static_feat = static_feat.unsqueeze(1).expand(
|
||||
-1, time_feat.shape[1], -1
|
||||
)
|
||||
|
||||
features = torch.cat((expanded_static_feat, time_feat), dim=-1)
|
||||
|
||||
# self._check_shapes(prior_input, inputs, features)
|
||||
|
||||
# sequence = torch.cat((prior_input, inputs), dim=1)
|
||||
lagged_sequence = self.get_lagged_subsequences(
|
||||
sequence=inputs,
|
||||
subsequences_length=subsequences_length,
|
||||
)
|
||||
|
||||
lags_shape = lagged_sequence.shape
|
||||
reshaped_lagged_sequence = lagged_sequence.reshape(
|
||||
lags_shape[0], lags_shape[1], -1
|
||||
)
|
||||
|
||||
transformer_inputs = torch.cat((reshaped_lagged_sequence, features), dim=-1)
|
||||
|
||||
return transformer_inputs, scale, static_feat
|
||||
|
||||
def output_params(self, transformer_inputs):
|
||||
enc_input = transformer_inputs[:, : self.context_length, ...]
|
||||
dec_input = transformer_inputs[:, self.context_length :, ...]
|
||||
|
||||
enc_out = self.transformer.encoder(enc_input)
|
||||
dec_output = self.transformer.decoder(
|
||||
dec_input, enc_out, tgt_mask=self.tgt_mask
|
||||
)
|
||||
|
||||
return self.param_proj(dec_output)
|
||||
|
||||
@torch.jit.ignore
|
||||
def output_distribution(
|
||||
self, params, scale=None, trailing_n=None
|
||||
) -> torch.distributions.Distribution:
|
||||
sliced_params = params
|
||||
if trailing_n is not None:
|
||||
sliced_params = [p[:, -trailing_n:] for p in params]
|
||||
return self.distr_output.distribution(sliced_params, scale=scale)
|
||||
|
||||
File diff suppressed because one or more lines are too long
@@ -378,10 +378,10 @@ class Predictor(nn.Module):
|
||||
class Decoder(nn.Module):
|
||||
""" A encoder model with self attention mechanism. """
|
||||
|
||||
def __init__(self, model,d_model,d_inner_hid,num_head,d_k,d_v,dropout,enc_in,covariate_size,seq_num, mask):
|
||||
def __init__(self,d_model,d_inner_hid,num_head,d_k,d_v,dropout,enc_in,covariate_size,seq_num, mask):
|
||||
super().__init__()
|
||||
|
||||
self.model_type = model
|
||||
# self.model_type = model
|
||||
self.mask = mask
|
||||
|
||||
self.layers = nn.ModuleList([
|
||||
|
||||
Reference in New Issue
Block a user