fix for input_size > 1

This commit is contained in:
Kashif Rasul
2022-10-17 12:00:28 +02:00
parent 1af9b1122d
commit 7b5cd6052f
12 changed files with 27 additions and 19 deletions
+1 -1
View File
@@ -75,6 +75,6 @@ class InformerLightningModule(pl.LightningModule):
if len(self.model.target_shape) == 0:
loss_weights = future_observed_values
else:
loss_weights = future_observed_values.min(dim=-1, keepdim=False)
loss_weights, _ = future_observed_values.min(dim=-1, keepdim=False)
return weighted_average(loss_values, weights=loss_weights)
+3 -2
View File
@@ -482,7 +482,7 @@ class InformerModel(nn.Module):
sum(self.embedding_dimension)
+ self.num_feat_dynamic_real
+ self.num_feat_static_real
+ 1 # the log(scale)
+ self.input_size # the log(scale)
)
@property
@@ -591,8 +591,9 @@ class InformerModel(nn.Module):
# embeddings
embedded_cat = self.embedder(feat_static_cat)
log_scale = scale.log() if self.input_size == 1 else scale.squeeze(1).log()
static_feat = torch.cat(
(embedded_cat, feat_static_real, scale.log()),
(embedded_cat, feat_static_real, log_scale),
dim=1,
)
expanded_static_feat = static_feat.unsqueeze(1).expand(
+1 -1
View File
@@ -74,6 +74,6 @@ class ReformerLightningModule(pl.LightningModule):
if len(self.model.target_shape) == 0:
loss_weights = future_observed_values
else:
loss_weights = future_observed_values.min(dim=-1, keepdim=False)
loss_weights, _ = future_observed_values.min(dim=-1, keepdim=False)
return weighted_average(loss_values, weights=loss_weights)
+3 -2
View File
@@ -157,7 +157,7 @@ class ReformerModel(nn.Module):
sum(self.embedding_dimension)
+ self.num_feat_dynamic_real
+ self.num_feat_static_real
+ 1 # the log(scale)
+ self.input_size # the log(scale)
)
@property
@@ -266,8 +266,9 @@ class ReformerModel(nn.Module):
# embeddings
embedded_cat = self.embedder(feat_static_cat)
log_scale = scale.log() if self.input_size == 1 else scale.squeeze(1).log()
static_feat = torch.cat(
(embedded_cat, feat_static_real, scale.log()),
(embedded_cat, feat_static_real, log_scale),
dim=1,
)
expanded_static_feat = static_feat.unsqueeze(1).expand(
+4 -2
View File
@@ -139,6 +139,7 @@
" num_parallel_samples: int = 100,\n",
" ) -> None:\n",
" super().__init__()\n",
" self.input_size = input_size\n",
" self.context_length = context_length\n",
" self.prediction_length = prediction_length\n",
" self.distr_output = distr_output\n",
@@ -190,7 +191,7 @@
" sum(self.embedding_dimension)\n",
" + self.num_feat_dynamic_real\n",
" + self.num_feat_static_real\n",
" + 1 # the log(scale)\n",
" + self.input_size # the log(scale)\n",
" )\n",
" \n",
" @property\n",
@@ -233,8 +234,9 @@
" assert inputs.shape[1] == unroll_length\n",
"\n",
" embedded_cat = self.embedder(feat_static_cat)\n",
" log_scale = scale.log() if self.input_size == 1 else scale.squeeze(1).log()\n",
" static_feat = torch.cat(\n",
" (embedded_cat, feat_static_real, scale.log()),\n",
" (embedded_cat, feat_static_real, log_scale),\n",
" dim=1,\n",
" )\n",
" expanded_static_feat = static_feat.unsqueeze(1).expand(\n",
+1 -1
View File
@@ -74,6 +74,6 @@ class SwitchTransformerLightningModule(pl.LightningModule):
if len(self.model.target_shape) == 0:
loss_weights = future_observed_values
else:
loss_weights = future_observed_values.min(dim=-1, keepdim=False)
loss_weights, _ = future_observed_values.min(dim=-1, keepdim=False)
return weighted_average(loss_values, weights=loss_weights)
+3 -2
View File
@@ -345,7 +345,7 @@ class SwitchTransformerModel(nn.Module):
sum(self.embedding_dimension)
+ self.num_feat_dynamic_real
+ self.num_feat_static_real
+ 1 # the log(scale)
+ self.input_size # the log(scale)
)
@property
@@ -454,8 +454,9 @@ class SwitchTransformerModel(nn.Module):
# embeddings
embedded_cat = self.embedder(feat_static_cat)
log_scale = scale.log() if self.input_size == 1 else scale.squeeze(1).log()
static_feat = torch.cat(
(embedded_cat, feat_static_real, scale.log()),
(embedded_cat, feat_static_real, log_scale),
dim=1,
)
expanded_static_feat = static_feat.unsqueeze(1).expand(
+1 -1
View File
@@ -86,6 +86,6 @@ class TFTLightningModule(pl.LightningModule):
if len(self.model.target_shape) == 0:
loss_weights = future_observed_values
else:
loss_weights = future_observed_values.min(dim=-1, keepdim=False)
loss_weights, _ = future_observed_values.min(dim=-1, keepdim=False)
return weighted_average(loss_values, weights=loss_weights)
+2 -1
View File
@@ -474,8 +474,9 @@ class TFTModel(nn.Module):
# embeddings
embedded_cat = self.embedder(feat_static_cat)
log_scale = scale.log() if self.input_size == 1 else scale.squeeze(1).log()
static_feat = torch.cat(
(feat_static_real, scale.log()),
(feat_static_real, log_scale),
dim=1,
)
+1 -1
View File
@@ -74,6 +74,6 @@ class TransformerLightningModule(pl.LightningModule):
if len(self.model.target_shape) == 0:
loss_weights = future_observed_values
else:
loss_weights = future_observed_values.min(dim=-1, keepdim=False)
loss_weights, _ = future_observed_values.min(dim=-1, keepdim=False)
return weighted_average(loss_values, weights=loss_weights)
+3 -2
View File
@@ -93,7 +93,7 @@ class TransformerModel(nn.Module):
sum(self.embedding_dimension)
+ self.num_feat_dynamic_real
+ self.num_feat_static_real
+ 1 # the log(scale)
+ self.input_size # the log(scale)
)
@property
@@ -202,8 +202,9 @@ class TransformerModel(nn.Module):
# embeddings
embedded_cat = self.embedder(feat_static_cat)
log_scale = scale.log() if self.input_size == 1 else scale.squeeze(1).log()
static_feat = torch.cat(
(embedded_cat, feat_static_real, scale.log()),
(embedded_cat, feat_static_real, log_scale),
dim=1,
)
expanded_static_feat = static_feat.unsqueeze(1).expand(
+4 -3
View File
@@ -238,7 +238,7 @@
" sum(self.embedding_dimension)\n",
" + self.num_feat_dynamic_real\n",
" + self.num_feat_static_real\n",
" + 1 # the log(scale)\n",
" + self.input_size # the log(scale)\n",
" )\n",
"\n",
" @property\n",
@@ -335,8 +335,9 @@
" \n",
" # embeddings\n",
" embedded_cat = self.embedder(feat_static_cat)\n",
" log_scale = scale.log() if self.input_size == 1 else scale.squeeze(1).log()\n",
" static_feat = torch.cat(\n",
" (embedded_cat, feat_static_real, scale.log()),\n",
" (embedded_cat, feat_static_real, log_scale),\n",
" dim=1,\n",
" )\n",
" expanded_static_feat = static_feat.unsqueeze(1).expand(\n",
@@ -549,7 +550,7 @@
" if len(self.model.target_shape) == 0:\n",
" loss_weights = future_observed_values\n",
" else:\n",
" loss_weights = future_observed_values.min(dim=-1, keepdim=False)\n",
" loss_weights, _ = future_observed_values.min(dim=-1, keepdim=False)\n",
"\n",
" return weighted_average(loss_values, weights=loss_weights)"
]