diff --git a/python/ray/tune/utils/util.py b/python/ray/tune/utils/util.py index 688261fdb..02daa858f 100644 --- a/python/ray/tune/utils/util.py +++ b/python/ray/tune/utils/util.py @@ -462,27 +462,29 @@ def load_newest_checkpoint(dirpath: str, ckpt_pattern: str) -> dict: return checkpoint_state -def wait_for_gpu(gpu_id=None, gpu_memory_limit=0.1, retry=20): +def wait_for_gpu(gpu_id=None, + target_util=0.01, + retry=20, + gpu_memory_limit=None): """Checks if a given GPU has freed memory. Requires ``gputil`` to be installed: ``pip install gputil``. Args: - gpu_id (Optional[str]): GPU id to check. Must be found - within GPUtil.getGPUs(). If none, resorts to + gpu_id (Optional[Union[int, str]]): GPU id or uuid to check. + Must be found within GPUtil.getGPUs(). If none, resorts to the first item returned from `ray.get_gpu_ids()`. - gpu_memory_limit (float): If memory usage is below - this quantity, the check will break. + target_util (float): The utilization threshold to reach to unblock. + Set this to 0 to block until the GPU is completely free. retry (int): Number of times to check GPU limit. Sleeps 5 seconds between checks. + gpu_memory_limit (float): Deprecated. Returns: - bool - True if free. + bool: True if free. Raises: - RuntimeError - If GPUtil is not found, if no GPUs are detected + RuntimeError: If GPUtil is not found, if no GPUs are detected or if the check fails. Example: @@ -495,20 +497,43 @@ def wait_for_gpu(gpu_id=None, gpu_memory_limit=0.1, retry=20): tune.run(tune_func, resources_per_trial={"GPU": 1}, num_samples=10) """ + if gpu_memory_limit: + raise ValueError("'gpu_memory_limit' is deprecated. " + "Use 'target_util' instead.") if GPUtil is None: raise RuntimeError( "GPUtil must be installed if calling `wait_for_gpu`.") - if not gpu_id: + if gpu_id is None: gpu_id_list = ray.get_gpu_ids() if not gpu_id_list: raise RuntimeError(f"No GPU ids found from {ray.get_gpu_ids()}. " "Did you set Tune resources correctly?") gpu_id = gpu_id_list[0] - gpu_object = GPUtil.getGPUs()[gpu_id] + + if isinstance(gpu_id, int): + list_gpu_ids = [g.id for g in GPUtil.getGPUs()] + if gpu_id not in list_gpu_ids: + raise ValueError( + f"{gpu_id} (int) not found in GPU ids: {list_gpu_ids}. " + "wait_for_gpu takes either int (gpu id) or str (gpu uuid).") + elif isinstance(gpu_id, str): + list_uuids = [g.uuid for g in GPUtil.getGPUs()] + if gpu_id not in list_uuids: + raise ValueError( + f"{gpu_id} (str) not found in GPU uuids: {list_uuids}. " + "wait_for_gpu takes either int (gpu id) or str (gpu uuid).") + else: + raise ValueError(f"gpu_id must be int or str -- got ({type(gpu_id)})") + for i in range(int(retry)): - if gpu_object.memoryUsed > gpu_memory_limit: - logger.info(f"Waiting for GPU {gpu_id} memory to free. " - f"Mem: {gpu_object.memoryUsed:0.3f}") + if isinstance(gpu_id, int): + gpu_object = [g for g in GPUtil.getGPUs() if g.id == gpu_id][0] + else: + gpu_object = [g for g in GPUtil.getGPUs() if g.uuid == gpu_id][0] + + if gpu_object.memoryUtil > target_util: + logger.info(f"Waiting for GPU util to reach {target_util}. " + f"Util: {gpu_object.memoryUtil:0.3f}") time.sleep(5) else: return True