Allow running DPO from a local model (#49)

* Update model_utils.py

Check if a model is adapter model when a local path is supplied instead of HF model

* Cleaner solution, thanks to lewtun
This commit is contained in:
Dragan Milchevski
2023-11-27 11:31:09 +01:00
committed by GitHub
parent f025057ce4
commit 80e952ec47
+8 -2
View File
@@ -12,7 +12,7 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import os
from typing import Dict
import torch
@@ -20,6 +20,7 @@ from transformers import AutoTokenizer, BitsAndBytesConfig, PreTrainedTokenizer
from accelerate import Accelerator
from huggingface_hub import list_repo_files
from huggingface_hub.utils._validators import HFValidationError
from peft import LoraConfig, PeftConfig
from .configs import DataArguments, ModelArguments
@@ -96,5 +97,10 @@ def get_peft_config(model_args: ModelArguments) -> PeftConfig | None:
def is_adapter_model(model_name_or_path: str, revision: str = "main") -> bool:
repo_files = list_repo_files(model_name_or_path, revision=revision)
try:
# Try first if model on a Hub repo
repo_files = list_repo_files(model_name_or_path, revision=revision)
except HFValidationError:
# If not, check local repo
repo_files = os.listdir(model_name_or_path)
return "adapter_model.safetensors" in repo_files or "adapter_model.bin" in repo_files