mirror of
https://github.com/wassname/pytorch-transformer-ts.git
synced 2026-06-27 16:46:28 +08:00
fix for input_size>1
This commit is contained in:
@@ -79,6 +79,6 @@ class AutoformerLightningModule(pl.LightningModule):
|
|||||||
if len(self.model.target_shape) == 0:
|
if len(self.model.target_shape) == 0:
|
||||||
loss_weights = future_observed_values
|
loss_weights = future_observed_values
|
||||||
else:
|
else:
|
||||||
loss_weights = future_observed_values.min(dim=-1, keepdim=False)
|
loss_weights, _ = future_observed_values.min(dim=-1, keepdim=False)
|
||||||
|
|
||||||
return weighted_average(loss_values, weights=loss_weights)
|
return weighted_average(loss_values, weights=loss_weights)
|
||||||
|
|||||||
@@ -587,7 +587,7 @@ class AutoformerModel(nn.Module):
|
|||||||
sum(self.embedding_dimension)
|
sum(self.embedding_dimension)
|
||||||
+ self.num_feat_dynamic_real
|
+ self.num_feat_dynamic_real
|
||||||
+ self.num_feat_static_real
|
+ self.num_feat_static_real
|
||||||
+ 1 # the log(scale)
|
+ self.input_size # the log(scale)
|
||||||
)
|
)
|
||||||
|
|
||||||
@property
|
@property
|
||||||
@@ -696,8 +696,9 @@ class AutoformerModel(nn.Module):
|
|||||||
|
|
||||||
# embeddings
|
# embeddings
|
||||||
embedded_cat = self.embedder(feat_static_cat)
|
embedded_cat = self.embedder(feat_static_cat)
|
||||||
|
log_scale = scale.log() if self.input_size == 1 else scale.squeeze(1).log()
|
||||||
static_feat = torch.cat(
|
static_feat = torch.cat(
|
||||||
(embedded_cat, feat_static_real, scale.log()),
|
(embedded_cat, feat_static_real, log_scale),
|
||||||
dim=1,
|
dim=1,
|
||||||
)
|
)
|
||||||
expanded_static_feat = static_feat.unsqueeze(1).expand(
|
expanded_static_feat = static_feat.unsqueeze(1).expand(
|
||||||
|
|||||||
Reference in New Issue
Block a user