rllib: use pytorch's fn to see if gpu is available (#5890)

This commit is contained in:
Matthew A. Wright
2019-10-12 00:13:00 -07:00
committed by Eric Liang
parent 898652837c
commit 0110941de5
2 changed files with 2 additions and 6 deletions
+1 -3
View File
@@ -3,7 +3,6 @@ from __future__ import division
from __future__ import print_function
from gym.spaces import Tuple, Discrete, Dict
import os
import logging
import numpy as np
import torch as th
@@ -172,8 +171,7 @@ class QMixTorchPolicy(Policy):
self.has_env_global_state = False
self.has_action_mask = False
self.device = (th.device("cuda")
if bool(os.environ.get("CUDA_VISIBLE_DEVICES", None))
else th.device("cpu"))
if th.cuda.is_available() else th.device("cpu"))
agent_obs_space = obs_space.original_space.spaces[0]
if isinstance(agent_obs_space, Dict):
+1 -3
View File
@@ -3,7 +3,6 @@ from __future__ import division
from __future__ import print_function
import numpy as np
import os
try:
import torch
@@ -49,8 +48,7 @@ class TorchPolicy(Policy):
self.observation_space = observation_space
self.action_space = action_space
self.device = (torch.device("cuda")
if bool(os.environ.get("CUDA_VISIBLE_DEVICES", None))
else torch.device("cpu"))
if torch.cuda.is_available() else torch.device("cpu"))
self.model = model.to(self.device)
self._loss = loss
self._optimizer = self.optimizer()