diff --git a/neural_processes/models/neural_process/model.py b/neural_processes/models/neural_process/model.py index ae83378..c1d00bc 100644 --- a/neural_processes/models/neural_process/model.py +++ b/neural_processes/models/neural_process/model.py @@ -324,7 +324,7 @@ class NeuralProcess(nn.Module): dist_prior, log_var_prior = self._latent_encoder(context_x, context_y) - if target_y is not None: + if (target_y is not None) and self.training: target_y2 = self.norm_y(target_y) if self._use_rnn: target_y2, _ = self._lstm_y(target_y2)