From 18c99b9e3b059d400afd33f45d1d25283b457f36 Mon Sep 17 00:00:00 2001 From: Yu Meng Date: Thu, 23 May 2024 23:53:20 -0400 Subject: [PATCH] update with chat templates --- chat_templates.json | 5 +++++ training_configs/llama-3-8b-base-simpo.yaml | 4 ++-- training_configs/llama-3-8b-instruct-simpo.yaml | 6 +++--- training_configs/mistral-7b-base-simpo.yaml | 2 +- training_configs/mistral-7b-instruct-simpo.yaml | 4 ++-- 5 files changed, 13 insertions(+), 8 deletions(-) create mode 100644 chat_templates.json diff --git a/chat_templates.json b/chat_templates.json new file mode 100644 index 0000000..44651b7 --- /dev/null +++ b/chat_templates.json @@ -0,0 +1,5 @@ +{ + "llama3": "{% set loop_messages = messages %}{% for message in loop_messages %}{% set content = '<|start_header_id|>' + message['role'] + '<|end_header_id|>\n\n'+ message['content'] | trim + '<|eot_id|>' %}{% if loop.index0 == 0 %}{% set content = bos_token + content %}{% endif %}{{ content }}{% endfor %}{% if add_generation_prompt %}{{ '<|start_header_id|>assistant<|end_header_id|>\n\n' }}{% endif %}", + "mistral-instruct": "{{ bos_token }}{% for message in messages %}{% if (message['role'] == 'user') != (loop.index0 % 2 == 0) %}{{ raise_exception('Conversation roles must alternate user/assistant/user/assistant/...') }}{% endif %}{% if message['role'] == 'user' %}{{ '[INST] ' + message['content'] + ' [/INST]' }}{% elif message['role'] == 'assistant' %}{{ message['content'] + eos_token}}{% else %}{{ raise_exception('Only user and assistant roles are supported!') }}{% endif %}{% endfor %}", + "mistral-base": "{% for message in messages %}\n{% if message['role'] == 'user' %}\n{{ '<|user|>\n' + message['content'] + eos_token }}\n{% elif message['role'] == 'system' %}\n{{ '<|system|>\n' + message['content'] + eos_token }}\n{% elif message['role'] == 'assistant' %}\n{{ '<|assistant|>\n' + message['content'] + eos_token }}\n{% endif %}\n{% if loop.last and add_generation_prompt %}\n{{ '<|assistant|>' }}\n{% endif %}\n{% endfor %}" +} \ No newline at end of file diff --git a/training_configs/llama-3-8b-base-simpo.yaml b/training_configs/llama-3-8b-base-simpo.yaml index d56ef87..f78bfa5 100644 --- a/training_configs/llama-3-8b-base-simpo.yaml +++ b/training_configs/llama-3-8b-base-simpo.yaml @@ -1,5 +1,5 @@ # Model arguments -model_name_or_path: /scratch/gpfs/DANQIC/ym0081/checkpoints_new/llama-3-8b-sft-full-lr-2e-5 +model_name_or_path: princeton-nlp/Llama-3-Base-8B-SFT torch_dtype: null use_flash_attention_2: true @@ -31,7 +31,7 @@ max_length: 2048 max_prompt_length: 1800 num_train_epochs: 1 optim: adamw_torch -output_dir: /scratch/gpfs/DANQIC/ym0081/checkpoints_new/llama-3-8b-base-simpo +output_dir: outputs/llama-3-8b-base-simpo run_name: llama-3-8b-base-simpo per_device_train_batch_size: 2 per_device_eval_batch_size: 4 diff --git a/training_configs/llama-3-8b-instruct-simpo.yaml b/training_configs/llama-3-8b-instruct-simpo.yaml index fe51723..d0c91ad 100644 --- a/training_configs/llama-3-8b-instruct-simpo.yaml +++ b/training_configs/llama-3-8b-instruct-simpo.yaml @@ -1,11 +1,11 @@ # Model arguments -model_name_or_path: /scratch/gpfs/DANQIC/ym0081/hf_cache/Meta-Llama-3-8B-Instruct +model_name_or_path: meta-llama/Meta-Llama-3-8B-Instruct torch_dtype: null use_flash_attention_2: true # Data training arguments dataset_mixer: - /scratch/gpfs/DANQIC/ym0081/hf_cache/llama3-ultrafeedback: 1.0 + princeton-nlp/llama3-ultrafeedback: 1.0 dataset_splits: - train - test @@ -31,7 +31,7 @@ max_length: 2048 max_prompt_length: 1800 num_train_epochs: 1 optim: adamw_torch -output_dir: /scratch/gpfs/DANQIC/ym0081/checkpoints_new/llama-3-8b-instruct-simpo +output_dir: outputs/llama-3-8b-instruct-simpo run_name: llama-3-8b-instruct-simpo per_device_train_batch_size: 2 per_device_eval_batch_size: 4 diff --git a/training_configs/mistral-7b-base-simpo.yaml b/training_configs/mistral-7b-base-simpo.yaml index 87d329c..2ffe421 100644 --- a/training_configs/mistral-7b-base-simpo.yaml +++ b/training_configs/mistral-7b-base-simpo.yaml @@ -31,7 +31,7 @@ max_length: 1024 max_prompt_length: 512 num_train_epochs: 1 optim: adamw_torch -output_dir: /scratch/gpfs/DANQIC/ym0081/checkpoints_new/mistral-7b-base-simpo +output_dir: outputs/mistral-7b-base-simpo run_name: mistral-7b-base-simpo per_device_train_batch_size: 2 per_device_eval_batch_size: 4 diff --git a/training_configs/mistral-7b-instruct-simpo.yaml b/training_configs/mistral-7b-instruct-simpo.yaml index 08f6842..cc42119 100644 --- a/training_configs/mistral-7b-instruct-simpo.yaml +++ b/training_configs/mistral-7b-instruct-simpo.yaml @@ -1,5 +1,5 @@ # Model arguments -model_name_or_path: /scratch/gpfs/DANQIC/ym0081/hf_cache/Mistral-7B-Instruct-v0.2 +model_name_or_path: mistralai/Mistral-7B-Instruct-v0.2 torch_dtype: null use_flash_attention_2: true @@ -31,7 +31,7 @@ max_length: 2048 max_prompt_length: 1800 num_train_epochs: 1 optim: adamw_torch -output_dir: /scratch/gpfs/DANQIC/ym0081/checkpoints_new/mistral-7b-instruct-simpo +output_dir: outputs/mistral-7b-instruct-simpo run_name: mistral-7b-instruct-simpo per_device_train_batch_size: 2 per_device_eval_batch_size: 4