mirror of
https://github.com/wassname/alignment-handbook.git
synced 2026-06-27 17:14:25 +08:00
adds auto adapter merge to dpo script
This commit is contained in:
+28
-3
@@ -34,7 +34,9 @@ from alignment import (
|
||||
get_tokenizer,
|
||||
)
|
||||
from trl import DPOTrainer
|
||||
|
||||
from transformers import AutoModelForCausalLM
|
||||
from alignment.model_utils import is_adapter_model
|
||||
from peft import PeftConfig, PeftModel
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
@@ -112,8 +114,31 @@ def main():
|
||||
device_map=get_kbit_device_map(),
|
||||
quantization_config=get_quantization_config(model_args),
|
||||
)
|
||||
|
||||
model = model_args.model_name_or_path
|
||||
if is_adapter_model(model, model_args.model_revision):
|
||||
# load the model, merge the adapter weights and unload the adapter
|
||||
# Note: to run QLora, you will need to merge the based model separately as the merged model in 16bit
|
||||
logger.info(f"Merging peft adapters for {model_args.model_name_or_path=}")
|
||||
|
||||
peft_config = PeftConfig.from_pretrained(model_args.model_name_or_path, revision=model_args.model_revision)
|
||||
|
||||
model_kwargs = dict(
|
||||
revision=model_args.base_model_revision,
|
||||
trust_remote_code=model_args.trust_remote_code,
|
||||
use_flash_attention_2=model_args.use_flash_attention_2,
|
||||
torch_dtype=torch_dtype,
|
||||
use_cache=False if training_args.gradient_checkpointing else True,
|
||||
)
|
||||
base_model = AutoModelForCausalLM.from_pretrained(
|
||||
peft_config.base_model_name_or_path, **model_kwargs,
|
||||
)
|
||||
model = PeftModel.from_pretrained(base_model, model_args.model_name_or_path, revision=model_args.model_revision)
|
||||
model.eval()
|
||||
model = model.merge_and_unload()
|
||||
model_kwargs = None
|
||||
|
||||
ref_model = model_args.model_name_or_path
|
||||
ref_model = model
|
||||
ref_model_kwargs = model_kwargs
|
||||
|
||||
if model_args.use_peft is True:
|
||||
@@ -124,7 +149,7 @@ def main():
|
||||
# Instantiate DPO trainer
|
||||
#########################
|
||||
dpo_trainer = DPOTrainer(
|
||||
model_args.model_name_or_path,
|
||||
model,
|
||||
ref_model,
|
||||
model_init_kwargs=model_kwargs,
|
||||
ref_model_init_kwargs=ref_model_kwargs,
|
||||
|
||||
@@ -5,7 +5,7 @@ from transformers import AutoTokenizer, BitsAndBytesConfig, PreTrainedTokenizer
|
||||
|
||||
from accelerate import Accelerator
|
||||
from peft import LoraConfig, PeftConfig
|
||||
|
||||
from huggingface_hub import list_repo_files
|
||||
from .configs import DataArguments, ModelArguments
|
||||
from .data import DEFAULT_CHAT_TEMPLATE
|
||||
|
||||
@@ -77,3 +77,7 @@ def get_peft_config(model_args: ModelArguments) -> PeftConfig | None:
|
||||
)
|
||||
|
||||
return peft_config
|
||||
|
||||
def is_adapter_model(model_name_or_path: str, revision: str = "main") -> bool:
|
||||
repo_files = list_repo_files(model_name_or_path, revision=revision)
|
||||
return "adapter_model.safetensors" in repo_files or "adapter_model.bin" in repo_files
|
||||
Reference in New Issue
Block a user