mirror of
https://github.com/wassname/pytorch-ts.git
synced 2026-06-27 18:06:19 +08:00
fixed typo
This commit is contained in:
@@ -251,7 +251,7 @@ class TransformerEstimator(PyTorchEstimator):
|
||||
|
||||
copy_parameters(trained_network, prediction_network)
|
||||
input_names = get_module_forward_input_names(prediction_network)
|
||||
prediction_splitter = self._create_instance_splitter("test")
|
||||
prediction_splitter = self.create_instance_splitter("test")
|
||||
|
||||
return PyTorchPredictor(
|
||||
input_transform=transformation + prediction_splitter,
|
||||
|
||||
Reference in New Issue
Block a user