From 5ad6db0c793d38422ed82a4e2f9a0456ee5760f5 Mon Sep 17 00:00:00 2001 From: Traun Leyden Date: Thu, 1 Feb 2024 15:47:14 +0100 Subject: [PATCH] Fixes #96 by handling RepositoryNotFoundError (#97) * Fixes #96 by handling RepositoryNotFoundError * Update src/alignment/model_utils.py Co-authored-by: lewtun * Remove redundant code * Add unit test * Reformat file * make style --------- Co-authored-by: lewtun --- src/alignment/model_utils.py | 3 ++- tests/test_model_utils.py | 16 +++++++++++++++- 2 files changed, 17 insertions(+), 2 deletions(-) diff --git a/src/alignment/model_utils.py b/src/alignment/model_utils.py index d538fb2..7126bdc 100644 --- a/src/alignment/model_utils.py +++ b/src/alignment/model_utils.py @@ -22,6 +22,7 @@ from transformers.trainer_utils import get_last_checkpoint from accelerate import Accelerator from huggingface_hub import list_repo_files +from huggingface_hub.utils._errors import RepositoryNotFoundError from huggingface_hub.utils._validators import HFValidationError from peft import LoraConfig, PeftConfig @@ -106,7 +107,7 @@ def is_adapter_model(model_name_or_path: str, revision: str = "main") -> bool: try: # Try first if model on a Hub repo repo_files = list_repo_files(model_name_or_path, revision=revision) - except HFValidationError: + except (HFValidationError, RepositoryNotFoundError): # If not, check local repo repo_files = os.listdir(model_name_or_path) return "adapter_model.safetensors" in repo_files or "adapter_model.bin" in repo_files diff --git a/tests/test_model_utils.py b/tests/test_model_utils.py index 688b3bd..ce71993 100644 --- a/tests/test_model_utils.py +++ b/tests/test_model_utils.py @@ -17,7 +17,14 @@ import unittest import torch from transformers import AutoTokenizer -from alignment import DataArguments, ModelArguments, get_peft_config, get_quantization_config, get_tokenizer +from alignment import ( + DataArguments, + ModelArguments, + get_peft_config, + get_quantization_config, + get_tokenizer, + is_adapter_model, +) from alignment.data import DEFAULT_CHAT_TEMPLATE @@ -88,3 +95,10 @@ class GetPeftConfigTest(unittest.TestCase): model_args = ModelArguments(use_peft=False) peft_config = get_peft_config(model_args) self.assertIsNone(peft_config) + + +class IsAdapterModelTest(unittest.TestCase): + def test_is_adapter_model_calls_listdir(self): + # Assert that for an invalid repo name it gets to the point where it calls os.listdir, + # which is expected to raise a FileNotFoundError + self.assertRaises(FileNotFoundError, is_adapter_model, "nonexistent/model")