From ff618a4d13a2c77cf97479fac8af2c576619062a Mon Sep 17 00:00:00 2001 From: lewtun Date: Fri, 1 Mar 2024 17:29:42 +0100 Subject: [PATCH] =?UTF-8?q?=F0=9F=AA=81=20(#129)?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit * Add Gemma 7B recipe * Use Gemma template * Make it work for dolly lol * Enable cahce * Clean up * DPO to the max * DPO, DPO, DPO * Add openhermes * Add custom configs * Add kwargs * Fix config * Bump deps * Move old recipes * Add doc * Add norte * Renable cache * Nuke * Clean * Apply suggestions from code review Co-authored-by: Alvaro Bartolome * Fix isort * Update README.md * Update config_full.yaml --------- Co-authored-by: Alvaro Bartolome Co-authored-by: Philipp Schmid <32632186+philschmid@users.noreply.github.com> --- README.md | 1 + recipes/zephyr-7b-gemma/README.md | 21 +++++++++ recipes/zephyr-7b-gemma/dpo/config_full.yaml | 42 +++++++++++++++++ recipes/zephyr-7b-gemma/sft/config_full.yaml | 48 ++++++++++++++++++++ scripts/run_dpo.py | 20 ++++---- scripts/run_sft.py | 22 ++++----- setup.py | 4 +- src/alignment/configs.py | 11 +++++ src/alignment/data.py | 2 +- src/alignment/model_utils.py | 4 +- 10 files changed, 150 insertions(+), 25 deletions(-) create mode 100644 recipes/zephyr-7b-gemma/README.md create mode 100644 recipes/zephyr-7b-gemma/dpo/config_full.yaml create mode 100644 recipes/zephyr-7b-gemma/sft/config_full.yaml diff --git a/README.md b/README.md index e9810a6..9d65808 100644 --- a/README.md +++ b/README.md @@ -19,6 +19,7 @@ However, we know from the [InstructGPT](https://huggingface.co/papers/2203.02155 The Alignment Handbook aims to fill that gap by providing the community with a series of robust training recipes that span the whole pipeline. ## News 🗞️ +* **March 1, 2024:** We release Zephyr 7B Gemma, which is a new recipe to align Gemma 7B with RLAIF 🔥 * **February 1, 2024:** We release a recipe to align open LLMs with Constitutional AI 📜! See the [recipe](https://github.com/huggingface/alignment-handbook/tree/main/recipes/constitutional-ai) and the [blog post](https://huggingface.co/blog/constitutional_ai) for details. * **January 18, 2024:** We release a suite of evaluations of DPO vs KTO vs IPO, see the [recipe](recipes/pref_align_scan/README.md) and the [blog post](https://huggingface.co/blog/pref-tuning) for details. * **November 10, 2023:** We release all the training code to replicate Zephyr-7b-β 🪁! We also release [No Robots](https://huggingface.co/datasets/HuggingFaceH4/no_robots), a brand new dataset of 10,000 instructions and demonstrations written entirely by skilled human annotators. diff --git a/recipes/zephyr-7b-gemma/README.md b/recipes/zephyr-7b-gemma/README.md new file mode 100644 index 0000000..416462e --- /dev/null +++ b/recipes/zephyr-7b-gemma/README.md @@ -0,0 +1,21 @@ + +# Instructions to Replicate Zephyr 7B Gemma + +Similar to how we trained Zephyr 7B Beta in our [technical report](https://huggingface.co/papers/2310.16944), training this model proceeds in two steps: + +1. Apply SFT to fine-tune Gemma 7B on the Deita 10k dataset ([link](https://huggingface.co/datasets/HuggingFaceH4/deita-10k-v0-sft)). The result is an SFT model like [`zephyr-7b-gemma-sft`](https://huggingface.co/HuggingFaceH4/zephyr-7b-gemma-sft-v0.1). +2. Align the SFT model to AI feedback via DPO on a curated mix of 7k examples by Argilla ([link](https://huggingface.co/datasets/argilla/dpo-mix-7k)). The result is a DPO model like [`zephyr-7b-gemma`](HuggingFaceH4/zephyr-7b-gemma-v0.1). + +See below for commands to train these models using either DeepSpeed ZeRO-3 or LoRA. + +## Full training examples + +You will require 8 GPUs (80GB of VRAM) to train the full model - alternatively, you can train on 1 GPU by adjusting the micro batch size and gradient accumulation steps to keep the global batch size constant. A recipe involving QLoRA will come later 🤗. + +```shell +# Step 1 - SFT +ACCELERATE_LOG_LEVEL=info accelerate launch --config_file recipes/accelerate_configs/deepspeed_zero3.yaml scripts/run_sft.py recipes/zephyr-7b-gemma/sft/config_full.yaml + +# Step 2 - DPO +ACCELERATE_LOG_LEVEL=info accelerate launch --config_file recipes/accelerate_configs/deepspeed_zero3.yaml scripts/run_dpo.py recipes/zephyr-7b-gemma/dpo/config_full.yaml +``` diff --git a/recipes/zephyr-7b-gemma/dpo/config_full.yaml b/recipes/zephyr-7b-gemma/dpo/config_full.yaml new file mode 100644 index 0000000..65b6b96 --- /dev/null +++ b/recipes/zephyr-7b-gemma/dpo/config_full.yaml @@ -0,0 +1,42 @@ +# Model arguments +model_name_or_path: HuggingFaceH4/zephyr-7b-gemma-sft-v0.1 +torch_dtype: bfloat16 + +# Data training arguments +# For definitions, see: src/h4/training/config.py +dataset_mixer: + argilla/dpo-mix-7k: 1.0 +dataset_splits: +- train +- test +preprocessing_num_workers: 12 + +# DPOTrainer arguments +bf16: true +beta: 0.05 +do_eval: true +evaluation_strategy: steps +eval_steps: 100 +gradient_accumulation_steps: 8 +gradient_checkpointing: true +gradient_checkpointing_kwargs: + use_reentrant: False +hub_model_id: zephyr-7b-gemma-dpo +learning_rate: 5.0e-7 +log_level: info +logging_steps: 10 +lr_scheduler_type: cosine +max_length: 1024 +max_prompt_length: 512 +num_train_epochs: 2 +optim: adamw_torch +output_dir: data/zephyr-7b-gemma-dpo +per_device_train_batch_size: 2 +per_device_eval_batch_size: 4 +push_to_hub: true +report_to: +- tensorboard +- wandb +save_strategy: "no" +seed: 42 +warmup_ratio: 0.1 diff --git a/recipes/zephyr-7b-gemma/sft/config_full.yaml b/recipes/zephyr-7b-gemma/sft/config_full.yaml new file mode 100644 index 0000000..5ac239b --- /dev/null +++ b/recipes/zephyr-7b-gemma/sft/config_full.yaml @@ -0,0 +1,48 @@ +# Model arguments +model_name_or_path: google/gemma-7b +model_revision: main +tokenizer_name_or_path: philschmid/gemma-tokenizer-chatml # Custom tokenizer with <|im_start|> and <|im_end|> tokens +torch_dtype: bfloat16 +use_flash_attention_2: true + +# Data training arguments +dataset_mixer: + HuggingFaceH4/deita-10k-v0-sft: 1.0 +dataset_splits: +- train_sft +- test_sft +preprocessing_num_workers: 12 + +# SFT trainer config +bf16: true +dataset_kwargs: + add_special_tokens: false # We already wrap and in the chat template + append_concat_token: false # No need to add across samples +do_eval: true +evaluation_strategy: epoch +gradient_accumulation_steps: 4 +gradient_checkpointing: true +gradient_checkpointing_kwargs: + use_reentrant: false +hub_model_id: zephyr-7b-gemma-sft +hub_strategy: every_save +learning_rate: 2.0e-05 +log_level: info +logging_steps: 5 +logging_strategy: steps +lr_scheduler_type: cosine +max_seq_length: 2048 +max_steps: -1 +num_train_epochs: 3 +output_dir: data/zephyr-7b-gemma-sft +overwrite_output_dir: true +per_device_eval_batch_size: 4 +per_device_train_batch_size: 4 +push_to_hub: true +remove_unused_columns: true +report_to: +- tensorboard +- wandb +save_strategy: "no" +seed: 42 +warmup_ratio: 0.1 \ No newline at end of file diff --git a/scripts/run_dpo.py b/scripts/run_dpo.py index c20b184..ee9453a 100644 --- a/scripts/run_dpo.py +++ b/scripts/run_dpo.py @@ -197,16 +197,6 @@ def main(): logger.info("*** Training complete ***") - ########## - # Evaluate - ########## - if training_args.do_eval: - logger.info("*** Evaluate ***") - metrics = trainer.evaluate() - metrics["eval_samples"] = len(raw_datasets["test"]) - trainer.log_metrics("eval", metrics) - trainer.save_metrics("eval", metrics) - ################################## # Save model and create model card ################################## @@ -227,6 +217,16 @@ def main(): trainer.model.config.use_cache = True trainer.model.config.save_pretrained(training_args.output_dir) + ########## + # Evaluate + ########## + if training_args.do_eval: + logger.info("*** Evaluate ***") + metrics = trainer.evaluate() + metrics["eval_samples"] = len(raw_datasets["test"]) + trainer.log_metrics("eval", metrics) + trainer.save_metrics("eval", metrics) + if training_args.push_to_hub is True: logger.info("Pushing to hub...") trainer.push_to_hub(**kwargs) diff --git a/scripts/run_sft.py b/scripts/run_sft.py index e6fa4e0..ce100b5 100644 --- a/scripts/run_sft.py +++ b/scripts/run_sft.py @@ -134,7 +134,6 @@ def main(): device_map=get_kbit_device_map() if quantization_config is not None else None, quantization_config=quantization_config, ) - logger.info("*** Model loaded! ***") ######################## # Initialize the Trainer @@ -150,6 +149,7 @@ def main(): tokenizer=tokenizer, packing=True, peft_config=get_peft_config(model_args), + dataset_kwargs=training_args.dataset_kwargs, ) ############### @@ -168,16 +168,6 @@ def main(): trainer.save_metrics("train", metrics) trainer.save_state() - ########## - # Evaluate - ########## - if training_args.do_eval: - logger.info("*** Evaluate ***") - metrics = trainer.evaluate() - metrics["eval_samples"] = len(eval_dataset) - trainer.log_metrics("eval", metrics) - trainer.save_metrics("eval", metrics) - ################################## # Save model and create model card ################################## @@ -198,6 +188,16 @@ def main(): trainer.model.config.use_cache = True trainer.model.config.save_pretrained(training_args.output_dir) + ########## + # Evaluate + ########## + if training_args.do_eval: + logger.info("*** Evaluate ***") + metrics = trainer.evaluate() + metrics["eval_samples"] = len(eval_dataset) + trainer.log_metrics("eval", metrics) + trainer.save_metrics("eval", metrics) + if training_args.push_to_hub is True: logger.info("Pushing to hub...") trainer.push_to_hub(**kwargs) diff --git a/setup.py b/setup.py index 28ff9f3..65fe730 100644 --- a/setup.py +++ b/setup.py @@ -41,7 +41,7 @@ if stale_egg_info.exists(): # IMPORTANT: all dependencies should be listed here with their version requirements, if any. # * If a dependency is fast-moving (e.g. transformers), pin to the exact version _deps = [ - "accelerate==0.23.0", + "accelerate==0.27.2", "bitsandbytes==0.41.2.post2", "black==23.1.0", "datasets==2.14.6", @@ -65,7 +65,7 @@ _deps = [ "scipy", "tensorboard", "torch==2.1.2", - "transformers==4.36.2", + "transformers>=4.38.2", # Fixes RoPE computation "trl==0.7.10", "jinja2>=3.0.0", "tqdm>=4.64.1", diff --git a/src/alignment/configs.py b/src/alignment/configs.py index ba96178..4a618ad 100644 --- a/src/alignment/configs.py +++ b/src/alignment/configs.py @@ -136,6 +136,14 @@ class ModelArguments: "choices": ["auto", "bfloat16", "float16", "float32"], }, ) + tokenizer_name_or_path: Optional[str] = field( + default=None, + metadata={ + "help": ( + "The path to the tokenizer. Useful if you want to use a different tokenizer to the one stored in `model_name_or_path`." + ) + }, + ) trust_remote_code: bool = field(default=False, metadata={"help": "Trust remote code when loading a model."}) use_flash_attention_2: bool = field( default=False, @@ -220,6 +228,9 @@ class SFTConfig(transformers.TrainingArguments): Arguments related to the training process itself. For all parameters, see: https://huggingface.co/docs/transformers/v4.26.1/en/main_classes/trainer#transformers.TrainingArguments """ + dataset_kwargs: Optional[Dict[str, Any]] = field( + default=None, metadata={"help": "Dataset kwargs for the SFTTrainer"} + ) max_seq_length: Optional[int] = field( default=None, metadata={"help": ("Used by TRL for reward model training, which tries to read this parameter in init.")}, diff --git a/src/alignment/data.py b/src/alignment/data.py index 16b095f..823a3bb 100644 --- a/src/alignment/data.py +++ b/src/alignment/data.py @@ -34,7 +34,7 @@ def maybe_insert_system_message(messages, tokenizer): chat_template = tokenizer.default_chat_template # confirm the jinja template refers to a system message before inserting - if "system" in chat_template: + if "system" in chat_template or "<|im_start|>" in chat_template: messages.insert(0, {"role": "system", "content": ""}) diff --git a/src/alignment/model_utils.py b/src/alignment/model_utils.py index 7126bdc..dd756c2 100644 --- a/src/alignment/model_utils.py +++ b/src/alignment/model_utils.py @@ -65,7 +65,9 @@ def get_quantization_config(model_args: ModelArguments) -> BitsAndBytesConfig | def get_tokenizer(model_args: ModelArguments, data_args: DataArguments) -> PreTrainedTokenizer: """Get the tokenizer for the model.""" tokenizer = AutoTokenizer.from_pretrained( - model_args.model_name_or_path, + model_args.model_name_or_path + if model_args.tokenizer_name_or_path is None + else model_args.tokenizer_name_or_path, revision=model_args.model_revision, ) if tokenizer.pad_token_id is None: