* Add StarChat2

* Add DPO

* Fix unit test

* Typos

* Typo
This commit is contained in:
lewtun
2024-03-12 17:22:21 +01:00
committed by GitHub
parent ff618a4d13
commit a9b8a50a27
14 changed files with 324 additions and 29 deletions
+1
View File
@@ -19,6 +19,7 @@ However, we know from the [InstructGPT](https://huggingface.co/papers/2203.02155
The Alignment Handbook aims to fill that gap by providing the community with a series of robust training recipes that span the whole pipeline.
## News 🗞️
* **March 12, 2024:** We release StarChat2 15B, along with the recipe to train capable coding assistants 🌟
* **March 1, 2024:** We release Zephyr 7B Gemma, which is a new recipe to align Gemma 7B with RLAIF 🔥
* **February 1, 2024:** We release a recipe to align open LLMs with Constitutional AI 📜! See the [recipe](https://github.com/huggingface/alignment-handbook/tree/main/recipes/constitutional-ai) and the [blog post](https://huggingface.co/blog/constitutional_ai) for details.
* **January 18, 2024:** We release a suite of evaluations of DPO vs KTO vs IPO, see the [recipe](recipes/pref_align_scan/README.md) and the [blog post](https://huggingface.co/blog/pref-tuning) for details.
+21
View File
@@ -0,0 +1,21 @@
# Instructions to StarChat2
Similar to how we trained Zephyr 7B Beta in our [technical report](https://huggingface.co/papers/2310.16944), training this model proceeds in two steps:
1. Apply SFT to fine-tune [StarCoder2 15B](https://huggingface.co/bigcode/starcoder2-15b) on a blend of chat, code, and math datastets. The result is an SFT model like [`starchat2-15b-sft-v0.1`](https://huggingface.co/HuggingFaceH4/starchat2-15b-sft-v0.1).
2. Align the SFT model to AI feedback via DPO on the UltraFeedback and Orca DPO Pairs datasets. The result is a DPO model like [`starchat2-15b-v0.1`](https://huggingface.co/HuggingFaceH4/starchat2-15b-v0.1).
See below for commands to train these models using DeepSpeed ZeRO-3.
## Full training examples
You will require 8 GPUs (80GB of VRAM) to train the full model - alternatively, you can train on 1 GPU by adjusting `per_device_train_batch_size` and `gradient_accumulation_steps` to keep the global batch size constant. A recipe involving QLoRA will come later 🤗.
```shell
# Step 1 - SFT
ACCELERATE_LOG_LEVEL=info accelerate launch --config_file recipes/accelerate_configs/deepspeed_zero3.yaml scripts/run_sft.py recipes/starchat2-15b/sft/config_v0.1.yaml
# Step 2 - DPO
ACCELERATE_LOG_LEVEL=info accelerate launch --config_file recipes/accelerate_configs/deepspeed_zero3.yaml scripts/run_dpo.py recipes/starchat2-15b/dpo/config_v0.1.yaml
```
@@ -0,0 +1,43 @@
# Model arguments
model_name_or_path: HuggingFaceH4/starchat2-15b-sft-v0.1
torch_dtype: bfloat16
# Data training arguments
# For definitions, see: src/h4/training/config.py
dataset_mixer:
HuggingFaceH4/ultrafeedback_binarized: 1.0
HuggingFaceH4/orca_dpo_pairs: 1.0
dataset_splits:
- train_prefs
- test_prefs
preprocessing_num_workers: 12
# DPOTrainer arguments
bf16: true
beta: 0.05
do_eval: true
evaluation_strategy: steps
eval_steps: 100
gradient_accumulation_steps: 8
gradient_checkpointing: true
gradient_checkpointing_kwargs:
use_reentrant: False
hub_model_id: starchat2-15b-dpo-v0.1
learning_rate: 5.0e-7
log_level: info
logging_steps: 10
lr_scheduler_type: cosine
max_length: 1024
max_prompt_length: 512
num_train_epochs: 2
optim: adamw_torch
output_dir: data/starchat2-15b-dpo-v0.1
per_device_train_batch_size: 2
per_device_eval_batch_size: 4
push_to_hub: true
report_to:
- tensorboard
- wandb
save_strategy: "no"
seed: 42
warmup_ratio: 0.1
@@ -0,0 +1,49 @@
# Model arguments
model_name_or_path: bigcode/starcoder2-15b
model_revision: main
torch_dtype: bfloat16
use_flash_attention_2: true
# Data training arguments
chat_template: "{% if not add_generation_prompt is defined %}{% set add_generation_prompt = false %}{% endif %}{% for message in messages %}{{'<|im_start|>' + message['role'] + '\n' + message['content'] + '<|im_end|>' + '\n'}}{% endfor %}{% if add_generation_prompt %}{{ '<|im_start|>assistant\n' }}{% endif %}"
dataset_mixer:
HuggingFaceH4/airoboros-3.2: 1.0
HuggingFaceH4/Code-Feedback: 1.0
HuggingFaceH4/orca-math-word-problems-200k: 1.0
HuggingFaceH4/SystemChat: 1.0
HuggingFaceH4/capybara: 1.0
dataset_splits:
- train_sft
- test_sft
preprocessing_num_workers: 24
# SFT trainer config
bf16: true
do_eval: true
evaluation_strategy: epoch
gradient_accumulation_steps: 2
gradient_checkpointing: true
gradient_checkpointing_kwargs:
use_reentrant: false
hub_model_id: starchat2-15b-v0.1
hub_strategy: every_save
learning_rate: 2.0e-05
log_level: info
logging_steps: 5
logging_strategy: steps
lr_scheduler_type: cosine
max_seq_length: 2048
max_steps: -1
num_train_epochs: 3
output_dir: data/starchat2-15b-v0.1
overwrite_output_dir: true
per_device_eval_batch_size: 8
per_device_train_batch_size: 8
push_to_hub: true
remove_unused_columns: true
report_to:
- tensorboard
- wandb
save_strategy: "no"
seed: 42
warmup_ratio: 0.1
+1 -1
View File
@@ -39,4 +39,4 @@ report_to:
- wandb
save_strategy: "no"
seed: 42
warmup_ratio: 0.1
warmup_ratio: 0.1
+1 -1
View File
@@ -28,7 +28,7 @@ hub_model_id: zephyr-7b-gemma-sft
hub_strategy: every_save
learning_rate: 2.0e-05
log_level: info
logging_steps: 5
logging_steps: 5
logging_strategy: steps
lr_scheduler_type: cosine
max_seq_length: 2048
+18
View File
@@ -27,6 +27,7 @@ from alignment import (
H4ArgumentParser,
ModelArguments,
apply_chat_template,
decontaminate_humaneval,
get_checkpoint,
get_datasets,
get_kbit_device_map,
@@ -103,6 +104,23 @@ def main():
desc="Formatting comparisons with prompt template",
)
##########################
# Decontaminate benchmarks
##########################
num_raw_train_samples = len(raw_datasets["train"])
raw_datasets = raw_datasets.filter(
decontaminate_humaneval,
fn_kwargs={"text_column": "text_chosen"},
batched=True,
batch_size=10_000,
num_proc=1,
desc="Decontaminating HumanEval samples",
)
num_filtered_train_samples = num_raw_train_samples - len(raw_datasets["train"])
logger.info(
f"Decontaminated {num_filtered_train_samples} ({num_filtered_train_samples/num_raw_train_samples * 100:.2f}%) samples from the training set."
)
# Replace column names with what TRL needs, text_chosen -> chosen and text_rejected -> rejected
for split in ["train", "test"]:
raw_datasets[split] = raw_datasets[split].rename_columns(
+43 -24
View File
@@ -24,7 +24,7 @@ import sys
import datasets
import torch
import transformers
from transformers import set_seed
from transformers import AutoModelForCausalLM, set_seed
from alignment import (
DataArguments,
@@ -32,6 +32,7 @@ from alignment import (
ModelArguments,
SFTConfig,
apply_chat_template,
decontaminate_humaneval,
get_checkpoint,
get_datasets,
get_kbit_device_map,
@@ -39,7 +40,7 @@ from alignment import (
get_quantization_config,
get_tokenizer,
)
from trl import SFTTrainer
from trl import SFTTrainer, setup_chat_format
logger = logging.getLogger(__name__)
@@ -95,27 +96,6 @@ def main():
################
tokenizer = get_tokenizer(model_args, data_args)
#####################
# Apply chat template
#####################
raw_datasets = raw_datasets.map(
apply_chat_template,
fn_kwargs={
"tokenizer": tokenizer,
"task": "sft",
"auto_insert_empty_system_msg": data_args.auto_insert_empty_system_msg,
},
num_proc=data_args.preprocessing_num_workers,
remove_columns=column_names,
desc="Applying chat template",
)
train_dataset = raw_datasets["train"]
eval_dataset = raw_datasets["test"]
with training_args.main_process_first(desc="Log a few random samples from the processed training set"):
for index in random.sample(range(len(raw_datasets["train"])), 3):
logger.info(f"Sample {index} of the processed training set:\n\n{raw_datasets['train'][index]['text']}")
#######################
# Load pretrained model
#######################
@@ -135,11 +115,50 @@ def main():
quantization_config=quantization_config,
)
model = model_args.model_name_or_path
# For ChatML we need to add special tokens and resize the embedding layer
if "<|im_start|>" in tokenizer.chat_template:
model = AutoModelForCausalLM.from_pretrained(model_args.model_name_or_path, **model_kwargs)
model, tokenizer = setup_chat_format(model, tokenizer)
model_kwargs = None
#####################
# Apply chat template
#####################
raw_datasets = raw_datasets.map(
apply_chat_template,
fn_kwargs={
"tokenizer": tokenizer,
"task": "sft",
"auto_insert_empty_system_msg": data_args.auto_insert_empty_system_msg,
},
num_proc=data_args.preprocessing_num_workers,
remove_columns=column_names,
desc="Applying chat template",
)
##########################
# Decontaminate benchmarks
##########################
num_raw_train_samples = len(raw_datasets["train"])
raw_datasets = raw_datasets.filter(decontaminate_humaneval, batched=True, batch_size=10_000, num_proc=1)
num_filtered_train_samples = num_raw_train_samples - len(raw_datasets["train"])
logger.info(
f"Decontaminated {num_filtered_train_samples} ({num_filtered_train_samples/num_raw_train_samples * 100:.2f}%) samples from the training set."
)
train_dataset = raw_datasets["train"]
eval_dataset = raw_datasets["test"]
with training_args.main_process_first(desc="Log a few random samples from the processed training set"):
for index in random.sample(range(len(raw_datasets["train"])), 3):
logger.info(f"Sample {index} of the processed training set:\n\n{raw_datasets['train'][index]['text']}")
########################
# Initialize the Trainer
########################
trainer = SFTTrainer(
model=model_args.model_name_or_path,
model=model,
model_init_kwargs=model_kwargs,
args=training_args,
train_dataset=train_dataset,
+1 -1
View File
@@ -65,7 +65,7 @@ _deps = [
"scipy",
"tensorboard",
"torch==2.1.2",
"transformers>=4.38.2", # Fixes RoPE computation
"transformers @ git+https://github.com/huggingface/transformers.git@831bc25d8fdb85768402f772cf65cc3d7872b211", # Enable StarCoder2
"trl==0.7.10",
"jinja2>=3.0.0",
"tqdm>=4.64.1",
+1
View File
@@ -2,6 +2,7 @@ __version__ = "0.3.0.dev0"
from .configs import DataArguments, DPOConfig, H4ArgumentParser, ModelArguments, SFTConfig
from .data import apply_chat_template, get_datasets
from .decontaminate import decontaminate_humaneval
from .model_utils import (
get_checkpoint,
get_kbit_device_map,
+4
View File
@@ -23,6 +23,8 @@ from .configs import DataArguments
DEFAULT_CHAT_TEMPLATE = "{% for message in messages %}\n{% if message['role'] == 'user' %}\n{{ '<|user|>\n' + message['content'] + eos_token }}\n{% elif message['role'] == 'system' %}\n{{ '<|system|>\n' + message['content'] + eos_token }}\n{% elif message['role'] == 'assistant' %}\n{{ '<|assistant|>\n' + message['content'] + eos_token }}\n{% endif %}\n{% if loop.last and add_generation_prompt %}\n{{ '<|assistant|>' }}\n{% endif %}\n{% endfor %}"
COLUMNS_TO_KEEP = ["messages", "chosen", "rejected", "prompt", "completion", "label"]
def maybe_insert_system_message(messages, tokenizer):
if messages[0]["role"] == "system":
@@ -161,6 +163,8 @@ def mix_datasets(dataset_mixer: dict, splits: Optional[List[str]] = None, shuffl
# If not, check local dataset
dataset = load_from_disk(os.path.join(ds, split))
# Remove redundant columns to avoid schema conflicts on load
dataset = dataset.remove_columns([col for col in dataset.column_names if col not in COLUMNS_TO_KEEP])
if "train" in split:
raw_train_datasets.append(dataset)
elif "test" in split:
+91
View File
@@ -0,0 +1,91 @@
# coding=utf-8
# Copyright 2024 The HuggingFace Inc. team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from typing import Any, Dict, List
from datasets import load_dataset
# HumanEval solutions that are considered simple/generic enough to be kept in the training dataset
HUMAN_EVAL_STRINGS_OK = ["return x + y", "return len(string)", "return n**2", "return " ".join(strings)"]
def extract_docstring(prompt: str) -> str:
if '"""' in prompt:
if prompt.count('"""') == 2:
return prompt.split('"""')[1].strip()
elif prompt.count('"""') == 4:
return prompt.split('"""')[3].strip()
else:
raise ValueError()
elif "'''" in prompt:
assert prompt.count("'''") == 2
return prompt.split("'''")[1].strip()
else:
raise ValueError()
def human_eval_docstrings() -> List[str]:
ds = load_dataset("openai_humaneval", split="test")
docstrings = [extract_docstring(v["prompt"]) for v in ds]
return docstrings
def load_dataset_column(dataset: str, column: str, split: str, name=None) -> List[str]:
ds = load_dataset(dataset, split=split, name=name)
res = [sample[column].strip() for sample in ds]
# Only return non-empty strings
return [sample for sample in res if len(sample) > 0]
FILTER_OUT = {
"human_eval_docstrings": human_eval_docstrings(),
"human_eval_solutions": [
s
for s in load_dataset_column("openai_humaneval", "canonical_solution", "test")
if s not in HUMAN_EVAL_STRINGS_OK
],
}
def normalize_whitespace(text: str) -> str:
return " ".join(text.split())
def decontaminate_humaneval(
samples: List[Dict[str, Any]], text_column: str = "text", filter_out: Dict[str, List[str]] = FILTER_OUT
) -> List[Dict[str, Any]]:
"""
filter_out: Dict[str, List[str]] mapping from benchmark name to list of strings that need to be
filtered-out.
Return a list where each element is True if the corresponding file should be included in the dataset.
Otherwise, the element is False.
"""
output = []
for content in samples[text_column]:
content = normalize_whitespace(content.lower())
matched = False
for _, substrings in filter_out.items():
for substring in substrings:
if normalize_whitespace(substring.lower()) in content:
matched = True
break
if matched:
break
# we keep files that are not matched
output.append(not matched)
return output
+48
View File
@@ -0,0 +1,48 @@
# coding=utf-8
# Copyright 2024 The HuggingFace Team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from unittest import TestCase
from datasets import Dataset
from transformers import AutoTokenizer
from alignment import apply_chat_template, decontaminate_humaneval
class DecontamintateHumanEvalTest(TestCase):
"""Test we decontaminate HumanEval samples correctly"""
def setUp(self) -> None:
# Create a dataset with a HumanEval sample wrapped in some fake text
dataset = Dataset.from_dict(
{
"messages": [
[{"content": "Hello", "role": "user"}],
[
{
"content": 'Hello, I am\nfrom\n\n typing import List\n\n\ndef has_close_elements(numbers: List[float], threshold: float) -> bool:\n """ Check if in given list of numbers, are any two numbers closer to each other than\n given threshold.\n >>> has_close_elements([1.0, 2.0, 3.0], 0.5)\n False\n >>> has_close_elements([1.0, 2.8, 3.0, 4.0, 5.0, 2.0], 0.3)\n True\n """\n',
"role": "assistant",
}
],
]
}
)
tokenizer = AutoTokenizer.from_pretrained("HuggingFaceH4/zephyr-7b-beta")
self.dataset = dataset.map(apply_chat_template, fn_kwargs={"tokenizer": tokenizer, "task": "sft"})
def test_decontamination(self):
"""Test we decontaminate HumanEval samples correctly"""
decontaminated_dataset = self.dataset.filter(decontaminate_humaneval, batched=True)
# Check we recover just the first message
self.assertEqual(decontaminated_dataset[0]["text"], self.dataset[0]["text"])
+2 -2
View File
@@ -70,8 +70,8 @@ class GetTokenizerTest(unittest.TestCase):
`default_chat_template` but no `chat_template` we do not set a `chat_template`,
and that we do not change `default_chat_template`
"""
model_args = ModelArguments(model_name_or_path="codellama/CodeLlama-7b-Instruct-hf")
base_tokenizer = AutoTokenizer.from_pretrained("codellama/CodeLlama-7b-Instruct-hf")
model_args = ModelArguments(model_name_or_path="m-a-p/OpenCodeInterpreter-SC2-7B")
base_tokenizer = AutoTokenizer.from_pretrained("m-a-p/OpenCodeInterpreter-SC2-7B")
processed_tokenizer = get_tokenizer(model_args, DataArguments())
assert getattr(processed_tokenizer, "chat_template") is None