mirror of
https://github.com/wassname/alignment-handbook.git
synced 2026-06-27 17:29:09 +08:00
Apply quantization during DPO QLoRA (#115)
* Add QLoRA fix * Update script
This commit is contained in:
@@ -1,12 +1,12 @@
|
|||||||
# Model arguments
|
# Model arguments
|
||||||
model_name_or_path: alignment-handbook/zephyr-7b-sft-qlora
|
model_name_or_path: alignment-handbook/zephyr-7b-sft-qlora
|
||||||
torch_dtype: float16
|
torch_dtype: bfloat16
|
||||||
|
|
||||||
# LoRA arguments
|
# LoRA arguments
|
||||||
use_peft: true
|
use_peft: true
|
||||||
load_in_4bit: true
|
load_in_4bit: true
|
||||||
lora_r: 16
|
lora_r: 128
|
||||||
lora_alpha: 16
|
lora_alpha: 128
|
||||||
lora_dropout: 0.05
|
lora_dropout: 0.05
|
||||||
lora_target_modules:
|
lora_target_modules:
|
||||||
- q_proj
|
- q_proj
|
||||||
@@ -32,7 +32,7 @@ beta: 0.01
|
|||||||
do_eval: true
|
do_eval: true
|
||||||
evaluation_strategy: steps
|
evaluation_strategy: steps
|
||||||
eval_steps: 100
|
eval_steps: 100
|
||||||
gradient_accumulation_steps: 2
|
gradient_accumulation_steps: 4
|
||||||
gradient_checkpointing: true
|
gradient_checkpointing: true
|
||||||
gradient_checkpointing_kwargs:
|
gradient_checkpointing_kwargs:
|
||||||
use_reentrant: false
|
use_reentrant: false
|
||||||
|
|||||||
+6
-8
@@ -128,28 +128,26 @@ def main():
|
|||||||
|
|
||||||
model = model_args.model_name_or_path
|
model = model_args.model_name_or_path
|
||||||
if is_adapter_model(model, model_args.model_revision) is True:
|
if is_adapter_model(model, model_args.model_revision) is True:
|
||||||
# Load the base model, merge the adapter weights and unload the adapter
|
logger.info(f"Loading SFT adapter for {model_args.model_name_or_path=}")
|
||||||
# Note: to run QLoRA, you will need to merge the base model separately as the merged model in 16bit
|
|
||||||
logger.info(f"Merging PEFT adapters for {model_args.model_name_or_path=}")
|
|
||||||
|
|
||||||
peft_config = PeftConfig.from_pretrained(model_args.model_name_or_path, revision=model_args.model_revision)
|
peft_config = PeftConfig.from_pretrained(model_args.model_name_or_path, revision=model_args.model_revision)
|
||||||
|
|
||||||
model_kwargs = dict(
|
model_kwargs = dict(
|
||||||
revision=model_args.base_model_revision,
|
revision=model_args.base_model_revision,
|
||||||
trust_remote_code=model_args.trust_remote_code,
|
trust_remote_code=model_args.trust_remote_code,
|
||||||
use_flash_attention_2=model_args.use_flash_attention_2,
|
use_flash_attention_2=model_args.use_flash_attention_2,
|
||||||
torch_dtype=torch_dtype,
|
torch_dtype=torch_dtype,
|
||||||
use_cache=False if training_args.gradient_checkpointing else True,
|
use_cache=False if training_args.gradient_checkpointing else True,
|
||||||
|
device_map=get_kbit_device_map() if quantization_config is not None else None,
|
||||||
|
quantization_config=quantization_config,
|
||||||
)
|
)
|
||||||
base_model = AutoModelForCausalLM.from_pretrained(
|
base_model = AutoModelForCausalLM.from_pretrained(
|
||||||
peft_config.base_model_name_or_path,
|
peft_config.base_model_name_or_path,
|
||||||
**model_kwargs,
|
**model_kwargs,
|
||||||
)
|
)
|
||||||
model = PeftModel.from_pretrained(
|
model = PeftModel.from_pretrained(
|
||||||
base_model, model_args.model_name_or_path, revision=model_args.model_revision
|
base_model,
|
||||||
|
model_args.model_name_or_path,
|
||||||
|
revision=model_args.model_revision,
|
||||||
)
|
)
|
||||||
model.eval()
|
|
||||||
model = model.merge_and_unload()
|
|
||||||
model_kwargs = None
|
model_kwargs = None
|
||||||
|
|
||||||
ref_model = model
|
ref_model = model
|
||||||
|
|||||||
Reference in New Issue
Block a user