diff --git a/python/ray/util/sgd/torch/training_operator.py b/python/ray/util/sgd/torch/training_operator.py index db440d93e..4cade28c8 100644 --- a/python/ray/util/sgd/torch/training_operator.py +++ b/python/ray/util/sgd/torch/training_operator.py @@ -230,7 +230,7 @@ class TrainingOperator: """ features, target = batch # Create non_blocking tensors for distributed training - if torch.cuda.is_available(): + if self.use_gpu: features = features.cuda(non_blocking=True) target = target.cuda(non_blocking=True) @@ -309,7 +309,7 @@ class TrainingOperator: calculate averages. """ features, target = batch - if torch.cuda.is_available(): + if self.use_gpu: features = features.cuda(non_blocking=True) target = target.cuda(non_blocking=True) @@ -404,6 +404,11 @@ class TrainingOperator: """List of schedulers created by the ``scheduler_creator``.""" return self._schedulers + @property + def use_gpu(self): + """Returns True if cuda is available and use_gpu is True.""" + return self._use_gpu + @property def use_fp16(self): """bool: Whether the model and optimizer have been FP16 enabled."""