fixes to losses

This commit is contained in:
wassname
2020-03-14 09:41:46 +08:00
parent a43975bbce
commit a468c499b3
3 changed files with 158 additions and 505 deletions
File diff suppressed because one or more lines are too long
+23 -14
View File
@@ -18,6 +18,7 @@ class LatentModelPL(pl.LightningModule):
self.hparams.update(hparams.__dict__ if hasattr(hparams, '__dict__') else hparams)
self.model = LatentModel(**self.hparams)
self._dfs = None
self.train_logs = []
def forward(self, context_x, context_y, target_x, target_y):
return self.model(context_x, context_y, target_x, target_y)
@@ -27,15 +28,16 @@ class LatentModelPL(pl.LightningModule):
context_x, context_y, target_x, target_y = batch
y_pred, losses, extra = self.forward(context_x, context_y, target_x, target_y)
y_std = extra['dist'].scale
loss = losses['loss']
tensorboard_logs = {
"train_loss": losses['loss'],
"train_loss": loss,
"train/kl": losses['loss_kl'].mean(),
"train/std": losses['y_std'].mean(),
"train/std": y_std.mean(),
"train/mse": losses['loss_mse'].mean(),
}
assert torch.isfinite(loss)
# print('device', next(self.model.parameters()).device)
self.train_logs.append(tensorboard_logs)
return {"loss": loss, "log": tensorboard_logs}
def validation_step(self, batch, batch_idx):
@@ -43,12 +45,13 @@ class LatentModelPL(pl.LightningModule):
context_x, context_y, target_x, target_y = batch
y_pred, losses, extra = self.forward(context_x, context_y, target_x, target_y)
y_std = extra['dist'].scale
loss = losses['loss']
tensorboard_logs = {
"val_loss": losses['loss'], # This exact key is needed for metrics
"val_loss": loss, # This exact key is needed for metrics
"val/kl": losses['loss_kl'].mean(),
"val/mse": losses['loss_mse'].mean(),
"val/std": losses['y_std'].mean(),
"val/std": y_std.mean(),
}
return {"val_loss": loss, "log": tensorboard_logs}
@@ -63,7 +66,12 @@ class LatentModelPL(pl.LightningModule):
self.show_image()
logs = self.agg_logs(outputs)
tensorboard_logs_str = {k: f'{v}' for k, v in logs["log"].items()}
print(f"step val {self.trainer.global_step}, {tensorboard_logs_str}")
# agg and print self.train_logs HACK https://github.com/PyTorchLightning/pytorch-lightning/issues/100
train_logs = self.agg_logs(self.train_logs)
train_logs_str = {k: f"{v}" for k, v in train_logs.items()}
self.train_logs = []
print(f"step val {self.trainer.global_step}, {tensorboard_logs_str} {train_logs}")
return logs
def show_image(self):
@@ -82,14 +90,15 @@ class LatentModelPL(pl.LightningModule):
if isinstance(outputs, dict):
outputs = [outputs]
aggs = {}
for j in outputs[0]:
if isinstance(outputs[0][j], dict):
# Take mean of sub dicts
keys = outputs[0][j].keys()
aggs[j] = {k: torch.stack([x[j][k] for x in outputs if k in x[j]]).mean() for k in keys}
else:
# Take mean of numbers
aggs[j] = torch.stack([x[j] for x in outputs if j in x]).mean()
if len(outputs)>0:
for j in outputs[0]:
if isinstance(outputs[0][j], dict):
# Take mean of sub dicts
keys = outputs[0][j].keys()
aggs[j] = {k: torch.stack([x[j][k] for x in outputs if k in x[j]]).mean() for k in keys}
else:
# Take mean of numbers
aggs[j] = torch.stack([x[j] for x in outputs if j in x]).mean()
return aggs
# # Log hparams with metric, doesn't work
+13 -10
View File
@@ -57,12 +57,14 @@ class LatentModel(nn.Module):
use_lstm_le=False, # use another LSTM in latent encoder instead of MLP
use_lstm_de=False, # use another LSTM in determinstic encoder instead of MLP
use_lstm_d=False, # use another lstm in decoder instead of MLP
context_in_target=False,
**kwargs,
):
super(LatentModel, self).__init__()
self._use_rnn = use_rnn
self.context_in_target = context_in_target
if self._use_rnn:
self._lstm = nn.LSTM(
@@ -152,25 +154,26 @@ class LatentModel(nn.Module):
if self._use_lvar:
log_p = log_prob_sigma(target_y, dist.loc, log_sigma).mean(-1) # [B, T_target, Y].mean(-1)
if self.hparams["context_in_target"]:
if self.context_in_target:
log_p[:, :context_x.size(1)] /= 100
kl_loss = kl_loss_var(dist_prior.loc, log_var_prior,
loss_kl = kl_loss_var(dist_prior.loc, log_var_prior,
dist_post.loc, log_var_post).mean(-1) # [B, R].mean(-1)
else:
log_p = dist.log_prob(target_y).mean(-1)
if self.hparams["context_in_target"]:
if self.context_in_target:
log_p[:, :context_x.size(1)] /= 100 # There's the temptation for it to fit only on context, where it knows the answer, and learn very low uncertainty.
kl_loss = torch.distributions.kl_divergence(
loss_kl = torch.distributions.kl_divergence(
dist_post, dist_prior).mean(-1) # [B, R].mean(-1)
kl_loss = kl_loss[:, None].expand(log_p.shape)
mse_loss = F.mse_loss(dist.loc, target_y, reduce=None)[:, :context_x.size(1)].mean()
loss = (kl_loss - log_p).mean()
loss_kl = loss_kl[:, None].expand(log_p.shape)
mse_loss = F.mse_loss(dist.loc, target_y, reduction='none')[:,:context_x.size(1)].mean()
loss_p = -log_p.mean()
loss = (loss_kl - log_p).mean()
else:
log_p = None
loss_p = None
mse_loss = None
kl_loss = None
loss_kl = None
loss = None
y_pred = dist.rsample() if self.training else dist.loc
return y_pred, dict(loss=loss, loss_p=loss_p.mean(), loss_kl=loss_kl, loss_mse=mse_loss.mean()), dict(log_sigma=log_sigma, dist=dist)
return y_pred, dict(loss=loss, loss_p=loss_p, loss_kl=loss_kl, loss_mse=mse_loss), dict(log_sigma=log_sigma, dist=dist)