mirror of
https://github.com/wassname/alignment-handbook.git
synced 2026-06-27 17:14:25 +08:00
@@ -0,0 +1,31 @@
|
||||
name: Tests
|
||||
|
||||
on:
|
||||
push:
|
||||
branches:
|
||||
- main
|
||||
- v*-release
|
||||
pull_request:
|
||||
branches:
|
||||
- main
|
||||
|
||||
jobs:
|
||||
|
||||
unit-tests:
|
||||
name: Run unit tests
|
||||
env:
|
||||
HF_TOKEN: ${{ secrets.HF_TOKEN }}
|
||||
runs-on: ubuntu-latest
|
||||
steps:
|
||||
- name: Checkout code
|
||||
uses: actions/checkout@v2
|
||||
- name: Setup Python environment
|
||||
uses: actions/setup-python@v2
|
||||
with:
|
||||
python-version: 3.10.10
|
||||
- name: Install dependencies
|
||||
run: |
|
||||
python -m pip install --upgrade pip
|
||||
python -m pip install ".[dev, torch]"
|
||||
- name: Run unit tests
|
||||
run: HF_TOKEN=$HF_TOKEN pytest -sv tests/
|
||||
@@ -158,3 +158,7 @@ cython_debug/
|
||||
# and can be added to the global gitignore or merged into this file. For a more nuclear
|
||||
# option (not recommended) you can uncomment the following to ignore the entire idea folder.
|
||||
.idea/
|
||||
|
||||
# Temp folders
|
||||
data/
|
||||
wandb/
|
||||
@@ -3,31 +3,31 @@
|
||||
# make sure to test the local checkout in scripts and not the pre-installed one (don't use quotes!)
|
||||
export PYTHONPATH = src
|
||||
|
||||
check_dirs := src tests
|
||||
check_dirs := src tests scripts
|
||||
|
||||
style:
|
||||
python -m black --line-length 119 --target-version py310 $(check_dirs) setup.py
|
||||
python -m isort $(check_dirs) setup.py
|
||||
black --line-length 119 --target-version py310 $(check_dirs) setup.py
|
||||
isort $(check_dirs) setup.py
|
||||
|
||||
quality:
|
||||
python -m black --check --line-length 119 --target-version py310 $(check_dirs) setup.py
|
||||
python -m isort --check-only $(check_dirs) setup.py
|
||||
python -m flake8 --max-line-length 119 $(check_dirs) setup.py
|
||||
black --check --line-length 119 --target-version py310 $(check_dirs) setup.py
|
||||
isort --check-only $(check_dirs) setup.py
|
||||
flake8 --max-line-length 119 $(check_dirs) setup.py
|
||||
|
||||
|
||||
# Release stuff
|
||||
|
||||
pre-release:
|
||||
python src/alignment/utils/release.py
|
||||
python src/alignment/release.py
|
||||
|
||||
pre-patch:
|
||||
python src/alignment/utils/release.py --patch
|
||||
python src/alignment/release.py --patch
|
||||
|
||||
post-release:
|
||||
python src/alignment/utils/release.py --post_release
|
||||
python src/alignment/release.py --post_release
|
||||
|
||||
post-patch:
|
||||
python src/alignment/utils/release.py --post_release --patch
|
||||
python src/alignment/release.py --post_release --patch
|
||||
|
||||
wheels:
|
||||
python setup.py bdist_wheel && python setup.py sdist
|
||||
|
||||
@@ -10,6 +10,10 @@ However, we know from the [InstructGPT](https://huggingface.co/papers/2203.02155
|
||||
|
||||
The Alignment Handbook aims to fill that gap by providing the community with a series of robust training recipes that span the whole pipeline.
|
||||
|
||||
## News 🗞️
|
||||
|
||||
* November 10, 2023: We release all the training code to replicate Zephyr-7b-β 🪁!
|
||||
|
||||
## Links 🔗
|
||||
|
||||
* [Zephyr 7B models, datasets, and demos](https://huggingface.co/collections/HuggingFaceH4/zephyr-7b-6538c6d6d5ddd1cbb1744a66)
|
||||
@@ -32,13 +36,20 @@ To run the code in this project, first create a Python virtual environment using
|
||||
conda create -n handbook python=3.10 && conda activate handbook
|
||||
```
|
||||
|
||||
Next, install PyTorch v2.1.0. Since this hardware-dependent, we
|
||||
Next, install PyTorch `v2.1.0` - the precise version is important for reproducibility! Since this is hardware-dependent, we
|
||||
direct you to the [PyTorch Installation Page](https://pytorch.org/get-started/locally/).
|
||||
|
||||
Once PyTorch is installed, you can install the remaining package dependencies as follows:
|
||||
You can then install the remaining package dependencies as follows:
|
||||
|
||||
```shell
|
||||
pip install .
|
||||
python -m pip install .
|
||||
```
|
||||
|
||||
You will also need Flash Attention 2 installed, which can be done by running:
|
||||
_Note: If your machine has less than 96GB of RAM and many CPU cores, reduce the MAX_JOBS., e.g. `MAX_JOBS=4 pip install flash-attn --no-build-isolation` _
|
||||
|
||||
```shell
|
||||
python -m pip install flash-attn --no-build-isolation
|
||||
```
|
||||
|
||||
Next, log into your Hugging Face account as follows:
|
||||
@@ -53,6 +64,23 @@ Finally, install Git LFS so that you can push models to the Hugging Face Hub:
|
||||
sudo apt-get install git-lfs
|
||||
```
|
||||
|
||||
You can now checkout the `scripts` and `recipes` directories for instructions on how to train some models 🪁!
|
||||
|
||||
## Project structure
|
||||
|
||||
```
|
||||
├── LICENSE
|
||||
├── Makefile <- Makefile with commands like `make style`
|
||||
├── README.md <- The top-level README for developers using this project
|
||||
├── chapters <- Educational content to render on hf.co/learn
|
||||
├── recipes <- Recipe configs, accelerate configs, slurm scripts
|
||||
├── scripts <- Scripts to train and evaluate chat models
|
||||
├── setup.cfg <- Installation config (mostly used for configuring code quality & tests)
|
||||
├── setup.py <- Makes project pip installable (pip install -e .) so `alignment` can be imported
|
||||
├── src <- Source code for use in this project
|
||||
└── tests <- Unit tests
|
||||
```
|
||||
|
||||
## Citation
|
||||
|
||||
If you find the content of this repo useful in your work, please cite it as follows:
|
||||
|
||||
@@ -0,0 +1,22 @@
|
||||
compute_environment: LOCAL_MACHINE
|
||||
debug: false
|
||||
deepspeed_config:
|
||||
deepspeed_multinode_launcher: standard
|
||||
offload_optimizer_device: none
|
||||
offload_param_device: none
|
||||
zero3_init_flag: true
|
||||
zero3_save_16bit_model: true
|
||||
zero_stage: 3
|
||||
distributed_type: DEEPSPEED
|
||||
downcast_bf16: 'no'
|
||||
machine_rank: 0
|
||||
main_training_function: main
|
||||
mixed_precision: bf16
|
||||
num_machines: 1
|
||||
num_processes: 8
|
||||
rdzv_backend: static
|
||||
same_network: true
|
||||
tpu_env: []
|
||||
tpu_use_cluster: false
|
||||
tpu_use_sudo: false
|
||||
use_cpu: false
|
||||
@@ -0,0 +1,16 @@
|
||||
compute_environment: LOCAL_MACHINE
|
||||
debug: false
|
||||
distributed_type: MULTI_GPU
|
||||
downcast_bf16: 'no'
|
||||
gpu_ids: all
|
||||
machine_rank: 0
|
||||
main_training_function: main
|
||||
mixed_precision: bf16
|
||||
num_machines: 1
|
||||
num_processes: 8
|
||||
rdzv_backend: static
|
||||
same_network: true
|
||||
tpu_env: []
|
||||
tpu_use_cluster: false
|
||||
tpu_use_sudo: false
|
||||
use_cpu: false
|
||||
@@ -0,0 +1,93 @@
|
||||
#!/bin/bash
|
||||
#SBATCH --ntasks-per-node=1
|
||||
#SBATCH --exclusive
|
||||
#SBATCH --gres=gpu:8
|
||||
#SBATCH --partition=production-cluster # Adjust this for your cluster
|
||||
#SBATCH --output=/fsx/h4/logs/%x-%j.out # Adjust this for your cluster
|
||||
#SBATCH --err=/fsx/h4/logs/%x-%j.err # Adjust this for your cluster
|
||||
|
||||
set -x -e
|
||||
|
||||
source ~/.bashrc
|
||||
conda activate handbook
|
||||
echo "START TIME: $(date)"
|
||||
|
||||
MODEL=$1
|
||||
TASK=$2
|
||||
PRECISION=$3
|
||||
ACCELERATOR=$4
|
||||
OPTIONAL_ARGS=$5
|
||||
|
||||
# Training setup
|
||||
NUM_NODES=$SLURM_NNODES
|
||||
GPUS_PER_NODE=8
|
||||
WORLD_SIZE=$(($NUM_NODES*$GPUS_PER_NODE))
|
||||
# Due to conflicts between Accelerate's DeepSpeed configs and Transformers' TrainingArguments, we need to parse the gradient accumulation steps from the config file to ensure they match
|
||||
CONFIG_FILE=recipes/$MODEL/$TASK/config_$PRECISION.yaml
|
||||
GRAD_ACC_STEPS=$(grep 'gradient_accumulation_steps' $CONFIG_FILE | awk '{print $2}')
|
||||
|
||||
# Split the string into individual arguments
|
||||
IFS=' ' read -ra ARGS <<< "$OPTIONAL_ARGS"
|
||||
|
||||
# Loop through the arguments and find the one with "--gradient_accumulation_steps"
|
||||
for arg in "${ARGS[@]}"; do
|
||||
if [[ "$arg" == "--gradient_accumulation_steps="* ]]; then
|
||||
# Extract the value after the equals sign
|
||||
GRAD_ACC_STEPS="${arg#*=}"
|
||||
break # Exit the loop once we find the desired argument
|
||||
fi
|
||||
done
|
||||
|
||||
echo "Gradient accumulation steps: $GRAD_ACC_STEPS"
|
||||
# so processes know who to talk to
|
||||
MASTER_ADDR=$(scontrol show hostnames $SLURM_JOB_NODELIST | head -n 1)
|
||||
MASTER_PORT=6000
|
||||
|
||||
export CMD=" \
|
||||
scripts/run_$TASK.py $CONFIG_FILE $OPTIONAL_ARGS
|
||||
"
|
||||
|
||||
export LAUNCHER="ACCELERATE_LOG_LEVEL=info accelerate launch \
|
||||
--config_file recipes/accelerate_configs/$ACCELERATOR.yaml \
|
||||
--gradient_accumulation_steps $GRAD_ACC_STEPS \
|
||||
--num_machines $NUM_NODES \
|
||||
--num_processes $WORLD_SIZE \
|
||||
--main_process_ip $MASTER_ADDR \
|
||||
--main_process_port $MASTER_PORT \
|
||||
--machine_rank \$SLURM_PROCID \
|
||||
--rdzv_conf "rdzv_backend=c10d,rdzv_endpoint=$MASTER_ADDR:$MASTER_PORT" \
|
||||
--max_restarts 1 \
|
||||
--role \$(hostname -s): \
|
||||
--tee 3 \
|
||||
"
|
||||
|
||||
# force crashing on nccl issues like hanging broadcast
|
||||
export NCCL_ASYNC_ERROR_HANDLING=1
|
||||
# export NCCL_DEBUG=INFO
|
||||
# export NCCL_DEBUG_SUBSYS=COLL
|
||||
# export NCCL_SOCKET_NTHREADS=1
|
||||
# export NCCL_NSOCKS_PERTHREAD=1
|
||||
# export CUDA_LAUNCH_BLOCKING=1
|
||||
|
||||
# Specific configuration optimized for the Hugging Face Compute Cluster
|
||||
# Be ye warned this may not work on other clusters!
|
||||
export NCCL_PROTO=simple
|
||||
export RDMAV_FORK_SAFE=1
|
||||
export FI_EFA_FORK_SAFE=1
|
||||
export FI_EFA_USE_DEVICE_RDMA=1
|
||||
export FI_PROVIDER=efa
|
||||
export FI_LOG_LEVEL=1
|
||||
export NCCL_IB_DISABLE=1
|
||||
export NCCL_SOCKET_IFNAME=ens
|
||||
|
||||
# srun error handling:
|
||||
# --wait=60: wait 60 sec after the first task terminates before terminating all remaining tasks
|
||||
# --kill-on-bad-exit=1: terminate a step if any task exits with a non-zero exit code
|
||||
SRUN_ARGS=" \
|
||||
--wait=60 \
|
||||
--kill-on-bad-exit=1 \
|
||||
"
|
||||
|
||||
clear; srun $SRUN_ARGS --jobid $SLURM_JOB_ID bash -c "$LAUNCHER --role \$SLURMD_NODENAME: $CMD" 2>&1
|
||||
|
||||
echo "END TIME: $(date)"
|
||||
@@ -0,0 +1,29 @@
|
||||
|
||||
# Instructions to Replicate Zephyr 7B
|
||||
|
||||
As described in the Zephyr [technical report](https://huggingface.co/papers/2310.16944), training this model proceeds in two steps:
|
||||
|
||||
1. Apply SFT to fine-tune Mistral 7B on a filtered version of the UltraChat dataset ([link](https://huggingface.co/datasets/HuggingFaceH4/ultrachat_200k)).
|
||||
2. Align the SFT model to AI feedback via DPO on a preprocessed version of the UltraFeedback dataset ([link](https://huggingface.co/datasets/HuggingFaceH4/ultrafeedback_binarized)).
|
||||
|
||||
See below for commands to train these models using either DeepSpeed ZeRO-3 or LoRA.
|
||||
|
||||
## Full training examples
|
||||
|
||||
```shell
|
||||
# Step 1 - SFT
|
||||
ACCELERATE_LOG_LEVEL=info accelerate launch --config_file recipes/accelerate_configs/deepspeed_zero3.yaml scripts/run_sft.py recipes/zephyr-7b-beta/sft/config_full.yaml
|
||||
|
||||
# Step 2 - DPO
|
||||
ACCELERATE_LOG_LEVEL=info accelerate launch --config_file recipes/accelerate_configs/deepspeed_zero3.yaml scripts/run_dpo.py recipes/zephyr-7b-beta/beta/beta/beta/dpo/config_full.yaml
|
||||
```
|
||||
|
||||
## LoRA training examples
|
||||
|
||||
```shell
|
||||
# Step 1 - SFT
|
||||
ACCELERATE_LOG_LEVEL=info accelerate launch --config_file recipes/accelerate_configs/multi_gpu.yaml --num_processes=1 scripts/run_sft.py recipes/zephyr-7b-beta/beta/sft/config_lora.yaml
|
||||
|
||||
# Step 2 - DPO
|
||||
ACCELERATE_LOG_LEVEL=info accelerate launch --config_file recipes/accelerate_configs/multi_gpu.yaml --num_processes=1 scripts/run_dpo.py recipes/zephyr-7b-beta/dpo/config_lora.yaml
|
||||
```
|
||||
@@ -0,0 +1,37 @@
|
||||
# Model arguments
|
||||
model_name_or_path: alignment-handbook/zephyr-7b-sft-full
|
||||
|
||||
# Data training arguments
|
||||
# For definitions, see: src/h4/training/config.py
|
||||
dataset_mixer:
|
||||
HuggingFaceH4/ultrafeedback_binarized: 1.0
|
||||
dataset_splits:
|
||||
- train_prefs
|
||||
- test_prefs
|
||||
preprocessing_num_workers: 12
|
||||
|
||||
# DPOTrainer arguments
|
||||
bf16: true
|
||||
beta: 0.1
|
||||
do_eval: true
|
||||
evaluation_strategy: steps
|
||||
eval_steps: 100
|
||||
gradient_accumulation_steps: 1
|
||||
gradient_checkpointing: true
|
||||
hub_model_id: zephyr-7b-dpo-full
|
||||
learning_rate: 5.0e-7
|
||||
log_level: info
|
||||
logging_steps: 10
|
||||
lr_scheduler_type: linear
|
||||
max_length: 1024
|
||||
max_prompt_length: 512
|
||||
num_train_epochs: 3
|
||||
optim: rmsprop
|
||||
output_dir: data/zephyr-7b-dpo-full
|
||||
per_device_train_batch_size: 8
|
||||
per_device_eval_batch_size: 4
|
||||
push_to_hub: true
|
||||
save_strategy: "no"
|
||||
save_total_limit: null
|
||||
seed: 42
|
||||
warmup_ratio: 0.1
|
||||
@@ -0,0 +1,51 @@
|
||||
# Model arguments
|
||||
model_name_or_path: alignment-handbook/zephyr-7b-sft-lora
|
||||
torch_dtype: auto
|
||||
|
||||
# LoRA arguments
|
||||
use_peft: true
|
||||
lora_r: 64
|
||||
lora_alpha: 16
|
||||
lora_dropout: 0.1
|
||||
lora_target_modules:
|
||||
- q_proj
|
||||
- k_proj
|
||||
- v_proj
|
||||
- o_proj
|
||||
|
||||
# Data training arguments
|
||||
|
||||
dataset_mixer:
|
||||
HuggingFaceH4/ultrafeedback_binarized: 1.0
|
||||
dataset_splits:
|
||||
- train_prefs
|
||||
- test_prefs
|
||||
preprocessing_num_workers: 12
|
||||
|
||||
# DPOTrainer arguments
|
||||
bf16: true
|
||||
beta: 0.1
|
||||
do_eval: true
|
||||
evaluation_strategy: epoch
|
||||
eval_steps: 100
|
||||
gradient_accumulation_steps: 32
|
||||
gradient_checkpointing: true
|
||||
gradient_checkpointing_kwargs:
|
||||
use_reentrant: false
|
||||
hub_model_id: zephyr-7b-dpo-lora
|
||||
learning_rate: 5.0e-7
|
||||
log_level: info
|
||||
logging_steps: 10
|
||||
lr_scheduler_type: linear
|
||||
max_length: 1024
|
||||
max_prompt_length: 512
|
||||
num_train_epochs: 3
|
||||
optim: rmsprop
|
||||
output_dir: data/zephyr-7b-dpo-lora # It is handy to append `hub_model_revision` to keep track of your local experiments
|
||||
per_device_train_batch_size: 2
|
||||
per_device_eval_batch_size: 4
|
||||
push_to_hub: true
|
||||
save_strategy: "no"
|
||||
save_total_limit: null
|
||||
seed: 42
|
||||
warmup_ratio: 0.1
|
||||
@@ -0,0 +1,42 @@
|
||||
# Model arguments
|
||||
model_name_or_path: mistralai/Mistral-7B-v0.1
|
||||
model_revision: main
|
||||
torch_dtype: bfloat16
|
||||
use_flash_attention_2: true
|
||||
|
||||
# Data training arguments
|
||||
dataset_mixer:
|
||||
HuggingFaceH4/ultrachat_200k: 1.0
|
||||
dataset_splits:
|
||||
- train_sft
|
||||
- test_sft
|
||||
preprocessing_num_workers: 12
|
||||
|
||||
# SFT trainer config
|
||||
bf16: true
|
||||
do_eval: true
|
||||
evaluation_strategy: epoch
|
||||
gradient_accumulation_steps: 2
|
||||
gradient_checkpointing: true
|
||||
hub_model_id: zephyr-7b-sft-full
|
||||
hub_strategy: every_save
|
||||
learning_rate: 2.0e-05
|
||||
log_level: info
|
||||
logging_steps: 5
|
||||
logging_strategy: steps
|
||||
lr_scheduler_type: cosine
|
||||
max_seq_length: 2048
|
||||
max_steps: -1
|
||||
num_train_epochs: 1
|
||||
output_dir: data/zephyr-7b-sft-full
|
||||
overwrite_output_dir: true
|
||||
per_device_eval_batch_size: 16
|
||||
per_device_train_batch_size: 32
|
||||
push_to_hub: true
|
||||
remove_unused_columns: true
|
||||
report_to:
|
||||
- tensorboard
|
||||
save_strategy: "no"
|
||||
save_total_limit: null
|
||||
seed: 42
|
||||
tf32: true
|
||||
@@ -0,0 +1,52 @@
|
||||
# Model arguments
|
||||
model_name_or_path: mistralai/Mistral-7B-v0.1
|
||||
torch_dtype: auto
|
||||
use_flash_attention_2: true
|
||||
|
||||
# LoRA arguments
|
||||
use_peft: true
|
||||
lora_r: 64
|
||||
lora_alpha: 16
|
||||
lora_dropout: 0.1
|
||||
lora_target_modules:
|
||||
- q_proj
|
||||
- k_proj
|
||||
- v_proj
|
||||
- o_proj
|
||||
|
||||
# Data training arguments
|
||||
dataset_mixer:
|
||||
HuggingFaceH4/ultrachat_200k: 1.0
|
||||
dataset_splits:
|
||||
- train_sft
|
||||
- test_sft
|
||||
preprocessing_num_workers: 12
|
||||
|
||||
# SFT trainer config
|
||||
bf16: true
|
||||
do_eval: true
|
||||
evaluation_strategy: epoch
|
||||
gradient_accumulation_steps: 128
|
||||
gradient_checkpointing: true
|
||||
gradient_checkpointing_kwargs:
|
||||
use_reentrant: false
|
||||
hub_model_id: zephyr-7b-sft-lora
|
||||
hub_strategy: every_save
|
||||
learning_rate: 2.0e-05
|
||||
log_level: info
|
||||
logging_steps: 5
|
||||
logging_strategy: steps
|
||||
lr_scheduler_type: cosine
|
||||
max_seq_length: 2048
|
||||
max_steps: -1
|
||||
num_train_epochs: 1
|
||||
output_dir: data/zephyr-7b-sft-lora
|
||||
overwrite_output_dir: true
|
||||
per_device_eval_batch_size: 8
|
||||
per_device_train_batch_size: 4
|
||||
push_to_hub: true
|
||||
report_to:
|
||||
- tensorboard
|
||||
save_strategy: "no"
|
||||
save_total_limit: null
|
||||
seed: 42
|
||||
@@ -0,0 +1,105 @@
|
||||
|
||||
# Scripts to Train and Evaluate Chat Models
|
||||
|
||||
## Fine-tuning
|
||||
|
||||
In the handbook, we provide three main ways to align LLMs for chat:
|
||||
|
||||
- Full fine-tuning on a multi-GPU machine with DeepSpeed ZeRO-3 (tested on an 8 x A100 (80GB) node).
|
||||
- LoRA or QLoRA fine-tuning on a single consumer 24GB GPU (tested on a RTX 4090).
|
||||
- LoRA fine-tuning on a multi-GPU machine with DeepSpeed ZeRO-3 (tested on a 2 x A100s (80GB)).
|
||||
|
||||
In practice, we find comparable performance for both full and LoRA fine-tuning, with the latter having the advantage of producing small adapter weights that are fast to upload and download from the Hugging Face Hub. Here's the two general commands to fine-tune your models:
|
||||
|
||||
```shell
|
||||
# Full training with ZeRO-3 on 8 GPUs
|
||||
ACCELERATE_LOG_LEVEL=info accelerate launch --config_file recipes/accelerate_configs/deepspeed_zero3.yaml scripts/run_{task}.py recipes/{model_name}/{task}/config_full.yaml
|
||||
|
||||
# LoRA training on a single GPU
|
||||
ACCELERATE_LOG_LEVEL=info accelerate launch --config_file recipes/accelerate_configs/multi_gpu.yaml --num_processes=1 scripts/run_{task}.py recipes/{model_name}/{task}/config_lora.yaml
|
||||
|
||||
# QLoRA 4-bit training on a single GPU
|
||||
ACCELERATE_LOG_LEVEL=info accelerate launch --config_file recipes/accelerate_configs/multi_gpu.yaml --num_processes=1 scripts/run_{task}.py recipes/{model_name}/{task}/config_lora.yaml --load_in_4bit=true
|
||||
|
||||
# LoRA training with ZeRO-3 on two or more GPUs
|
||||
ACCELERATE_LOG_LEVEL=info accelerate launch --config_file recipes/accelerate_configs/deepspeed_zero3.yaml --num_processes={num_gpus} scripts/run_{task}.py recipes/{model_name}/{task}/config_lora.yaml
|
||||
```
|
||||
|
||||
Here `{task}` refers to type of training you wish to run (SFT, DPO, etc), while `{model_name}` refers to the choice of recipe in the `recipes` directory. For example, to replicate Zephyr-7B-β you can run:
|
||||
|
||||
```shell
|
||||
# Step 1 - train SFT policy
|
||||
ACCELERATE_LOG_LEVEL=info accelerate launch --config_file recipes/accelerate_configs/deepspeed_zero3.yaml scripts/run_sft.py recipes/zephyr-7b-beta/sft/config_full.yaml
|
||||
|
||||
# Step 2 - align with DPO
|
||||
ACCELERATE_LOG_LEVEL=info accelerate launch --config_file recipes/accelerate_configs/deepspeed_zero3.yaml scripts/run_dpo.py recipes/zephyr-7b-beta/dpo/config_full.yaml
|
||||
```
|
||||
|
||||
** 💡 Tip:** If you scale up/down the number of GPUs, we recommend also scaling up the per-device batch size or number of gradient accumulation steps to keep the global batch size constant (and thus replicate our results).
|
||||
|
||||
By default, these scripts will push each model to your Hugging Face Hub username, i.e. `{username}/{model_name}-{task}`. You can override the parameters in each YAML config by appending them to the command as follows:
|
||||
|
||||
```shell
|
||||
# Change batch size, number of epochs etc
|
||||
ACCELERATE_LOG_LEVEL=info accelerate launch --config_file recipes/accelerate_configs/deepspeed_zero3.yaml scripts/run_{task}.py recipes/{model_name}/{task}/config_full.yaml --per_device_train_batch_size=42 --num_train_epochs=5
|
||||
```
|
||||
|
||||
By default all training metrics are logged with TensorBoard. If you have a [Weights and Biases](https://wandb.ai/site) account and are logged in, you can view the training metrics by appending `--report_to=wandb`, e.g.
|
||||
|
||||
```shell
|
||||
ACCELERATE_LOG_LEVEL=info accelerate launch --config_file recipes/accelerate_configs/deepspeed_zero3.yaml scripts/run_{task}.py recipes/{model_name}/{task}/config_full.yaml --report_to=wandb
|
||||
```
|
||||
|
||||
## Launching jobs on a Slurm cluster
|
||||
|
||||
If you have access to a Slurm cluster, we provide a `recipes/launch.slurm` script that will automatically queue training jobs for you. Here's how you can use it:
|
||||
|
||||
```shell
|
||||
sbatch --job-name=handbook_{task} --nodes=1 recipes/launch.slurm {model_name} {task} {precision} {accelerator}
|
||||
```
|
||||
|
||||
Here `{model_name}` and `{task}` are defined as above, while `{precision}` refers to the type of training (full vs LoRA) and `{accelerator}` refers to the choice of 🤗 Accelerate config in `recipes/accelerate_configs`. If you wish to override the default config parameters, you can provide them by appending a space-separated string like `'--arg1=value1 --arg2=value2'. Here's a concrete example to run SFT on 1 node of 8 GPUs:
|
||||
|
||||
```shell
|
||||
# Launch on Slurm and override default hyperparameters
|
||||
sbatch --job-name=handbook_sft --nodes=1 recipes/launch.slurm zephyr-7b-beta sft full deepspeed_zero3 '--per_device_train_batch_size=42 --num_train_epochs=5'
|
||||
```
|
||||
|
||||
You can scale the number of nodes by increasing the `--nodes` flag.
|
||||
|
||||
**⚠️ Note:** the configuration in `recipes/launch.slurm` is optimised for the Hugging Face Compute Cluster and may require tweaking to be adapted to your own compute nodes.
|
||||
|
||||
## Fine-tuning on your datasets
|
||||
|
||||
Under the hood, each training script uses the `get_datasets()` function which allows one to easily combing multiple datasets with varying proportions. For instance, this is how one can specify multiple datasets and which splits to combine in one of the YAML configs:
|
||||
|
||||
```yaml
|
||||
datasets_mixer:
|
||||
dataset_1: 0.5 # Use 50% of the training examples
|
||||
dataset_2: 0.66 # Use 66% of the training examples
|
||||
dataset_3: 0.10 # Use 10% of the training examples
|
||||
dataset_splits:
|
||||
- train_xxx # The training splits to mix
|
||||
- test_xxx # The test splits to mix
|
||||
```
|
||||
|
||||
If you want to fine-tune on your own datasets, the main thing to keep in mind is how the chat templates are applied to the dataset blend. Since each task (SFT, DPO, etc), requires a different format, we assume the datasets have the following columns:
|
||||
|
||||
**SFT**
|
||||
|
||||
* `messages`: A list of `dicts` in the form `{"role": "{role}", "content": {content}}`.
|
||||
* See [ultrachat_200k](https://huggingface.co/datasets/HuggingFaceH4/ultrachat_200k) for an example.
|
||||
|
||||
**DPO**
|
||||
|
||||
* `chosen`: A list of `dicts` in the form `{"role": "{role}", "content": {content}}` corresponding to the preferred dialogue.
|
||||
* `rejected`: A list of `dicts` in the form `{"role": "{role}", "content": {content}}` corresponding to the dispreferred dialogue.
|
||||
* See [ultrafeedback_binarized](https://huggingface.co/datasets/HuggingFaceH4/ultrafeedback_binarized) for an example.
|
||||
|
||||
We also find it useful to include dedicated splits per task in our datasets, so e.g. we have:
|
||||
|
||||
* `{train,test}_sft`: Splits for SFT training.
|
||||
* `{train,test}_gen`: Splits for generation ranking like rejection sampling or PPO.
|
||||
* `{train,test}_prefs`: Splits for preference modelling, like reward modelling or DPO.
|
||||
|
||||
If you format your dataset in the same way, our training scripts should work out of the box!
|
||||
@@ -0,0 +1,224 @@
|
||||
#!/usr/bin/env python
|
||||
# coding=utf-8
|
||||
# Copyright 2023 The HuggingFace Inc. team. All rights reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
import logging
|
||||
import sys
|
||||
|
||||
import torch
|
||||
import transformers
|
||||
from transformers import AutoModelForCausalLM, set_seed
|
||||
|
||||
from accelerate import Accelerator
|
||||
from alignment import (
|
||||
DataArguments,
|
||||
DPOConfig,
|
||||
H4ArgumentParser,
|
||||
ModelArguments,
|
||||
apply_chat_template,
|
||||
get_datasets,
|
||||
get_kbit_device_map,
|
||||
get_peft_config,
|
||||
get_quantization_config,
|
||||
get_tokenizer,
|
||||
is_adapter_model,
|
||||
)
|
||||
from peft import PeftConfig, PeftModel
|
||||
from trl import DPOTrainer
|
||||
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
def main():
|
||||
parser = H4ArgumentParser((ModelArguments, DataArguments, DPOConfig))
|
||||
model_args, data_args, training_args = parser.parse()
|
||||
|
||||
#######
|
||||
# Setup
|
||||
#######
|
||||
logging.basicConfig(
|
||||
format="%(asctime)s - %(levelname)s - %(name)s - %(message)s",
|
||||
datefmt="%Y-%m-%d %H:%M:%S",
|
||||
handlers=[logging.StreamHandler(sys.stdout)],
|
||||
)
|
||||
log_level = training_args.get_process_log_level()
|
||||
logger.setLevel(log_level)
|
||||
transformers.utils.logging.set_verbosity(log_level)
|
||||
transformers.utils.logging.enable_default_handler()
|
||||
transformers.utils.logging.enable_explicit_format()
|
||||
|
||||
# Log on each process the small summary:
|
||||
logger.info(f"Model parameters {model_args}")
|
||||
logger.info(f"Data parameters {data_args}")
|
||||
logger.info(f"Training/evaluation parameters {training_args}")
|
||||
|
||||
# Set seed for reproducibility
|
||||
set_seed(training_args.seed)
|
||||
|
||||
# Increase distributed timeout to 3h to enable push to Hub to complete
|
||||
accelerator = Accelerator()
|
||||
|
||||
###############
|
||||
# Load datasets
|
||||
###############
|
||||
raw_datasets = get_datasets(data_args, splits=data_args.dataset_splits)
|
||||
logger.info(
|
||||
f"Training on the following splits: {[split + ' : ' + str(dset.num_rows) for split, dset in raw_datasets.items()]}"
|
||||
)
|
||||
column_names = list(raw_datasets["train"].features)
|
||||
|
||||
#####################################
|
||||
# Load tokenizer and process datasets
|
||||
#####################################
|
||||
data_args.truncation_side = "left" # Truncate from left to ensure we don't lose labels in final turn
|
||||
tokenizer = get_tokenizer(model_args, data_args)
|
||||
|
||||
#####################
|
||||
# Apply chat template
|
||||
#####################
|
||||
raw_datasets = raw_datasets.map(
|
||||
apply_chat_template,
|
||||
fn_kwargs={"tokenizer": tokenizer, "task": "dpo"},
|
||||
num_proc=data_args.preprocessing_num_workers,
|
||||
remove_columns=column_names,
|
||||
desc="Formatting comparisons with prompt template",
|
||||
)
|
||||
|
||||
# Replace column names with what TRL needs, text_chosen -> chosen and text_rejected -> rejected
|
||||
for split in ["train", "test"]:
|
||||
raw_datasets[split] = raw_datasets[split].rename_columns(
|
||||
{"text_prompt": "prompt", "text_chosen": "chosen", "text_rejected": "rejected"}
|
||||
)
|
||||
|
||||
torch_dtype = (
|
||||
model_args.torch_dtype if model_args.torch_dtype in ["auto", None] else getattr(torch, model_args.torch_dtype)
|
||||
)
|
||||
model_kwargs = dict(
|
||||
revision=model_args.model_revision,
|
||||
trust_remote_code=model_args.trust_remote_code,
|
||||
use_flash_attention_2=model_args.use_flash_attention_2,
|
||||
torch_dtype=torch_dtype,
|
||||
use_cache=False if training_args.gradient_checkpointing else True,
|
||||
device_map=get_kbit_device_map(),
|
||||
quantization_config=get_quantization_config(model_args),
|
||||
)
|
||||
|
||||
model = model_args.model_name_or_path
|
||||
if is_adapter_model(model, model_args.model_revision):
|
||||
# load the model, merge the adapter weights and unload the adapter
|
||||
# Note: to run QLora, you will need to merge the based model separately as the merged model in 16bit
|
||||
logger.info(f"Merging peft adapters for {model_args.model_name_or_path=}")
|
||||
|
||||
peft_config = PeftConfig.from_pretrained(model_args.model_name_or_path, revision=model_args.model_revision)
|
||||
|
||||
model_kwargs = dict(
|
||||
revision=model_args.base_model_revision,
|
||||
trust_remote_code=model_args.trust_remote_code,
|
||||
use_flash_attention_2=model_args.use_flash_attention_2,
|
||||
torch_dtype=torch_dtype,
|
||||
use_cache=False if training_args.gradient_checkpointing else True,
|
||||
)
|
||||
base_model = AutoModelForCausalLM.from_pretrained(
|
||||
peft_config.base_model_name_or_path,
|
||||
**model_kwargs,
|
||||
)
|
||||
model = PeftModel.from_pretrained(
|
||||
base_model, model_args.model_name_or_path, revision=model_args.model_revision
|
||||
)
|
||||
model.eval()
|
||||
model = model.merge_and_unload()
|
||||
model_kwargs = None
|
||||
|
||||
ref_model = model
|
||||
ref_model_kwargs = model_kwargs
|
||||
|
||||
if model_args.use_peft is True:
|
||||
ref_model = None
|
||||
ref_model_kwargs = None
|
||||
|
||||
#########################
|
||||
# Instantiate DPO trainer
|
||||
#########################
|
||||
dpo_trainer = DPOTrainer(
|
||||
model,
|
||||
ref_model,
|
||||
model_init_kwargs=model_kwargs,
|
||||
ref_model_init_kwargs=ref_model_kwargs,
|
||||
args=training_args,
|
||||
beta=training_args.beta,
|
||||
train_dataset=raw_datasets["train"],
|
||||
eval_dataset=raw_datasets["test"],
|
||||
tokenizer=tokenizer,
|
||||
max_length=training_args.max_length,
|
||||
max_prompt_length=training_args.max_prompt_length,
|
||||
peft_config=get_peft_config(model_args),
|
||||
)
|
||||
|
||||
###############
|
||||
# Training loop
|
||||
###############
|
||||
train_result = dpo_trainer.train()
|
||||
metrics = train_result.metrics
|
||||
max_train_samples = (
|
||||
data_args.max_train_samples if data_args.max_train_samples is not None else len(raw_datasets["train"])
|
||||
)
|
||||
metrics["train_samples"] = min(max_train_samples, len(raw_datasets["train"]))
|
||||
dpo_trainer.log_metrics("train", metrics)
|
||||
dpo_trainer.save_metrics("train", metrics)
|
||||
dpo_trainer.save_state()
|
||||
|
||||
logger.info("*** Training complete ***")
|
||||
|
||||
##########
|
||||
# Evaluate
|
||||
##########
|
||||
if training_args.do_eval:
|
||||
logger.info("*** Evaluate ***")
|
||||
metrics = dpo_trainer.evaluate()
|
||||
max_eval_samples = (
|
||||
data_args.max_eval_samples if data_args.max_eval_samples is not None else len(raw_datasets["test"])
|
||||
)
|
||||
metrics["eval_samples"] = min(max_eval_samples, len(raw_datasets["test"]))
|
||||
dpo_trainer.log_metrics("eval", metrics)
|
||||
dpo_trainer.save_metrics("eval", metrics)
|
||||
|
||||
##################################
|
||||
# Save model and create model card
|
||||
##################################
|
||||
dpo_trainer.save_model(training_args.output_dir)
|
||||
# Save everything else on main process
|
||||
if accelerator.is_main_process:
|
||||
kwargs = {
|
||||
"finetuned_from": model_args.model_name_or_path,
|
||||
"dataset": list(data_args.dataset_mixer.keys()),
|
||||
"dataset_tags": list(data_args.dataset_mixer.keys()),
|
||||
"tags": ["alignment-handbook"],
|
||||
}
|
||||
dpo_trainer.create_model_card(**kwargs)
|
||||
# Restore k,v cache for fast inference
|
||||
dpo_trainer.model.config.use_cache = True
|
||||
dpo_trainer.model.config.save_pretrained(training_args.output_dir)
|
||||
if training_args.push_to_hub is True:
|
||||
dpo_trainer.push_to_hub()
|
||||
|
||||
# Ensure we don't timeout on model save / push to Hub
|
||||
logger.info("*** Waiting for all processes to finish ***")
|
||||
accelerator.wait_for_everyone()
|
||||
|
||||
logger.info("*** Run complete! ***")
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
||||
@@ -0,0 +1,191 @@
|
||||
#!/usr/bin/env python
|
||||
# coding=utf-8
|
||||
# Copyright 2023 The HuggingFace Inc. team. All rights reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
"""
|
||||
Supervised fine-tuning script for decoder language models.
|
||||
"""
|
||||
|
||||
import logging
|
||||
import random
|
||||
import sys
|
||||
|
||||
import datasets
|
||||
import torch
|
||||
import transformers
|
||||
from transformers import set_seed
|
||||
|
||||
from accelerate import Accelerator
|
||||
from alignment import (
|
||||
DataArguments,
|
||||
H4ArgumentParser,
|
||||
ModelArguments,
|
||||
SFTConfig,
|
||||
apply_chat_template,
|
||||
get_datasets,
|
||||
get_kbit_device_map,
|
||||
get_peft_config,
|
||||
get_quantization_config,
|
||||
get_tokenizer,
|
||||
)
|
||||
from trl import SFTTrainer
|
||||
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
def main():
|
||||
parser = H4ArgumentParser((ModelArguments, DataArguments, SFTConfig))
|
||||
model_args, data_args, training_args = parser.parse()
|
||||
|
||||
# Set seed for reproducibility
|
||||
set_seed(training_args.seed)
|
||||
|
||||
accelerator = Accelerator()
|
||||
|
||||
###############
|
||||
# Setup logging
|
||||
###############
|
||||
logging.basicConfig(
|
||||
format="%(asctime)s - %(levelname)s - %(name)s - %(message)s",
|
||||
datefmt="%Y-%m-%d %H:%M:%S",
|
||||
handlers=[logging.StreamHandler(sys.stdout)],
|
||||
)
|
||||
log_level = training_args.get_process_log_level()
|
||||
logger.setLevel(log_level)
|
||||
datasets.utils.logging.set_verbosity(log_level)
|
||||
transformers.utils.logging.set_verbosity(log_level)
|
||||
transformers.utils.logging.enable_default_handler()
|
||||
transformers.utils.logging.enable_explicit_format()
|
||||
|
||||
# Log on each process a small summary
|
||||
logger.warning(
|
||||
f"Process rank: {training_args.local_rank}, device: {training_args.device}, n_gpu: {training_args.n_gpu}"
|
||||
+ f" distributed training: {bool(training_args.local_rank != -1)}, 16-bits training: {training_args.fp16}"
|
||||
)
|
||||
logger.info(f"Model parameters {model_args}")
|
||||
logger.info(f"Data parameters {data_args}")
|
||||
logger.info(f"Training/evaluation parameters {training_args}")
|
||||
|
||||
###############
|
||||
# Load datasets
|
||||
###############
|
||||
raw_datasets = get_datasets(data_args, splits=data_args.dataset_splits)
|
||||
logger.info(
|
||||
f"Training on the following datasets and their proportions: {[split + ' : ' + str(dset.num_rows) for split, dset in raw_datasets.items()]}"
|
||||
)
|
||||
|
||||
################
|
||||
# Load tokenizer
|
||||
################
|
||||
tokenizer = get_tokenizer(model_args, data_args)
|
||||
|
||||
#####################
|
||||
# Apply chat template
|
||||
#####################
|
||||
raw_datasets = raw_datasets.map(apply_chat_template, fn_kwargs={"tokenizer": tokenizer, "task": "sft"})
|
||||
train_dataset = raw_datasets["train"]
|
||||
eval_dataset = raw_datasets["test"]
|
||||
|
||||
with training_args.main_process_first(desc="Log a few random samples from the processed training set"):
|
||||
for index in random.sample(range(len(raw_datasets["train"])), 3):
|
||||
logger.info(f"Sample {index} of the processed training set:\n\n{raw_datasets['train'][index]['text']}")
|
||||
|
||||
#######################
|
||||
# Load pretrained model
|
||||
#######################
|
||||
logger.info("*** Load pretrained model ***")
|
||||
torch_dtype = (
|
||||
model_args.torch_dtype if model_args.torch_dtype in ["auto", None] else getattr(torch, model_args.torch_dtype)
|
||||
)
|
||||
|
||||
model_kwargs = dict(
|
||||
revision=model_args.model_revision,
|
||||
trust_remote_code=model_args.trust_remote_code,
|
||||
use_flash_attention_2=model_args.use_flash_attention_2,
|
||||
torch_dtype=torch_dtype,
|
||||
use_cache=False if training_args.gradient_checkpointing else True,
|
||||
device_map=get_kbit_device_map(),
|
||||
quantization_config=get_quantization_config(model_args),
|
||||
)
|
||||
logger.info("*** Model loaded! ***")
|
||||
|
||||
########################
|
||||
# Initialize the Trainer
|
||||
########################
|
||||
trainer = SFTTrainer(
|
||||
model=model_args.model_name_or_path,
|
||||
model_init_kwargs=model_kwargs,
|
||||
args=training_args,
|
||||
train_dataset=train_dataset,
|
||||
eval_dataset=eval_dataset,
|
||||
dataset_text_field="text",
|
||||
max_seq_length=training_args.max_seq_length,
|
||||
tokenizer=tokenizer,
|
||||
packing=True,
|
||||
peft_config=get_peft_config(model_args),
|
||||
)
|
||||
|
||||
###############
|
||||
# Training loop
|
||||
###############
|
||||
logger.info("*** Train ***")
|
||||
train_result = trainer.train()
|
||||
metrics = train_result.metrics
|
||||
max_train_samples = data_args.max_train_samples if data_args.max_train_samples is not None else len(train_dataset)
|
||||
metrics["train_samples"] = min(max_train_samples, len(train_dataset))
|
||||
trainer.log_metrics("train", metrics)
|
||||
trainer.save_metrics("train", metrics)
|
||||
trainer.save_state()
|
||||
|
||||
##########
|
||||
# Evaluate
|
||||
##########
|
||||
if training_args.do_eval:
|
||||
logger.info("*** Evaluate ***")
|
||||
metrics = trainer.evaluate()
|
||||
max_eval_samples = data_args.max_eval_samples if data_args.max_eval_samples is not None else len(eval_dataset)
|
||||
metrics["eval_samples"] = min(max_eval_samples, len(eval_dataset))
|
||||
trainer.log_metrics("eval", metrics)
|
||||
trainer.save_metrics("eval", metrics)
|
||||
|
||||
##################################
|
||||
# Save model and create model card
|
||||
##################################
|
||||
logger.info("*** Save model ***")
|
||||
trainer.save_model(training_args.output_dir)
|
||||
logger.info(f"Model saved to {training_args.output_dir}")
|
||||
|
||||
# Save everything else on main process
|
||||
if accelerator.is_main_process:
|
||||
kwargs = {
|
||||
"finetuned_from": model_args.model_name_or_path,
|
||||
"dataset": list(data_args.dataset_mixer.keys()),
|
||||
"dataset_tags": list(data_args.dataset_mixer.keys()),
|
||||
"tags": ["alignment-handbook"],
|
||||
}
|
||||
trainer.create_model_card(**kwargs)
|
||||
# Restore k,v cache for fast inference
|
||||
trainer.model.config.use_cache = True
|
||||
trainer.model.config.save_pretrained(training_args.output_dir)
|
||||
|
||||
if training_args.push_to_hub is True:
|
||||
logger.info("Pushing to hub...")
|
||||
trainer.push_to_hub()
|
||||
|
||||
accelerator.wait_for_everyone()
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
||||
@@ -42,28 +42,30 @@ if stale_egg_info.exists():
|
||||
# * If a dependency is fast-moving (e.g. transformers), pin to the exact version
|
||||
_deps = [
|
||||
"accelerate==0.23.0",
|
||||
"bitsandbytes==0.41.1",
|
||||
"bitsandbytes==0.41.2.post2",
|
||||
"black==23.1.0",
|
||||
"datasets==2.12.0",
|
||||
"deepspeed==0.9.5",
|
||||
"einops==0.6.1",
|
||||
"datasets==2.14.6",
|
||||
"deepspeed==0.12.2",
|
||||
"einops>=0.6.1",
|
||||
"evaluate==0.4.0",
|
||||
"flake8>=6.0.0",
|
||||
"hf-doc-builder>=0.4.0",
|
||||
"huggingface-hub>=0.14.1,<1.0",
|
||||
"isort>=5.12.0",
|
||||
"ninja==1.11.1",
|
||||
"ninja>=1.11.1",
|
||||
"numpy>=1.24.2",
|
||||
"packaging>=23.0",
|
||||
"parameterized>=0.9.0",
|
||||
"peft==0.5.0",
|
||||
"peft==0.6.1",
|
||||
"protobuf<=3.20.2", # Needed to avoid conflicts with `transformers`
|
||||
"pytest",
|
||||
"safetensors==0.3.3",
|
||||
"safetensors>=0.3.3",
|
||||
"scipy",
|
||||
"tensorboard",
|
||||
"torch==2.0.1",
|
||||
"transformers @ git+https://github.com/huggingface/transformers.git@b3961f7291307ee877ef1a4d057949597d805220",
|
||||
"trl @ git+https://github.com/huggingface/trl.git@1e56ff0f166888973d69cd9d56be60a9f8edfedb", # TODO bump to next release, added for NEFTune
|
||||
"torch==2.1.0",
|
||||
"transformers==4.35.0",
|
||||
"trl==0.7.4",
|
||||
"jinja2>=3.0.0",
|
||||
"tqdm>=4.64.1",
|
||||
]
|
||||
|
||||
@@ -96,12 +98,14 @@ install_requires = [
|
||||
deps["datasets"],
|
||||
deps["deepspeed"],
|
||||
deps["huggingface-hub"],
|
||||
deps["jinja2"],
|
||||
deps["ninja"],
|
||||
deps["numpy"],
|
||||
deps["packaging"], # utilities from PyPA to e.g., compare versions
|
||||
deps["peft"],
|
||||
deps["protobuf"],
|
||||
deps["safetensors"],
|
||||
deps["scipy"],
|
||||
deps["tensorboard"],
|
||||
deps["tqdm"], # progress bars in model download and training scripts
|
||||
deps["transformers"],
|
||||
|
||||
@@ -1 +1,5 @@
|
||||
__version__ = "0.2.0.dev0"
|
||||
|
||||
from .configs import DataArguments, DPOConfig, H4ArgumentParser, ModelArguments, SFTConfig
|
||||
from .data import apply_chat_template, get_datasets
|
||||
from .model_utils import get_kbit_device_map, get_peft_config, get_quantization_config, get_tokenizer, is_adapter_model
|
||||
|
||||
@@ -0,0 +1,272 @@
|
||||
# coding=utf-8
|
||||
# coding=utf-8
|
||||
# Copyright 2023 The HuggingFace Team. All rights reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
import dataclasses
|
||||
import os
|
||||
import sys
|
||||
from dataclasses import dataclass, field
|
||||
from typing import Any, Dict, List, NewType, Optional, Tuple
|
||||
|
||||
import transformers
|
||||
from transformers import MODEL_FOR_CAUSAL_LM_MAPPING, HfArgumentParser
|
||||
|
||||
|
||||
MODEL_CONFIG_CLASSES = list(MODEL_FOR_CAUSAL_LM_MAPPING.keys())
|
||||
MODEL_TYPES = tuple(conf.model_type for conf in MODEL_CONFIG_CLASSES)
|
||||
|
||||
|
||||
DataClassType = NewType("DataClassType", Any)
|
||||
|
||||
|
||||
class H4ArgumentParser(HfArgumentParser):
|
||||
def parse_yaml_and_args(self, yaml_arg: str, other_args: Optional[List[str]] = None) -> List[dataclass]:
|
||||
"""
|
||||
Parse a YAML file and overwrite the default/loaded values with the values provided to the command line.
|
||||
|
||||
Args:
|
||||
yaml_arg (`str`):
|
||||
The path to the config file used
|
||||
other_args (`List[str]`, *optional`):
|
||||
A list of strings to parse as command line arguments, e.g. ['--arg=val', '--arg2=val2'].
|
||||
|
||||
Returns:
|
||||
[`List[dataclass]`]: a list of dataclasses with the values from the YAML file and the command line
|
||||
"""
|
||||
arg_list = self.parse_yaml_file(os.path.abspath(yaml_arg))
|
||||
|
||||
outputs = []
|
||||
# strip other args list into dict of key-value pairs
|
||||
other_args = {arg.split("=")[0].strip("-"): arg.split("=")[1] for arg in other_args}
|
||||
used_args = {}
|
||||
|
||||
# overwrite the default/loaded value with the value provided to the command line
|
||||
# adapted from https://github.com/huggingface/transformers/blob/d0b5002378daabf62769159add3e7d66d3f83c3b/src/transformers/hf_argparser.py#L327
|
||||
for data_yaml, data_class in zip(arg_list, self.dataclass_types):
|
||||
keys = {f.name for f in dataclasses.fields(data_yaml) if f.init}
|
||||
inputs = {k: v for k, v in vars(data_yaml).items() if k in keys}
|
||||
for arg, val in other_args.items():
|
||||
# add only if in keys
|
||||
if arg in keys:
|
||||
base_type = data_yaml.__dataclass_fields__[arg].type
|
||||
inputs[arg] = val
|
||||
|
||||
# cast type for ints, floats (default to strings)
|
||||
if base_type in [int, float]:
|
||||
inputs[arg] = base_type(val)
|
||||
|
||||
if base_type == List[str]:
|
||||
inputs[arg] = [str(v) for v in val.split(",")]
|
||||
|
||||
# bool of a non-empty string is True, so we manually check for bools
|
||||
if base_type == bool:
|
||||
if val in ["true", "True"]:
|
||||
inputs[arg] = True
|
||||
else:
|
||||
inputs[arg] = False
|
||||
|
||||
# add to used-args so we can check if double add
|
||||
if arg not in used_args:
|
||||
used_args[arg] = val
|
||||
else:
|
||||
raise ValueError(f"Duplicate argument provided: {arg}, may cause unexpected behavior")
|
||||
|
||||
obj = data_class(**inputs)
|
||||
outputs.append(obj)
|
||||
|
||||
return outputs
|
||||
|
||||
def parse(self) -> DataClassType | Tuple[DataClassType]:
|
||||
if len(sys.argv) == 2 and sys.argv[1].endswith(".yaml"):
|
||||
# If we pass only one argument to the script and it's the path to a YAML file,
|
||||
# let's parse it to get our arguments.
|
||||
output = self.parse_yaml_file(os.path.abspath(sys.argv[1]))
|
||||
# parse command line args and yaml file
|
||||
elif len(sys.argv) > 2 and sys.argv[1].endswith(".yaml"):
|
||||
output = self.parse_yaml_and_args(os.path.abspath(sys.argv[1]), sys.argv[2:])
|
||||
# parse command line args only
|
||||
else:
|
||||
output = self.parse_args_into_dataclasses()
|
||||
|
||||
if len(output) == 1:
|
||||
output = output[0]
|
||||
return output
|
||||
|
||||
|
||||
@dataclass
|
||||
class ModelArguments:
|
||||
"""
|
||||
Arguments pertaining to which model/config/tokenizer we are going to fine-tune.
|
||||
"""
|
||||
|
||||
base_model_revision: Optional[str] = field(
|
||||
default=None,
|
||||
metadata={"help": ("The base model checkpoint for weights initialization with PEFT adatpers.")},
|
||||
)
|
||||
model_name_or_path: Optional[str] = field(
|
||||
default=None,
|
||||
metadata={
|
||||
"help": (
|
||||
"The model checkpoint for weights initialization. Don't set if you want to train a model from scratch."
|
||||
)
|
||||
},
|
||||
)
|
||||
model_revision: str = field(
|
||||
default="main",
|
||||
metadata={"help": "The specific model version to use (can be a branch name, tag name or commit id)."},
|
||||
)
|
||||
model_code_revision: str = field(default=None, metadata={"help": "The branch of the IFT model"})
|
||||
torch_dtype: Optional[str] = field(
|
||||
default=None,
|
||||
metadata={
|
||||
"help": (
|
||||
"Override the default `torch.dtype` and load the model under this dtype. If `auto` is passed, the "
|
||||
"dtype will be automatically derived from the model's weights."
|
||||
),
|
||||
"choices": ["auto", "bfloat16", "float16", "float32"],
|
||||
},
|
||||
)
|
||||
trust_remote_code: bool = field(default=False, metadata={"help": "Trust remote code when loading a model."})
|
||||
use_flash_attention_2: bool = field(
|
||||
default=False,
|
||||
metadata={
|
||||
"help": (
|
||||
"Whether to use flash attention 2. You must install this manually by running `pip install flash-attn --no-build-isolation`"
|
||||
)
|
||||
},
|
||||
)
|
||||
use_peft: bool = field(
|
||||
default=False,
|
||||
metadata={"help": ("Whether to use PEFT or not for training.")},
|
||||
)
|
||||
lora_r: Optional[int] = field(
|
||||
default=16,
|
||||
metadata={"help": ("LoRA R value.")},
|
||||
)
|
||||
lora_alpha: Optional[int] = field(
|
||||
default=32,
|
||||
metadata={"help": ("LoRA alpha.")},
|
||||
)
|
||||
lora_dropout: Optional[float] = field(
|
||||
default=0.05,
|
||||
metadata={"help": ("LoRA dropout.")},
|
||||
)
|
||||
lora_target_modules: Optional[List[str]] = field(
|
||||
default=None,
|
||||
metadata={"help": ("LoRA target modules.")},
|
||||
)
|
||||
lora_modules_to_save: Optional[List[str]] = field(
|
||||
default=None,
|
||||
metadata={"help": ("Model layers to unfreeze & train")},
|
||||
)
|
||||
load_in_8bit: bool = field(default=False, metadata={"help": "use 8 bit precision"})
|
||||
load_in_4bit: bool = field(default=False, metadata={"help": "use 4 bit precision"})
|
||||
|
||||
bnb_4bit_quant_type: Optional[str] = field(
|
||||
default="nf4", metadata={"help": "precise the quantization type (fp4 or nf4)"}
|
||||
)
|
||||
use_bnb_nested_quant: bool = field(default=False, metadata={"help": "use nested quantization"})
|
||||
|
||||
def __post_init__(self):
|
||||
if self.load_in_8bit and self.load_in_4bit:
|
||||
raise ValueError("You can't use 8 bit and 4 bit precision at the same time")
|
||||
|
||||
|
||||
@dataclass
|
||||
class DataArguments:
|
||||
"""
|
||||
Arguments pertaining to what data we are going to input our model for training and eval.
|
||||
"""
|
||||
|
||||
chat_template: Optional[str] = field(default=None, metadata={"help": "The chat template to use."})
|
||||
dataset_mixer: Optional[Dict[str, float]] = field(
|
||||
default=None,
|
||||
metadata={"help": ("Datasets and their proportions to be used for training ift/rl.")},
|
||||
)
|
||||
dataset_splits: Optional[List[str]] = field(
|
||||
default_factory=lambda: ["train", "test"],
|
||||
metadata={"help": ("List of train test splits to use in the dataset")},
|
||||
)
|
||||
max_train_samples: Optional[int] = field(
|
||||
default=None,
|
||||
metadata={
|
||||
"help": (
|
||||
"For debugging purposes or quicker training, truncate the number of training examples to this "
|
||||
"value if set."
|
||||
)
|
||||
},
|
||||
)
|
||||
max_eval_samples: Optional[int] = field(
|
||||
default=None,
|
||||
metadata={
|
||||
"help": (
|
||||
"For debugging purposes or quicker training, truncate the number of evaluation examples to this "
|
||||
"value if set."
|
||||
)
|
||||
},
|
||||
)
|
||||
preprocessing_num_workers: Optional[int] = field(
|
||||
default=None,
|
||||
metadata={"help": "The number of processes to use for the preprocessing."},
|
||||
)
|
||||
truncation_side: Optional[str] = field(
|
||||
default=None, metadata={"help": "Truncation side to use for the tokenizer."}
|
||||
)
|
||||
|
||||
|
||||
@dataclass
|
||||
class SFTConfig(transformers.TrainingArguments):
|
||||
"""
|
||||
Arguments related to the training process itself. For all parameters, see: https://huggingface.co/docs/transformers/v4.26.1/en/main_classes/trainer#transformers.TrainingArguments
|
||||
"""
|
||||
|
||||
max_seq_length: Optional[int] = field(
|
||||
default=None,
|
||||
metadata={"help": ("Used by TRL for reward model training, which tries to read this parameter in init.")},
|
||||
)
|
||||
logging_first_step: bool = field(
|
||||
default=True,
|
||||
metadata={"help": ("Whether to log and evaluate the first global_step or not.")},
|
||||
)
|
||||
optim: Optional[str] = field(default="adamw_torch")
|
||||
|
||||
|
||||
@dataclass
|
||||
class DPOConfig(transformers.TrainingArguments):
|
||||
"""
|
||||
Arguments related to the DPO training process itself. For all parameters, see: https://huggingface.co/docs/transformers/v4.26.1/en/main_classes/trainer#transformers.TrainingArguments
|
||||
"""
|
||||
|
||||
beta: Optional[float] = field(
|
||||
default=0.1,
|
||||
metadata={"help": "The beta factor in DPO loss. Higher beta means less divergence from the initial policy."},
|
||||
)
|
||||
hub_model_revision: Optional[str] = field(
|
||||
default="main",
|
||||
metadata={"help": ("The Hub model branch to push the model to.")},
|
||||
)
|
||||
logging_first_step: bool = field(
|
||||
default=True,
|
||||
metadata={"help": ("Whether to log and evaluate the first global_step or not.")},
|
||||
)
|
||||
max_prompt_length: Optional[int] = field(
|
||||
default=None,
|
||||
metadata={"help": ("For DPO, the maximum length of the prompt to use for conditioning the model.")},
|
||||
)
|
||||
max_length: Optional[int] = field(
|
||||
default=None,
|
||||
metadata={"help": ("Used by TRL for reward model training, which tries to read this parameter in init.")},
|
||||
)
|
||||
optim: Optional[str] = field(default="rmsprop")
|
||||
remove_unused_columns: bool = field(default=False)
|
||||
@@ -0,0 +1,187 @@
|
||||
# coding=utf-8
|
||||
# coding=utf-8
|
||||
# Copyright 2023 The HuggingFace Team. All rights reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
import re
|
||||
from typing import List, Literal, Optional
|
||||
|
||||
from datasets import DatasetDict, concatenate_datasets, load_dataset
|
||||
|
||||
from .configs import DataArguments
|
||||
|
||||
|
||||
DEFAULT_CHAT_TEMPLATE = "{% for message in messages %}\n{% if message['role'] == 'user' %}\n{{ '<|user|>\n' + message['content'] + eos_token }}\n{% elif message['role'] == 'system' %}\n{{ '<|system|>\n' + message['content'] + eos_token }}\n{% elif message['role'] == 'assistant' %}\n{{ '<|assistant|>\n' + message['content'] + eos_token }}\n{% endif %}\n{% if loop.last and add_generation_prompt %}\n{{ '<|assistant|>' }}\n{% endif %}\n{% endfor %}"
|
||||
|
||||
|
||||
def apply_chat_template(
|
||||
example, tokenizer, task: Literal["sft", "generation", "rm", "dpo"] = "sft", assistant_prefix="<|assistant|>\n"
|
||||
):
|
||||
def _strip_prefix(s, pattern):
|
||||
# Use re.escape to escape any special characters in the pattern
|
||||
return re.sub(f"^{re.escape(pattern)}", "", s)
|
||||
|
||||
if task in ["sft", "generation"]:
|
||||
messages = example["messages"]
|
||||
# We add an empty system message if there is none
|
||||
if messages[0]["role"] != "system":
|
||||
messages.insert(0, {"role": "system", "content": ""})
|
||||
example["text"] = tokenizer.apply_chat_template(
|
||||
messages, tokenize=False, add_generation_prompt=True if task == "generation" else False
|
||||
)
|
||||
elif task == "rm":
|
||||
if all(k in example.keys() for k in ("chosen", "rejected")):
|
||||
chosen_messages = example["chosen"]
|
||||
rejected_messages = example["rejected"]
|
||||
# We add an empty system message if there is none
|
||||
if chosen_messages[0]["role"] != "system":
|
||||
chosen_messages.insert(0, {"role": "system", "content": ""})
|
||||
if rejected_messages[0]["role"] != "system":
|
||||
rejected_messages.insert(0, {"role": "system", "content": ""})
|
||||
example["text_chosen"] = tokenizer.apply_chat_template(chosen_messages, tokenize=False)
|
||||
example["text_rejected"] = tokenizer.apply_chat_template(rejected_messages, tokenize=False)
|
||||
else:
|
||||
raise ValueError(
|
||||
f"Could not format example as dialogue for `rm` task! Require `[chosen, rejected]` keys but found {list(example.keys())}"
|
||||
)
|
||||
elif task == "dpo":
|
||||
if all(k in example.keys() for k in ("chosen", "rejected")):
|
||||
# Compared to reward modeling, we filter out the prompt, so the text is everything after the last assistant token
|
||||
prompt_messages = [[msg for msg in example["chosen"] if msg["role"] == "user"][0]]
|
||||
# Insert system message
|
||||
if example["chosen"][0]["role"] != "system":
|
||||
prompt_messages.insert(0, {"role": "system", "content": ""})
|
||||
else:
|
||||
prompt_messages.insert(0, example["chosen"][0])
|
||||
# TODO: handle case where chosen/rejected also have system messages
|
||||
chosen_messages = example["chosen"][1:]
|
||||
rejected_messages = example["rejected"][1:]
|
||||
example["text_chosen"] = tokenizer.apply_chat_template(chosen_messages, tokenize=False)
|
||||
example["text_rejected"] = tokenizer.apply_chat_template(rejected_messages, tokenize=False)
|
||||
example["text_prompt"] = tokenizer.apply_chat_template(
|
||||
prompt_messages, tokenize=False, add_generation_prompt=True
|
||||
)
|
||||
|
||||
example["text_chosen"] = _strip_prefix(example["text_chosen"], assistant_prefix)
|
||||
example["text_rejected"] = _strip_prefix(example["text_rejected"], assistant_prefix)
|
||||
else:
|
||||
raise ValueError(
|
||||
f"Could not format example as dialogue for `dpo` task! Require `[chosen, rejected]` keys but found {list(example.keys())}"
|
||||
)
|
||||
return example
|
||||
|
||||
|
||||
def get_datasets(
|
||||
data_config: DataArguments | dict,
|
||||
splits: List[str] = ["train", "test"],
|
||||
shuffle: bool = True,
|
||||
) -> DatasetDict:
|
||||
"""
|
||||
Loads one or more datasets with varying training set proportions.
|
||||
|
||||
Args:
|
||||
data_config (`DataArguments` or `dict`):
|
||||
Dataset configuration and split proportions.
|
||||
splits (`List[str]`, *optional*, defaults to `['train', 'test']`):
|
||||
Dataset splits to load and mix. Assumes the splits exist in all datasets and have a `train_` or `test_` prefix.
|
||||
shuffle (`bool`, *optional*, defaults to `True`):
|
||||
Whether to shuffle the training data.
|
||||
|
||||
Returns
|
||||
[`DatasetDict`]: The dataset dictionary containing the loaded datasets.
|
||||
"""
|
||||
|
||||
if type(data_config) is DataArguments:
|
||||
# Structure of the config to read the datasets and their mix
|
||||
# datasets_mixer:
|
||||
# - 'dataset1': 0.5
|
||||
# - 'dataset2': 0.3
|
||||
# - 'dataset3': 0.2
|
||||
dataset_mixer = data_config.dataset_mixer
|
||||
elif type(data_config) is dict:
|
||||
# Structure of the input is:
|
||||
# dataset_mixer = {
|
||||
# "dataset1": 0.5,
|
||||
# "dataset1": 0.3,
|
||||
# "dataset1": 0.2,
|
||||
# }
|
||||
dataset_mixer = data_config
|
||||
else:
|
||||
raise ValueError(f"Data config {data_config} not recognized.")
|
||||
|
||||
raw_datasets = mix_datasets(dataset_mixer, splits=splits, shuffle=shuffle)
|
||||
return raw_datasets
|
||||
|
||||
|
||||
def mix_datasets(dataset_mixer: dict, splits: Optional[List[str]] = None, shuffle=True) -> DatasetDict:
|
||||
"""
|
||||
Loads and mixes datasets according to proportions specified in `dataset_mixer`.
|
||||
|
||||
Args:
|
||||
dataset_mixer (`dict`):
|
||||
Dictionary containing the dataset names and their training proportions. By default, all test proportions are 1.
|
||||
splits (Optional[List[str]], *optional*, defaults to `None`):
|
||||
Dataset splits to load and mix. Assumes the splits exist in all datasets and have a `train_` or `test_` prefix.
|
||||
shuffle (`bool`, *optional*, defaults to `True`):
|
||||
Whether to shuffle the training data.
|
||||
"""
|
||||
raw_datasets = DatasetDict()
|
||||
raw_train_datasets = []
|
||||
raw_val_datasets = []
|
||||
fracs = []
|
||||
for ds, frac in dataset_mixer.items():
|
||||
fracs.append(frac)
|
||||
for split in splits:
|
||||
if "train" in split:
|
||||
raw_train_datasets.append(
|
||||
load_dataset(
|
||||
ds,
|
||||
split=split,
|
||||
)
|
||||
)
|
||||
elif "test" in split:
|
||||
raw_val_datasets.append(
|
||||
load_dataset(
|
||||
ds,
|
||||
split=split,
|
||||
)
|
||||
)
|
||||
else:
|
||||
raise ValueError(f"Split type {split} not recognized as one of test or train.")
|
||||
|
||||
if any(frac < 0 for frac in fracs):
|
||||
raise ValueError("Dataset fractions cannot be negative.")
|
||||
|
||||
if len(raw_train_datasets) > 0:
|
||||
train_subsets = []
|
||||
for dataset, frac in zip(raw_train_datasets, fracs):
|
||||
train_subset = dataset.select(range(int(frac * len(dataset))))
|
||||
train_subsets.append(train_subset)
|
||||
if shuffle:
|
||||
raw_datasets["train"] = concatenate_datasets(train_subsets).shuffle(seed=42)
|
||||
else:
|
||||
raw_datasets["train"] = concatenate_datasets(train_subsets)
|
||||
# No subsampling for test datasets to enable fair comparison across models
|
||||
if len(raw_val_datasets) > 0:
|
||||
if shuffle:
|
||||
raw_datasets["test"] = concatenate_datasets(raw_val_datasets).shuffle(seed=42)
|
||||
else:
|
||||
raw_datasets["test"] = concatenate_datasets(raw_val_datasets)
|
||||
|
||||
if len(raw_datasets) == 0:
|
||||
raise ValueError(
|
||||
f"Dataset {dataset_mixer} not recognized with split {split}. Check the dataset has been correctly formatted."
|
||||
)
|
||||
|
||||
return raw_datasets
|
||||
@@ -0,0 +1,101 @@
|
||||
# coding=utf-8
|
||||
# coding=utf-8
|
||||
# Copyright 2023 The HuggingFace Team. All rights reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
from typing import Dict
|
||||
|
||||
import torch
|
||||
from transformers import AutoTokenizer, BitsAndBytesConfig, PreTrainedTokenizer
|
||||
|
||||
from accelerate import Accelerator
|
||||
from huggingface_hub import list_repo_files
|
||||
from peft import LoraConfig, PeftConfig
|
||||
|
||||
from .configs import DataArguments, ModelArguments
|
||||
from .data import DEFAULT_CHAT_TEMPLATE
|
||||
|
||||
|
||||
def get_current_device() -> int:
|
||||
"""Get the current device. For GPU we return the local process index to enable multiple GPU training."""
|
||||
return Accelerator().local_process_index if torch.cuda.is_available() else "cpu"
|
||||
|
||||
|
||||
def get_kbit_device_map() -> Dict[str, int] | None:
|
||||
"""Useful for running inference with quantized models by setting `device_map=get_peft_device_map()`"""
|
||||
return {"": get_current_device()} if torch.cuda.is_available() else None
|
||||
|
||||
|
||||
def get_quantization_config(model_args) -> BitsAndBytesConfig | None:
|
||||
if model_args.load_in_4bit:
|
||||
quantization_config = BitsAndBytesConfig(
|
||||
load_in_4bit=True,
|
||||
bnb_4bit_compute_dtype=torch.float16, # For consistency with model weights, we use the same value as `torch_dtype` which is float16 for PEFT models
|
||||
bnb_4bit_quant_type=model_args.bnb_4bit_quant_type,
|
||||
bnb_4bit_use_double_quant=model_args.use_bnb_nested_quant,
|
||||
)
|
||||
elif model_args.load_in_8bit:
|
||||
quantization_config = BitsAndBytesConfig(
|
||||
load_in_8bit=True,
|
||||
)
|
||||
else:
|
||||
quantization_config = None
|
||||
|
||||
return quantization_config
|
||||
|
||||
|
||||
def get_tokenizer(model_args: ModelArguments, data_args: DataArguments) -> PreTrainedTokenizer:
|
||||
"""Get the tokenizer for the model."""
|
||||
tokenizer = AutoTokenizer.from_pretrained(
|
||||
model_args.model_name_or_path,
|
||||
revision=model_args.model_revision,
|
||||
)
|
||||
if tokenizer.pad_token_id is None:
|
||||
tokenizer.pad_token_id = tokenizer.eos_token_id
|
||||
|
||||
if data_args.truncation_side is not None:
|
||||
tokenizer.truncation_side = data_args.truncation_side
|
||||
|
||||
# Set reasonable default for models without max length
|
||||
if tokenizer.model_max_length > 100_000:
|
||||
tokenizer.model_max_length = 2048
|
||||
|
||||
if data_args.chat_template is not None:
|
||||
tokenizer.chat_template = data_args.chat_template
|
||||
elif tokenizer.chat_template is None:
|
||||
tokenizer.chat_template = DEFAULT_CHAT_TEMPLATE
|
||||
|
||||
return tokenizer
|
||||
|
||||
|
||||
def get_peft_config(model_args: ModelArguments) -> PeftConfig | None:
|
||||
if model_args.use_peft is False:
|
||||
return None
|
||||
|
||||
peft_config = LoraConfig(
|
||||
r=model_args.lora_r,
|
||||
lora_alpha=model_args.lora_alpha,
|
||||
lora_dropout=model_args.lora_dropout,
|
||||
bias="none",
|
||||
task_type="CAUSAL_LM",
|
||||
target_modules=model_args.lora_target_modules,
|
||||
modules_to_save=model_args.lora_modules_to_save,
|
||||
)
|
||||
|
||||
return peft_config
|
||||
|
||||
|
||||
def is_adapter_model(model_name_or_path: str, revision: str = "main") -> bool:
|
||||
repo_files = list_repo_files(model_name_or_path, revision=revision)
|
||||
return "adapter_model.safetensors" in repo_files or "adapter_model.bin" in repo_files
|
||||
Vendored
+37
@@ -0,0 +1,37 @@
|
||||
# Model arguments
|
||||
model_name_or_path: alignment-handbook/zephyr-7b-sft-full
|
||||
|
||||
# Data training arguments
|
||||
# For definitions, see: src/h4/training/config.py
|
||||
dataset_mixer:
|
||||
HuggingFaceH4/ultrafeedback_binarized: 1.0
|
||||
dataset_splits:
|
||||
- train_prefs
|
||||
- test_prefs
|
||||
preprocessing_num_workers: 12
|
||||
|
||||
# DPOTrainer arguments
|
||||
bf16: true
|
||||
beta: 0.1
|
||||
do_eval: true
|
||||
evaluation_strategy: steps
|
||||
eval_steps: 100
|
||||
gradient_accumulation_steps: 1
|
||||
gradient_checkpointing: true
|
||||
hub_model_id: zephyr-7b-dpo-full
|
||||
learning_rate: 5.0e-7
|
||||
log_level: info
|
||||
logging_steps: 10
|
||||
lr_scheduler_type: linear
|
||||
max_length: 1024
|
||||
max_prompt_length: 512
|
||||
num_train_epochs: 3
|
||||
optim: rmsprop
|
||||
output_dir: data/zephyr-7b-dpo-full
|
||||
per_device_train_batch_size: 8
|
||||
per_device_eval_batch_size: 4
|
||||
push_to_hub: true
|
||||
save_strategy: "no"
|
||||
save_total_limit: null
|
||||
seed: 42
|
||||
warmup_ratio: 0.1
|
||||
Vendored
+41
@@ -0,0 +1,41 @@
|
||||
# Model arguments
|
||||
model_name_or_path: mistralai/Mistral-7B-v0.1
|
||||
model_revision: main
|
||||
torch_dtype: bfloat16
|
||||
use_flash_attention_2: true
|
||||
|
||||
# Data training arguments
|
||||
dataset_mixer:
|
||||
HuggingFaceH4/ultrachat_200k: 1.0
|
||||
dataset_splits:
|
||||
- train_sft
|
||||
- test_sft
|
||||
preprocessing_num_workers: 12
|
||||
|
||||
# SFT trainer config
|
||||
bf16: true
|
||||
do_eval: true
|
||||
evaluation_strategy: epoch
|
||||
gradient_accumulation_steps: 2
|
||||
gradient_checkpointing: true
|
||||
hub_model_id: zephyr-7b-sft-full
|
||||
hub_strategy: every_save
|
||||
learning_rate: 2.0e-05
|
||||
log_level: info
|
||||
logging_steps: 5
|
||||
logging_strategy: steps
|
||||
lr_scheduler_type: cosine
|
||||
max_seq_length: 2048
|
||||
max_steps: -1
|
||||
num_train_epochs: 1
|
||||
output_dir: data/zephyr-7b-sft-full
|
||||
overwrite_output_dir: true
|
||||
per_device_eval_batch_size: 16
|
||||
per_device_train_batch_size: 32
|
||||
push_to_hub: true
|
||||
remove_unused_columns: true
|
||||
report_to:
|
||||
- tensorboard
|
||||
save_strategy: "no"
|
||||
save_total_limit: null
|
||||
seed: 42
|
||||
@@ -0,0 +1,43 @@
|
||||
# Copyright 2023 The HuggingFace Team. All rights reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
import os
|
||||
import unittest
|
||||
|
||||
from alignment import DataArguments, H4ArgumentParser, ModelArguments, SFTConfig
|
||||
|
||||
|
||||
class H4ArgumentParserTest(unittest.TestCase):
|
||||
def setUp(self):
|
||||
self.parser = H4ArgumentParser((ModelArguments, DataArguments, SFTConfig))
|
||||
self.yaml_file_path = "tests/fixtures/config_sft_full.yaml"
|
||||
|
||||
def test_load_yaml(self):
|
||||
model_args, data_args, training_args = self.parser.parse_yaml_file(os.path.abspath(self.yaml_file_path))
|
||||
self.assertEqual(model_args.model_name_or_path, "mistralai/Mistral-7B-v0.1")
|
||||
|
||||
def test_load_yaml_and_args(self):
|
||||
command_line_args = [
|
||||
"--model_name_or_path=test",
|
||||
"--use_peft=true",
|
||||
"--lora_r=16",
|
||||
"--lora_dropout=0.5",
|
||||
]
|
||||
model_args, data_args, training_args = self.parser.parse_yaml_and_args(
|
||||
os.path.abspath(self.yaml_file_path), command_line_args
|
||||
)
|
||||
self.assertEqual(model_args.model_name_or_path, "test")
|
||||
self.assertEqual(model_args.use_peft, True)
|
||||
self.assertEqual(model_args.lora_r, 16)
|
||||
self.assertEqual(model_args.lora_dropout, 0.5)
|
||||
@@ -0,0 +1,148 @@
|
||||
# coding=utf-8
|
||||
# Copyright 2023 The HuggingFace Team. All rights reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
import unittest
|
||||
|
||||
import pytest
|
||||
from datasets import Dataset
|
||||
|
||||
from alignment import DataArguments, ModelArguments, apply_chat_template, get_datasets, get_tokenizer
|
||||
|
||||
|
||||
class GetDatasetsTest(unittest.TestCase):
|
||||
"""Each of these test datasets has 100 examples"""
|
||||
|
||||
def test_loading_data_args(self):
|
||||
dataset_mixer = {
|
||||
"HuggingFaceH4/testing_alpaca_small": 0.5,
|
||||
"HuggingFaceH4/testing_self_instruct_small": 0.3,
|
||||
"HuggingFaceH4/testing_codealpaca_small": 0.2,
|
||||
}
|
||||
data_args = DataArguments(dataset_mixer=dataset_mixer)
|
||||
datasets = get_datasets(data_args)
|
||||
self.assertEqual(len(datasets["train"]), 100)
|
||||
self.assertEqual(len(datasets["test"]), 300)
|
||||
|
||||
def test_loading_data_dict(self):
|
||||
dataset_mixer = {
|
||||
"HuggingFaceH4/testing_alpaca_small": 0.5,
|
||||
"HuggingFaceH4/testing_self_instruct_small": 0.3,
|
||||
"HuggingFaceH4/testing_codealpaca_small": 0.2,
|
||||
}
|
||||
datasets = get_datasets(dataset_mixer)
|
||||
self.assertEqual(len(datasets["train"]), 100)
|
||||
self.assertEqual(len(datasets["test"]), 300)
|
||||
|
||||
def test_loading_with_unit_fractions(self):
|
||||
dataset_mixer = {
|
||||
"HuggingFaceH4/testing_alpaca_small": 1.0,
|
||||
"HuggingFaceH4/testing_self_instruct_small": 1.0,
|
||||
"HuggingFaceH4/testing_codealpaca_small": 1.0,
|
||||
}
|
||||
datasets = get_datasets(dataset_mixer)
|
||||
self.assertEqual(len(datasets["train"]), 300)
|
||||
self.assertEqual(len(datasets["test"]), 300)
|
||||
|
||||
def test_loading_with_fractions_greater_than_unity(self):
|
||||
dataset_mixer = {
|
||||
"HuggingFaceH4/testing_alpaca_small": 0.7,
|
||||
"HuggingFaceH4/testing_self_instruct_small": 0.4,
|
||||
}
|
||||
datasets = get_datasets(dataset_mixer)
|
||||
self.assertEqual(len(datasets["train"]), 70 + 40)
|
||||
self.assertEqual(len(datasets["test"]), 200)
|
||||
|
||||
def test_loading_fails_with_negative_fractions(self):
|
||||
dataset_mixer = {
|
||||
"HuggingFaceH4/testing_alpaca_small": 0.7,
|
||||
"HuggingFaceH4/testing_self_instruct_small": -0.3,
|
||||
}
|
||||
with pytest.raises(ValueError, match=r"Dataset fractions cannot be negative."):
|
||||
get_datasets(dataset_mixer)
|
||||
|
||||
def test_loading_single_split_with_unit_fractions(self):
|
||||
dataset_mixer = {
|
||||
"HuggingFaceH4/testing_alpaca_small": 1.0,
|
||||
}
|
||||
datasets = get_datasets(dataset_mixer, splits=["test"])
|
||||
self.assertEqual(len(datasets["test"]), 100)
|
||||
self.assertRaises(KeyError, lambda: datasets["train"])
|
||||
|
||||
|
||||
class ApplyChatTemplateTest(unittest.TestCase):
|
||||
def setUp(self):
|
||||
model_args = ModelArguments(model_name_or_path="HuggingFaceH4/zephyr-7b-alpha")
|
||||
data_args = DataArguments()
|
||||
self.tokenizer = get_tokenizer(model_args, data_args)
|
||||
self.dataset = Dataset.from_dict(
|
||||
{
|
||||
"prompt": ["Hello!"],
|
||||
"messages": [[{"role": "user", "content": "Hello!"}, {"role": "assistant", "content": "Bonjour!"}]],
|
||||
"chosen": [[{"role": "user", "content": "Hello!"}, {"role": "assistant", "content": "Bonjour!"}]],
|
||||
"rejected": [[{"role": "user", "content": "Hello!"}, {"role": "assistant", "content": "Hola!"}]],
|
||||
}
|
||||
)
|
||||
|
||||
def test_sft(self):
|
||||
dataset = self.dataset.map(
|
||||
apply_chat_template,
|
||||
fn_kwargs={"tokenizer": self.tokenizer, "task": "sft"},
|
||||
remove_columns=self.dataset.column_names,
|
||||
)
|
||||
self.assertDictEqual(
|
||||
dataset[0],
|
||||
{"text": "<|system|>\n</s>\n<|user|>\nHello!</s>\n<|assistant|>\nBonjour!</s>\n"},
|
||||
)
|
||||
|
||||
def test_generation(self):
|
||||
# Remove last turn from messages
|
||||
dataset = self.dataset.map(lambda x: {"messages": x["messages"][:-1]})
|
||||
dataset = dataset.map(
|
||||
apply_chat_template,
|
||||
fn_kwargs={"tokenizer": self.tokenizer, "task": "generation"},
|
||||
remove_columns=self.dataset.column_names,
|
||||
)
|
||||
self.assertDictEqual(
|
||||
dataset[0],
|
||||
{"text": "<|system|>\n</s>\n<|user|>\nHello!</s>\n<|assistant|>\n"},
|
||||
)
|
||||
|
||||
def test_rm(self):
|
||||
dataset = self.dataset.map(
|
||||
apply_chat_template,
|
||||
fn_kwargs={"tokenizer": self.tokenizer, "task": "rm"},
|
||||
remove_columns=self.dataset.column_names,
|
||||
)
|
||||
self.assertDictEqual(
|
||||
dataset[0],
|
||||
{
|
||||
"text_chosen": "<|system|>\n</s>\n<|user|>\nHello!</s>\n<|assistant|>\nBonjour!</s>\n",
|
||||
"text_rejected": "<|system|>\n</s>\n<|user|>\nHello!</s>\n<|assistant|>\nHola!</s>\n",
|
||||
},
|
||||
)
|
||||
|
||||
def test_dpo(self):
|
||||
dataset = self.dataset.map(
|
||||
apply_chat_template,
|
||||
fn_kwargs={"tokenizer": self.tokenizer, "task": "dpo"},
|
||||
remove_columns=self.dataset.column_names,
|
||||
)
|
||||
self.assertDictEqual(
|
||||
dataset[0],
|
||||
{
|
||||
"text_prompt": "<|system|>\n</s>\n<|user|>\nHello!</s>\n<|assistant|>\n",
|
||||
"text_chosen": "Bonjour!</s>\n",
|
||||
"text_rejected": "Hola!</s>\n",
|
||||
},
|
||||
)
|
||||
@@ -0,0 +1,76 @@
|
||||
# coding=utf-8
|
||||
# Copyright 2023 The HuggingFace Team. All rights reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
import unittest
|
||||
|
||||
import torch
|
||||
|
||||
from alignment import DataArguments, ModelArguments, get_peft_config, get_quantization_config, get_tokenizer
|
||||
from alignment.data import DEFAULT_CHAT_TEMPLATE
|
||||
|
||||
|
||||
class GetQuantizationConfigTest(unittest.TestCase):
|
||||
def test_4bit(self):
|
||||
model_args = ModelArguments(load_in_4bit=True)
|
||||
quantization_config = get_quantization_config(model_args)
|
||||
self.assertTrue(quantization_config.load_in_4bit)
|
||||
self.assertEqual(quantization_config.bnb_4bit_compute_dtype, torch.float16)
|
||||
self.assertEqual(quantization_config.bnb_4bit_quant_type, "nf4")
|
||||
self.assertFalse(quantization_config.bnb_4bit_use_double_quant)
|
||||
|
||||
def test_8bit(self):
|
||||
model_args = ModelArguments(load_in_8bit=True)
|
||||
quantization_config = get_quantization_config(model_args)
|
||||
self.assertTrue(quantization_config.load_in_8bit)
|
||||
|
||||
def test_no_quantization(self):
|
||||
model_args = ModelArguments()
|
||||
quantization_config = get_quantization_config(model_args)
|
||||
self.assertIsNone(quantization_config)
|
||||
|
||||
|
||||
class GetTokenizerTest(unittest.TestCase):
|
||||
def setUp(self) -> None:
|
||||
self.model_args = ModelArguments(model_name_or_path="HuggingFaceH4/zephyr-7b-alpha")
|
||||
|
||||
def test_right_truncation_side(self):
|
||||
tokenizer = get_tokenizer(self.model_args, DataArguments(truncation_side="right"))
|
||||
self.assertEqual(tokenizer.truncation_side, "right")
|
||||
|
||||
def test_left_truncation_side(self):
|
||||
tokenizer = get_tokenizer(self.model_args, DataArguments(truncation_side="left"))
|
||||
self.assertEqual(tokenizer.truncation_side, "left")
|
||||
|
||||
def test_default_chat_template(self):
|
||||
tokenizer = get_tokenizer(self.model_args, DataArguments())
|
||||
self.assertEqual(tokenizer.chat_template, DEFAULT_CHAT_TEMPLATE)
|
||||
|
||||
def test_chatml_chat_template(self):
|
||||
chat_template = "{% if not add_generation_prompt is defined %}{% set add_generation_prompt = false %}{% endif %}{% for message in messages %}{{'<|im_start|>' + message['role'] + '\n' + message['content'] + '<|im_end|>' + '\n'}}{% endfor %}{% if add_generation_prompt %}{{ '<|im_start|>assistant\n' }}{% endif %}"
|
||||
tokenizer = get_tokenizer(self.model_args, DataArguments(chat_template=chat_template))
|
||||
self.assertEqual(tokenizer.chat_template, chat_template)
|
||||
|
||||
|
||||
class GetPeftConfigTest(unittest.TestCase):
|
||||
def test_peft_config(self):
|
||||
model_args = ModelArguments(use_peft=True, lora_r=42, lora_alpha=0.66, lora_dropout=0.99)
|
||||
peft_config = get_peft_config(model_args)
|
||||
self.assertEqual(peft_config.r, 42)
|
||||
self.assertEqual(peft_config.lora_alpha, 0.66)
|
||||
self.assertEqual(peft_config.lora_dropout, 0.99)
|
||||
|
||||
def test_no_peft_config(self):
|
||||
model_args = ModelArguments(use_peft=False)
|
||||
peft_config = get_peft_config(model_args)
|
||||
self.assertIsNone(peft_config)
|
||||
Reference in New Issue
Block a user