mirror of
https://github.com/wassname/ray.git
synced 2026-06-28 12:10:40 +08:00
[RaySGD] Honor the use_gpu flag (#7942)
This commit is contained in:
@@ -230,7 +230,7 @@ class TrainingOperator:
|
||||
"""
|
||||
features, target = batch
|
||||
# Create non_blocking tensors for distributed training
|
||||
if torch.cuda.is_available():
|
||||
if self.use_gpu:
|
||||
features = features.cuda(non_blocking=True)
|
||||
target = target.cuda(non_blocking=True)
|
||||
|
||||
@@ -309,7 +309,7 @@ class TrainingOperator:
|
||||
calculate averages.
|
||||
"""
|
||||
features, target = batch
|
||||
if torch.cuda.is_available():
|
||||
if self.use_gpu:
|
||||
features = features.cuda(non_blocking=True)
|
||||
target = target.cuda(non_blocking=True)
|
||||
|
||||
@@ -404,6 +404,11 @@ class TrainingOperator:
|
||||
"""List of schedulers created by the ``scheduler_creator``."""
|
||||
return self._schedulers
|
||||
|
||||
@property
|
||||
def use_gpu(self):
|
||||
"""Returns True if cuda is available and use_gpu is True."""
|
||||
return self._use_gpu
|
||||
|
||||
@property
|
||||
def use_fp16(self):
|
||||
"""bool: Whether the model and optimizer have been FP16 enabled."""
|
||||
|
||||
Reference in New Issue
Block a user