diff --git a/rllib/agents/qmix/qmix_policy.py b/rllib/agents/qmix/qmix_policy.py index 05a348c9b..407e5e821 100644 --- a/rllib/agents/qmix/qmix_policy.py +++ b/rllib/agents/qmix/qmix_policy.py @@ -3,7 +3,6 @@ from __future__ import division from __future__ import print_function from gym.spaces import Tuple, Discrete, Dict -import os import logging import numpy as np import torch as th @@ -172,8 +171,7 @@ class QMixTorchPolicy(Policy): self.has_env_global_state = False self.has_action_mask = False self.device = (th.device("cuda") - if bool(os.environ.get("CUDA_VISIBLE_DEVICES", None)) - else th.device("cpu")) + if th.cuda.is_available() else th.device("cpu")) agent_obs_space = obs_space.original_space.spaces[0] if isinstance(agent_obs_space, Dict): diff --git a/rllib/policy/torch_policy.py b/rllib/policy/torch_policy.py index ea81e4e29..8278d7d28 100644 --- a/rllib/policy/torch_policy.py +++ b/rllib/policy/torch_policy.py @@ -3,7 +3,6 @@ from __future__ import division from __future__ import print_function import numpy as np -import os try: import torch @@ -49,8 +48,7 @@ class TorchPolicy(Policy): self.observation_space = observation_space self.action_space = action_space self.device = (torch.device("cuda") - if bool(os.environ.get("CUDA_VISIBLE_DEVICES", None)) - else torch.device("cpu")) + if torch.cuda.is_available() else torch.device("cpu")) self.model = model.to(self.device) self._loss = loss self._optimizer = self.optimizer()