mirror of
https://github.com/wassname/attentive-neural-processes.git
synced 2026-06-27 18:03:39 +08:00
ops we should sample latent space
This commit is contained in:
@@ -303,7 +303,10 @@ class NeuralProcess(nn.Module):
|
||||
self._use_deterministic_path = use_deterministic_path
|
||||
self._use_lvar = use_lvar
|
||||
|
||||
def forward(self, context_x, context_y, target_x, target_y=None):
|
||||
def forward(self, context_x, context_y, target_x, target_y=None, sample_latent=None):
|
||||
if sample_latent is None:
|
||||
sample_latent = self.training
|
||||
|
||||
device = next(self.parameters()).device
|
||||
|
||||
# https://stackoverflow.com/a/46772183/221742
|
||||
@@ -318,7 +321,6 @@ class NeuralProcess(nn.Module):
|
||||
context_x, _ = self._lstm_x(context_x)
|
||||
context_y, _ = self._lstm_y(context_y)
|
||||
|
||||
|
||||
dist_prior, log_var_prior = self._latent_encoder(context_x, context_y)
|
||||
|
||||
if target_y is not None:
|
||||
@@ -326,9 +328,9 @@ class NeuralProcess(nn.Module):
|
||||
if self._use_rnn:
|
||||
target_y2, _ = self._lstm_y(target_y2)
|
||||
dist_post, log_var_post = self._latent_encoder(target_x, target_y2)
|
||||
z = dist_post.loc
|
||||
z = dist_post.rsample() if sample_latent else dist_post.loc
|
||||
else:
|
||||
z = dist_prior.loc
|
||||
z = dist_prior.rsample() if sample_latent else dist_prior.loc
|
||||
|
||||
num_targets = target_x.size(1)
|
||||
z = z.unsqueeze(1).repeat(1, num_targets, 1) # [B, T_target, H]
|
||||
|
||||
Reference in New Issue
Block a user