ops we should sample latent space

This commit is contained in:
wassname
2020-04-25 10:18:43 +08:00
parent b37bf7f7ac
commit fae53ab515
@@ -303,7 +303,10 @@ class NeuralProcess(nn.Module):
self._use_deterministic_path = use_deterministic_path
self._use_lvar = use_lvar
def forward(self, context_x, context_y, target_x, target_y=None):
def forward(self, context_x, context_y, target_x, target_y=None, sample_latent=None):
if sample_latent is None:
sample_latent = self.training
device = next(self.parameters()).device
# https://stackoverflow.com/a/46772183/221742
@@ -318,7 +321,6 @@ class NeuralProcess(nn.Module):
context_x, _ = self._lstm_x(context_x)
context_y, _ = self._lstm_y(context_y)
dist_prior, log_var_prior = self._latent_encoder(context_x, context_y)
if target_y is not None:
@@ -326,9 +328,9 @@ class NeuralProcess(nn.Module):
if self._use_rnn:
target_y2, _ = self._lstm_y(target_y2)
dist_post, log_var_post = self._latent_encoder(target_x, target_y2)
z = dist_post.loc
z = dist_post.rsample() if sample_latent else dist_post.loc
else:
z = dist_prior.loc
z = dist_prior.rsample() if sample_latent else dist_prior.loc
num_targets = target_x.size(1)
z = z.unsqueeze(1).repeat(1, num_targets, 1) # [B, T_target, H]