This commit is contained in:
wassname
2020-04-11 18:56:20 +08:00
parent 7ff5df3942
commit c906140440
5 changed files with 261 additions and 579 deletions
+2
View File
@@ -143,6 +143,7 @@ class LSTMSeq2Seq_PL(PL_Seq2Seq):
"learning_rate": 0.001,
"lstm_layers": 4,
'bidirectional': False
}
@staticmethod
@@ -165,5 +166,6 @@ class LSTMSeq2Seq_PL(PL_Seq2Seq):
"input_size_decoder": 17,
"context_in_target": False,
"output_size": 1,
'min_std': 0.005,
}
return trial
+60 -19
View File
@@ -36,7 +36,7 @@ class LSTMNet(nn.Module):
self._min_std = _min_std
self.lstm1 = nn.LSTM(
x_dim=self.hparams.x_dim,
input_size=self.hparams.x_dim+self.hparams.y_dim,
hidden_size=self.hparams.hidden_size,
batch_first=True,
num_layers=self.hparams.lstm_layers,
@@ -48,29 +48,70 @@ class LSTMNet(nn.Module):
)
self.mean = nn.Linear(self.hidden_out_size, 1)
self.std = nn.Linear(self.hidden_out_size, 1)
self._use_lvar = 0
def forward(self, context_x, context_y, target_x, target_y=None):
loss_scale = 1
device = next(self.parameters()).device
x = torch.cat([context_x, context_y], -1).detach()
target_y_fake = (
torch.ones(context_y.shape[0], target_x.shape[1], context_y.shape[2]).float().to(device) * self.hparams.nan_value
)
loss_scale = 1
context = torch.cat([context_x, context_y], -1).detach()
target = torch.cat([target_x, target_y_fake], -1).detach()
x = torch.cat([context, target * 1], 1).detach()
outputs, (h_out, _) = self.lstm1(x)
# outputs: [B, T, num_direction * H]
y_pred = self.mean(outputs).squeeze(2)
log_sigma = self.std(outputs).squeeze(2)
loss = None
if target_y is not None:
loss = (
F.mse_loss(
y_pred * loss_scale, y[:, -steps:, :] * loss_scale, reduction="none"
)
/ loss_scale
steps = context_y.shape[1]
mean = self.mean(outputs)[:, steps:, :]#.squeeze(2)
log_sigma = self.std(outputs)[:, steps:,:] #.squeeze(2)
if self._use_lvar:
log_sigma = torch.clamp(
log_sigma, math.log(self._min_std), -math.log(self._min_std)
)
sigma = torch.exp(log_sigma)
else:
sigma = self._min_std + (1 - self._min_std) * F.softplus(log_sigma)
y_dist = torch.distributions.Normal(mean, sigma)
assert torch.isfinite(loss)
# Loss
loss_mse = loss_p = None
if target_y is not None:
loss_mse = F.mse_loss(mean, target_y, reduction="none")
if self._use_lvar:
loss_p = -log_prob_sigma(target_y, mean, log_sigma)
else:
loss_p = -y_dist.log_prob(target_y).mean(-1)
if self.hparams["context_in_target"]:
loss_p[: context_x.size(1)] /= 100
loss_mse[: context_x.size(1)] /= 100
# # Don't catch loss on context window
# mean = mean[:, self.hparams.num_context:]
# log_sigma = log_sigma[:, self.hparams.num_context:]
return y_pred, dict(loss=loss), dict()
# Weight loss nearer to prediction time?
weight = (torch.arange(loss_p.shape[1]) + 1).float().to(device)[None, :]
loss_p = loss_p / torch.sqrt(weight) # We want to weight nearer stuff more
y_pred = y_dist.rsample if self.training else y_dist.loc
return (
y_pred,
dict(loss_p=loss_p.mean(), loss_mse=loss_mse.mean()),
dict(log_sigma=log_sigma, dist=y_dist),
)
# loss = None
# if target_y is not None:
# loss = (
# F.mse_loss(
# y_pred * loss_scale, y[:, -steps:, :] * loss_scale, reduction="none"
# )
# / loss_scale
# )
# assert torch.isfinite(loss)
# return y_pred, dict(loss=loss), dict()
class LSTM_PL_STD(PL_Seq2Seq):
@@ -79,7 +120,7 @@ class LSTM_PL_STD(PL_Seq2Seq):
DEFAULT_ARGS = {
"bidirectional": False,
"hidden_size_power": 4,
"hidden_size_power": 5,
"learning_rate": 0.001,
"lstm_dropout": 0.39,
"lstm_layers": 4,
@@ -100,12 +141,12 @@ class LSTM_PL_STD(PL_Seq2Seq):
"grad_clip": 40,
"max_nb_epochs": 200,
"num_workers": 4,
# "num_extra_target": 24 * 4,
"vis_i": "670",
# "num_context": 24 * 4,
"x_dim": 18,
"y_dim": 1,
"context_in_target": False,
# "output_size": 1,
"patience": 3,
'min_std': 0.005,
'nan_value': -99.9
}
return trial
+102 -60
View File
@@ -18,6 +18,7 @@ class NetTransformer(nn.Module):
super().__init__()
hparams = hparams_power(hparams)
self.hparams = hparams
self._min_std = hparams.min_std
hidden_out_size = self.hparams.hidden_out_size
enc_x_dim = self.hparams.x_dim + self.hparams.y_dim
@@ -50,11 +51,13 @@ class NetTransformer(nn.Module):
# norm=decoder_norm
# )
self.mean = nn.Linear(hidden_out_size, self.hparams.y_dim)
self.std = nn.Linear(hidden_out_size, self.hparams.y_dim)
self._use_lvar = 0
def forward(self, context_x, context_y, target_x, target_y=None):
device = next(self.parameters()).device
target_y_fake = (
torch.ones(context_y.shape).float().to(device) * self.hparams.nan_value
torch.ones(context_y.shape[0], target_x.shape[1], context_y.shape[2]).float().to(device) * self.hparams.nan_value
)
context = torch.cat([context_x, context_y], -1).detach()
target = torch.cat([target_x, target_y_fake], -1).detach()
@@ -76,69 +79,107 @@ class NetTransformer(nn.Module):
# print(outputs.shape, 'outputs')
# Seems to help a little, especially with extrapolating out of bounds
steps = target_y.shape[1]
mean = self.mean(outputs)
mean_target = mean[:, -steps:, :]
mean_context = mean[:, :-steps, :]
steps = context_y.shape[1]
mean = self.mean(outputs)[:, steps:, :]
log_sigma = self.std(outputs)[:, steps:, :]
# mean_target = mean[:, -steps:, :]
# mean_context = mean[:, :-steps, :]
loss = None
if self._use_lvar:
log_sigma = torch.clamp(
log_sigma, math.log(self._min_std), -math.log(self._min_std)
)
sigma = torch.exp(log_sigma)
else:
sigma = self._min_std + (1 - self._min_std) * F.softplus(log_sigma)
y_dist = torch.distributions.Normal(mean, sigma)
# Loss
loss_mse = loss_p = None
if target_y is not None:
y = torch.cat([context_y, target_y], 1)
y_mask = torch.isfinite(y) & (y != self.hparams.nan_value)
y[~y_mask] = 0
y = y.detach()
loss_mse = F.mse_loss(mean, target_y, reduction="none")
if self._use_lvar:
loss_p = -log_prob_sigma(target_y, mean, log_sigma)
else:
loss_p = -y_dist.log_prob(target_y).mean(-1)
if self.hparams["context_in_target"]:
loss_p[: context_x.size(1)] /= 100
loss_mse[: context_x.size(1)] /= 100
# # Don't catch loss on context window
# mean = mean[:, self.hparams.num_context:]
# log_sigma = log_sigma[:, self.hparams.num_context:]
loss_scale = 100
# loss = F.mse_loss(mean * loss_scale, y * loss_scale, reduction='none') / loss_scale
# Weight loss nearer to prediction time?
weight = (torch.arange(loss_p.shape[1]) + 1).float().to(device)[None, :]
loss_p = loss_p / torch.sqrt(weight) # We want to weight nearer stuff more
loss_target = (
F.mse_loss(
mean_target * loss_scale,
y[:, -steps:, :] * loss_scale,
reduction="none",
)
/ loss_scale
)
loss_context = (
F.mse_loss(
mean_context * loss_scale,
y[:, :-steps, :] * loss_scale,
reduction="none",
)
/ loss_scale
)
y_pred = y_dist.rsample if self.training else y_dist.loc
return (
y_pred,
dict(loss_p=loss_p.mean(), loss_mse=loss_mse.mean()),
dict(log_sigma=log_sigma, dist=y_dist),
)
# mean_target = mean[:, -steps:, :]
# mean_context = mean[:, :-steps, :]
y_mask_target = y_mask[:, -steps:, :].detach()
y_mask_context = y_mask[:, :-steps, :].detach()
# loss_target = loss[:, -steps:, :]
# loss_context = loss[:, :-steps, :]
# print(0, loss_context.sum(), loss_target.sum())
# loss = None
# if target_y is not None:
# y = torch.cat([context_y, target_y], 1)
# y_mask = torch.isfinite(y) & (y != self.hparams.nan_value)
# y[~y_mask] = 0
# y = y.detach()
weight = (
(torch.arange(loss_target.shape[1]) + 0.5)
.float()
.to(device)[None, :, None]
)
# weight /= weight.sum()
# print(1.0, loss_context.sum(), loss_target.sum())
loss_target = loss_target / torch.sqrt(
weight
) # We want to weight nearer stuff more
# print(1.5, loss_context.sum(), y_mask_context.sum(), loss_target.sum(), y_mask_target.sum(), (loss_context * y_mask_context).sum())
loss_context = (loss_context * y_mask_context.float()).sum() / (
y_mask_context.sum() + 1.0
)
loss_target = (loss_target * y_mask_target.float()).sum() / (
y_mask_target.sum() + 1.0
) # Mean over unmasked ones
# print(2, loss_context.sum(), loss_target.sum())
# loss_scale = 100
# # loss = F.mse_loss(mean * loss_scale, y * loss_scale, reduction='none') / loss_scale
# Perhaps predicting the past, as a secondary loss will help
loss = loss_context / 100.0 + loss_target
# loss_target = (
# F.mse_loss(
# mean_target * loss_scale,
# y[:, -steps:, :] * loss_scale,
# reduction="none",
# )
# / loss_scale
# )
# loss_context = (
# F.mse_loss(
# mean_context * loss_scale,
# y[:, :-steps, :] * loss_scale,
# reduction="none",
# )
# / loss_scale
# )
assert torch.isfinite(loss)
# y_mask_target = y_mask[:, -steps:, :].detach()
# y_mask_context = y_mask[:, :-steps, :].detach()
# # loss_target = loss[:, -steps:, :]
# # loss_context = loss[:, :-steps, :]
# # print(0, loss_context.sum(), loss_target.sum())
return mean_target, dict(loss=loss), dict()
# weight = (
# (torch.arange(loss_target.shape[1]) + 0.5)
# .float()
# .to(device)[None, :, None]
# )
# # weight /= weight.sum()
# # print(1.0, loss_context.sum(), loss_target.sum())
# loss_target = loss_target / torch.sqrt(
# weight
# ) # We want to weight nearer stuff more
# # print(1.5, loss_context.sum(), y_mask_context.sum(), loss_target.sum(), y_mask_target.sum(), (loss_context * y_mask_context).sum())
# loss_context = (loss_context * y_mask_context.float()).sum() / (
# y_mask_context.sum() + 1.0
# )
# loss_target = (loss_target * y_mask_target.float()).sum() / (
# y_mask_target.sum() + 1.0
# ) # Mean over unmasked ones
# # print(2, loss_context.sum(), loss_target.sum())
# # Perhaps predicting the past, as a secondary loss will help
# loss = loss_context / 100.0 + loss_target
# assert torch.isfinite(loss)
# return mean_target, dict(loss=loss), dict()
class PL_Transformer(PL_Seq2Seq):
@@ -146,10 +187,10 @@ class PL_Transformer(PL_Seq2Seq):
super().__init__(hparams, MODEL_CLS=MODEL_CLS, **kwargs)
DEFAULT_ARGS = {
"attention_dropout": 0.4151003234623061,
"hidden_out_size_power": 2.0,
"hidden_size_power": 2.0,
"learning_rate": 0.0026738884132767185,
"attention_dropout": 0.4,
"hidden_out_size_power": 7.0,
"hidden_size_power": 7.0,
"learning_rate": 0.003,
"nhead_power": 1.0,
"nlayers": 2,
}
@@ -186,9 +227,10 @@ class PL_Transformer(PL_Seq2Seq):
"vis_i": 670,
"x_dim": 6,
"y_dim": 1,
"nan_value": -99.9,
"context_in_target": False,
"patience": 3,
'min_std': 0.005,
"nan_value": -99.9,
}
[trial.set_user_attr(k, v) for k, v in user_attrs_default.items()]
[trial.set_user_attr(k, v) for k, v in user_attrs.items()]
@@ -36,11 +36,11 @@ from ..utils import hparams_power
class TransformerSeq2SeqNet(nn.Module):
def __init__(self, hparams, _min_std=0.05):
def __init__(self, hparams):
super().__init__()
hparams = hparams_power(hparams)
self.hparams = hparams
self._min_std = _min_std
self._min_std = hparams.min_std
hidden_out_size = self.hparams.hidden_out_size
self.enc_norm = BatchNormSequence(self.hparams.input_size)
@@ -187,5 +187,6 @@ class TransformerSeq2Seq_PL(PL_Seq2Seq):
"context_in_target": False,
"output_size": 1,
"patience": 3,
'min_std': 0.005,
}
return trial
+94 -498
View File
File diff suppressed because one or more lines are too long