Fixes #96 by handling RepositoryNotFoundError (#97)

* Fixes #96 by handling RepositoryNotFoundError

* Update src/alignment/model_utils.py

Co-authored-by: lewtun <lewis.c.tunstall@gmail.com>

* Remove redundant code

* Add unit test

* Reformat file

* make style

---------

Co-authored-by: lewtun <lewis.c.tunstall@gmail.com>
This commit is contained in:
Traun Leyden
2024-02-01 15:47:14 +01:00
committed by GitHub
parent ad3d43aeea
commit 5ad6db0c79
2 changed files with 17 additions and 2 deletions
+2 -1
View File
@@ -22,6 +22,7 @@ from transformers.trainer_utils import get_last_checkpoint
from accelerate import Accelerator
from huggingface_hub import list_repo_files
from huggingface_hub.utils._errors import RepositoryNotFoundError
from huggingface_hub.utils._validators import HFValidationError
from peft import LoraConfig, PeftConfig
@@ -106,7 +107,7 @@ def is_adapter_model(model_name_or_path: str, revision: str = "main") -> bool:
try:
# Try first if model on a Hub repo
repo_files = list_repo_files(model_name_or_path, revision=revision)
except HFValidationError:
except (HFValidationError, RepositoryNotFoundError):
# If not, check local repo
repo_files = os.listdir(model_name_or_path)
return "adapter_model.safetensors" in repo_files or "adapter_model.bin" in repo_files
+15 -1
View File
@@ -17,7 +17,14 @@ import unittest
import torch
from transformers import AutoTokenizer
from alignment import DataArguments, ModelArguments, get_peft_config, get_quantization_config, get_tokenizer
from alignment import (
DataArguments,
ModelArguments,
get_peft_config,
get_quantization_config,
get_tokenizer,
is_adapter_model,
)
from alignment.data import DEFAULT_CHAT_TEMPLATE
@@ -88,3 +95,10 @@ class GetPeftConfigTest(unittest.TestCase):
model_args = ModelArguments(use_peft=False)
peft_config = get_peft_config(model_args)
self.assertIsNone(peft_config)
class IsAdapterModelTest(unittest.TestCase):
def test_is_adapter_model_calls_listdir(self):
# Assert that for an invalid repo name it gets to the point where it calls os.listdir,
# which is expected to raise a FileNotFoundError
self.assertRaises(FileNotFoundError, is_adapter_model, "nonexistent/model")