mirror of
https://github.com/wassname/attentive-neural-processes.git
synced 2026-06-27 18:03:39 +08:00
fixes
This commit is contained in:
+166
-9933
File diff suppressed because one or more lines are too long
+220
-279
File diff suppressed because one or more lines are too long
+6648
-5232
File diff suppressed because one or more lines are too long
File diff suppressed because one or more lines are too long
@@ -25,18 +25,18 @@ class LatentModelPL(pl.LightningModule):
|
||||
def training_step(self, batch, batch_idx):
|
||||
assert all(torch.isfinite(d).all() for d in batch)
|
||||
context_x, context_y, target_x, target_y = batch
|
||||
y_pred, losses, extra = = self.forward(context_x, context_y, target_x, target_y)
|
||||
y_pred, losses, extra = self.forward(context_x, context_y, target_x, target_y)
|
||||
y_std = extra['dist'].scale
|
||||
|
||||
tensorboard_logs = {
|
||||
"train_loss": losses['loss'],
|
||||
"train/kl": losses['loss_kl'].mean(),
|
||||
"train/std": losses['y_std'].mean(),
|
||||
"train/std": y_std.mean(),
|
||||
"train/mse": losses['loss_mse'].mean(),
|
||||
}
|
||||
assert torch.isfinite(loss)
|
||||
# print('device', next(self.model.parameters()).device)
|
||||
return {"loss": loss, "log": tensorboard_logs}
|
||||
return {"loss": losses['loss'], "log": tensorboard_logs}
|
||||
|
||||
def validation_step(self, batch, batch_idx):
|
||||
assert all(torch.isfinite(d).all() for d in batch)
|
||||
@@ -48,9 +48,9 @@ class LatentModelPL(pl.LightningModule):
|
||||
"val_loss": losses['loss'], # This exact key is needed for metrics
|
||||
"val/kl": losses['loss_kl'].mean(),
|
||||
"val/mse": losses['loss_mse'].mean(),
|
||||
"val/std": losses['y_std'].mean(),
|
||||
"val/std": y_std.mean(),
|
||||
}
|
||||
return {"val_loss": loss, "log": tensorboard_logs}
|
||||
return {"val_loss": losses['loss'], "log": tensorboard_logs}
|
||||
|
||||
# def training_end(self, outputs):
|
||||
# logs = self.agg_logs(outputs)
|
||||
|
||||
+7
-5
@@ -56,7 +56,8 @@ class LatentModel(nn.Module):
|
||||
use_rnn=True, # use RNN/LSTM?
|
||||
use_lstm_le=False, # use another LSTM in latent encoder instead of MLP
|
||||
use_lstm_de=False, # use another LSTM in determinstic encoder instead of MLP
|
||||
use_lstm_d=False, # use another lstm in decoder instead of MLP
|
||||
use_lstm_d=False, # use another lstm in decoder instead of MLP
|
||||
context_in_target=True,
|
||||
**kwargs,
|
||||
):
|
||||
|
||||
@@ -121,6 +122,7 @@ class LatentModel(nn.Module):
|
||||
)
|
||||
self._use_deterministic_path = use_deterministic_path
|
||||
self._use_lvar = use_lvar
|
||||
self.context_in_target = context_in_target
|
||||
|
||||
def forward(self, context_x, context_y, target_x, target_y=None):
|
||||
|
||||
@@ -152,18 +154,18 @@ class LatentModel(nn.Module):
|
||||
|
||||
if self._use_lvar:
|
||||
log_p = log_prob_sigma(target_y, dist.loc, log_sigma).mean(-1) # [B, T_target, Y].mean(-1)
|
||||
if self.hparams["context_in_target"]:
|
||||
if self.context_in_target:
|
||||
log_p[:, :context_x.size(1)] /= 100
|
||||
kl_loss = kl_loss_var(dist_prior.loc, log_var_prior,
|
||||
dist_post.loc, log_var_post).mean(-1) # [B, R].mean(-1)
|
||||
else:
|
||||
log_p = dist.log_prob(target_y).mean(-1)
|
||||
if self.hparams["context_in_target"]:
|
||||
if self.context_in_target:
|
||||
log_p[:, :context_x.size(1)] /= 100 # There's the temptation for it to fit only on context, where it knows the answer, and learn very low uncertainty.
|
||||
kl_loss = torch.distributions.kl_divergence(
|
||||
dist_post, dist_prior).mean(-1) # [B, R].mean(-1)
|
||||
kl_loss = kl_loss[:, None].expand(log_p.shape)
|
||||
mse_loss = F.mse_loss(dist.loc, target_y, reduce=None)[:, :context_x.size(1)].mean()
|
||||
mse_loss = F.mse_loss(dist.loc, target_y, reduction='none')[:, :context_x.size(1)].mean()
|
||||
loss = (kl_loss - log_p).mean()
|
||||
|
||||
else:
|
||||
@@ -173,4 +175,4 @@ class LatentModel(nn.Module):
|
||||
loss = None
|
||||
|
||||
y_pred = dist.rsample() if self.training else dist.loc
|
||||
return y_pred, dict(loss=loss, loss_p=loss_p.mean(), loss_kl=loss_kl, loss_mse=mse_loss.mean()), dict(log_sigma=log_sigma, dist=dist)
|
||||
return y_pred, dict(loss=loss, loss_p=-log_p.mean(), loss_kl=kl_loss, loss_mse=mse_loss.mean()), dict(log_sigma=log_sigma, dist=dist)
|
||||
|
||||
Reference in New Issue
Block a user