diff --git a/python/ray/util/sgd/tests/test_torch_runner.py b/python/ray/util/sgd/tests/test_torch_runner.py index 20257526c..3a1065aad 100644 --- a/python/ray/util/sgd/tests/test_torch_runner.py +++ b/python/ray/util/sgd/tests/test_torch_runner.py @@ -1,10 +1,14 @@ import numpy as np +import os import torch import torch.nn as nn import unittest -from unittest.mock import MagicMock +from unittest.mock import MagicMock, patch +import ray from ray.util.sgd.torch.training_operator import TrainingOperator +from ray.util.sgd.torch.distributed_torch_runner import ( + LocalDistributedRunner, clear_dummy_actor) from ray.util.sgd.torch.torch_runner import TorchRunner @@ -170,3 +174,81 @@ class TestTorchRunner(unittest.TestCase): with self.assertRaises(ValueError): runner.setup() + + +class TestLocalDistributedRunner(unittest.TestCase): + def setUp(self): + os.environ.pop("CUDA_VISIBLE_DEVICES", None) + ray.init(num_gpus=4) + + def tearDown(self): + clear_dummy_actor() + ray.shutdown() + + def _testWithInitialized(self, init_mock): + mock_runner = MagicMock() + mock_runner._set_cuda_device = MagicMock() + preset_devices = os.environ.get("CUDA_VISIBLE_DEVICES") + + LocalDistributedRunner._try_reserve_and_set_cuda(mock_runner) + + self.assertTrue(mock_runner._set_cuda_device.called) + local_device = mock_runner._set_cuda_device.call_args[0][0] + env_set_device = os.environ["CUDA_VISIBLE_DEVICES"] + self.assertEquals(len(env_set_device), 1) + + if preset_devices: + self.assertIn(env_set_device, preset_devices.split(",")) + self.assertEquals(local_device, "0") + else: + self.assertEquals(local_device, env_set_device) + + def testNoVisibleWithInitialized(self): + with patch("torch.cuda.is_initialized") as init_mock: + init_mock.return_value = True + self._testWithInitialized(init_mock) + + def test2VisibleWithInitialized(self): + os.environ["CUDA_VISIBLE_DEVICES"] = "2,3" + with patch("torch.cuda.is_initialized") as init_mock: + init_mock.return_value = True + self._testWithInitialized(init_mock) + + def test1VisibleWithInitialized(self): + os.environ["CUDA_VISIBLE_DEVICES"] = "0" + with patch("torch.cuda.is_initialized") as init_mock: + init_mock.return_value = True + self._testWithInitialized(init_mock) + + def _testNotInitialized(self, init_mock): + mock_runner = MagicMock() + mock_runner._set_cuda_device = MagicMock() + LocalDistributedRunner._try_reserve_and_set_cuda(mock_runner) + mock_runner._set_cuda_device.assert_called_with("0") + self.assertEquals(len(os.environ["CUDA_VISIBLE_DEVICES"]), 1) + + def testNoVisibleNotInitialized(self): + with patch("torch.cuda.is_initialized") as init_mock: + init_mock.return_value = False + self._testNotInitialized(init_mock) + + def test2VisibleNotInitialized(self): + os.environ["CUDA_VISIBLE_DEVICES"] = "2,3" + with patch("torch.cuda.is_initialized") as init_mock: + init_mock.return_value = False + self._testNotInitialized(init_mock) + + def test1VisibleNotInitialized(self): + os.environ["CUDA_VISIBLE_DEVICES"] = "0" + with patch("torch.cuda.is_initialized") as init_mock: + init_mock.return_value = False + self._testNotInitialized(init_mock) + + @patch("torch.cuda.set_device") + def testSetDevice(self, set_mock): + mock_runner = MagicMock() + mock_runner._is_set = False + LocalDistributedRunner._set_cuda_device(mock_runner, "123") + self.assertEquals(mock_runner.local_device, "123") + self.assertTrue(set_mock.called) + set_mock.assert_called_with(123) diff --git a/python/ray/util/sgd/torch/distributed_torch_runner.py b/python/ray/util/sgd/torch/distributed_torch_runner.py index abf2f0d0e..b73bb3933 100644 --- a/python/ray/util/sgd/torch/distributed_torch_runner.py +++ b/python/ray/util/sgd/torch/distributed_torch_runner.py @@ -193,6 +193,62 @@ class _DummyActor: _dummy_actor = None +def clear_dummy_actor(): + global _dummy_actor + if _dummy_actor: + try: + _dummy_actor.__ray_terminate__.remote() + except Exception as exc: + logger.info("Tried to clear dummy actor: %s", str(exc)) + + _dummy_actor = None + + +def reserve_cuda_device(retries=20): + ip = ray.services.get_node_ip_address() + reserved_device = None + + cuda_devices = os.environ.get("CUDA_VISIBLE_DEVICES") + cuda_device_set = {} + match_devices = bool(cuda_devices) + if match_devices: + logger.debug("Found set devices: {}".format(cuda_devices)) + assert isinstance(cuda_devices, str) + cuda_device_set = set(cuda_devices.split(",")) + + global _dummy_actor + unused_actors = [] + + success = False + for _ in range(retries): + if _dummy_actor is None: + _dummy_actor = ray.remote( + num_gpus=1, + resources={"node:" + ip: 0.1})(_DummyActor).remote() + + reserved_device = ray.get(_dummy_actor.cuda_devices.remote()) + + if match_devices and reserved_device not in cuda_device_set: + logger.debug("Device %s not in list of visible devices %s", + reserved_device, cuda_device_set) + unused_actors.append(_dummy_actor) + _dummy_actor = None + else: + logger.debug("Found matching device %s", reserved_device) + success = True + for actor in unused_actors: + actor.__ray_terminate__.remote() + break + + if not success: + raise RuntimeError( + "Unable to reserve the set CUDA VISIBLE DEVICES on Ray. Please " + "make sure that Ray has access to all the visible devices: " + "{}".format(os.environ.get("CUDA_VISIBLE_DEVICES"))) + + return reserved_device + + class LocalDistributedRunner(DistributedTorchRunner): """A wrapper for running a distributed Runner on the driver. @@ -202,43 +258,55 @@ class LocalDistributedRunner(DistributedTorchRunner): """ def __init__(self, *args, num_cpus=None, num_gpus=None, **kwargs): - ip = ray.services.get_node_ip_address() # Reserve a local GPU or CPU for the local worker # TODO: we should make sure this NEVER dies. self.local_device = "0" - global _dummy_actor - if not self.is_actor(): - if _dummy_actor is None: - _dummy_actor = ray.remote( - num_cpus=num_cpus, - num_gpus=num_gpus, - resources={"node:" + ip: 0.1})(_DummyActor).remote() + self._is_set = False + if num_gpus: + assert num_gpus == 1, "Does not support multi-gpu workers" - self.local_device = ray.get(_dummy_actor.cuda_devices.remote()) - - # This is a pretty annoying workaround. To enable SyncBatchNorm, - # we need to signify that we are using only 1 CUDA device (via - # the DDP constructor). - # However, on the local worker, we have to set the - # CUDA_VISIBLE_DEVICES at runtime rather at process start. - - # You can only call setdevice(int > 0) after you've interacted with - # torch.cuda. But you can't guarantee that you _haven't_ interacted - # with it (user can do arbitrary things), so we force an - # interaction. - _init_cuda_context() - os.environ["CUDA_VISIBLE_DEVICES"] = self.local_device - - if self.local_device: - try: - torch.cuda.set_device(int(self.local_device)) - except RuntimeError: - logger.error("This happens if cuda is not initialized.") - raise + if not self.is_actor() and num_gpus > 0: + self._try_reserve_and_set_cuda() super(LocalDistributedRunner, self).__init__(*args, **kwargs) + def _try_reserve_and_set_cuda(self): + use_found_device = os.environ.get("CUDA_VISIBLE_DEVICES") is None \ + and torch.cuda.is_initialized() + device = reserve_cuda_device() + # This needs to be set even if torch.cuda is already + # initialized because the env var is used later when + # starting the DDP setup. + os.environ["CUDA_VISIBLE_DEVICES"] = device + if use_found_device: + # Once cuda is initialized, torch.device ignores the os.env + # so we have to set the right actual device. + self._set_cuda_device(device) + else: + # if CUDA is not initialized, we can set the os.env. + # Even if initialized, we want to set the device to use BatchNorm. + # and make Torch think it only sees 1 GPU. + self._set_cuda_device("0") + + def _set_cuda_device(self, device_str): + """Sets the CUDA device for this current local worker.""" + if self._is_set: + raise RuntimeError("CUDA devices already set.") + self._is_set = True + + # This is idempotent. We need to call it + # before we call 'set_device'. + _init_cuda_context() + assert isinstance(device_str, str) + self.local_device = device_str + logger.debug("Setting local device: %s", self.local_device) + try: + torch.cuda.set_device(int(self.local_device)) + except RuntimeError: + logger.error("Failed to set local device.") + raise + def get_device_ids(self): return [int(self.local_device)]