From 87cc800498b17432cfb7f5acb5e9a79f15c867fc Mon Sep 17 00:00:00 2001 From: lewtun Date: Mon, 5 Feb 2024 16:50:17 +0100 Subject: [PATCH] Apply quantization during DPO QLoRA (#115) * Add QLoRA fix * Update script --- recipes/zephyr-7b-beta/dpo/config_qlora.yaml | 8 ++++---- scripts/run_dpo.py | 14 ++++++-------- setup.py | 2 +- 3 files changed, 11 insertions(+), 13 deletions(-) diff --git a/recipes/zephyr-7b-beta/dpo/config_qlora.yaml b/recipes/zephyr-7b-beta/dpo/config_qlora.yaml index 65adcf3..3928341 100644 --- a/recipes/zephyr-7b-beta/dpo/config_qlora.yaml +++ b/recipes/zephyr-7b-beta/dpo/config_qlora.yaml @@ -1,12 +1,12 @@ # Model arguments model_name_or_path: alignment-handbook/zephyr-7b-sft-qlora -torch_dtype: float16 +torch_dtype: bfloat16 # LoRA arguments use_peft: true load_in_4bit: true -lora_r: 16 -lora_alpha: 16 +lora_r: 128 +lora_alpha: 128 lora_dropout: 0.05 lora_target_modules: - q_proj @@ -32,7 +32,7 @@ beta: 0.01 do_eval: true evaluation_strategy: steps eval_steps: 100 -gradient_accumulation_steps: 2 +gradient_accumulation_steps: 4 gradient_checkpointing: true gradient_checkpointing_kwargs: use_reentrant: false diff --git a/scripts/run_dpo.py b/scripts/run_dpo.py index 58b97ed..3a41c37 100644 --- a/scripts/run_dpo.py +++ b/scripts/run_dpo.py @@ -128,28 +128,26 @@ def main(): model = model_args.model_name_or_path if is_adapter_model(model, model_args.model_revision) is True: - # Load the base model, merge the adapter weights and unload the adapter - # Note: to run QLoRA, you will need to merge the base model separately as the merged model in 16bit - logger.info(f"Merging PEFT adapters for {model_args.model_name_or_path=}") - + logger.info(f"Loading SFT adapter for {model_args.model_name_or_path=}") peft_config = PeftConfig.from_pretrained(model_args.model_name_or_path, revision=model_args.model_revision) - model_kwargs = dict( revision=model_args.base_model_revision, trust_remote_code=model_args.trust_remote_code, use_flash_attention_2=model_args.use_flash_attention_2, torch_dtype=torch_dtype, use_cache=False if training_args.gradient_checkpointing else True, + device_map=get_kbit_device_map() if quantization_config is not None else None, + quantization_config=quantization_config, ) base_model = AutoModelForCausalLM.from_pretrained( peft_config.base_model_name_or_path, **model_kwargs, ) model = PeftModel.from_pretrained( - base_model, model_args.model_name_or_path, revision=model_args.model_revision + base_model, + model_args.model_name_or_path, + revision=model_args.model_revision, ) - model.eval() - model = model.merge_and_unload() model_kwargs = None ref_model = model diff --git a/setup.py b/setup.py index e5ed532..28ff9f3 100644 --- a/setup.py +++ b/setup.py @@ -66,7 +66,7 @@ _deps = [ "tensorboard", "torch==2.1.2", "transformers==4.36.2", - "trl==0.7.7", + "trl==0.7.10", "jinja2>=3.0.0", "tqdm>=4.64.1", ]