mirror of
https://github.com/wassname/attentive-neural-processes.git
synced 2026-06-27 18:03:39 +08:00
misc
This commit is contained in:
@@ -143,6 +143,7 @@ class LSTMSeq2Seq_PL(PL_Seq2Seq):
|
||||
"learning_rate": 0.001,
|
||||
"lstm_layers": 4,
|
||||
'bidirectional': False
|
||||
|
||||
}
|
||||
|
||||
@staticmethod
|
||||
@@ -165,5 +166,6 @@ class LSTMSeq2Seq_PL(PL_Seq2Seq):
|
||||
"input_size_decoder": 17,
|
||||
"context_in_target": False,
|
||||
"output_size": 1,
|
||||
'min_std': 0.005,
|
||||
}
|
||||
return trial
|
||||
|
||||
@@ -36,7 +36,7 @@ class LSTMNet(nn.Module):
|
||||
self._min_std = _min_std
|
||||
|
||||
self.lstm1 = nn.LSTM(
|
||||
x_dim=self.hparams.x_dim,
|
||||
input_size=self.hparams.x_dim+self.hparams.y_dim,
|
||||
hidden_size=self.hparams.hidden_size,
|
||||
batch_first=True,
|
||||
num_layers=self.hparams.lstm_layers,
|
||||
@@ -48,29 +48,70 @@ class LSTMNet(nn.Module):
|
||||
)
|
||||
self.mean = nn.Linear(self.hidden_out_size, 1)
|
||||
self.std = nn.Linear(self.hidden_out_size, 1)
|
||||
self._use_lvar = 0
|
||||
|
||||
def forward(self, context_x, context_y, target_x, target_y=None):
|
||||
loss_scale = 1
|
||||
device = next(self.parameters()).device
|
||||
x = torch.cat([context_x, context_y], -1).detach()
|
||||
target_y_fake = (
|
||||
torch.ones(context_y.shape[0], target_x.shape[1], context_y.shape[2]).float().to(device) * self.hparams.nan_value
|
||||
)
|
||||
loss_scale = 1
|
||||
context = torch.cat([context_x, context_y], -1).detach()
|
||||
target = torch.cat([target_x, target_y_fake], -1).detach()
|
||||
x = torch.cat([context, target * 1], 1).detach()
|
||||
|
||||
outputs, (h_out, _) = self.lstm1(x)
|
||||
# outputs: [B, T, num_direction * H]
|
||||
y_pred = self.mean(outputs).squeeze(2)
|
||||
log_sigma = self.std(outputs).squeeze(2)
|
||||
|
||||
loss = None
|
||||
if target_y is not None:
|
||||
loss = (
|
||||
F.mse_loss(
|
||||
y_pred * loss_scale, y[:, -steps:, :] * loss_scale, reduction="none"
|
||||
)
|
||||
/ loss_scale
|
||||
steps = context_y.shape[1]
|
||||
mean = self.mean(outputs)[:, steps:, :]#.squeeze(2)
|
||||
log_sigma = self.std(outputs)[:, steps:,:] #.squeeze(2)
|
||||
|
||||
if self._use_lvar:
|
||||
log_sigma = torch.clamp(
|
||||
log_sigma, math.log(self._min_std), -math.log(self._min_std)
|
||||
)
|
||||
sigma = torch.exp(log_sigma)
|
||||
else:
|
||||
sigma = self._min_std + (1 - self._min_std) * F.softplus(log_sigma)
|
||||
y_dist = torch.distributions.Normal(mean, sigma)
|
||||
|
||||
assert torch.isfinite(loss)
|
||||
# Loss
|
||||
loss_mse = loss_p = None
|
||||
if target_y is not None:
|
||||
loss_mse = F.mse_loss(mean, target_y, reduction="none")
|
||||
if self._use_lvar:
|
||||
loss_p = -log_prob_sigma(target_y, mean, log_sigma)
|
||||
else:
|
||||
loss_p = -y_dist.log_prob(target_y).mean(-1)
|
||||
if self.hparams["context_in_target"]:
|
||||
loss_p[: context_x.size(1)] /= 100
|
||||
loss_mse[: context_x.size(1)] /= 100
|
||||
# # Don't catch loss on context window
|
||||
# mean = mean[:, self.hparams.num_context:]
|
||||
# log_sigma = log_sigma[:, self.hparams.num_context:]
|
||||
|
||||
return y_pred, dict(loss=loss), dict()
|
||||
# Weight loss nearer to prediction time?
|
||||
weight = (torch.arange(loss_p.shape[1]) + 1).float().to(device)[None, :]
|
||||
loss_p = loss_p / torch.sqrt(weight) # We want to weight nearer stuff more
|
||||
|
||||
y_pred = y_dist.rsample if self.training else y_dist.loc
|
||||
return (
|
||||
y_pred,
|
||||
dict(loss_p=loss_p.mean(), loss_mse=loss_mse.mean()),
|
||||
dict(log_sigma=log_sigma, dist=y_dist),
|
||||
)
|
||||
# loss = None
|
||||
# if target_y is not None:
|
||||
# loss = (
|
||||
# F.mse_loss(
|
||||
# y_pred * loss_scale, y[:, -steps:, :] * loss_scale, reduction="none"
|
||||
# )
|
||||
# / loss_scale
|
||||
# )
|
||||
|
||||
# assert torch.isfinite(loss)
|
||||
|
||||
# return y_pred, dict(loss=loss), dict()
|
||||
|
||||
|
||||
class LSTM_PL_STD(PL_Seq2Seq):
|
||||
@@ -79,7 +120,7 @@ class LSTM_PL_STD(PL_Seq2Seq):
|
||||
|
||||
DEFAULT_ARGS = {
|
||||
"bidirectional": False,
|
||||
"hidden_size_power": 4,
|
||||
"hidden_size_power": 5,
|
||||
"learning_rate": 0.001,
|
||||
"lstm_dropout": 0.39,
|
||||
"lstm_layers": 4,
|
||||
@@ -100,12 +141,12 @@ class LSTM_PL_STD(PL_Seq2Seq):
|
||||
"grad_clip": 40,
|
||||
"max_nb_epochs": 200,
|
||||
"num_workers": 4,
|
||||
# "num_extra_target": 24 * 4,
|
||||
"vis_i": "670",
|
||||
# "num_context": 24 * 4,
|
||||
"x_dim": 18,
|
||||
"y_dim": 1,
|
||||
"context_in_target": False,
|
||||
# "output_size": 1,
|
||||
"patience": 3,
|
||||
'min_std': 0.005,
|
||||
'nan_value': -99.9
|
||||
}
|
||||
return trial
|
||||
|
||||
@@ -18,6 +18,7 @@ class NetTransformer(nn.Module):
|
||||
super().__init__()
|
||||
hparams = hparams_power(hparams)
|
||||
self.hparams = hparams
|
||||
self._min_std = hparams.min_std
|
||||
|
||||
hidden_out_size = self.hparams.hidden_out_size
|
||||
enc_x_dim = self.hparams.x_dim + self.hparams.y_dim
|
||||
@@ -50,11 +51,13 @@ class NetTransformer(nn.Module):
|
||||
# norm=decoder_norm
|
||||
# )
|
||||
self.mean = nn.Linear(hidden_out_size, self.hparams.y_dim)
|
||||
self.std = nn.Linear(hidden_out_size, self.hparams.y_dim)
|
||||
self._use_lvar = 0
|
||||
|
||||
def forward(self, context_x, context_y, target_x, target_y=None):
|
||||
device = next(self.parameters()).device
|
||||
target_y_fake = (
|
||||
torch.ones(context_y.shape).float().to(device) * self.hparams.nan_value
|
||||
torch.ones(context_y.shape[0], target_x.shape[1], context_y.shape[2]).float().to(device) * self.hparams.nan_value
|
||||
)
|
||||
context = torch.cat([context_x, context_y], -1).detach()
|
||||
target = torch.cat([target_x, target_y_fake], -1).detach()
|
||||
@@ -76,69 +79,107 @@ class NetTransformer(nn.Module):
|
||||
# print(outputs.shape, 'outputs')
|
||||
|
||||
# Seems to help a little, especially with extrapolating out of bounds
|
||||
steps = target_y.shape[1]
|
||||
mean = self.mean(outputs)
|
||||
mean_target = mean[:, -steps:, :]
|
||||
mean_context = mean[:, :-steps, :]
|
||||
steps = context_y.shape[1]
|
||||
mean = self.mean(outputs)[:, steps:, :]
|
||||
log_sigma = self.std(outputs)[:, steps:, :]
|
||||
# mean_target = mean[:, -steps:, :]
|
||||
# mean_context = mean[:, :-steps, :]
|
||||
|
||||
loss = None
|
||||
if self._use_lvar:
|
||||
log_sigma = torch.clamp(
|
||||
log_sigma, math.log(self._min_std), -math.log(self._min_std)
|
||||
)
|
||||
sigma = torch.exp(log_sigma)
|
||||
else:
|
||||
sigma = self._min_std + (1 - self._min_std) * F.softplus(log_sigma)
|
||||
y_dist = torch.distributions.Normal(mean, sigma)
|
||||
|
||||
# Loss
|
||||
loss_mse = loss_p = None
|
||||
if target_y is not None:
|
||||
y = torch.cat([context_y, target_y], 1)
|
||||
y_mask = torch.isfinite(y) & (y != self.hparams.nan_value)
|
||||
y[~y_mask] = 0
|
||||
y = y.detach()
|
||||
loss_mse = F.mse_loss(mean, target_y, reduction="none")
|
||||
if self._use_lvar:
|
||||
loss_p = -log_prob_sigma(target_y, mean, log_sigma)
|
||||
else:
|
||||
loss_p = -y_dist.log_prob(target_y).mean(-1)
|
||||
if self.hparams["context_in_target"]:
|
||||
loss_p[: context_x.size(1)] /= 100
|
||||
loss_mse[: context_x.size(1)] /= 100
|
||||
# # Don't catch loss on context window
|
||||
# mean = mean[:, self.hparams.num_context:]
|
||||
# log_sigma = log_sigma[:, self.hparams.num_context:]
|
||||
|
||||
loss_scale = 100
|
||||
# loss = F.mse_loss(mean * loss_scale, y * loss_scale, reduction='none') / loss_scale
|
||||
# Weight loss nearer to prediction time?
|
||||
weight = (torch.arange(loss_p.shape[1]) + 1).float().to(device)[None, :]
|
||||
loss_p = loss_p / torch.sqrt(weight) # We want to weight nearer stuff more
|
||||
|
||||
loss_target = (
|
||||
F.mse_loss(
|
||||
mean_target * loss_scale,
|
||||
y[:, -steps:, :] * loss_scale,
|
||||
reduction="none",
|
||||
)
|
||||
/ loss_scale
|
||||
)
|
||||
loss_context = (
|
||||
F.mse_loss(
|
||||
mean_context * loss_scale,
|
||||
y[:, :-steps, :] * loss_scale,
|
||||
reduction="none",
|
||||
)
|
||||
/ loss_scale
|
||||
)
|
||||
y_pred = y_dist.rsample if self.training else y_dist.loc
|
||||
return (
|
||||
y_pred,
|
||||
dict(loss_p=loss_p.mean(), loss_mse=loss_mse.mean()),
|
||||
dict(log_sigma=log_sigma, dist=y_dist),
|
||||
)
|
||||
# mean_target = mean[:, -steps:, :]
|
||||
# mean_context = mean[:, :-steps, :]
|
||||
|
||||
y_mask_target = y_mask[:, -steps:, :].detach()
|
||||
y_mask_context = y_mask[:, :-steps, :].detach()
|
||||
# loss_target = loss[:, -steps:, :]
|
||||
# loss_context = loss[:, :-steps, :]
|
||||
# print(0, loss_context.sum(), loss_target.sum())
|
||||
# loss = None
|
||||
# if target_y is not None:
|
||||
# y = torch.cat([context_y, target_y], 1)
|
||||
# y_mask = torch.isfinite(y) & (y != self.hparams.nan_value)
|
||||
# y[~y_mask] = 0
|
||||
# y = y.detach()
|
||||
|
||||
weight = (
|
||||
(torch.arange(loss_target.shape[1]) + 0.5)
|
||||
.float()
|
||||
.to(device)[None, :, None]
|
||||
)
|
||||
# weight /= weight.sum()
|
||||
# print(1.0, loss_context.sum(), loss_target.sum())
|
||||
loss_target = loss_target / torch.sqrt(
|
||||
weight
|
||||
) # We want to weight nearer stuff more
|
||||
# print(1.5, loss_context.sum(), y_mask_context.sum(), loss_target.sum(), y_mask_target.sum(), (loss_context * y_mask_context).sum())
|
||||
loss_context = (loss_context * y_mask_context.float()).sum() / (
|
||||
y_mask_context.sum() + 1.0
|
||||
)
|
||||
loss_target = (loss_target * y_mask_target.float()).sum() / (
|
||||
y_mask_target.sum() + 1.0
|
||||
) # Mean over unmasked ones
|
||||
# print(2, loss_context.sum(), loss_target.sum())
|
||||
# loss_scale = 100
|
||||
# # loss = F.mse_loss(mean * loss_scale, y * loss_scale, reduction='none') / loss_scale
|
||||
|
||||
# Perhaps predicting the past, as a secondary loss will help
|
||||
loss = loss_context / 100.0 + loss_target
|
||||
# loss_target = (
|
||||
# F.mse_loss(
|
||||
# mean_target * loss_scale,
|
||||
# y[:, -steps:, :] * loss_scale,
|
||||
# reduction="none",
|
||||
# )
|
||||
# / loss_scale
|
||||
# )
|
||||
# loss_context = (
|
||||
# F.mse_loss(
|
||||
# mean_context * loss_scale,
|
||||
# y[:, :-steps, :] * loss_scale,
|
||||
# reduction="none",
|
||||
# )
|
||||
# / loss_scale
|
||||
# )
|
||||
|
||||
assert torch.isfinite(loss)
|
||||
# y_mask_target = y_mask[:, -steps:, :].detach()
|
||||
# y_mask_context = y_mask[:, :-steps, :].detach()
|
||||
# # loss_target = loss[:, -steps:, :]
|
||||
# # loss_context = loss[:, :-steps, :]
|
||||
# # print(0, loss_context.sum(), loss_target.sum())
|
||||
|
||||
return mean_target, dict(loss=loss), dict()
|
||||
# weight = (
|
||||
# (torch.arange(loss_target.shape[1]) + 0.5)
|
||||
# .float()
|
||||
# .to(device)[None, :, None]
|
||||
# )
|
||||
# # weight /= weight.sum()
|
||||
# # print(1.0, loss_context.sum(), loss_target.sum())
|
||||
# loss_target = loss_target / torch.sqrt(
|
||||
# weight
|
||||
# ) # We want to weight nearer stuff more
|
||||
# # print(1.5, loss_context.sum(), y_mask_context.sum(), loss_target.sum(), y_mask_target.sum(), (loss_context * y_mask_context).sum())
|
||||
# loss_context = (loss_context * y_mask_context.float()).sum() / (
|
||||
# y_mask_context.sum() + 1.0
|
||||
# )
|
||||
# loss_target = (loss_target * y_mask_target.float()).sum() / (
|
||||
# y_mask_target.sum() + 1.0
|
||||
# ) # Mean over unmasked ones
|
||||
# # print(2, loss_context.sum(), loss_target.sum())
|
||||
|
||||
# # Perhaps predicting the past, as a secondary loss will help
|
||||
# loss = loss_context / 100.0 + loss_target
|
||||
|
||||
# assert torch.isfinite(loss)
|
||||
|
||||
# return mean_target, dict(loss=loss), dict()
|
||||
|
||||
|
||||
class PL_Transformer(PL_Seq2Seq):
|
||||
@@ -146,10 +187,10 @@ class PL_Transformer(PL_Seq2Seq):
|
||||
super().__init__(hparams, MODEL_CLS=MODEL_CLS, **kwargs)
|
||||
|
||||
DEFAULT_ARGS = {
|
||||
"attention_dropout": 0.4151003234623061,
|
||||
"hidden_out_size_power": 2.0,
|
||||
"hidden_size_power": 2.0,
|
||||
"learning_rate": 0.0026738884132767185,
|
||||
"attention_dropout": 0.4,
|
||||
"hidden_out_size_power": 7.0,
|
||||
"hidden_size_power": 7.0,
|
||||
"learning_rate": 0.003,
|
||||
"nhead_power": 1.0,
|
||||
"nlayers": 2,
|
||||
}
|
||||
@@ -186,9 +227,10 @@ class PL_Transformer(PL_Seq2Seq):
|
||||
"vis_i": 670,
|
||||
"x_dim": 6,
|
||||
"y_dim": 1,
|
||||
"nan_value": -99.9,
|
||||
"context_in_target": False,
|
||||
"patience": 3,
|
||||
'min_std': 0.005,
|
||||
"nan_value": -99.9,
|
||||
}
|
||||
[trial.set_user_attr(k, v) for k, v in user_attrs_default.items()]
|
||||
[trial.set_user_attr(k, v) for k, v in user_attrs.items()]
|
||||
|
||||
@@ -36,11 +36,11 @@ from ..utils import hparams_power
|
||||
|
||||
|
||||
class TransformerSeq2SeqNet(nn.Module):
|
||||
def __init__(self, hparams, _min_std=0.05):
|
||||
def __init__(self, hparams):
|
||||
super().__init__()
|
||||
hparams = hparams_power(hparams)
|
||||
self.hparams = hparams
|
||||
self._min_std = _min_std
|
||||
self._min_std = hparams.min_std
|
||||
|
||||
hidden_out_size = self.hparams.hidden_out_size
|
||||
self.enc_norm = BatchNormSequence(self.hparams.input_size)
|
||||
@@ -187,5 +187,6 @@ class TransformerSeq2Seq_PL(PL_Seq2Seq):
|
||||
"context_in_target": False,
|
||||
"output_size": 1,
|
||||
"patience": 3,
|
||||
'min_std': 0.005,
|
||||
}
|
||||
return trial
|
||||
|
||||
+94
-498
File diff suppressed because one or more lines are too long
Reference in New Issue
Block a user