diff --git a/perceiverar/estimator.py b/perceiverar/estimator.py index 144512d..9f3d93e 100644 --- a/perceiverar/estimator.py +++ b/perceiverar/estimator.py @@ -131,6 +131,7 @@ class PerceiverAREstimator(PyTorchLightningEstimator): prediction_length: int, depth: int, context_length: Optional[int] = None, + input_size: int = 1, perceive_depth: int = 1, heads: int = 2, hidden_size: int = 32, @@ -163,6 +164,7 @@ class PerceiverAREstimator(PyTorchLightningEstimator): default_trainer_kwargs.update(trainer_kwargs) super().__init__(trainer_kwargs=default_trainer_kwargs) + self.input_size = input_size self.freq = freq self.context_length = ( context_length if context_length is not None else prediction_length @@ -341,6 +343,7 @@ class PerceiverAREstimator(PyTorchLightningEstimator): def create_lightning_module(self) -> PerceiverARLightningModule: model = PerceiverARModel( + input_size=self.input_size, freq=self.freq, depth=self.depth, context_length=self.context_length, diff --git a/perceiverar/module.py b/perceiverar/module.py index 5267d8f..4c87805 100644 --- a/perceiverar/module.py +++ b/perceiverar/module.py @@ -236,6 +236,7 @@ class PerceiverARModel(nn.Module): num_feat_static_cat: int, cardinality: List[int], embedding_dimension: Optional[List[int]] = None, + input_size: int = 1, perceive_depth: int = 1, heads: int = 2, perceive_max_heads_process: int = 2, @@ -250,6 +251,7 @@ class PerceiverARModel(nn.Module): ) -> None: super().__init__() + self.input_size = input_size self.context_length = context_length self.prediction_length = prediction_length self.distr_output = distr_output @@ -274,7 +276,7 @@ class PerceiverARModel(nn.Module): else: self.scaler = NOPScaler(dim=1, keepdim=True) - dim_head = len(self.lags_seq) + self._number_of_features + dim_head = input_size * len(self.lags_seq) + self._number_of_features self.perceive_layers = nn.ModuleList([]) for _ in range(perceive_depth): @@ -315,7 +317,7 @@ class PerceiverARModel(nn.Module): sum(self.embedding_dimension) + self.num_feat_dynamic_real + self.num_feat_static_real - + 1 # the log(scale) + + self.input_size # the log(scale) ) @property @@ -387,8 +389,9 @@ class PerceiverARModel(nn.Module): ) embedded_cat = self.embedder(feat_static_cat) + log_scale = scale.log() if self.input_size == 1 else scale.squeeze(1).log() static_feat = torch.cat( - (embedded_cat, feat_static_real, scale.log()), + (embedded_cat, feat_static_real, log_scale), dim=1, ) expanded_static_feat = static_feat.unsqueeze(1).expand(-1, input.shape[1], -1)