mirror of
https://github.com/wassname/ray.git
synced 2026-07-06 04:44:08 +08:00
[sgd] simplify cuda visible device setting (#8775)
This commit is contained in:
@@ -276,22 +276,28 @@ class LocalDistributedRunner(DistributedTorchRunner):
|
||||
super(LocalDistributedRunner, self).__init__(*args, **kwargs)
|
||||
|
||||
def _try_reserve_and_set_cuda(self):
|
||||
use_found_device = os.environ.get("CUDA_VISIBLE_DEVICES") is None \
|
||||
and torch.cuda.is_initialized()
|
||||
device = reserve_cuda_device()
|
||||
visible_devices = os.environ.get("CUDA_VISIBLE_DEVICES")
|
||||
reserved_device = reserve_cuda_device()
|
||||
# This needs to be set even if torch.cuda is already
|
||||
# initialized because the env var is used later when
|
||||
# starting the DDP setup.
|
||||
os.environ["CUDA_VISIBLE_DEVICES"] = device
|
||||
if use_found_device:
|
||||
os.environ["CUDA_VISIBLE_DEVICES"] = reserved_device
|
||||
if visible_devices:
|
||||
# We want to set the index on the visible devices list.
|
||||
if reserved_device not in visible_devices:
|
||||
raise RuntimeError(
|
||||
"TorchTrainer reserved a device {} that was not in the "
|
||||
"CUDA_VISIBLE_DEVICES {}. This may be because the "
|
||||
"Ray cluster is not set with the right env vars. "
|
||||
"If that is not the issue, please raise a "
|
||||
"Github issue.".format(reserved_device, visible_devices))
|
||||
devices = visible_devices.split(",")
|
||||
scoped_index = devices.index(reserved_device)
|
||||
self._set_cuda_device(str(scoped_index))
|
||||
else:
|
||||
# Once cuda is initialized, torch.device ignores the os.env
|
||||
# so we have to set the right actual device.
|
||||
self._set_cuda_device(device)
|
||||
else:
|
||||
# if CUDA is not initialized, we can set the os.env.
|
||||
# Even if initialized, we want to set the device to use BatchNorm.
|
||||
# and make Torch think it only sees 1 GPU.
|
||||
self._set_cuda_device("0")
|
||||
self._set_cuda_device(reserved_device)
|
||||
|
||||
def _set_cuda_device(self, device_str):
|
||||
"""Sets the CUDA device for this current local worker."""
|
||||
|
||||
Reference in New Issue
Block a user