mirror of
https://github.com/wassname/ray.git
synced 2026-06-27 21:23:10 +08:00
[windows] Improve GPU detection (#9300)
Co-authored-by: Mehrdad <noreply@github.com>
This commit is contained in:
@@ -3,6 +3,8 @@ from collections import namedtuple
|
||||
import logging
|
||||
import multiprocessing
|
||||
import os
|
||||
import subprocess
|
||||
import sys
|
||||
|
||||
import ray
|
||||
import ray.ray_constants as ray_constants
|
||||
@@ -229,12 +231,23 @@ class ResourceSpec(
|
||||
def _autodetect_num_gpus():
|
||||
"""Attempt to detect the number of GPUs on this machine.
|
||||
|
||||
TODO(rkn): This currently assumes Nvidia GPUs and Linux.
|
||||
TODO(rkn): This currently assumes NVIDIA GPUs on Linux.
|
||||
TODO(mehrdadn): This currently does not work on macOS.
|
||||
TODO(mehrdadn): Use a better mechanism for Windows.
|
||||
|
||||
Possibly useful: tensorflow.config.list_physical_devices()
|
||||
|
||||
Returns:
|
||||
The number of GPUs if any were detected, otherwise 0.
|
||||
"""
|
||||
proc_gpus_path = "/proc/driver/nvidia/gpus"
|
||||
if os.path.isdir(proc_gpus_path):
|
||||
return len(os.listdir(proc_gpus_path))
|
||||
return 0
|
||||
result = 0
|
||||
if sys.platform.startswith("linux"):
|
||||
proc_gpus_path = "/proc/driver/nvidia/gpus"
|
||||
if os.path.isdir(proc_gpus_path):
|
||||
result = len(os.listdir(proc_gpus_path))
|
||||
elif sys.platform == "win32":
|
||||
props = "AdapterCompatibility"
|
||||
cmdargs = ["WMIC", "PATH", "Win32_VideoController", "GET", props]
|
||||
lines = subprocess.check_output(cmdargs).splitlines()[1:]
|
||||
result = len([l.rstrip() for l in lines if l.startswith(b"NVIDIA")])
|
||||
return result
|
||||
|
||||
Reference in New Issue
Block a user