From 3e9c4cc3bd3132648a8659e1a40bd2593c00c3fe Mon Sep 17 00:00:00 2001 From: Yu Meng Date: Sat, 13 Jul 2024 16:50:42 -0400 Subject: [PATCH] update trainer argument --- scripts/run_simpo.py | 2 ++ scripts/simpo_config.py | 2 ++ training_configs/llama-3-8b-base-sft.yaml | 2 +- training_configs/llama-3-8b-base-simpo.yaml | 2 +- training_configs/llama-3-8b-instruct-simpo-v2.yaml | 2 +- training_configs/llama-3-8b-instruct-simpo.yaml | 2 +- training_configs/mistral-7b-base-simpo.yaml | 2 +- training_configs/mistral-7b-instruct-simpo.yaml | 2 +- 8 files changed, 10 insertions(+), 6 deletions(-) diff --git a/scripts/run_simpo.py b/scripts/run_simpo.py index 9a9f174..4ff7457 100644 --- a/scripts/run_simpo.py +++ b/scripts/run_simpo.py @@ -218,6 +218,7 @@ def main(): use_cache=False if training_args.gradient_checkpointing else True, device_map=get_kbit_device_map() if quantization_config is not None else None, quantization_config=quantization_config, + attn_implementation=training_args.attn_implementation, ) model = model_args.model_name_or_path @@ -245,6 +246,7 @@ def main(): # ) # model_kwargs = None + training_args.model_init_kwargs = model_kwargs ######################### # Instantiate SimPO trainer ######################### diff --git a/scripts/simpo_config.py b/scripts/simpo_config.py index 5d98921..41247d5 100644 --- a/scripts/simpo_config.py +++ b/scripts/simpo_config.py @@ -68,3 +68,5 @@ class SimPOConfig(TrainingArguments): model_init_kwargs: Optional[Dict] = None dataset_num_proc: Optional[int] = None + + attn_implementation: str = None diff --git a/training_configs/llama-3-8b-base-sft.yaml b/training_configs/llama-3-8b-base-sft.yaml index 38601d7..792a37e 100644 --- a/training_configs/llama-3-8b-base-sft.yaml +++ b/training_configs/llama-3-8b-base-sft.yaml @@ -2,7 +2,7 @@ model_name_or_path: meta-llama/Meta-Llama-3-8B model_revision: main torch_dtype: bfloat16 -use_flash_attention_2: true +attn_implementation: flash_attention_2 # Data training arguments chat_template: "{% set loop_messages = messages %}{% for message in loop_messages %}{% set content = '<|start_header_id|>' + message['role'] + '<|end_header_id|>\n\n'+ message['content'] | trim + '<|eot_id|>' %}{% if loop.index0 == 0 %}{% set content = bos_token + content %}{% endif %}{{ content }}{% endfor %}{% if add_generation_prompt %}{{ '<|start_header_id|>assistant<|end_header_id|>\n\n' }}{% endif %}" diff --git a/training_configs/llama-3-8b-base-simpo.yaml b/training_configs/llama-3-8b-base-simpo.yaml index c23d0ae..e2e17ea 100644 --- a/training_configs/llama-3-8b-base-simpo.yaml +++ b/training_configs/llama-3-8b-base-simpo.yaml @@ -1,7 +1,7 @@ # Model arguments model_name_or_path: princeton-nlp/Llama-3-Base-8B-SFT torch_dtype: null -use_flash_attention_2: true +attn_implementation: flash_attention_2 # Data training arguments dataset_mixer: diff --git a/training_configs/llama-3-8b-instruct-simpo-v2.yaml b/training_configs/llama-3-8b-instruct-simpo-v2.yaml index 77a899f..842e0fb 100644 --- a/training_configs/llama-3-8b-instruct-simpo-v2.yaml +++ b/training_configs/llama-3-8b-instruct-simpo-v2.yaml @@ -1,7 +1,7 @@ # Model arguments model_name_or_path: meta-llama/Meta-Llama-3-8B-Instruct torch_dtype: null -use_flash_attention_2: true +attn_implementation: flash_attention_2 # Data training arguments dataset_mixer: diff --git a/training_configs/llama-3-8b-instruct-simpo.yaml b/training_configs/llama-3-8b-instruct-simpo.yaml index 665efa8..cb22996 100644 --- a/training_configs/llama-3-8b-instruct-simpo.yaml +++ b/training_configs/llama-3-8b-instruct-simpo.yaml @@ -1,7 +1,7 @@ # Model arguments model_name_or_path: meta-llama/Meta-Llama-3-8B-Instruct torch_dtype: null -use_flash_attention_2: true +attn_implementation: flash_attention_2 # Data training arguments dataset_mixer: diff --git a/training_configs/mistral-7b-base-simpo.yaml b/training_configs/mistral-7b-base-simpo.yaml index 86647ae..5828d9a 100644 --- a/training_configs/mistral-7b-base-simpo.yaml +++ b/training_configs/mistral-7b-base-simpo.yaml @@ -1,7 +1,7 @@ # Model arguments model_name_or_path: alignment-handbook/zephyr-7b-sft-full torch_dtype: null -use_flash_attention_2: true +attn_implementation: flash_attention_2 # Data training arguments dataset_mixer: diff --git a/training_configs/mistral-7b-instruct-simpo.yaml b/training_configs/mistral-7b-instruct-simpo.yaml index 00c4208..4ea8da7 100644 --- a/training_configs/mistral-7b-instruct-simpo.yaml +++ b/training_configs/mistral-7b-instruct-simpo.yaml @@ -1,7 +1,7 @@ # Model arguments model_name_or_path: mistralai/Mistral-7B-Instruct-v0.2 torch_dtype: null -use_flash_attention_2: true +attn_implementation: flash_attention_2 # Data training arguments dataset_mixer: