From e2e8ab945db9ca680e833e169d7d8ba00923cd33 Mon Sep 17 00:00:00 2001 From: Lewis Tunstall Date: Fri, 10 Nov 2023 13:38:45 +0000 Subject: [PATCH] Refactor imports --- recipes/zephyr-7b-beta/dpo/config_lora.yaml | 2 +- recipes/zephyr-7b-beta/sft/config_lora.yaml | 2 +- scripts/README.md | 43 +++++++++++++++++++-- scripts/run_dpo.py | 21 +++++----- src/alignment/__init__.py | 2 +- src/alignment/model_utils.py | 6 ++- 6 files changed, 58 insertions(+), 18 deletions(-) diff --git a/recipes/zephyr-7b-beta/dpo/config_lora.yaml b/recipes/zephyr-7b-beta/dpo/config_lora.yaml index 6d04714..afeb8b4 100644 --- a/recipes/zephyr-7b-beta/dpo/config_lora.yaml +++ b/recipes/zephyr-7b-beta/dpo/config_lora.yaml @@ -31,7 +31,7 @@ eval_steps: 100 gradient_accumulation_steps: 32 gradient_checkpointing: true gradient_checkpointing_kwargs: - use_reentrant: False + use_reentrant: false hub_model_id: zephyr-7b-dpo-lora learning_rate: 5.0e-7 log_level: info diff --git a/recipes/zephyr-7b-beta/sft/config_lora.yaml b/recipes/zephyr-7b-beta/sft/config_lora.yaml index 3eb2d0e..286166a 100644 --- a/recipes/zephyr-7b-beta/sft/config_lora.yaml +++ b/recipes/zephyr-7b-beta/sft/config_lora.yaml @@ -29,7 +29,7 @@ evaluation_strategy: epoch gradient_accumulation_steps: 128 gradient_checkpointing: true gradient_checkpointing_kwargs: - use_reentrant: False + use_reentrant: false hub_model_id: zephyr-7b-sft-lora hub_strategy: every_save learning_rate: 2.0e-05 diff --git a/scripts/README.md b/scripts/README.md index 62c2be9..10dc3fb 100644 --- a/scripts/README.md +++ b/scripts/README.md @@ -1,7 +1,7 @@ -## Scripts to Train and Evaluate Chat Models +# Scripts to Train and Evaluate Chat Models -### Fine-tuning +## Fine-tuning In the handbook, we provide three main ways to align LLMs for chat: @@ -47,7 +47,7 @@ By default all training metrics are logged with TensorBoard. If you have a [Weig ACCELERATE_LOG_LEVEL=info accelerate launch --config_file recipes/accelerate_configs/deepspeed_zero3.yaml scripts/run_{task}.py recipes/{model_name}/{task}/config_full.yaml --report_to=wandb ``` -### Launching jobs on a Slurm cluster +## Launching jobs on a Slurm cluster If you have access to a Slurm cluster, we provide a `recipes/launch.slurm` script that will automatically queue training jobs for you. Here's how you can use it: @@ -63,4 +63,39 @@ sbatch --job-name=handbook_sft --nodes=1 recipes/launch.slurm zephyr-7b-beta sft You can scale the number of nodes by increasing the `--nodes` flag. -**⚠️ Note:** the configuration in `recipes/launch.slurm` is optimised for the Hugging Face Compute Cluster and may require tweaking to be adapted to your own compute nodes. \ No newline at end of file +**⚠️ Note:** the configuration in `recipes/launch.slurm` is optimised for the Hugging Face Compute Cluster and may require tweaking to be adapted to your own compute nodes. + +## Fine-tuning on custom datasets + +Under the hood, each training script uses the `get_datasets()` function which allows one to easily combing multiple datasets with varying proportions. For instance, this is how one can specify multiple datasets and which splits to combine in one of the YAML configs: + +```yaml +datasets_mixer: + dataset_1: 0.5 # Use 50% of the training examples + dataset_2: 0.66 # Use 66% of the training examples + dataset_3: 0.10 # Use 10% of the training examples +dataset_splits: +- train_x # Samples from each train split +- test_x # Test splits aren't sampled +``` + +If you want to fine-tune on your own datasets, the main thing to keep in mind is how the chat templates are applied to the dataset blend. Since each task (SFT, DPO, etc), requires a different format, we assume the datasets have the following columns: + +**SFT** + +* `messages`: A list of `dicts` in the form `{"role": "{role}", "content": {content}}`. +* See [ultrachat_200k](https://huggingface.co/datasets/HuggingFaceH4/ultrachat_200k) for an example. + +**DPO** + +* `chosen`: A list of `dicts` in the form `{"role": "{role}", "content": {content}}` corresponding to the preferred dialogue. +* `rejected`: A list of `dicts` in the form `{"role": "{role}", "content": {content}}` corresponding to the dispreferred dialogue. +* See [ultrafeedback_binarized](https://huggingface.co/datasets/HuggingFaceH4/ultrafeedback_binarized) for an example. + +We also find it useful to include dedicated splits per task in our datasets, so e.g. we have: + +* `{train,test}_sft`: Splits for SFT training. +* `{train,test}_gen`: Splits for generation ranking like rejection sampling or PPO. +* `{train,test}_prefs`: Splits for preference modelling, like reward modelling or DPO. + +If you format your dataset in the same way, our training scripts should work out of the box! \ No newline at end of file diff --git a/scripts/run_dpo.py b/scripts/run_dpo.py index 2a428ac..11f9a2f 100644 --- a/scripts/run_dpo.py +++ b/scripts/run_dpo.py @@ -18,7 +18,7 @@ import sys import torch import transformers -from transformers import set_seed +from transformers import AutoModelForCausalLM, set_seed from accelerate import Accelerator from alignment import ( @@ -32,11 +32,11 @@ from alignment import ( get_peft_config, get_quantization_config, get_tokenizer, + is_adapter_model, ) -from trl import DPOTrainer -from transformers import AutoModelForCausalLM -from alignment.model_utils import is_adapter_model from peft import PeftConfig, PeftModel +from trl import DPOTrainer + logger = logging.getLogger(__name__) @@ -114,15 +114,15 @@ def main(): device_map=get_kbit_device_map(), quantization_config=get_quantization_config(model_args), ) - + model = model_args.model_name_or_path if is_adapter_model(model, model_args.model_revision): # load the model, merge the adapter weights and unload the adapter # Note: to run QLora, you will need to merge the based model separately as the merged model in 16bit logger.info(f"Merging peft adapters for {model_args.model_name_or_path=}") - + peft_config = PeftConfig.from_pretrained(model_args.model_name_or_path, revision=model_args.model_revision) - + model_kwargs = dict( revision=model_args.base_model_revision, trust_remote_code=model_args.trust_remote_code, @@ -131,9 +131,12 @@ def main(): use_cache=False if training_args.gradient_checkpointing else True, ) base_model = AutoModelForCausalLM.from_pretrained( - peft_config.base_model_name_or_path, **model_kwargs, + peft_config.base_model_name_or_path, + **model_kwargs, + ) + model = PeftModel.from_pretrained( + base_model, model_args.model_name_or_path, revision=model_args.model_revision ) - model = PeftModel.from_pretrained(base_model, model_args.model_name_or_path, revision=model_args.model_revision) model.eval() model = model.merge_and_unload() model_kwargs = None diff --git a/src/alignment/__init__.py b/src/alignment/__init__.py index 3080b6a..17f4767 100644 --- a/src/alignment/__init__.py +++ b/src/alignment/__init__.py @@ -2,4 +2,4 @@ __version__ = "0.2.0.dev0" from .configs import DataArguments, DPOConfig, H4ArgumentParser, ModelArguments, SFTConfig from .data import apply_chat_template, get_datasets -from .model_utils import get_kbit_device_map, get_peft_config, get_quantization_config, get_tokenizer +from .model_utils import get_kbit_device_map, get_peft_config, get_quantization_config, get_tokenizer, is_adapter_model diff --git a/src/alignment/model_utils.py b/src/alignment/model_utils.py index d35e037..cbaad69 100644 --- a/src/alignment/model_utils.py +++ b/src/alignment/model_utils.py @@ -4,8 +4,9 @@ import torch from transformers import AutoTokenizer, BitsAndBytesConfig, PreTrainedTokenizer from accelerate import Accelerator -from peft import LoraConfig, PeftConfig from huggingface_hub import list_repo_files +from peft import LoraConfig, PeftConfig + from .configs import DataArguments, ModelArguments from .data import DEFAULT_CHAT_TEMPLATE @@ -78,6 +79,7 @@ def get_peft_config(model_args: ModelArguments) -> PeftConfig | None: return peft_config + def is_adapter_model(model_name_or_path: str, revision: str = "main") -> bool: repo_files = list_repo_files(model_name_or_path, revision=revision) - return "adapter_model.safetensors" in repo_files or "adapter_model.bin" in repo_files \ No newline at end of file + return "adapter_model.safetensors" in repo_files or "adapter_model.bin" in repo_files