mirror of
https://github.com/wassname/alignment-handbook.git
synced 2026-06-27 17:14:25 +08:00
Allow running DPO from a local model (#49)
* Update model_utils.py Check if a model is adapter model when a local path is supplied instead of HF model * Cleaner solution, thanks to lewtun
This commit is contained in:
committed by
GitHub
parent
f025057ce4
commit
80e952ec47
@@ -12,7 +12,7 @@
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
import os
|
||||
from typing import Dict
|
||||
|
||||
import torch
|
||||
@@ -20,6 +20,7 @@ from transformers import AutoTokenizer, BitsAndBytesConfig, PreTrainedTokenizer
|
||||
|
||||
from accelerate import Accelerator
|
||||
from huggingface_hub import list_repo_files
|
||||
from huggingface_hub.utils._validators import HFValidationError
|
||||
from peft import LoraConfig, PeftConfig
|
||||
|
||||
from .configs import DataArguments, ModelArguments
|
||||
@@ -96,5 +97,10 @@ def get_peft_config(model_args: ModelArguments) -> PeftConfig | None:
|
||||
|
||||
|
||||
def is_adapter_model(model_name_or_path: str, revision: str = "main") -> bool:
|
||||
repo_files = list_repo_files(model_name_or_path, revision=revision)
|
||||
try:
|
||||
# Try first if model on a Hub repo
|
||||
repo_files = list_repo_files(model_name_or_path, revision=revision)
|
||||
except HFValidationError:
|
||||
# If not, check local repo
|
||||
repo_files = os.listdir(model_name_or_path)
|
||||
return "adapter_model.safetensors" in repo_files or "adapter_model.bin" in repo_files
|
||||
|
||||
Reference in New Issue
Block a user