update with chat templates

This commit is contained in:
Yu Meng
2024-05-23 23:53:20 -04:00
parent 7da63e9a6e
commit 18c99b9e3b
5 changed files with 13 additions and 8 deletions
+5
View File
@@ -0,0 +1,5 @@
{
"llama3": "{% set loop_messages = messages %}{% for message in loop_messages %}{% set content = '<|start_header_id|>' + message['role'] + '<|end_header_id|>\n\n'+ message['content'] | trim + '<|eot_id|>' %}{% if loop.index0 == 0 %}{% set content = bos_token + content %}{% endif %}{{ content }}{% endfor %}{% if add_generation_prompt %}{{ '<|start_header_id|>assistant<|end_header_id|>\n\n' }}{% endif %}",
"mistral-instruct": "{{ bos_token }}{% for message in messages %}{% if (message['role'] == 'user') != (loop.index0 % 2 == 0) %}{{ raise_exception('Conversation roles must alternate user/assistant/user/assistant/...') }}{% endif %}{% if message['role'] == 'user' %}{{ '[INST] ' + message['content'] + ' [/INST]' }}{% elif message['role'] == 'assistant' %}{{ message['content'] + eos_token}}{% else %}{{ raise_exception('Only user and assistant roles are supported!') }}{% endif %}{% endfor %}",
"mistral-base": "{% for message in messages %}\n{% if message['role'] == 'user' %}\n{{ '<|user|>\n' + message['content'] + eos_token }}\n{% elif message['role'] == 'system' %}\n{{ '<|system|>\n' + message['content'] + eos_token }}\n{% elif message['role'] == 'assistant' %}\n{{ '<|assistant|>\n' + message['content'] + eos_token }}\n{% endif %}\n{% if loop.last and add_generation_prompt %}\n{{ '<|assistant|>' }}\n{% endif %}\n{% endfor %}"
}
+2 -2
View File
@@ -1,5 +1,5 @@
# Model arguments # Model arguments
model_name_or_path: /scratch/gpfs/DANQIC/ym0081/checkpoints_new/llama-3-8b-sft-full-lr-2e-5 model_name_or_path: princeton-nlp/Llama-3-Base-8B-SFT
torch_dtype: null torch_dtype: null
use_flash_attention_2: true use_flash_attention_2: true
@@ -31,7 +31,7 @@ max_length: 2048
max_prompt_length: 1800 max_prompt_length: 1800
num_train_epochs: 1 num_train_epochs: 1
optim: adamw_torch optim: adamw_torch
output_dir: /scratch/gpfs/DANQIC/ym0081/checkpoints_new/llama-3-8b-base-simpo output_dir: outputs/llama-3-8b-base-simpo
run_name: llama-3-8b-base-simpo run_name: llama-3-8b-base-simpo
per_device_train_batch_size: 2 per_device_train_batch_size: 2
per_device_eval_batch_size: 4 per_device_eval_batch_size: 4
@@ -1,11 +1,11 @@
# Model arguments # Model arguments
model_name_or_path: /scratch/gpfs/DANQIC/ym0081/hf_cache/Meta-Llama-3-8B-Instruct model_name_or_path: meta-llama/Meta-Llama-3-8B-Instruct
torch_dtype: null torch_dtype: null
use_flash_attention_2: true use_flash_attention_2: true
# Data training arguments # Data training arguments
dataset_mixer: dataset_mixer:
/scratch/gpfs/DANQIC/ym0081/hf_cache/llama3-ultrafeedback: 1.0 princeton-nlp/llama3-ultrafeedback: 1.0
dataset_splits: dataset_splits:
- train - train
- test - test
@@ -31,7 +31,7 @@ max_length: 2048
max_prompt_length: 1800 max_prompt_length: 1800
num_train_epochs: 1 num_train_epochs: 1
optim: adamw_torch optim: adamw_torch
output_dir: /scratch/gpfs/DANQIC/ym0081/checkpoints_new/llama-3-8b-instruct-simpo output_dir: outputs/llama-3-8b-instruct-simpo
run_name: llama-3-8b-instruct-simpo run_name: llama-3-8b-instruct-simpo
per_device_train_batch_size: 2 per_device_train_batch_size: 2
per_device_eval_batch_size: 4 per_device_eval_batch_size: 4
+1 -1
View File
@@ -31,7 +31,7 @@ max_length: 1024
max_prompt_length: 512 max_prompt_length: 512
num_train_epochs: 1 num_train_epochs: 1
optim: adamw_torch optim: adamw_torch
output_dir: /scratch/gpfs/DANQIC/ym0081/checkpoints_new/mistral-7b-base-simpo output_dir: outputs/mistral-7b-base-simpo
run_name: mistral-7b-base-simpo run_name: mistral-7b-base-simpo
per_device_train_batch_size: 2 per_device_train_batch_size: 2
per_device_eval_batch_size: 4 per_device_eval_batch_size: 4
@@ -1,5 +1,5 @@
# Model arguments # Model arguments
model_name_or_path: /scratch/gpfs/DANQIC/ym0081/hf_cache/Mistral-7B-Instruct-v0.2 model_name_or_path: mistralai/Mistral-7B-Instruct-v0.2
torch_dtype: null torch_dtype: null
use_flash_attention_2: true use_flash_attention_2: true
@@ -31,7 +31,7 @@ max_length: 2048
max_prompt_length: 1800 max_prompt_length: 1800
num_train_epochs: 1 num_train_epochs: 1
optim: adamw_torch optim: adamw_torch
output_dir: /scratch/gpfs/DANQIC/ym0081/checkpoints_new/mistral-7b-instruct-simpo output_dir: outputs/mistral-7b-instruct-simpo
run_name: mistral-7b-instruct-simpo run_name: mistral-7b-instruct-simpo
per_device_train_batch_size: 2 per_device_train_batch_size: 2
per_device_eval_batch_size: 4 per_device_eval_batch_size: 4