diff --git a/pts/model/simple_feedforward/simple_feedforward_estimator.py b/pts/model/simple_feedforward/simple_feedforward_estimator.py index 50adda2..87c6b22 100644 --- a/pts/model/simple_feedforward/simple_feedforward_estimator.py +++ b/pts/model/simple_feedforward/simple_feedforward_estimator.py @@ -158,7 +158,7 @@ class SimpleFeedForwardEstimator(PTSEstimator): batch_normalization=self.batch_normalization, mean_scaling=self.mean_scaling, num_parallel_samples=self.num_parallel_samples, - ) + ).to(device) copy_parameters(trained_network, prediction_network)