From 6521e92a95dfa5d0212a571bcc4bcb35b5e1e2d6 Mon Sep 17 00:00:00 2001 From: David Chan Date: Wed, 8 Apr 2020 20:20:09 -0700 Subject: [PATCH] [RaySGD] Honor the use_gpu flag (#7942) --- python/ray/util/sgd/torch/training_operator.py | 9 +++++++-- 1 file changed, 7 insertions(+), 2 deletions(-) diff --git a/python/ray/util/sgd/torch/training_operator.py b/python/ray/util/sgd/torch/training_operator.py index db440d93e..4cade28c8 100644 --- a/python/ray/util/sgd/torch/training_operator.py +++ b/python/ray/util/sgd/torch/training_operator.py @@ -230,7 +230,7 @@ class TrainingOperator: """ features, target = batch # Create non_blocking tensors for distributed training - if torch.cuda.is_available(): + if self.use_gpu: features = features.cuda(non_blocking=True) target = target.cuda(non_blocking=True) @@ -309,7 +309,7 @@ class TrainingOperator: calculate averages. """ features, target = batch - if torch.cuda.is_available(): + if self.use_gpu: features = features.cuda(non_blocking=True) target = target.cuda(non_blocking=True) @@ -404,6 +404,11 @@ class TrainingOperator: """List of schedulers created by the ``scheduler_creator``.""" return self._schedulers + @property + def use_gpu(self): + """Returns True if cuda is available and use_gpu is True.""" + return self._use_gpu + @property def use_fp16(self): """bool: Whether the model and optimizer have been FP16 enabled."""