This commit is contained in:
wassname
2020-03-01 12:23:48 +08:00
parent 45117ff491
commit e91377ed96
6 changed files with 7537 additions and 15458 deletions
+166 -9933
View File
File diff suppressed because one or more lines are too long
File diff suppressed because one or more lines are too long
+6648 -5232
View File
File diff suppressed because one or more lines are too long
File diff suppressed because one or more lines are too long
+5 -5
View File
@@ -25,18 +25,18 @@ class LatentModelPL(pl.LightningModule):
def training_step(self, batch, batch_idx):
assert all(torch.isfinite(d).all() for d in batch)
context_x, context_y, target_x, target_y = batch
y_pred, losses, extra = = self.forward(context_x, context_y, target_x, target_y)
y_pred, losses, extra = self.forward(context_x, context_y, target_x, target_y)
y_std = extra['dist'].scale
tensorboard_logs = {
"train_loss": losses['loss'],
"train/kl": losses['loss_kl'].mean(),
"train/std": losses['y_std'].mean(),
"train/std": y_std.mean(),
"train/mse": losses['loss_mse'].mean(),
}
assert torch.isfinite(loss)
# print('device', next(self.model.parameters()).device)
return {"loss": loss, "log": tensorboard_logs}
return {"loss": losses['loss'], "log": tensorboard_logs}
def validation_step(self, batch, batch_idx):
assert all(torch.isfinite(d).all() for d in batch)
@@ -48,9 +48,9 @@ class LatentModelPL(pl.LightningModule):
"val_loss": losses['loss'], # This exact key is needed for metrics
"val/kl": losses['loss_kl'].mean(),
"val/mse": losses['loss_mse'].mean(),
"val/std": losses['y_std'].mean(),
"val/std": y_std.mean(),
}
return {"val_loss": loss, "log": tensorboard_logs}
return {"val_loss": losses['loss'], "log": tensorboard_logs}
# def training_end(self, outputs):
# logs = self.agg_logs(outputs)
+7 -5
View File
@@ -56,7 +56,8 @@ class LatentModel(nn.Module):
use_rnn=True, # use RNN/LSTM?
use_lstm_le=False, # use another LSTM in latent encoder instead of MLP
use_lstm_de=False, # use another LSTM in determinstic encoder instead of MLP
use_lstm_d=False, # use another lstm in decoder instead of MLP
use_lstm_d=False, # use another lstm in decoder instead of MLP
context_in_target=True,
**kwargs,
):
@@ -121,6 +122,7 @@ class LatentModel(nn.Module):
)
self._use_deterministic_path = use_deterministic_path
self._use_lvar = use_lvar
self.context_in_target = context_in_target
def forward(self, context_x, context_y, target_x, target_y=None):
@@ -152,18 +154,18 @@ class LatentModel(nn.Module):
if self._use_lvar:
log_p = log_prob_sigma(target_y, dist.loc, log_sigma).mean(-1) # [B, T_target, Y].mean(-1)
if self.hparams["context_in_target"]:
if self.context_in_target:
log_p[:, :context_x.size(1)] /= 100
kl_loss = kl_loss_var(dist_prior.loc, log_var_prior,
dist_post.loc, log_var_post).mean(-1) # [B, R].mean(-1)
else:
log_p = dist.log_prob(target_y).mean(-1)
if self.hparams["context_in_target"]:
if self.context_in_target:
log_p[:, :context_x.size(1)] /= 100 # There's the temptation for it to fit only on context, where it knows the answer, and learn very low uncertainty.
kl_loss = torch.distributions.kl_divergence(
dist_post, dist_prior).mean(-1) # [B, R].mean(-1)
kl_loss = kl_loss[:, None].expand(log_p.shape)
mse_loss = F.mse_loss(dist.loc, target_y, reduce=None)[:, :context_x.size(1)].mean()
mse_loss = F.mse_loss(dist.loc, target_y, reduction='none')[:, :context_x.size(1)].mean()
loss = (kl_loss - log_p).mean()
else:
@@ -173,4 +175,4 @@ class LatentModel(nn.Module):
loss = None
y_pred = dist.rsample() if self.training else dist.loc
return y_pred, dict(loss=loss, loss_p=loss_p.mean(), loss_kl=loss_kl, loss_mse=mse_loss.mean()), dict(log_sigma=log_sigma, dist=dist)
return y_pred, dict(loss=loss, loss_p=-log_p.mean(), loss_kl=kl_loss, loss_mse=mse_loss.mean()), dict(log_sigma=log_sigma, dist=dist)