mirror of
https://github.com/wassname/pytorch-transformer-ts.git
synced 2026-06-27 15:16:27 +08:00
added input_size
This commit is contained in:
@@ -131,6 +131,7 @@ class PerceiverAREstimator(PyTorchLightningEstimator):
|
||||
prediction_length: int,
|
||||
depth: int,
|
||||
context_length: Optional[int] = None,
|
||||
input_size: int = 1,
|
||||
perceive_depth: int = 1,
|
||||
heads: int = 2,
|
||||
hidden_size: int = 32,
|
||||
@@ -163,6 +164,7 @@ class PerceiverAREstimator(PyTorchLightningEstimator):
|
||||
default_trainer_kwargs.update(trainer_kwargs)
|
||||
super().__init__(trainer_kwargs=default_trainer_kwargs)
|
||||
|
||||
self.input_size = input_size
|
||||
self.freq = freq
|
||||
self.context_length = (
|
||||
context_length if context_length is not None else prediction_length
|
||||
@@ -341,6 +343,7 @@ class PerceiverAREstimator(PyTorchLightningEstimator):
|
||||
|
||||
def create_lightning_module(self) -> PerceiverARLightningModule:
|
||||
model = PerceiverARModel(
|
||||
input_size=self.input_size,
|
||||
freq=self.freq,
|
||||
depth=self.depth,
|
||||
context_length=self.context_length,
|
||||
|
||||
@@ -236,6 +236,7 @@ class PerceiverARModel(nn.Module):
|
||||
num_feat_static_cat: int,
|
||||
cardinality: List[int],
|
||||
embedding_dimension: Optional[List[int]] = None,
|
||||
input_size: int = 1,
|
||||
perceive_depth: int = 1,
|
||||
heads: int = 2,
|
||||
perceive_max_heads_process: int = 2,
|
||||
@@ -250,6 +251,7 @@ class PerceiverARModel(nn.Module):
|
||||
) -> None:
|
||||
super().__init__()
|
||||
|
||||
self.input_size = input_size
|
||||
self.context_length = context_length
|
||||
self.prediction_length = prediction_length
|
||||
self.distr_output = distr_output
|
||||
@@ -274,7 +276,7 @@ class PerceiverARModel(nn.Module):
|
||||
else:
|
||||
self.scaler = NOPScaler(dim=1, keepdim=True)
|
||||
|
||||
dim_head = len(self.lags_seq) + self._number_of_features
|
||||
dim_head = input_size * len(self.lags_seq) + self._number_of_features
|
||||
|
||||
self.perceive_layers = nn.ModuleList([])
|
||||
for _ in range(perceive_depth):
|
||||
@@ -315,7 +317,7 @@ class PerceiverARModel(nn.Module):
|
||||
sum(self.embedding_dimension)
|
||||
+ self.num_feat_dynamic_real
|
||||
+ self.num_feat_static_real
|
||||
+ 1 # the log(scale)
|
||||
+ self.input_size # the log(scale)
|
||||
)
|
||||
|
||||
@property
|
||||
@@ -387,8 +389,9 @@ class PerceiverARModel(nn.Module):
|
||||
)
|
||||
|
||||
embedded_cat = self.embedder(feat_static_cat)
|
||||
log_scale = scale.log() if self.input_size == 1 else scale.squeeze(1).log()
|
||||
static_feat = torch.cat(
|
||||
(embedded_cat, feat_static_real, scale.log()),
|
||||
(embedded_cat, feat_static_real, log_scale),
|
||||
dim=1,
|
||||
)
|
||||
expanded_static_feat = static_feat.unsqueeze(1).expand(-1, input.shape[1], -1)
|
||||
|
||||
Reference in New Issue
Block a user