mirror of
https://github.com/wassname/attentive-neural-processes.git
synced 2026-06-27 16:44:27 +08:00
fix leak
This commit is contained in:
@@ -324,12 +324,15 @@ class NeuralProcess(nn.Module):
|
||||
|
||||
dist_prior, log_var_prior = self._latent_encoder(context_x, context_y)
|
||||
|
||||
if (target_y is not None) and self.training:
|
||||
if (target_y is not None):
|
||||
target_y2 = self.norm_y(target_y)
|
||||
if self._use_rnn:
|
||||
target_y2, _ = self._lstm_y(target_y2)
|
||||
dist_post, log_var_post = self._latent_encoder(target_x, target_y2)
|
||||
z = dist_post.rsample() if sample_latent else dist_post.loc
|
||||
if self.training:
|
||||
z = dist_post.rsample() if sample_latent else dist_post.loc
|
||||
else:
|
||||
z = dist_prior.rsample() if sample_latent else dist_prior.loc
|
||||
else:
|
||||
z = dist_prior.rsample() if sample_latent else dist_prior.loc
|
||||
|
||||
|
||||
+10782
-12975
File diff suppressed because one or more lines are too long
Reference in New Issue
Block a user