This commit is contained in:
wassname
2020-07-07 17:50:41 +08:00
parent f7e21ec6ef
commit ed31c5a3a3
2 changed files with 10787 additions and 12977 deletions
@@ -324,12 +324,15 @@ class NeuralProcess(nn.Module):
dist_prior, log_var_prior = self._latent_encoder(context_x, context_y)
if (target_y is not None) and self.training:
if (target_y is not None):
target_y2 = self.norm_y(target_y)
if self._use_rnn:
target_y2, _ = self._lstm_y(target_y2)
dist_post, log_var_post = self._latent_encoder(target_x, target_y2)
z = dist_post.rsample() if sample_latent else dist_post.loc
if self.training:
z = dist_post.rsample() if sample_latent else dist_post.loc
else:
z = dist_prior.rsample() if sample_latent else dist_prior.loc
else:
z = dist_prior.rsample() if sample_latent else dist_prior.loc
+10782 -12975
View File
File diff suppressed because one or more lines are too long