mirror of
https://github.com/wassname/attentive-neural-processes.git
synced 2026-06-27 18:03:39 +08:00
fixes to losses
This commit is contained in:
+122
-481
File diff suppressed because one or more lines are too long
+23
-14
@@ -18,6 +18,7 @@ class LatentModelPL(pl.LightningModule):
|
||||
self.hparams.update(hparams.__dict__ if hasattr(hparams, '__dict__') else hparams)
|
||||
self.model = LatentModel(**self.hparams)
|
||||
self._dfs = None
|
||||
self.train_logs = []
|
||||
|
||||
def forward(self, context_x, context_y, target_x, target_y):
|
||||
return self.model(context_x, context_y, target_x, target_y)
|
||||
@@ -27,15 +28,16 @@ class LatentModelPL(pl.LightningModule):
|
||||
context_x, context_y, target_x, target_y = batch
|
||||
y_pred, losses, extra = self.forward(context_x, context_y, target_x, target_y)
|
||||
y_std = extra['dist'].scale
|
||||
loss = losses['loss']
|
||||
|
||||
tensorboard_logs = {
|
||||
"train_loss": losses['loss'],
|
||||
"train_loss": loss,
|
||||
"train/kl": losses['loss_kl'].mean(),
|
||||
"train/std": losses['y_std'].mean(),
|
||||
"train/std": y_std.mean(),
|
||||
"train/mse": losses['loss_mse'].mean(),
|
||||
}
|
||||
assert torch.isfinite(loss)
|
||||
# print('device', next(self.model.parameters()).device)
|
||||
self.train_logs.append(tensorboard_logs)
|
||||
return {"loss": loss, "log": tensorboard_logs}
|
||||
|
||||
def validation_step(self, batch, batch_idx):
|
||||
@@ -43,12 +45,13 @@ class LatentModelPL(pl.LightningModule):
|
||||
context_x, context_y, target_x, target_y = batch
|
||||
y_pred, losses, extra = self.forward(context_x, context_y, target_x, target_y)
|
||||
y_std = extra['dist'].scale
|
||||
loss = losses['loss']
|
||||
|
||||
tensorboard_logs = {
|
||||
"val_loss": losses['loss'], # This exact key is needed for metrics
|
||||
"val_loss": loss, # This exact key is needed for metrics
|
||||
"val/kl": losses['loss_kl'].mean(),
|
||||
"val/mse": losses['loss_mse'].mean(),
|
||||
"val/std": losses['y_std'].mean(),
|
||||
"val/std": y_std.mean(),
|
||||
}
|
||||
return {"val_loss": loss, "log": tensorboard_logs}
|
||||
|
||||
@@ -63,7 +66,12 @@ class LatentModelPL(pl.LightningModule):
|
||||
self.show_image()
|
||||
logs = self.agg_logs(outputs)
|
||||
tensorboard_logs_str = {k: f'{v}' for k, v in logs["log"].items()}
|
||||
print(f"step val {self.trainer.global_step}, {tensorboard_logs_str}")
|
||||
|
||||
# agg and print self.train_logs HACK https://github.com/PyTorchLightning/pytorch-lightning/issues/100
|
||||
train_logs = self.agg_logs(self.train_logs)
|
||||
train_logs_str = {k: f"{v}" for k, v in train_logs.items()}
|
||||
self.train_logs = []
|
||||
print(f"step val {self.trainer.global_step}, {tensorboard_logs_str} {train_logs}")
|
||||
return logs
|
||||
|
||||
def show_image(self):
|
||||
@@ -82,14 +90,15 @@ class LatentModelPL(pl.LightningModule):
|
||||
if isinstance(outputs, dict):
|
||||
outputs = [outputs]
|
||||
aggs = {}
|
||||
for j in outputs[0]:
|
||||
if isinstance(outputs[0][j], dict):
|
||||
# Take mean of sub dicts
|
||||
keys = outputs[0][j].keys()
|
||||
aggs[j] = {k: torch.stack([x[j][k] for x in outputs if k in x[j]]).mean() for k in keys}
|
||||
else:
|
||||
# Take mean of numbers
|
||||
aggs[j] = torch.stack([x[j] for x in outputs if j in x]).mean()
|
||||
if len(outputs)>0:
|
||||
for j in outputs[0]:
|
||||
if isinstance(outputs[0][j], dict):
|
||||
# Take mean of sub dicts
|
||||
keys = outputs[0][j].keys()
|
||||
aggs[j] = {k: torch.stack([x[j][k] for x in outputs if k in x[j]]).mean() for k in keys}
|
||||
else:
|
||||
# Take mean of numbers
|
||||
aggs[j] = torch.stack([x[j] for x in outputs if j in x]).mean()
|
||||
return aggs
|
||||
|
||||
# # Log hparams with metric, doesn't work
|
||||
|
||||
+13
-10
@@ -57,12 +57,14 @@ class LatentModel(nn.Module):
|
||||
use_lstm_le=False, # use another LSTM in latent encoder instead of MLP
|
||||
use_lstm_de=False, # use another LSTM in determinstic encoder instead of MLP
|
||||
use_lstm_d=False, # use another lstm in decoder instead of MLP
|
||||
context_in_target=False,
|
||||
**kwargs,
|
||||
):
|
||||
|
||||
super(LatentModel, self).__init__()
|
||||
|
||||
self._use_rnn = use_rnn
|
||||
self.context_in_target = context_in_target
|
||||
|
||||
if self._use_rnn:
|
||||
self._lstm = nn.LSTM(
|
||||
@@ -152,25 +154,26 @@ class LatentModel(nn.Module):
|
||||
|
||||
if self._use_lvar:
|
||||
log_p = log_prob_sigma(target_y, dist.loc, log_sigma).mean(-1) # [B, T_target, Y].mean(-1)
|
||||
if self.hparams["context_in_target"]:
|
||||
if self.context_in_target:
|
||||
log_p[:, :context_x.size(1)] /= 100
|
||||
kl_loss = kl_loss_var(dist_prior.loc, log_var_prior,
|
||||
loss_kl = kl_loss_var(dist_prior.loc, log_var_prior,
|
||||
dist_post.loc, log_var_post).mean(-1) # [B, R].mean(-1)
|
||||
else:
|
||||
log_p = dist.log_prob(target_y).mean(-1)
|
||||
if self.hparams["context_in_target"]:
|
||||
if self.context_in_target:
|
||||
log_p[:, :context_x.size(1)] /= 100 # There's the temptation for it to fit only on context, where it knows the answer, and learn very low uncertainty.
|
||||
kl_loss = torch.distributions.kl_divergence(
|
||||
loss_kl = torch.distributions.kl_divergence(
|
||||
dist_post, dist_prior).mean(-1) # [B, R].mean(-1)
|
||||
kl_loss = kl_loss[:, None].expand(log_p.shape)
|
||||
mse_loss = F.mse_loss(dist.loc, target_y, reduce=None)[:, :context_x.size(1)].mean()
|
||||
loss = (kl_loss - log_p).mean()
|
||||
loss_kl = loss_kl[:, None].expand(log_p.shape)
|
||||
mse_loss = F.mse_loss(dist.loc, target_y, reduction='none')[:,:context_x.size(1)].mean()
|
||||
loss_p = -log_p.mean()
|
||||
loss = (loss_kl - log_p).mean()
|
||||
|
||||
else:
|
||||
log_p = None
|
||||
loss_p = None
|
||||
mse_loss = None
|
||||
kl_loss = None
|
||||
loss_kl = None
|
||||
loss = None
|
||||
|
||||
y_pred = dist.rsample() if self.training else dist.loc
|
||||
return y_pred, dict(loss=loss, loss_p=loss_p.mean(), loss_kl=loss_kl, loss_mse=mse_loss.mean()), dict(log_sigma=log_sigma, dist=dist)
|
||||
return y_pred, dict(loss=loss, loss_p=loss_p, loss_kl=loss_kl, loss_mse=mse_loss), dict(log_sigma=log_sigma, dist=dist)
|
||||
|
||||
Reference in New Issue
Block a user