diff --git a/neural_processes/models/neural_process/model.py b/neural_processes/models/neural_process/model.py index 09cc868..f379042 100644 --- a/neural_processes/models/neural_process/model.py +++ b/neural_processes/models/neural_process/model.py @@ -303,7 +303,10 @@ class NeuralProcess(nn.Module): self._use_deterministic_path = use_deterministic_path self._use_lvar = use_lvar - def forward(self, context_x, context_y, target_x, target_y=None): + def forward(self, context_x, context_y, target_x, target_y=None, sample_latent=None): + if sample_latent is None: + sample_latent = self.training + device = next(self.parameters()).device # https://stackoverflow.com/a/46772183/221742 @@ -318,7 +321,6 @@ class NeuralProcess(nn.Module): context_x, _ = self._lstm_x(context_x) context_y, _ = self._lstm_y(context_y) - dist_prior, log_var_prior = self._latent_encoder(context_x, context_y) if target_y is not None: @@ -326,9 +328,9 @@ class NeuralProcess(nn.Module): if self._use_rnn: target_y2, _ = self._lstm_y(target_y2) dist_post, log_var_post = self._latent_encoder(target_x, target_y2) - z = dist_post.loc + z = dist_post.rsample() if sample_latent else dist_post.loc else: - z = dist_prior.loc + z = dist_prior.rsample() if sample_latent else dist_prior.loc num_targets = target_x.size(1) z = z.unsqueeze(1).repeat(1, num_targets, 1) # [B, T_target, H]