diff --git a/src/alignment/model_utils.py b/src/alignment/model_utils.py index b9d2315..ab2025f 100644 --- a/src/alignment/model_utils.py +++ b/src/alignment/model_utils.py @@ -12,7 +12,7 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. - +import os from typing import Dict import torch @@ -20,6 +20,7 @@ from transformers import AutoTokenizer, BitsAndBytesConfig, PreTrainedTokenizer from accelerate import Accelerator from huggingface_hub import list_repo_files +from huggingface_hub.utils._validators import HFValidationError from peft import LoraConfig, PeftConfig from .configs import DataArguments, ModelArguments @@ -96,5 +97,10 @@ def get_peft_config(model_args: ModelArguments) -> PeftConfig | None: def is_adapter_model(model_name_or_path: str, revision: str = "main") -> bool: - repo_files = list_repo_files(model_name_or_path, revision=revision) + try: + # Try first if model on a Hub repo + repo_files = list_repo_files(model_name_or_path, revision=revision) + except HFValidationError: + # If not, check local repo + repo_files = os.listdir(model_name_or_path) return "adapter_model.safetensors" in repo_files or "adapter_model.bin" in repo_files