diff --git a/pts/model/transformer/transformer_estimator.py b/pts/model/transformer/transformer_estimator.py index 4a9d7c4..4adc188 100644 --- a/pts/model/transformer/transformer_estimator.py +++ b/pts/model/transformer/transformer_estimator.py @@ -251,7 +251,7 @@ class TransformerEstimator(PyTorchEstimator): copy_parameters(trained_network, prediction_network) input_names = get_module_forward_input_names(prediction_network) - prediction_splitter = self._create_instance_splitter("test") + prediction_splitter = self.create_instance_splitter("test") return PyTorchPredictor( input_transform=transformation + prediction_splitter,