[RaySGD] Honor the use_gpu flag (#7942)

This commit is contained in:
David Chan
2020-04-08 20:20:09 -07:00
committed by GitHub
parent 44825d81e9
commit 6521e92a95
@@ -230,7 +230,7 @@ class TrainingOperator:
"""
features, target = batch
# Create non_blocking tensors for distributed training
if torch.cuda.is_available():
if self.use_gpu:
features = features.cuda(non_blocking=True)
target = target.cuda(non_blocking=True)
@@ -309,7 +309,7 @@ class TrainingOperator:
calculate averages.
"""
features, target = batch
if torch.cuda.is_available():
if self.use_gpu:
features = features.cuda(non_blocking=True)
target = target.cuda(non_blocking=True)
@@ -404,6 +404,11 @@ class TrainingOperator:
"""List of schedulers created by the ``scheduler_creator``."""
return self._schedulers
@property
def use_gpu(self):
"""Returns True if cuda is available and use_gpu is True."""
return self._use_gpu
@property
def use_fp16(self):
"""bool: Whether the model and optimizer have been FP16 enabled."""