mirror of
https://github.com/wassname/pytorch-transformer-ts.git
synced 2026-06-27 16:46:28 +08:00
format
This commit is contained in:
+17
-27
@@ -55,28 +55,26 @@ TRAINING_INPUT_NAMES = PREDICTION_INPUT_NAMES + [
|
||||
|
||||
# -
|
||||
|
||||
|
||||
class XformerEstimator(PyTorchLightningEstimator):
|
||||
@validated()
|
||||
def __init__(
|
||||
self,
|
||||
freq: str,
|
||||
prediction_length: int,
|
||||
|
||||
# Xformer arguments
|
||||
nhead: int,
|
||||
num_encoder_layers: int,
|
||||
num_decoder_layers: int,
|
||||
hidden_layer_multiplier: int = 1,
|
||||
attention_args = {"name": "scaled_dot_product"},
|
||||
attention_args={"name": "scaled_dot_product"},
|
||||
input_size: int = 1,
|
||||
activation: str = "gelu",
|
||||
residual_norm_style: str = "pre",
|
||||
dropout: float = 0.1,
|
||||
use_rotary_embeddings = False,
|
||||
reversible = False,
|
||||
|
||||
use_rotary_embeddings=False,
|
||||
reversible=False,
|
||||
context_length: Optional[int] = None,
|
||||
|
||||
num_feat_dynamic_real: int = 0,
|
||||
num_feat_static_cat: int = 0,
|
||||
num_feat_static_real: int = 0,
|
||||
@@ -97,7 +95,7 @@ class XformerEstimator(PyTorchLightningEstimator):
|
||||
**trainer_kwargs,
|
||||
}
|
||||
super().__init__(trainer_kwargs=trainer_kwargs)
|
||||
|
||||
|
||||
self.freq = freq
|
||||
self.context_length = (
|
||||
context_length if context_length is not None else prediction_length
|
||||
@@ -105,7 +103,7 @@ class XformerEstimator(PyTorchLightningEstimator):
|
||||
self.prediction_length = prediction_length
|
||||
self.distr_output = distr_output
|
||||
self.loss = loss
|
||||
|
||||
|
||||
self.input_size = input_size
|
||||
self.nhead = nhead
|
||||
self.num_encoder_layers = num_encoder_layers
|
||||
@@ -117,7 +115,7 @@ class XformerEstimator(PyTorchLightningEstimator):
|
||||
self.reversible = reversible
|
||||
self.hidden_layer_multiplier = hidden_layer_multiplier
|
||||
self.residual_norm_style = residual_norm_style
|
||||
|
||||
|
||||
self.num_feat_dynamic_real = num_feat_dynamic_real
|
||||
self.num_feat_static_cat = num_feat_static_cat
|
||||
self.num_feat_static_real = num_feat_static_real
|
||||
@@ -140,10 +138,8 @@ class XformerEstimator(PyTorchLightningEstimator):
|
||||
self.train_sampler = ExpectedNumInstanceSampler(
|
||||
num_instances=1.0, min_future=prediction_length
|
||||
)
|
||||
self.validation_sampler = ValidationSplitSampler(
|
||||
min_future=prediction_length
|
||||
)
|
||||
|
||||
self.validation_sampler = ValidationSplitSampler(min_future=prediction_length)
|
||||
|
||||
def create_transformation(self) -> Transformation:
|
||||
remove_field_names = []
|
||||
if self.num_feat_static_real == 0:
|
||||
@@ -159,11 +155,7 @@ class XformerEstimator(PyTorchLightningEstimator):
|
||||
else []
|
||||
)
|
||||
+ (
|
||||
[
|
||||
SetField(
|
||||
output_field=FieldName.FEAT_STATIC_REAL, value=[0.0]
|
||||
)
|
||||
]
|
||||
[SetField(output_field=FieldName.FEAT_STATIC_REAL, value=[0.0])]
|
||||
if not self.num_feat_static_real > 0
|
||||
else []
|
||||
)
|
||||
@@ -211,9 +203,7 @@ class XformerEstimator(PyTorchLightningEstimator):
|
||||
]
|
||||
)
|
||||
|
||||
def _create_instance_splitter(
|
||||
self, module: XformerLightningModule, mode: str
|
||||
):
|
||||
def _create_instance_splitter(self, module: XformerLightningModule, mode: str):
|
||||
assert mode in ["training", "validation", "test"]
|
||||
|
||||
instance_sampler = {
|
||||
@@ -284,14 +274,14 @@ class XformerEstimator(PyTorchLightningEstimator):
|
||||
batch_size=self.batch_size,
|
||||
**kwargs,
|
||||
)
|
||||
|
||||
|
||||
def create_predictor(
|
||||
self,
|
||||
transformation: Transformation,
|
||||
module: XformerLightningModule,
|
||||
) -> PyTorchPredictor:
|
||||
prediction_splitter = self._create_instance_splitter(module, "test")
|
||||
|
||||
|
||||
return PyTorchPredictor(
|
||||
input_transform=transformation + prediction_splitter,
|
||||
input_names=PREDICTION_INPUT_NAMES,
|
||||
@@ -306,12 +296,13 @@ class XformerEstimator(PyTorchLightningEstimator):
|
||||
freq=self.freq,
|
||||
context_length=self.context_length,
|
||||
prediction_length=self.prediction_length,
|
||||
num_feat_dynamic_real=1 + self.num_feat_dynamic_real + len(self.time_features),
|
||||
num_feat_dynamic_real=1
|
||||
+ self.num_feat_dynamic_real
|
||||
+ len(self.time_features),
|
||||
num_feat_static_real=max(1, self.num_feat_static_real),
|
||||
num_feat_static_cat=max(1, self.num_feat_static_cat),
|
||||
cardinality=self.cardinality,
|
||||
embedding_dimension=self.embedding_dimension,
|
||||
|
||||
# xformer arguments
|
||||
nhead=self.nhead,
|
||||
num_encoder_layers=self.num_encoder_layers,
|
||||
@@ -323,7 +314,6 @@ class XformerEstimator(PyTorchLightningEstimator):
|
||||
use_rotary_embeddings=self.use_rotary_embeddings,
|
||||
reversible=self.reversible,
|
||||
residual_norm_style=self.residual_norm_style,
|
||||
|
||||
# univariate input
|
||||
input_size=self.input_size,
|
||||
distr_output=self.distr_output,
|
||||
@@ -331,5 +321,5 @@ class XformerEstimator(PyTorchLightningEstimator):
|
||||
scaling=self.scaling,
|
||||
num_parallel_samples=self.num_parallel_samples,
|
||||
)
|
||||
|
||||
|
||||
return XformerLightningModule(model=model, loss=self.loss)
|
||||
|
||||
@@ -19,7 +19,7 @@ class XformerLightningModule(pl.LightningModule):
|
||||
self.loss = loss
|
||||
self.lr = lr
|
||||
self.weight_decay = weight_decay
|
||||
|
||||
|
||||
def training_step(self, batch, batch_idx: int):
|
||||
"""Execute training step"""
|
||||
train_loss = self(batch)
|
||||
@@ -36,9 +36,7 @@ class XformerLightningModule(pl.LightningModule):
|
||||
"""Execute validation step"""
|
||||
with torch.no_grad():
|
||||
val_loss = self(batch)
|
||||
self.log(
|
||||
"val_loss", val_loss, on_epoch=True, on_step=False, prog_bar=True
|
||||
)
|
||||
self.log("val_loss", val_loss, on_epoch=True, on_step=False, prog_bar=True)
|
||||
return val_loss
|
||||
|
||||
def configure_optimizers(self):
|
||||
@@ -58,7 +56,7 @@ class XformerLightningModule(pl.LightningModule):
|
||||
future_target = batch["future_target"]
|
||||
past_observed_values = batch["past_observed_values"]
|
||||
future_observed_values = batch["future_observed_values"]
|
||||
|
||||
|
||||
transformer_inputs, scale, _ = self.model.create_network_inputs(
|
||||
feat_static_cat,
|
||||
feat_static_real,
|
||||
@@ -72,7 +70,7 @@ class XformerLightningModule(pl.LightningModule):
|
||||
distr = self.model.output_distribution(params, scale)
|
||||
|
||||
loss_values = self.loss(distr, future_target)
|
||||
|
||||
|
||||
if len(self.model.target_shape) == 0:
|
||||
loss_weights = future_observed_values
|
||||
else:
|
||||
|
||||
+53
-66
@@ -14,6 +14,7 @@ from xformers.factory.model_factory import xFormer, xFormerConfig
|
||||
|
||||
# -
|
||||
|
||||
|
||||
class XformerModel(nn.Module):
|
||||
@validated()
|
||||
def __init__(
|
||||
@@ -25,7 +26,6 @@ class XformerModel(nn.Module):
|
||||
num_feat_static_real: int,
|
||||
num_feat_static_cat: int,
|
||||
cardinality: List[int],
|
||||
|
||||
# xformer arguments
|
||||
nhead: int,
|
||||
num_encoder_layers: int,
|
||||
@@ -37,7 +37,6 @@ class XformerModel(nn.Module):
|
||||
reversible: bool = False,
|
||||
hidden_layer_multiplier: int = 2,
|
||||
use_rotary_embeddings: bool = False,
|
||||
|
||||
# univariate input
|
||||
input_size: int = 1,
|
||||
embedding_dimension: Optional[List[int]] = None,
|
||||
@@ -47,9 +46,9 @@ class XformerModel(nn.Module):
|
||||
num_parallel_samples: int = 1,
|
||||
) -> None:
|
||||
super().__init__()
|
||||
|
||||
|
||||
self.input_size = input_size
|
||||
|
||||
|
||||
self.target_shape = distr_output.event_shape
|
||||
self.num_feat_dynamic_real = num_feat_dynamic_real
|
||||
self.num_feat_static_cat = num_feat_static_cat
|
||||
@@ -70,22 +69,21 @@ class XformerModel(nn.Module):
|
||||
self.scaler = MeanScaler(dim=1, keepdim=True)
|
||||
else:
|
||||
self.scaler = NOPScaler(dim=1, keepdim=True)
|
||||
|
||||
|
||||
# total feature size
|
||||
d_model = self.input_size * len(self.lags_seq) + self._number_of_features
|
||||
|
||||
|
||||
self.context_length = context_length
|
||||
self.prediction_length = prediction_length
|
||||
self.distr_output = distr_output
|
||||
self.param_proj = distr_output.get_args_proj(d_model)
|
||||
|
||||
|
||||
attention_args["dropout"] = dropout
|
||||
attention_args["causal"] = False
|
||||
attention_args["seq_len"] = self.context_length
|
||||
attention_args["num_rules"] = nhead
|
||||
attention_args["attention_query_mask"] = (torch.rand((context_length, 1)) < 0.5)
|
||||
|
||||
|
||||
attention_args["attention_query_mask"] = torch.rand((context_length, 1)) < 0.5
|
||||
|
||||
xformer_config = [
|
||||
# A list of the encoder blocks which constitute the Transformer.
|
||||
# Note that a sequence of different encoder blocks can be used
|
||||
@@ -117,21 +115,23 @@ class XformerModel(nn.Module):
|
||||
config = xFormerConfig(xformer_config)
|
||||
# xformer encoder
|
||||
self.encoder = xFormer.from_config(config)
|
||||
|
||||
|
||||
# causal vanilla transformer decoder
|
||||
decoder_layer = nn.TransformerDecoderLayer(
|
||||
d_model,
|
||||
nhead,
|
||||
dim_feedforward=d_model*hidden_layer_multiplier,
|
||||
d_model,
|
||||
nhead,
|
||||
dim_feedforward=d_model * hidden_layer_multiplier,
|
||||
dropout=dropout,
|
||||
activation=activation,
|
||||
layer_norm_eps=1e-5,
|
||||
batch_first=True,
|
||||
activation=activation,
|
||||
layer_norm_eps=1e-5,
|
||||
batch_first=True,
|
||||
norm_first=False,
|
||||
)
|
||||
decoder_norm = nn.LayerNorm(d_model, eps=1e-5)
|
||||
self.decoder = nn.TransformerDecoder(decoder_layer, num_decoder_layers, decoder_norm)
|
||||
|
||||
self.decoder = nn.TransformerDecoder(
|
||||
decoder_layer, num_decoder_layers, decoder_norm
|
||||
)
|
||||
|
||||
# causal decoder tgt mask for training
|
||||
self.register_buffer(
|
||||
"tgt_mask",
|
||||
@@ -150,12 +150,9 @@ class XformerModel(nn.Module):
|
||||
@property
|
||||
def _past_length(self) -> int:
|
||||
return self.context_length + max(self.lags_seq)
|
||||
|
||||
|
||||
def get_lagged_subsequences(
|
||||
self,
|
||||
sequence: torch.Tensor,
|
||||
subsequences_length: int,
|
||||
shift: int = 0
|
||||
self, sequence: torch.Tensor, subsequences_length: int, shift: int = 0
|
||||
) -> torch.Tensor:
|
||||
"""
|
||||
Returns lagged subsequences of a given sequence.
|
||||
@@ -189,18 +186,17 @@ class XformerModel(nn.Module):
|
||||
end_index = -lag_index if lag_index > 0 else None
|
||||
lagged_values.append(sequence[:, begin_index:end_index, ...])
|
||||
return torch.stack(lagged_values, dim=-1)
|
||||
|
||||
|
||||
|
||||
def create_network_inputs(
|
||||
self,
|
||||
feat_static_cat: torch.Tensor,
|
||||
self,
|
||||
feat_static_cat: torch.Tensor,
|
||||
feat_static_real: torch.Tensor,
|
||||
past_time_feat: torch.Tensor,
|
||||
past_target: torch.Tensor,
|
||||
past_observed_values: torch.Tensor,
|
||||
future_time_feat: Optional[torch.Tensor] = None,
|
||||
future_target: Optional[torch.Tensor] = None,
|
||||
):
|
||||
):
|
||||
# time feature
|
||||
time_feat = (
|
||||
past_time_feat[:, self._past_length - self.context_length :, ...]
|
||||
@@ -216,7 +212,7 @@ class XformerModel(nn.Module):
|
||||
|
||||
# target
|
||||
context = past_target[:, -self.context_length :]
|
||||
observed_context = past_observed_values[:, -self.context_length :]
|
||||
observed_context = past_observed_values[:, -self.context_length :]
|
||||
# weights = torch.linspace(0.0001, 1, steps=observed_context.size(-1), device=observed_context.device)
|
||||
_, scale = self.scaler(context, observed_context)
|
||||
|
||||
@@ -232,13 +228,13 @@ class XformerModel(nn.Module):
|
||||
else self._past_length
|
||||
)
|
||||
assert inputs.shape[1] == inputs_length
|
||||
|
||||
|
||||
subsequences_length = (
|
||||
self.context_length
|
||||
if future_time_feat is None or future_target is None
|
||||
else self.context_length + self.prediction_length
|
||||
)
|
||||
|
||||
|
||||
# embeddings
|
||||
embedded_cat = self.embedder(feat_static_cat)
|
||||
log_scale = scale.log() if self.input_size == 1 else scale.squeeze(1).log()
|
||||
@@ -249,11 +245,11 @@ class XformerModel(nn.Module):
|
||||
expanded_static_feat = static_feat.unsqueeze(1).expand(
|
||||
-1, time_feat.shape[1], -1
|
||||
)
|
||||
|
||||
|
||||
features = torch.cat((expanded_static_feat, time_feat), dim=-1)
|
||||
|
||||
#self._check_shapes(prior_input, inputs, features)
|
||||
#sequence = torch.cat((prior_input, inputs), dim=1)
|
||||
|
||||
# self._check_shapes(prior_input, inputs, features)
|
||||
# sequence = torch.cat((prior_input, inputs), dim=1)
|
||||
|
||||
lagged_sequence = self.get_lagged_subsequences(
|
||||
sequence=inputs,
|
||||
@@ -269,16 +265,16 @@ class XformerModel(nn.Module):
|
||||
transformer_inputs = reshaped_lagged_sequence
|
||||
else:
|
||||
transformer_inputs = torch.cat((reshaped_lagged_sequence, features), dim=-1)
|
||||
|
||||
|
||||
return transformer_inputs, scale, static_feat
|
||||
|
||||
|
||||
def output_params(self, transformer_inputs):
|
||||
enc_input = transformer_inputs[:, :self.context_length, ...]
|
||||
dec_input = transformer_inputs[:, self.context_length:, ...]
|
||||
|
||||
enc_input = transformer_inputs[:, : self.context_length, ...]
|
||||
dec_input = transformer_inputs[:, self.context_length :, ...]
|
||||
|
||||
enc_out = self.encoder(src=enc_input)
|
||||
dec_output = self.decoder(dec_input, enc_out, tgt_mask=self.tgt_mask)
|
||||
|
||||
|
||||
return self.param_proj(dec_output)
|
||||
|
||||
@torch.jit.ignore
|
||||
@@ -289,7 +285,7 @@ class XformerModel(nn.Module):
|
||||
if trailing_n is not None:
|
||||
sliced_params = [p[:, -trailing_n:] for p in params]
|
||||
return self.distr_output.distribution(sliced_params, scale=scale)
|
||||
|
||||
|
||||
# for prediction
|
||||
def forward(
|
||||
self,
|
||||
@@ -303,7 +299,7 @@ class XformerModel(nn.Module):
|
||||
) -> torch.Tensor:
|
||||
if num_parallel_samples is None:
|
||||
num_parallel_samples = self.num_parallel_samples
|
||||
|
||||
|
||||
encoder_inputs, scale, static_feat = self.create_network_inputs(
|
||||
feat_static_cat,
|
||||
feat_static_real,
|
||||
@@ -312,12 +308,12 @@ class XformerModel(nn.Module):
|
||||
past_observed_values,
|
||||
future_time_feat,
|
||||
)
|
||||
|
||||
|
||||
enc_out = self.encoder(src=encoder_inputs)
|
||||
|
||||
|
||||
params = self.param_proj(enc_out)
|
||||
distr = self.output_distribution(params, trailing_n=1)
|
||||
|
||||
|
||||
repeated_scale = scale.repeat_interleave(
|
||||
repeats=self.num_parallel_samples, dim=0
|
||||
)
|
||||
@@ -325,9 +321,7 @@ class XformerModel(nn.Module):
|
||||
repeats=self.num_parallel_samples, dim=0
|
||||
).unsqueeze(dim=1)
|
||||
repeated_past_target = (
|
||||
past_target.repeat_interleave(
|
||||
repeats=self.num_parallel_samples, dim=0
|
||||
)
|
||||
past_target.repeat_interleave(repeats=self.num_parallel_samples, dim=0)
|
||||
/ repeated_scale
|
||||
)
|
||||
repeated_time_feat = future_time_feat.repeat_interleave(
|
||||
@@ -338,43 +332,36 @@ class XformerModel(nn.Module):
|
||||
)
|
||||
|
||||
future_samples = []
|
||||
|
||||
|
||||
for k in range(self.prediction_length):
|
||||
next_features = torch.cat(
|
||||
(repeated_static_feat, repeated_time_feat[:, k : k + 1]),
|
||||
dim=-1,
|
||||
)
|
||||
|
||||
|
||||
lagged_sequence = self.get_lagged_subsequences(
|
||||
sequence=repeated_past_target,
|
||||
subsequences_length=1,
|
||||
shift=1,
|
||||
shift=1,
|
||||
)
|
||||
|
||||
lags_shape = lagged_sequence.shape
|
||||
reshaped_lagged_sequence = lagged_sequence.reshape(
|
||||
lags_shape[0], lags_shape[1], -1
|
||||
)
|
||||
|
||||
|
||||
decoder_input = torch.cat((reshaped_lagged_sequence, next_features), dim=-1)
|
||||
|
||||
output = self.decoder(decoder_input, repeated_enc_out)
|
||||
|
||||
|
||||
params = self.param_proj(output)
|
||||
distr = self.output_distribution(params)
|
||||
next_sample = distr.sample()
|
||||
|
||||
repeated_past_target = torch.cat(
|
||||
(repeated_past_target, next_sample), dim=1
|
||||
)
|
||||
|
||||
repeated_past_target = torch.cat((repeated_past_target, next_sample), dim=1)
|
||||
future_samples.append(next_sample)
|
||||
|
||||
unscaled_future_samples = (
|
||||
torch.cat(future_samples, dim=1) * repeated_scale
|
||||
)
|
||||
unscaled_future_samples = torch.cat(future_samples, dim=1) * repeated_scale
|
||||
return unscaled_future_samples.reshape(
|
||||
(-1, self.num_parallel_samples, self.prediction_length)
|
||||
+ self.target_shape,
|
||||
(-1, self.num_parallel_samples, self.prediction_length) + self.target_shape,
|
||||
)
|
||||
|
||||
|
||||
|
||||
Reference in New Issue
Block a user