diff --git a/informer/lightning_module.py b/informer/lightning_module.py index dcbc667..b5b3a69 100644 --- a/informer/lightning_module.py +++ b/informer/lightning_module.py @@ -75,6 +75,6 @@ class InformerLightningModule(pl.LightningModule): if len(self.model.target_shape) == 0: loss_weights = future_observed_values else: - loss_weights = future_observed_values.min(dim=-1, keepdim=False) + loss_weights, _ = future_observed_values.min(dim=-1, keepdim=False) return weighted_average(loss_values, weights=loss_weights) diff --git a/informer/module.py b/informer/module.py index 4890d64..d83f502 100644 --- a/informer/module.py +++ b/informer/module.py @@ -482,7 +482,7 @@ class InformerModel(nn.Module): sum(self.embedding_dimension) + self.num_feat_dynamic_real + self.num_feat_static_real - + 1 # the log(scale) + + self.input_size # the log(scale) ) @property @@ -591,8 +591,9 @@ class InformerModel(nn.Module): # embeddings embedded_cat = self.embedder(feat_static_cat) + log_scale = scale.log() if self.input_size == 1 else scale.squeeze(1).log() static_feat = torch.cat( - (embedded_cat, feat_static_real, scale.log()), + (embedded_cat, feat_static_real, log_scale), dim=1, ) expanded_static_feat = static_feat.unsqueeze(1).expand( diff --git a/reformer/lightning_module.py b/reformer/lightning_module.py index 63f82b5..ed147d6 100644 --- a/reformer/lightning_module.py +++ b/reformer/lightning_module.py @@ -74,6 +74,6 @@ class ReformerLightningModule(pl.LightningModule): if len(self.model.target_shape) == 0: loss_weights = future_observed_values else: - loss_weights = future_observed_values.min(dim=-1, keepdim=False) + loss_weights, _ = future_observed_values.min(dim=-1, keepdim=False) return weighted_average(loss_values, weights=loss_weights) diff --git a/reformer/module.py b/reformer/module.py index dbad019..19116d1 100644 --- a/reformer/module.py +++ b/reformer/module.py @@ -157,7 +157,7 @@ class ReformerModel(nn.Module): sum(self.embedding_dimension) + self.num_feat_dynamic_real + self.num_feat_static_real - + 1 # the log(scale) + + self.input_size # the log(scale) ) @property @@ -266,8 +266,9 @@ class ReformerModel(nn.Module): # embeddings embedded_cat = self.embedder(feat_static_cat) + log_scale = scale.log() if self.input_size == 1 else scale.squeeze(1).log() static_feat = torch.cat( - (embedded_cat, feat_static_real, scale.log()), + (embedded_cat, feat_static_real, log_scale), dim=1, ) expanded_static_feat = static_feat.unsqueeze(1).expand( diff --git a/s4/s4.ipynb b/s4/s4.ipynb index 77c6e5d..91b5a8c 100644 --- a/s4/s4.ipynb +++ b/s4/s4.ipynb @@ -139,6 +139,7 @@ " num_parallel_samples: int = 100,\n", " ) -> None:\n", " super().__init__()\n", + " self.input_size = input_size\n", " self.context_length = context_length\n", " self.prediction_length = prediction_length\n", " self.distr_output = distr_output\n", @@ -190,7 +191,7 @@ " sum(self.embedding_dimension)\n", " + self.num_feat_dynamic_real\n", " + self.num_feat_static_real\n", - " + 1 # the log(scale)\n", + " + self.input_size # the log(scale)\n", " )\n", " \n", " @property\n", @@ -233,8 +234,9 @@ " assert inputs.shape[1] == unroll_length\n", "\n", " embedded_cat = self.embedder(feat_static_cat)\n", + " log_scale = scale.log() if self.input_size == 1 else scale.squeeze(1).log()\n", " static_feat = torch.cat(\n", - " (embedded_cat, feat_static_real, scale.log()),\n", + " (embedded_cat, feat_static_real, log_scale),\n", " dim=1,\n", " )\n", " expanded_static_feat = static_feat.unsqueeze(1).expand(\n", diff --git a/switch/lightning_module.py b/switch/lightning_module.py index 211db08..293c5f2 100644 --- a/switch/lightning_module.py +++ b/switch/lightning_module.py @@ -74,6 +74,6 @@ class SwitchTransformerLightningModule(pl.LightningModule): if len(self.model.target_shape) == 0: loss_weights = future_observed_values else: - loss_weights = future_observed_values.min(dim=-1, keepdim=False) + loss_weights, _ = future_observed_values.min(dim=-1, keepdim=False) return weighted_average(loss_values, weights=loss_weights) diff --git a/switch/module.py b/switch/module.py index fc58dde..edb4ab3 100644 --- a/switch/module.py +++ b/switch/module.py @@ -345,7 +345,7 @@ class SwitchTransformerModel(nn.Module): sum(self.embedding_dimension) + self.num_feat_dynamic_real + self.num_feat_static_real - + 1 # the log(scale) + + self.input_size # the log(scale) ) @property @@ -454,8 +454,9 @@ class SwitchTransformerModel(nn.Module): # embeddings embedded_cat = self.embedder(feat_static_cat) + log_scale = scale.log() if self.input_size == 1 else scale.squeeze(1).log() static_feat = torch.cat( - (embedded_cat, feat_static_real, scale.log()), + (embedded_cat, feat_static_real, log_scale), dim=1, ) expanded_static_feat = static_feat.unsqueeze(1).expand( diff --git a/tft/lightning_module.py b/tft/lightning_module.py index 747c300..90663af 100644 --- a/tft/lightning_module.py +++ b/tft/lightning_module.py @@ -86,6 +86,6 @@ class TFTLightningModule(pl.LightningModule): if len(self.model.target_shape) == 0: loss_weights = future_observed_values else: - loss_weights = future_observed_values.min(dim=-1, keepdim=False) + loss_weights, _ = future_observed_values.min(dim=-1, keepdim=False) return weighted_average(loss_values, weights=loss_weights) diff --git a/tft/module.py b/tft/module.py index cfbc111..afe9543 100644 --- a/tft/module.py +++ b/tft/module.py @@ -474,8 +474,9 @@ class TFTModel(nn.Module): # embeddings embedded_cat = self.embedder(feat_static_cat) + log_scale = scale.log() if self.input_size == 1 else scale.squeeze(1).log() static_feat = torch.cat( - (feat_static_real, scale.log()), + (feat_static_real, log_scale), dim=1, ) diff --git a/transformer/lightning_module.py b/transformer/lightning_module.py index afaee88..68c28f3 100644 --- a/transformer/lightning_module.py +++ b/transformer/lightning_module.py @@ -74,6 +74,6 @@ class TransformerLightningModule(pl.LightningModule): if len(self.model.target_shape) == 0: loss_weights = future_observed_values else: - loss_weights = future_observed_values.min(dim=-1, keepdim=False) + loss_weights, _ = future_observed_values.min(dim=-1, keepdim=False) return weighted_average(loss_values, weights=loss_weights) diff --git a/transformer/module.py b/transformer/module.py index a5cb972..3568dc5 100644 --- a/transformer/module.py +++ b/transformer/module.py @@ -93,7 +93,7 @@ class TransformerModel(nn.Module): sum(self.embedding_dimension) + self.num_feat_dynamic_real + self.num_feat_static_real - + 1 # the log(scale) + + self.input_size # the log(scale) ) @property @@ -202,8 +202,9 @@ class TransformerModel(nn.Module): # embeddings embedded_cat = self.embedder(feat_static_cat) + log_scale = scale.log() if self.input_size == 1 else scale.squeeze(1).log() static_feat = torch.cat( - (embedded_cat, feat_static_real, scale.log()), + (embedded_cat, feat_static_real, log_scale), dim=1, ) expanded_static_feat = static_feat.unsqueeze(1).expand( diff --git a/xformers/xformers.ipynb b/xformers/xformers.ipynb index 31e8256..a62e4f7 100644 --- a/xformers/xformers.ipynb +++ b/xformers/xformers.ipynb @@ -238,7 +238,7 @@ " sum(self.embedding_dimension)\n", " + self.num_feat_dynamic_real\n", " + self.num_feat_static_real\n", - " + 1 # the log(scale)\n", + " + self.input_size # the log(scale)\n", " )\n", "\n", " @property\n", @@ -335,8 +335,9 @@ " \n", " # embeddings\n", " embedded_cat = self.embedder(feat_static_cat)\n", + " log_scale = scale.log() if self.input_size == 1 else scale.squeeze(1).log()\n", " static_feat = torch.cat(\n", - " (embedded_cat, feat_static_real, scale.log()),\n", + " (embedded_cat, feat_static_real, log_scale),\n", " dim=1,\n", " )\n", " expanded_static_feat = static_feat.unsqueeze(1).expand(\n", @@ -549,7 +550,7 @@ " if len(self.model.target_shape) == 0:\n", " loss_weights = future_observed_values\n", " else:\n", - " loss_weights = future_observed_values.min(dim=-1, keepdim=False)\n", + " loss_weights, _ = future_observed_values.min(dim=-1, keepdim=False)\n", "\n", " return weighted_average(loss_values, weights=loss_weights)" ]