mirror of
https://github.com/wassname/ray.git
synced 2026-06-29 14:55:50 +08:00
[core] Return all GPUs for Local Mode (#9108)
This commit is contained in:
@@ -661,6 +661,26 @@ def test_specific_gpus(save_gpu_ids_shutdown_only):
|
||||
ray.get([g.remote() for _ in range(100)])
|
||||
|
||||
|
||||
def test_local_mode_gpus(save_gpu_ids_shutdown_only):
|
||||
allowed_gpu_ids = [4, 5, 6, 7, 8]
|
||||
os.environ["CUDA_VISIBLE_DEVICES"] = ",".join(
|
||||
[str(i) for i in allowed_gpu_ids])
|
||||
|
||||
from importlib import reload
|
||||
reload(ray.worker)
|
||||
|
||||
ray.init(num_gpus=3, local_mode=True)
|
||||
|
||||
@ray.remote
|
||||
def f():
|
||||
gpu_ids = ray.get_gpu_ids()
|
||||
assert len(gpu_ids) == 3
|
||||
for gpu in gpu_ids:
|
||||
assert gpu in allowed_gpu_ids
|
||||
|
||||
ray.get([f.remote() for _ in range(100)])
|
||||
|
||||
|
||||
def test_blocking_tasks(ray_start_regular):
|
||||
@ray.remote
|
||||
def f(i, j):
|
||||
|
||||
@@ -395,6 +395,10 @@ def get_gpu_ids():
|
||||
assigned_ids = [
|
||||
global_worker.original_gpu_ids[gpu_id] for gpu_id in assigned_ids
|
||||
]
|
||||
# Give all GPUs in local_mode.
|
||||
if global_worker.mode == LOCAL_MODE:
|
||||
max_gpus = global_worker.node.get_resource_spec().num_gpus
|
||||
return global_worker.original_gpu_ids[:max_gpus]
|
||||
|
||||
return assigned_ids
|
||||
|
||||
|
||||
Reference in New Issue
Block a user