mirror of
https://github.com/wassname/pytorch-transformer-ts.git
synced 2026-06-27 15:16:27 +08:00
fix for input_size > 1
This commit is contained in:
@@ -75,6 +75,6 @@ class InformerLightningModule(pl.LightningModule):
|
||||
if len(self.model.target_shape) == 0:
|
||||
loss_weights = future_observed_values
|
||||
else:
|
||||
loss_weights = future_observed_values.min(dim=-1, keepdim=False)
|
||||
loss_weights, _ = future_observed_values.min(dim=-1, keepdim=False)
|
||||
|
||||
return weighted_average(loss_values, weights=loss_weights)
|
||||
|
||||
+3
-2
@@ -482,7 +482,7 @@ class InformerModel(nn.Module):
|
||||
sum(self.embedding_dimension)
|
||||
+ self.num_feat_dynamic_real
|
||||
+ self.num_feat_static_real
|
||||
+ 1 # the log(scale)
|
||||
+ self.input_size # the log(scale)
|
||||
)
|
||||
|
||||
@property
|
||||
@@ -591,8 +591,9 @@ class InformerModel(nn.Module):
|
||||
|
||||
# embeddings
|
||||
embedded_cat = self.embedder(feat_static_cat)
|
||||
log_scale = scale.log() if self.input_size == 1 else scale.squeeze(1).log()
|
||||
static_feat = torch.cat(
|
||||
(embedded_cat, feat_static_real, scale.log()),
|
||||
(embedded_cat, feat_static_real, log_scale),
|
||||
dim=1,
|
||||
)
|
||||
expanded_static_feat = static_feat.unsqueeze(1).expand(
|
||||
|
||||
@@ -74,6 +74,6 @@ class ReformerLightningModule(pl.LightningModule):
|
||||
if len(self.model.target_shape) == 0:
|
||||
loss_weights = future_observed_values
|
||||
else:
|
||||
loss_weights = future_observed_values.min(dim=-1, keepdim=False)
|
||||
loss_weights, _ = future_observed_values.min(dim=-1, keepdim=False)
|
||||
|
||||
return weighted_average(loss_values, weights=loss_weights)
|
||||
|
||||
+3
-2
@@ -157,7 +157,7 @@ class ReformerModel(nn.Module):
|
||||
sum(self.embedding_dimension)
|
||||
+ self.num_feat_dynamic_real
|
||||
+ self.num_feat_static_real
|
||||
+ 1 # the log(scale)
|
||||
+ self.input_size # the log(scale)
|
||||
)
|
||||
|
||||
@property
|
||||
@@ -266,8 +266,9 @@ class ReformerModel(nn.Module):
|
||||
|
||||
# embeddings
|
||||
embedded_cat = self.embedder(feat_static_cat)
|
||||
log_scale = scale.log() if self.input_size == 1 else scale.squeeze(1).log()
|
||||
static_feat = torch.cat(
|
||||
(embedded_cat, feat_static_real, scale.log()),
|
||||
(embedded_cat, feat_static_real, log_scale),
|
||||
dim=1,
|
||||
)
|
||||
expanded_static_feat = static_feat.unsqueeze(1).expand(
|
||||
|
||||
+4
-2
@@ -139,6 +139,7 @@
|
||||
" num_parallel_samples: int = 100,\n",
|
||||
" ) -> None:\n",
|
||||
" super().__init__()\n",
|
||||
" self.input_size = input_size\n",
|
||||
" self.context_length = context_length\n",
|
||||
" self.prediction_length = prediction_length\n",
|
||||
" self.distr_output = distr_output\n",
|
||||
@@ -190,7 +191,7 @@
|
||||
" sum(self.embedding_dimension)\n",
|
||||
" + self.num_feat_dynamic_real\n",
|
||||
" + self.num_feat_static_real\n",
|
||||
" + 1 # the log(scale)\n",
|
||||
" + self.input_size # the log(scale)\n",
|
||||
" )\n",
|
||||
" \n",
|
||||
" @property\n",
|
||||
@@ -233,8 +234,9 @@
|
||||
" assert inputs.shape[1] == unroll_length\n",
|
||||
"\n",
|
||||
" embedded_cat = self.embedder(feat_static_cat)\n",
|
||||
" log_scale = scale.log() if self.input_size == 1 else scale.squeeze(1).log()\n",
|
||||
" static_feat = torch.cat(\n",
|
||||
" (embedded_cat, feat_static_real, scale.log()),\n",
|
||||
" (embedded_cat, feat_static_real, log_scale),\n",
|
||||
" dim=1,\n",
|
||||
" )\n",
|
||||
" expanded_static_feat = static_feat.unsqueeze(1).expand(\n",
|
||||
|
||||
@@ -74,6 +74,6 @@ class SwitchTransformerLightningModule(pl.LightningModule):
|
||||
if len(self.model.target_shape) == 0:
|
||||
loss_weights = future_observed_values
|
||||
else:
|
||||
loss_weights = future_observed_values.min(dim=-1, keepdim=False)
|
||||
loss_weights, _ = future_observed_values.min(dim=-1, keepdim=False)
|
||||
|
||||
return weighted_average(loss_values, weights=loss_weights)
|
||||
|
||||
+3
-2
@@ -345,7 +345,7 @@ class SwitchTransformerModel(nn.Module):
|
||||
sum(self.embedding_dimension)
|
||||
+ self.num_feat_dynamic_real
|
||||
+ self.num_feat_static_real
|
||||
+ 1 # the log(scale)
|
||||
+ self.input_size # the log(scale)
|
||||
)
|
||||
|
||||
@property
|
||||
@@ -454,8 +454,9 @@ class SwitchTransformerModel(nn.Module):
|
||||
|
||||
# embeddings
|
||||
embedded_cat = self.embedder(feat_static_cat)
|
||||
log_scale = scale.log() if self.input_size == 1 else scale.squeeze(1).log()
|
||||
static_feat = torch.cat(
|
||||
(embedded_cat, feat_static_real, scale.log()),
|
||||
(embedded_cat, feat_static_real, log_scale),
|
||||
dim=1,
|
||||
)
|
||||
expanded_static_feat = static_feat.unsqueeze(1).expand(
|
||||
|
||||
@@ -86,6 +86,6 @@ class TFTLightningModule(pl.LightningModule):
|
||||
if len(self.model.target_shape) == 0:
|
||||
loss_weights = future_observed_values
|
||||
else:
|
||||
loss_weights = future_observed_values.min(dim=-1, keepdim=False)
|
||||
loss_weights, _ = future_observed_values.min(dim=-1, keepdim=False)
|
||||
|
||||
return weighted_average(loss_values, weights=loss_weights)
|
||||
|
||||
+2
-1
@@ -474,8 +474,9 @@ class TFTModel(nn.Module):
|
||||
|
||||
# embeddings
|
||||
embedded_cat = self.embedder(feat_static_cat)
|
||||
log_scale = scale.log() if self.input_size == 1 else scale.squeeze(1).log()
|
||||
static_feat = torch.cat(
|
||||
(feat_static_real, scale.log()),
|
||||
(feat_static_real, log_scale),
|
||||
dim=1,
|
||||
)
|
||||
|
||||
|
||||
@@ -74,6 +74,6 @@ class TransformerLightningModule(pl.LightningModule):
|
||||
if len(self.model.target_shape) == 0:
|
||||
loss_weights = future_observed_values
|
||||
else:
|
||||
loss_weights = future_observed_values.min(dim=-1, keepdim=False)
|
||||
loss_weights, _ = future_observed_values.min(dim=-1, keepdim=False)
|
||||
|
||||
return weighted_average(loss_values, weights=loss_weights)
|
||||
|
||||
@@ -93,7 +93,7 @@ class TransformerModel(nn.Module):
|
||||
sum(self.embedding_dimension)
|
||||
+ self.num_feat_dynamic_real
|
||||
+ self.num_feat_static_real
|
||||
+ 1 # the log(scale)
|
||||
+ self.input_size # the log(scale)
|
||||
)
|
||||
|
||||
@property
|
||||
@@ -202,8 +202,9 @@ class TransformerModel(nn.Module):
|
||||
|
||||
# embeddings
|
||||
embedded_cat = self.embedder(feat_static_cat)
|
||||
log_scale = scale.log() if self.input_size == 1 else scale.squeeze(1).log()
|
||||
static_feat = torch.cat(
|
||||
(embedded_cat, feat_static_real, scale.log()),
|
||||
(embedded_cat, feat_static_real, log_scale),
|
||||
dim=1,
|
||||
)
|
||||
expanded_static_feat = static_feat.unsqueeze(1).expand(
|
||||
|
||||
@@ -238,7 +238,7 @@
|
||||
" sum(self.embedding_dimension)\n",
|
||||
" + self.num_feat_dynamic_real\n",
|
||||
" + self.num_feat_static_real\n",
|
||||
" + 1 # the log(scale)\n",
|
||||
" + self.input_size # the log(scale)\n",
|
||||
" )\n",
|
||||
"\n",
|
||||
" @property\n",
|
||||
@@ -335,8 +335,9 @@
|
||||
" \n",
|
||||
" # embeddings\n",
|
||||
" embedded_cat = self.embedder(feat_static_cat)\n",
|
||||
" log_scale = scale.log() if self.input_size == 1 else scale.squeeze(1).log()\n",
|
||||
" static_feat = torch.cat(\n",
|
||||
" (embedded_cat, feat_static_real, scale.log()),\n",
|
||||
" (embedded_cat, feat_static_real, log_scale),\n",
|
||||
" dim=1,\n",
|
||||
" )\n",
|
||||
" expanded_static_feat = static_feat.unsqueeze(1).expand(\n",
|
||||
@@ -549,7 +550,7 @@
|
||||
" if len(self.model.target_shape) == 0:\n",
|
||||
" loss_weights = future_observed_values\n",
|
||||
" else:\n",
|
||||
" loss_weights = future_observed_values.min(dim=-1, keepdim=False)\n",
|
||||
" loss_weights, _ = future_observed_values.min(dim=-1, keepdim=False)\n",
|
||||
"\n",
|
||||
" return weighted_average(loss_values, weights=loss_weights)"
|
||||
]
|
||||
|
||||
Reference in New Issue
Block a user