diff --git a/python/ray/util/sgd/torch/training_operator.py b/python/ray/util/sgd/torch/training_operator.py index 85ba958c0..aad6db064 100644 --- a/python/ray/util/sgd/torch/training_operator.py +++ b/python/ray/util/sgd/torch/training_operator.py @@ -257,8 +257,8 @@ class TrainingOperator: if self.use_fp16 and amp: logger.debug("Setting up Apex.") - self._models, self._optimizers = amp.initialize( - self._models, self._optimizers, **self._apex_args) + self._original_models, self._optimizers = amp.initialize( + self._original_models, self._optimizers, **self._apex_args) self._amp = amp if self._wrap_ddp: