added input_size

This commit is contained in:
Kashif Rasul
2022-10-17 12:12:47 +02:00
parent a3290aa3e4
commit 6c6ec9461c
2 changed files with 9 additions and 3 deletions
+3
View File
@@ -131,6 +131,7 @@ class PerceiverAREstimator(PyTorchLightningEstimator):
prediction_length: int,
depth: int,
context_length: Optional[int] = None,
input_size: int = 1,
perceive_depth: int = 1,
heads: int = 2,
hidden_size: int = 32,
@@ -163,6 +164,7 @@ class PerceiverAREstimator(PyTorchLightningEstimator):
default_trainer_kwargs.update(trainer_kwargs)
super().__init__(trainer_kwargs=default_trainer_kwargs)
self.input_size = input_size
self.freq = freq
self.context_length = (
context_length if context_length is not None else prediction_length
@@ -341,6 +343,7 @@ class PerceiverAREstimator(PyTorchLightningEstimator):
def create_lightning_module(self) -> PerceiverARLightningModule:
model = PerceiverARModel(
input_size=self.input_size,
freq=self.freq,
depth=self.depth,
context_length=self.context_length,
+6 -3
View File
@@ -236,6 +236,7 @@ class PerceiverARModel(nn.Module):
num_feat_static_cat: int,
cardinality: List[int],
embedding_dimension: Optional[List[int]] = None,
input_size: int = 1,
perceive_depth: int = 1,
heads: int = 2,
perceive_max_heads_process: int = 2,
@@ -250,6 +251,7 @@ class PerceiverARModel(nn.Module):
) -> None:
super().__init__()
self.input_size = input_size
self.context_length = context_length
self.prediction_length = prediction_length
self.distr_output = distr_output
@@ -274,7 +276,7 @@ class PerceiverARModel(nn.Module):
else:
self.scaler = NOPScaler(dim=1, keepdim=True)
dim_head = len(self.lags_seq) + self._number_of_features
dim_head = input_size * len(self.lags_seq) + self._number_of_features
self.perceive_layers = nn.ModuleList([])
for _ in range(perceive_depth):
@@ -315,7 +317,7 @@ class PerceiverARModel(nn.Module):
sum(self.embedding_dimension)
+ self.num_feat_dynamic_real
+ self.num_feat_static_real
+ 1 # the log(scale)
+ self.input_size # the log(scale)
)
@property
@@ -387,8 +389,9 @@ class PerceiverARModel(nn.Module):
)
embedded_cat = self.embedder(feat_static_cat)
log_scale = scale.log() if self.input_size == 1 else scale.squeeze(1).log()
static_feat = torch.cat(
(embedded_cat, feat_static_real, scale.log()),
(embedded_cat, feat_static_real, log_scale),
dim=1,
)
expanded_static_feat = static_feat.unsqueeze(1).expand(-1, input.shape[1], -1)