From 7f1a14e0d4aec284a6afdfffe5d7fd23989748e7 Mon Sep 17 00:00:00 2001 From: edbeeching Date: Fri, 10 Nov 2023 14:15:44 +0100 Subject: [PATCH] adds auto adapter merge to dpo script --- scripts/run_dpo.py | 31 ++++++++++++++++++++++++++++--- src/alignment/model_utils.py | 6 +++++- 2 files changed, 33 insertions(+), 4 deletions(-) diff --git a/scripts/run_dpo.py b/scripts/run_dpo.py index 704ce1f..2a428ac 100644 --- a/scripts/run_dpo.py +++ b/scripts/run_dpo.py @@ -34,7 +34,9 @@ from alignment import ( get_tokenizer, ) from trl import DPOTrainer - +from transformers import AutoModelForCausalLM +from alignment.model_utils import is_adapter_model +from peft import PeftConfig, PeftModel logger = logging.getLogger(__name__) @@ -112,8 +114,31 @@ def main(): device_map=get_kbit_device_map(), quantization_config=get_quantization_config(model_args), ) + + model = model_args.model_name_or_path + if is_adapter_model(model, model_args.model_revision): + # load the model, merge the adapter weights and unload the adapter + # Note: to run QLora, you will need to merge the based model separately as the merged model in 16bit + logger.info(f"Merging peft adapters for {model_args.model_name_or_path=}") + + peft_config = PeftConfig.from_pretrained(model_args.model_name_or_path, revision=model_args.model_revision) + + model_kwargs = dict( + revision=model_args.base_model_revision, + trust_remote_code=model_args.trust_remote_code, + use_flash_attention_2=model_args.use_flash_attention_2, + torch_dtype=torch_dtype, + use_cache=False if training_args.gradient_checkpointing else True, + ) + base_model = AutoModelForCausalLM.from_pretrained( + peft_config.base_model_name_or_path, **model_kwargs, + ) + model = PeftModel.from_pretrained(base_model, model_args.model_name_or_path, revision=model_args.model_revision) + model.eval() + model = model.merge_and_unload() + model_kwargs = None - ref_model = model_args.model_name_or_path + ref_model = model ref_model_kwargs = model_kwargs if model_args.use_peft is True: @@ -124,7 +149,7 @@ def main(): # Instantiate DPO trainer ######################### dpo_trainer = DPOTrainer( - model_args.model_name_or_path, + model, ref_model, model_init_kwargs=model_kwargs, ref_model_init_kwargs=ref_model_kwargs, diff --git a/src/alignment/model_utils.py b/src/alignment/model_utils.py index f237745..d35e037 100644 --- a/src/alignment/model_utils.py +++ b/src/alignment/model_utils.py @@ -5,7 +5,7 @@ from transformers import AutoTokenizer, BitsAndBytesConfig, PreTrainedTokenizer from accelerate import Accelerator from peft import LoraConfig, PeftConfig - +from huggingface_hub import list_repo_files from .configs import DataArguments, ModelArguments from .data import DEFAULT_CHAT_TEMPLATE @@ -77,3 +77,7 @@ def get_peft_config(model_args: ModelArguments) -> PeftConfig | None: ) return peft_config + +def is_adapter_model(model_name_or_path: str, revision: str = "main") -> bool: + repo_files = list_repo_files(model_name_or_path, revision=revision) + return "adapter_model.safetensors" in repo_files or "adapter_model.bin" in repo_files \ No newline at end of file