move predictor network to device

This commit is contained in:
Dr. Kashif Rasul
2020-02-12 17:18:02 +01:00
parent 1f3c2ef943
commit 923ba8978a
@@ -158,7 +158,7 @@ class SimpleFeedForwardEstimator(PTSEstimator):
batch_normalization=self.batch_normalization,
mean_scaling=self.mean_scaling,
num_parallel_samples=self.num_parallel_samples,
)
).to(device)
copy_parameters(trained_network, prediction_network)