diff --git a/xformers/estimator.py b/xformers/estimator.py index f212356..ced643b 100644 --- a/xformers/estimator.py +++ b/xformers/estimator.py @@ -55,28 +55,26 @@ TRAINING_INPUT_NAMES = PREDICTION_INPUT_NAMES + [ # - + class XformerEstimator(PyTorchLightningEstimator): @validated() def __init__( self, freq: str, prediction_length: int, - # Xformer arguments nhead: int, num_encoder_layers: int, num_decoder_layers: int, hidden_layer_multiplier: int = 1, - attention_args = {"name": "scaled_dot_product"}, + attention_args={"name": "scaled_dot_product"}, input_size: int = 1, activation: str = "gelu", residual_norm_style: str = "pre", dropout: float = 0.1, - use_rotary_embeddings = False, - reversible = False, - + use_rotary_embeddings=False, + reversible=False, context_length: Optional[int] = None, - num_feat_dynamic_real: int = 0, num_feat_static_cat: int = 0, num_feat_static_real: int = 0, @@ -97,7 +95,7 @@ class XformerEstimator(PyTorchLightningEstimator): **trainer_kwargs, } super().__init__(trainer_kwargs=trainer_kwargs) - + self.freq = freq self.context_length = ( context_length if context_length is not None else prediction_length @@ -105,7 +103,7 @@ class XformerEstimator(PyTorchLightningEstimator): self.prediction_length = prediction_length self.distr_output = distr_output self.loss = loss - + self.input_size = input_size self.nhead = nhead self.num_encoder_layers = num_encoder_layers @@ -117,7 +115,7 @@ class XformerEstimator(PyTorchLightningEstimator): self.reversible = reversible self.hidden_layer_multiplier = hidden_layer_multiplier self.residual_norm_style = residual_norm_style - + self.num_feat_dynamic_real = num_feat_dynamic_real self.num_feat_static_cat = num_feat_static_cat self.num_feat_static_real = num_feat_static_real @@ -140,10 +138,8 @@ class XformerEstimator(PyTorchLightningEstimator): self.train_sampler = ExpectedNumInstanceSampler( num_instances=1.0, min_future=prediction_length ) - self.validation_sampler = ValidationSplitSampler( - min_future=prediction_length - ) - + self.validation_sampler = ValidationSplitSampler(min_future=prediction_length) + def create_transformation(self) -> Transformation: remove_field_names = [] if self.num_feat_static_real == 0: @@ -159,11 +155,7 @@ class XformerEstimator(PyTorchLightningEstimator): else [] ) + ( - [ - SetField( - output_field=FieldName.FEAT_STATIC_REAL, value=[0.0] - ) - ] + [SetField(output_field=FieldName.FEAT_STATIC_REAL, value=[0.0])] if not self.num_feat_static_real > 0 else [] ) @@ -211,9 +203,7 @@ class XformerEstimator(PyTorchLightningEstimator): ] ) - def _create_instance_splitter( - self, module: XformerLightningModule, mode: str - ): + def _create_instance_splitter(self, module: XformerLightningModule, mode: str): assert mode in ["training", "validation", "test"] instance_sampler = { @@ -284,14 +274,14 @@ class XformerEstimator(PyTorchLightningEstimator): batch_size=self.batch_size, **kwargs, ) - + def create_predictor( self, transformation: Transformation, module: XformerLightningModule, ) -> PyTorchPredictor: prediction_splitter = self._create_instance_splitter(module, "test") - + return PyTorchPredictor( input_transform=transformation + prediction_splitter, input_names=PREDICTION_INPUT_NAMES, @@ -306,12 +296,13 @@ class XformerEstimator(PyTorchLightningEstimator): freq=self.freq, context_length=self.context_length, prediction_length=self.prediction_length, - num_feat_dynamic_real=1 + self.num_feat_dynamic_real + len(self.time_features), + num_feat_dynamic_real=1 + + self.num_feat_dynamic_real + + len(self.time_features), num_feat_static_real=max(1, self.num_feat_static_real), num_feat_static_cat=max(1, self.num_feat_static_cat), cardinality=self.cardinality, embedding_dimension=self.embedding_dimension, - # xformer arguments nhead=self.nhead, num_encoder_layers=self.num_encoder_layers, @@ -323,7 +314,6 @@ class XformerEstimator(PyTorchLightningEstimator): use_rotary_embeddings=self.use_rotary_embeddings, reversible=self.reversible, residual_norm_style=self.residual_norm_style, - # univariate input input_size=self.input_size, distr_output=self.distr_output, @@ -331,5 +321,5 @@ class XformerEstimator(PyTorchLightningEstimator): scaling=self.scaling, num_parallel_samples=self.num_parallel_samples, ) - + return XformerLightningModule(model=model, loss=self.loss) diff --git a/xformers/lightning_module.py b/xformers/lightning_module.py index 077147c..d686a11 100644 --- a/xformers/lightning_module.py +++ b/xformers/lightning_module.py @@ -19,7 +19,7 @@ class XformerLightningModule(pl.LightningModule): self.loss = loss self.lr = lr self.weight_decay = weight_decay - + def training_step(self, batch, batch_idx: int): """Execute training step""" train_loss = self(batch) @@ -36,9 +36,7 @@ class XformerLightningModule(pl.LightningModule): """Execute validation step""" with torch.no_grad(): val_loss = self(batch) - self.log( - "val_loss", val_loss, on_epoch=True, on_step=False, prog_bar=True - ) + self.log("val_loss", val_loss, on_epoch=True, on_step=False, prog_bar=True) return val_loss def configure_optimizers(self): @@ -58,7 +56,7 @@ class XformerLightningModule(pl.LightningModule): future_target = batch["future_target"] past_observed_values = batch["past_observed_values"] future_observed_values = batch["future_observed_values"] - + transformer_inputs, scale, _ = self.model.create_network_inputs( feat_static_cat, feat_static_real, @@ -72,7 +70,7 @@ class XformerLightningModule(pl.LightningModule): distr = self.model.output_distribution(params, scale) loss_values = self.loss(distr, future_target) - + if len(self.model.target_shape) == 0: loss_weights = future_observed_values else: diff --git a/xformers/module.py b/xformers/module.py index bca546a..86dfe5c 100644 --- a/xformers/module.py +++ b/xformers/module.py @@ -14,6 +14,7 @@ from xformers.factory.model_factory import xFormer, xFormerConfig # - + class XformerModel(nn.Module): @validated() def __init__( @@ -25,7 +26,6 @@ class XformerModel(nn.Module): num_feat_static_real: int, num_feat_static_cat: int, cardinality: List[int], - # xformer arguments nhead: int, num_encoder_layers: int, @@ -37,7 +37,6 @@ class XformerModel(nn.Module): reversible: bool = False, hidden_layer_multiplier: int = 2, use_rotary_embeddings: bool = False, - # univariate input input_size: int = 1, embedding_dimension: Optional[List[int]] = None, @@ -47,9 +46,9 @@ class XformerModel(nn.Module): num_parallel_samples: int = 1, ) -> None: super().__init__() - + self.input_size = input_size - + self.target_shape = distr_output.event_shape self.num_feat_dynamic_real = num_feat_dynamic_real self.num_feat_static_cat = num_feat_static_cat @@ -70,22 +69,21 @@ class XformerModel(nn.Module): self.scaler = MeanScaler(dim=1, keepdim=True) else: self.scaler = NOPScaler(dim=1, keepdim=True) - + # total feature size d_model = self.input_size * len(self.lags_seq) + self._number_of_features - + self.context_length = context_length self.prediction_length = prediction_length self.distr_output = distr_output self.param_proj = distr_output.get_args_proj(d_model) - + attention_args["dropout"] = dropout attention_args["causal"] = False attention_args["seq_len"] = self.context_length attention_args["num_rules"] = nhead - attention_args["attention_query_mask"] = (torch.rand((context_length, 1)) < 0.5) - - + attention_args["attention_query_mask"] = torch.rand((context_length, 1)) < 0.5 + xformer_config = [ # A list of the encoder blocks which constitute the Transformer. # Note that a sequence of different encoder blocks can be used @@ -117,21 +115,23 @@ class XformerModel(nn.Module): config = xFormerConfig(xformer_config) # xformer encoder self.encoder = xFormer.from_config(config) - + # causal vanilla transformer decoder decoder_layer = nn.TransformerDecoderLayer( - d_model, - nhead, - dim_feedforward=d_model*hidden_layer_multiplier, + d_model, + nhead, + dim_feedforward=d_model * hidden_layer_multiplier, dropout=dropout, - activation=activation, - layer_norm_eps=1e-5, - batch_first=True, + activation=activation, + layer_norm_eps=1e-5, + batch_first=True, norm_first=False, ) decoder_norm = nn.LayerNorm(d_model, eps=1e-5) - self.decoder = nn.TransformerDecoder(decoder_layer, num_decoder_layers, decoder_norm) - + self.decoder = nn.TransformerDecoder( + decoder_layer, num_decoder_layers, decoder_norm + ) + # causal decoder tgt mask for training self.register_buffer( "tgt_mask", @@ -150,12 +150,9 @@ class XformerModel(nn.Module): @property def _past_length(self) -> int: return self.context_length + max(self.lags_seq) - + def get_lagged_subsequences( - self, - sequence: torch.Tensor, - subsequences_length: int, - shift: int = 0 + self, sequence: torch.Tensor, subsequences_length: int, shift: int = 0 ) -> torch.Tensor: """ Returns lagged subsequences of a given sequence. @@ -189,18 +186,17 @@ class XformerModel(nn.Module): end_index = -lag_index if lag_index > 0 else None lagged_values.append(sequence[:, begin_index:end_index, ...]) return torch.stack(lagged_values, dim=-1) - - + def create_network_inputs( - self, - feat_static_cat: torch.Tensor, + self, + feat_static_cat: torch.Tensor, feat_static_real: torch.Tensor, past_time_feat: torch.Tensor, past_target: torch.Tensor, past_observed_values: torch.Tensor, future_time_feat: Optional[torch.Tensor] = None, future_target: Optional[torch.Tensor] = None, - ): + ): # time feature time_feat = ( past_time_feat[:, self._past_length - self.context_length :, ...] @@ -216,7 +212,7 @@ class XformerModel(nn.Module): # target context = past_target[:, -self.context_length :] - observed_context = past_observed_values[:, -self.context_length :] + observed_context = past_observed_values[:, -self.context_length :] # weights = torch.linspace(0.0001, 1, steps=observed_context.size(-1), device=observed_context.device) _, scale = self.scaler(context, observed_context) @@ -232,13 +228,13 @@ class XformerModel(nn.Module): else self._past_length ) assert inputs.shape[1] == inputs_length - + subsequences_length = ( self.context_length if future_time_feat is None or future_target is None else self.context_length + self.prediction_length ) - + # embeddings embedded_cat = self.embedder(feat_static_cat) log_scale = scale.log() if self.input_size == 1 else scale.squeeze(1).log() @@ -249,11 +245,11 @@ class XformerModel(nn.Module): expanded_static_feat = static_feat.unsqueeze(1).expand( -1, time_feat.shape[1], -1 ) - + features = torch.cat((expanded_static_feat, time_feat), dim=-1) - - #self._check_shapes(prior_input, inputs, features) - #sequence = torch.cat((prior_input, inputs), dim=1) + + # self._check_shapes(prior_input, inputs, features) + # sequence = torch.cat((prior_input, inputs), dim=1) lagged_sequence = self.get_lagged_subsequences( sequence=inputs, @@ -269,16 +265,16 @@ class XformerModel(nn.Module): transformer_inputs = reshaped_lagged_sequence else: transformer_inputs = torch.cat((reshaped_lagged_sequence, features), dim=-1) - + return transformer_inputs, scale, static_feat - + def output_params(self, transformer_inputs): - enc_input = transformer_inputs[:, :self.context_length, ...] - dec_input = transformer_inputs[:, self.context_length:, ...] - + enc_input = transformer_inputs[:, : self.context_length, ...] + dec_input = transformer_inputs[:, self.context_length :, ...] + enc_out = self.encoder(src=enc_input) dec_output = self.decoder(dec_input, enc_out, tgt_mask=self.tgt_mask) - + return self.param_proj(dec_output) @torch.jit.ignore @@ -289,7 +285,7 @@ class XformerModel(nn.Module): if trailing_n is not None: sliced_params = [p[:, -trailing_n:] for p in params] return self.distr_output.distribution(sliced_params, scale=scale) - + # for prediction def forward( self, @@ -303,7 +299,7 @@ class XformerModel(nn.Module): ) -> torch.Tensor: if num_parallel_samples is None: num_parallel_samples = self.num_parallel_samples - + encoder_inputs, scale, static_feat = self.create_network_inputs( feat_static_cat, feat_static_real, @@ -312,12 +308,12 @@ class XformerModel(nn.Module): past_observed_values, future_time_feat, ) - + enc_out = self.encoder(src=encoder_inputs) - + params = self.param_proj(enc_out) distr = self.output_distribution(params, trailing_n=1) - + repeated_scale = scale.repeat_interleave( repeats=self.num_parallel_samples, dim=0 ) @@ -325,9 +321,7 @@ class XformerModel(nn.Module): repeats=self.num_parallel_samples, dim=0 ).unsqueeze(dim=1) repeated_past_target = ( - past_target.repeat_interleave( - repeats=self.num_parallel_samples, dim=0 - ) + past_target.repeat_interleave(repeats=self.num_parallel_samples, dim=0) / repeated_scale ) repeated_time_feat = future_time_feat.repeat_interleave( @@ -338,43 +332,36 @@ class XformerModel(nn.Module): ) future_samples = [] - + for k in range(self.prediction_length): next_features = torch.cat( (repeated_static_feat, repeated_time_feat[:, k : k + 1]), dim=-1, ) - + lagged_sequence = self.get_lagged_subsequences( sequence=repeated_past_target, subsequences_length=1, - shift=1, + shift=1, ) lags_shape = lagged_sequence.shape reshaped_lagged_sequence = lagged_sequence.reshape( lags_shape[0], lags_shape[1], -1 ) - + decoder_input = torch.cat((reshaped_lagged_sequence, next_features), dim=-1) output = self.decoder(decoder_input, repeated_enc_out) - + params = self.param_proj(output) distr = self.output_distribution(params) next_sample = distr.sample() - - repeated_past_target = torch.cat( - (repeated_past_target, next_sample), dim=1 - ) + + repeated_past_target = torch.cat((repeated_past_target, next_sample), dim=1) future_samples.append(next_sample) - unscaled_future_samples = ( - torch.cat(future_samples, dim=1) * repeated_scale - ) + unscaled_future_samples = torch.cat(future_samples, dim=1) * repeated_scale return unscaled_future_samples.reshape( - (-1, self.num_parallel_samples, self.prediction_length) - + self.target_shape, + (-1, self.num_parallel_samples, self.prediction_length) + self.target_shape, ) - -