mirror of
https://github.com/wassname/ray.git
synced 2026-07-04 13:57:20 +08:00
rllib: use pytorch's fn to see if gpu is available (#5890)
This commit is contained in:
committed by
Eric Liang
parent
898652837c
commit
0110941de5
@@ -3,7 +3,6 @@ from __future__ import division
|
||||
from __future__ import print_function
|
||||
|
||||
from gym.spaces import Tuple, Discrete, Dict
|
||||
import os
|
||||
import logging
|
||||
import numpy as np
|
||||
import torch as th
|
||||
@@ -172,8 +171,7 @@ class QMixTorchPolicy(Policy):
|
||||
self.has_env_global_state = False
|
||||
self.has_action_mask = False
|
||||
self.device = (th.device("cuda")
|
||||
if bool(os.environ.get("CUDA_VISIBLE_DEVICES", None))
|
||||
else th.device("cpu"))
|
||||
if th.cuda.is_available() else th.device("cpu"))
|
||||
|
||||
agent_obs_space = obs_space.original_space.spaces[0]
|
||||
if isinstance(agent_obs_space, Dict):
|
||||
|
||||
@@ -3,7 +3,6 @@ from __future__ import division
|
||||
from __future__ import print_function
|
||||
|
||||
import numpy as np
|
||||
import os
|
||||
|
||||
try:
|
||||
import torch
|
||||
@@ -49,8 +48,7 @@ class TorchPolicy(Policy):
|
||||
self.observation_space = observation_space
|
||||
self.action_space = action_space
|
||||
self.device = (torch.device("cuda")
|
||||
if bool(os.environ.get("CUDA_VISIBLE_DEVICES", None))
|
||||
else torch.device("cpu"))
|
||||
if torch.cuda.is_available() else torch.device("cpu"))
|
||||
self.model = model.to(self.device)
|
||||
self._loss = loss
|
||||
self._optimizer = self.optimizer()
|
||||
|
||||
Reference in New Issue
Block a user