mirror of
https://github.com/wassname/alignment-handbook.git
synced 2026-06-27 17:29:09 +08:00
* Fixes #96 by handling RepositoryNotFoundError * Update src/alignment/model_utils.py Co-authored-by: lewtun <lewis.c.tunstall@gmail.com> * Remove redundant code * Add unit test * Reformat file * make style --------- Co-authored-by: lewtun <lewis.c.tunstall@gmail.com>
This commit is contained in:
@@ -22,6 +22,7 @@ from transformers.trainer_utils import get_last_checkpoint
|
||||
|
||||
from accelerate import Accelerator
|
||||
from huggingface_hub import list_repo_files
|
||||
from huggingface_hub.utils._errors import RepositoryNotFoundError
|
||||
from huggingface_hub.utils._validators import HFValidationError
|
||||
from peft import LoraConfig, PeftConfig
|
||||
|
||||
@@ -106,7 +107,7 @@ def is_adapter_model(model_name_or_path: str, revision: str = "main") -> bool:
|
||||
try:
|
||||
# Try first if model on a Hub repo
|
||||
repo_files = list_repo_files(model_name_or_path, revision=revision)
|
||||
except HFValidationError:
|
||||
except (HFValidationError, RepositoryNotFoundError):
|
||||
# If not, check local repo
|
||||
repo_files = os.listdir(model_name_or_path)
|
||||
return "adapter_model.safetensors" in repo_files or "adapter_model.bin" in repo_files
|
||||
|
||||
@@ -17,7 +17,14 @@ import unittest
|
||||
import torch
|
||||
from transformers import AutoTokenizer
|
||||
|
||||
from alignment import DataArguments, ModelArguments, get_peft_config, get_quantization_config, get_tokenizer
|
||||
from alignment import (
|
||||
DataArguments,
|
||||
ModelArguments,
|
||||
get_peft_config,
|
||||
get_quantization_config,
|
||||
get_tokenizer,
|
||||
is_adapter_model,
|
||||
)
|
||||
from alignment.data import DEFAULT_CHAT_TEMPLATE
|
||||
|
||||
|
||||
@@ -88,3 +95,10 @@ class GetPeftConfigTest(unittest.TestCase):
|
||||
model_args = ModelArguments(use_peft=False)
|
||||
peft_config = get_peft_config(model_args)
|
||||
self.assertIsNone(peft_config)
|
||||
|
||||
|
||||
class IsAdapterModelTest(unittest.TestCase):
|
||||
def test_is_adapter_model_calls_listdir(self):
|
||||
# Assert that for an invalid repo name it gets to the point where it calls os.listdir,
|
||||
# which is expected to raise a FileNotFoundError
|
||||
self.assertRaises(FileNotFoundError, is_adapter_model, "nonexistent/model")
|
||||
|
||||
Reference in New Issue
Block a user