Add skeleton

This commit is contained in:
Lewis Tunstall
2023-11-08 13:21:57 +00:00
parent b9d9aa0a29
commit 967eab4cfb
13 changed files with 971 additions and 12 deletions
+5 -5
View File
@@ -3,7 +3,7 @@
# make sure to test the local checkout in scripts and not the pre-installed one (don't use quotes!)
export PYTHONPATH = src
check_dirs := src tests
check_dirs := src tests scripts
style:
python -m black --line-length 119 --target-version py310 $(check_dirs) setup.py
@@ -18,16 +18,16 @@ quality:
# Release stuff
pre-release:
python src/alignment/utils/release.py
python src/alignment/release.py
pre-patch:
python src/alignment/utils/release.py --patch
python src/alignment/release.py --patch
post-release:
python src/alignment/utils/release.py --post_release
python src/alignment/release.py --post_release
post-patch:
python src/alignment/utils/release.py --post_release --patch
python src/alignment/release.py --post_release --patch
wheels:
python setup.py bdist_wheel && python setup.py sdist
View File
View File
View File
View File
+235
View File
@@ -0,0 +1,235 @@
#!/usr/bin/env python
# coding=utf-8
# Copyright 2023 The HuggingFace Inc. team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import logging
import random
import subprocess
import sys
from datetime import timedelta
import torch
import transformers
from transformers import set_seed
import wandb
from accelerate import Accelerator, InitProcessGroupKwargs
from h4.data import get_datasets
from h4.training import DataArguments, DPOTrainingArguments, ModelArguments, init_wandb_training
from h4.utils import (
H4ArgumentParser,
apply_chat_template,
convert_to_safetensors,
get_kbit_device_map,
get_peft_config,
get_quantization_config,
get_tokenizer,
hf_login,
is_slurm_available,
push_to_hub_revision,
run_mt_bench_job,
)
from trl import DPOTrainer
logger = logging.getLogger(__name__)
def main():
parser = H4ArgumentParser((ModelArguments, DataArguments, DPOTrainingArguments))
model_args, data_args, training_args = parser.parse()
#######
# Setup
#######
logging.basicConfig(
format="%(asctime)s - %(levelname)s - %(name)s - %(message)s",
datefmt="%Y-%m-%d %H:%M:%S",
handlers=[logging.StreamHandler(sys.stdout)],
)
log_level = training_args.get_process_log_level()
logger.setLevel(log_level)
transformers.utils.logging.set_verbosity(log_level)
transformers.utils.logging.enable_default_handler()
transformers.utils.logging.enable_explicit_format()
# Log on each process the small summary:
logger.info(f"Model parameters {model_args}")
logger.info(f"Data parameters {data_args}")
logger.info(f"Training/evaluation parameters {training_args}")
# Setup WandB
if training_args.wandb_enabled:
init_wandb_training(training_args)
# Login to HuggingFace Hub if needed
hf_login()
# Set seed for reproducibility
set_seed(training_args.seed)
# Increase distributed timeout to 3h to enable push to Hub to complete
accelerator = Accelerator(kwargs_handlers=[InitProcessGroupKwargs(timeout=timedelta(seconds=6 * 1800))])
###############
# Load datasets
###############
raw_datasets = get_datasets(data_args, splits=data_args.dataset_splits)
logger.info(
f"Training on the following splits: {[split + ' : ' + str(dset.num_rows) for split, dset in raw_datasets.items()]}"
)
column_names = list(raw_datasets["train"].features)
#####################################
# Load tokenizer and process datasets
#####################################
data_args.truncation_side = "left" # Truncate from left to ensure we don't lose labels in final turn
tokenizer = get_tokenizer(model_args, data_args)
#####################
# Apply chat template
#####################
raw_datasets = raw_datasets.map(
apply_chat_template,
fn_kwargs={"tokenizer": tokenizer, "task": "dpo"},
num_proc=data_args.preprocessing_num_workers,
remove_columns=column_names,
desc="Formatting comparisons with prompt template",
)
# Replace column names with what TRL needs, text_chosen -> chosen and text_rejected -> rejected
for split in ["train", "test"]:
raw_datasets[split] = raw_datasets[split].rename_columns(
{"text_prompt": "prompt", "text_chosen": "chosen", "text_rejected": "rejected"}
)
# Log a few random samples from the training set:
for index in random.sample(range(len(raw_datasets["train"])), 3):
logger.info(f"Prompt sample {index} of the raw training set:\n\n{raw_datasets['train'][index]['prompt']}")
logger.info(f"Chosen sample {index} of the raw training set:\n\n{raw_datasets['train'][index]['chosen']}")
logger.info(f"Rejected sample {index} of the raw training set:\n\n{raw_datasets['train'][index]['rejected']}")
torch_dtype = (
model_args.torch_dtype if model_args.torch_dtype in ["auto", None] else getattr(torch, model_args.torch_dtype)
)
model_kwargs = dict(
revision=model_args.model_revision,
trust_remote_code=model_args.trust_remote_code,
use_flash_attention_2=model_args.use_flash_attention_2,
torch_dtype=torch_dtype,
use_cache=False if training_args.gradient_checkpointing else True,
device_map=get_kbit_device_map(),
quantization_config=get_quantization_config(model_args),
)
ref_model = model_args.model_name_or_path
ref_model_kwargs = model_kwargs
if model_args.use_peft:
ref_model = None
ref_model_kwargs = None
#########################
# Instantiate DPO trainer
#########################
dpo_trainer = DPOTrainer(
model_args.model_name_or_path,
ref_model,
model_init_kwargs=model_kwargs,
ref_model_init_kwargs=ref_model_kwargs,
args=training_args,
beta=training_args.beta,
train_dataset=raw_datasets["train"],
eval_dataset=raw_datasets["test"],
tokenizer=tokenizer,
max_length=training_args.max_seq_length,
max_prompt_length=training_args.max_prompt_length,
peft_config=get_peft_config(model_args),
)
###############
# Training loop
###############
train_result = dpo_trainer.train()
metrics = train_result.metrics
max_train_samples = (
data_args.max_train_samples if data_args.max_train_samples is not None else len(raw_datasets["train"])
)
metrics["train_samples"] = min(max_train_samples, len(raw_datasets["train"]))
dpo_trainer.log_metrics("train", metrics)
dpo_trainer.save_metrics("train", metrics)
dpo_trainer.save_state()
logger.info("*** Training complete ***")
##########
# Evaluate
##########
if training_args.do_eval:
logger.info("*** Evaluate ***")
metrics = dpo_trainer.evaluate(eval_dataset=raw_datasets["test"])
max_eval_samples = (
data_args.max_eval_samples if data_args.max_eval_samples is not None else len(raw_datasets["test"])
)
metrics["eval_samples"] = min(max_eval_samples, len(raw_datasets["test"]))
dpo_trainer.log_metrics("eval", metrics)
dpo_trainer.save_metrics("eval", metrics)
##################################
# Save model and create model card
##################################
dpo_trainer.save_model(training_args.output_dir)
# Save everything else on main process
if accelerator.is_main_process:
kwargs = {"finetuned_from": model_args.model_name_or_path, "tasks": "text-generation"}
kwargs["dataset"] = list(data_args.dataset_mixer.keys())
dpo_trainer.create_model_card(**kwargs)
# Restore k,v cache for fast inference
dpo_trainer.model.config.use_cache = True
# Fix custom code paths
if model_args.trust_remote_code is True:
auto_map = dpo_trainer.model.config.auto_map
dpo_trainer.model.config.auto_map = {k: v.split("--")[-1] for k, v in auto_map.items()}
dpo_trainer.model.config.save_pretrained(training_args.output_dir)
# FSDP/DeepSpeed save the model as a single `pytorch_model.bin` file, so we need to shard it.
# We run this in a subprocess to avoid interference from the accelerators.
subprocess.run(
[
"python",
"scripts/training/shard_checkpoint.py",
f"--output_dir={training_args.output_dir}",
f"--trust_remote_code={model_args.trust_remote_code}",
],
check=True,
)
# Convert torch weights to safetensors for deployment with TGI
convert_to_safetensors(training_args.output_dir)
if training_args.push_to_hub_revision:
is_model_on_hub = push_to_hub_revision(training_args, model_args)
# Run automatic evaluation once the model is pushed to the Hub
if is_slurm_available() and is_model_on_hub is True and training_args.do_eval is True:
logger.info("*** Launching MT Bench ***")
run_mt_bench_job(training_args, model_args)
# Ensure we don't timeout on model save / push to Hub
logger.info("*** Waiting for all processes to finish ***")
accelerator.wait_for_everyone()
wandb.finish()
logger.info("*** Run complete! ***")
if __name__ == "__main__":
main()
+198
View File
@@ -0,0 +1,198 @@
#!/usr/bin/env python
# coding=utf-8
# Copyright 2023 The HuggingFace Inc. team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""
Supervised fine-tuning script for decoder language models.
"""
import logging
import math
import random
import sys
import datasets
import torch
import transformers
from transformers import set_seed
from accelerate import Accelerator
from alignment import (
DataArguments,
H4ArgumentParser,
ModelArguments,
SFTConfig,
apply_chat_template,
get_datasets,
get_kbit_device_map,
get_peft_config,
get_quantization_config,
get_tokenizer,
)
from trl import SFTTrainer
logger = logging.getLogger(__name__)
def main():
parser = H4ArgumentParser((ModelArguments, DataArguments, SFTConfig))
model_args, data_args, training_args = parser.parse()
# Set seed for reproducibility
set_seed(training_args.seed)
accelerator = Accelerator()
###############
# Setup logging
###############
logging.basicConfig(
format="%(asctime)s - %(levelname)s - %(name)s - %(message)s",
datefmt="%Y-%m-%d %H:%M:%S",
handlers=[logging.StreamHandler(sys.stdout)],
)
log_level = training_args.get_process_log_level()
logger.setLevel(log_level)
datasets.utils.logging.set_verbosity(log_level)
transformers.utils.logging.set_verbosity(log_level)
transformers.utils.logging.enable_default_handler()
transformers.utils.logging.enable_explicit_format()
# Log on each process a small summary
logger.warning(
f"Process rank: {training_args.local_rank}, device: {training_args.device}, n_gpu: {training_args.n_gpu}"
+ f" distributed training: {bool(training_args.local_rank != -1)}, 16-bits training: {training_args.bf16}"
)
logger.info(f"Model parameters {model_args}")
logger.info(f"Data parameters {data_args}")
logger.info(f"Training/evaluation parameters {training_args}")
###############
# Load datasets
###############
raw_datasets = get_datasets(data_args, splits=data_args.dataset_splits)
logger.info(
f"Training on the following datasets and their proportions: {[split + ' : ' + str(dset.num_rows) for split, dset in raw_datasets.items()]}"
)
with training_args.main_process_first(desc="Log a few random samples from the raw training set"):
for index in random.sample(range(len(raw_datasets["train"])), 3):
logger.info(f"Sample {index} of the raw training set:\n\n{raw_datasets['train'][index]['messages']}")
#####################################
# Load tokenizer and process datasets
#####################################
tokenizer = get_tokenizer(model_args, data_args)
#####################
# Apply chat template
#####################
raw_datasets = raw_datasets.map(apply_chat_template, fn_kwargs={"tokenizer": tokenizer, "task": "sft"})
train_dataset = raw_datasets["train"]
eval_dataset = raw_datasets["test"]
with training_args.main_process_first(desc="Log a few random samples from the processed training set"):
for index in random.sample(range(len(raw_datasets["train"])), 3):
logger.info(f"Sample {index} of the processed training set:\n\n{raw_datasets['train'][index]['text']}")
#######################
# Load pretrained model
#######################
logger.info("*** Load pretrained model ***")
torch_dtype = (
model_args.torch_dtype if model_args.torch_dtype in ["auto", None] else getattr(torch, model_args.torch_dtype)
)
model_kwargs = dict(
revision=model_args.model_revision,
trust_remote_code=model_args.trust_remote_code,
use_flash_attention_2=model_args.use_flash_attention_2,
torch_dtype=torch_dtype,
use_cache=False if training_args.gradient_checkpointing else True,
device_map=get_kbit_device_map(),
quantization_config=get_quantization_config(model_args),
)
logger.info("*** Model loaded! ***")
########################
# Initialize the Trainer
########################
trainer = SFTTrainer(
model=model_args.model_name_or_path,
model_init_kwargs=model_kwargs,
args=training_args,
train_dataset=raw_datasets["train"] if training_args.do_train else None,
eval_dataset=raw_datasets["test"] if training_args.do_eval else None,
dataset_text_field="text",
max_seq_length=training_args.max_seq_length,
tokenizer=tokenizer,
packing=True,
peft_config=get_peft_config(model_args),
)
###############
# Training loop
###############
if training_args.do_train:
logger.info("*** Train ***")
train_result = trainer.train()
metrics = train_result.metrics
max_train_samples = (
data_args.max_train_samples if data_args.max_train_samples is not None else len(train_dataset)
)
metrics["train_samples"] = min(max_train_samples, len(train_dataset))
trainer.log_metrics("train", metrics)
trainer.save_metrics("train", metrics)
trainer.save_state()
##########
# Evaluate
##########
if training_args.do_eval:
logger.info("*** Evaluate ***")
metrics = trainer.evaluate()
max_eval_samples = data_args.max_eval_samples if data_args.max_eval_samples is not None else len(eval_dataset)
metrics["eval_samples"] = min(max_eval_samples, len(eval_dataset))
try:
perplexity = math.exp(metrics["eval_loss"])
except OverflowError:
perplexity = float("inf")
metrics["perplexity"] = perplexity
trainer.log_metrics("eval", metrics)
trainer.save_metrics("eval", metrics)
##################################
# Save model and create model card
##################################
logger.info("*** Save model ***")
trainer.save_model(training_args.output_dir)
logger.info(f"Model saved to {training_args.output_dir}")
# Save everything else on main process
if accelerator.is_main_process:
kwargs = {"finetuned_from": model_args.model_name_or_path, "tasks": "text-generation"}
kwargs["dataset"] = list(data_args.dataset_mixer.keys())
trainer.create_model_card(**kwargs)
# Restore k,v cache for fast inference
trainer.model.config.use_cache = True
trainer.model.config.save_pretrained(training_args.output_dir)
if training_args.push_to_hub:
trainer.push_to_hub()
accelerator.wait_for_everyone()
if __name__ == "__main__":
main()
+7 -7
View File
@@ -45,25 +45,25 @@ _deps = [
"bitsandbytes==0.41.1",
"black==23.1.0",
"datasets==2.12.0",
"deepspeed==0.9.5",
"einops==0.6.1",
"deepspeed==0.12.2",
"einops>=0.6.1",
"evaluate==0.4.0",
"flake8>=6.0.0",
"hf-doc-builder>=0.4.0",
"huggingface-hub>=0.14.1,<1.0",
"isort>=5.12.0",
"ninja==1.11.1",
"ninja>=1.11.1",
"numpy>=1.24.2",
"packaging>=23.0",
"parameterized>=0.9.0",
"peft==0.5.0",
"peft==0.6.0",
"protobuf<=3.20.2", # Needed to avoid conflicts with `transformers`
"pytest",
"safetensors==0.3.3",
"safetensors>=0.3.3",
"tensorboard",
"torch==2.0.1",
"transformers @ git+https://github.com/huggingface/transformers.git@b3961f7291307ee877ef1a4d057949597d805220",
"trl @ git+https://github.com/huggingface/trl.git@1e56ff0f166888973d69cd9d56be60a9f8edfedb", # TODO bump to next release, added for NEFTune
"transformers==4.35.0",
"trl==0.7.4", # TODO bump to next release, added for NEFTune
"tqdm>=4.64.1",
]
+4
View File
@@ -1 +1,5 @@
__version__ = "0.2.0.dev0"
from .configs import DataArguments, DPOConfig, H4ArgumentParser, ModelArguments, SFTConfig
from .data import apply_chat_template, get_datasets
from .model_utils import get_kbit_device_map, get_peft_config, get_quantization_config, get_tokenizer
+272
View File
@@ -0,0 +1,272 @@
# coding=utf-8
# coding=utf-8
# Copyright 2023 The HuggingFace Team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import dataclasses
import os
import sys
from dataclasses import dataclass, field
from typing import Any, Dict, List, NewType, Optional, Tuple, Union
import transformers
from transformers import MODEL_FOR_CAUSAL_LM_MAPPING, HfArgumentParser
MODEL_CONFIG_CLASSES = list(MODEL_FOR_CAUSAL_LM_MAPPING.keys())
MODEL_TYPES = tuple(conf.model_type for conf in MODEL_CONFIG_CLASSES)
DataClassType = NewType("DataClassType", Any)
class H4ArgumentParser(HfArgumentParser):
def parse_yaml_and_args(self, yaml_arg: str, other_args: Optional[List[str]] = None) -> List[dataclass]:
"""
Parse a YAML file and overwrite the default/loaded values with the values provided to the command line.
Args:
yaml_arg (`str`):
The path to the config file used
other_args (`List[str]`, *optional`):
A list of strings to parse as command line arguments, e.g. ['--arg=val', '--arg2=val2'].
Returns:
[`List[dataclass]`]: a list of dataclasses with the values from the YAML file and the command line
"""
arg_list = self.parse_yaml_file(os.path.abspath(yaml_arg))
outputs = []
# strip other args list into dict of key-value pairs
other_args = {arg.split("=")[0].strip("-"): arg.split("=")[1] for arg in other_args}
used_args = {}
# overwrite the default/loaded value with the value provided to the command line
# adapted from https://github.com/huggingface/transformers/blob/d0b5002378daabf62769159add3e7d66d3f83c3b/src/transformers/hf_argparser.py#L327
for data_yaml, data_class in zip(arg_list, self.dataclass_types):
keys = {f.name for f in dataclasses.fields(data_yaml) if f.init}
inputs = {k: v for k, v in vars(data_yaml).items() if k in keys}
for arg, val in other_args.items():
# add only if in keys
if arg in keys:
base_type = data_yaml.__dataclass_fields__[arg].type
inputs[arg] = val
# cast type for ints, floats (default to strings)
if base_type in [int, float]:
inputs[arg] = base_type(val)
if base_type == List[str]:
inputs[arg] = [str(v) for v in val.split(",")]
# bool of a non-empty string is True, so we manually check for bools
if base_type == bool:
if val in ["true", "True"]:
inputs[arg] = True
else:
inputs[arg] = False
# add to used-args so we can check if double add
if arg not in used_args:
used_args[arg] = val
else:
raise ValueError(f"Duplicate argument provided: {arg}, may cause unexpected behavior")
obj = data_class(**inputs)
outputs.append(obj)
return outputs
def parse(self) -> Union[DataClassType, Tuple[DataClassType]]:
if len(sys.argv) == 2 and sys.argv[1].endswith(".yaml"):
# If we pass only one argument to the script and it's the path to a YAML file,
# let's parse it to get our arguments.
output = self.parse_yaml_file(os.path.abspath(sys.argv[1]))
# parse command line args and yaml file
elif len(sys.argv) > 2 and sys.argv[1].endswith(".yaml"):
output = self.parse_yaml_and_args(os.path.abspath(sys.argv[1]), sys.argv[2:])
# parse command line args only
else:
output = self.parse_args_into_dataclasses()
if len(output) == 1:
output = output[0]
return output
@dataclass
class ModelArguments:
"""
Arguments pertaining to which model/config/tokenizer we are going to fine-tune.
"""
base_model_revision: Optional[str] = field(
default=None,
metadata={"help": ("The base model checkpoint for weights initialization with PEFT adatpers.")},
)
model_name_or_path: Optional[str] = field(
default=None,
metadata={
"help": (
"The model checkpoint for weights initialization. Don't set if you want to train a model from scratch."
)
},
)
model_revision: str = field(
default="main",
metadata={"help": "The specific model version to use (can be a branch name, tag name or commit id)."},
)
model_code_revision: str = field(default=None, metadata={"help": "The branch of the IFT model"})
torch_dtype: Optional[str] = field(
default=None,
metadata={
"help": (
"Override the default `torch.dtype` and load the model under this dtype. If `auto` is passed, the "
"dtype will be automatically derived from the model's weights."
),
"choices": ["auto", "bfloat16", "float16", "float32"],
},
)
trust_remote_code: bool = field(default=False, metadata={"help": "Trust remote code when loading a model."})
use_flash_attention_2: bool = field(
default=False,
metadata={
"help": (
"Whether to use flash attention 2. You must install this manually by running `pip install flash-attn --no-build-isolation`"
)
},
)
use_peft: bool = field(
default=False,
metadata={"help": ("Whether to use PEFT or not for training.")},
)
lora_r: Optional[int] = field(
default=16,
metadata={"help": ("LoRA R value.")},
)
lora_alpha: Optional[int] = field(
default=32,
metadata={"help": ("LoRA alpha.")},
)
lora_dropout: Optional[float] = field(
default=0.05,
metadata={"help": ("LoRA dropout.")},
)
lora_target_modules: Optional[List[str]] = field(
default=None,
metadata={"help": ("LoRA target modules.")},
)
lora_modules_to_save: Optional[List[str]] = field(
default=None,
metadata={"help": ("Model layers to unfreeze & train")},
)
load_in_8bit: bool = field(default=False, metadata={"help": "use 8 bit precision"})
load_in_4bit: bool = field(default=False, metadata={"help": "use 4 bit precision"})
bnb_4bit_quant_type: Optional[str] = field(
default="nf4", metadata={"help": "precise the quantization type (fp4 or nf4)"}
)
use_bnb_nested_quant: bool = field(default=False, metadata={"help": "use nested quantization"})
def __post_init__(self):
if self.load_in_8bit and self.load_in_4bit:
raise ValueError("You can't use 8 bit and 4 bit precision at the same time")
@dataclass
class DataArguments:
"""
Arguments pertaining to what data we are going to input our model for training and eval.
"""
chat_template: Optional[str] = field(default=None, metadata={"help": "The chat template to use."})
dataset_mixer: Optional[Dict[str, float]] = field(
default=None,
metadata={"help": ("Datasets and their proportions to be used for training ift/rl.")},
)
dataset_splits: Optional[List[str]] = field(
default_factory=lambda: ["train", "test"],
metadata={"help": ("List of train test splits to use in the dataset")},
)
max_train_samples: Optional[int] = field(
default=None,
metadata={
"help": (
"For debugging purposes or quicker training, truncate the number of training examples to this "
"value if set."
)
},
)
max_eval_samples: Optional[int] = field(
default=None,
metadata={
"help": (
"For debugging purposes or quicker training, truncate the number of evaluation examples to this "
"value if set."
)
},
)
preprocessing_num_workers: Optional[int] = field(
default=None,
metadata={"help": "The number of processes to use for the preprocessing."},
)
truncation_side: Optional[str] = field(
default=None, metadata={"help": "Truncation side to use for the tokenizer."}
)
@dataclass
class SFTConfig(transformers.TrainingArguments):
"""
Arguments related to the training process itself. For all parameters, see: https://huggingface.co/docs/transformers/v4.26.1/en/main_classes/trainer#transformers.TrainingArguments
"""
max_seq_length: Optional[int] = field(
default=None,
metadata={"help": ("Used by TRL for reward model training, which tries to read this parameter in init.")},
)
logging_first_step: bool = field(
default=True,
metadata={"help": ("Whether to log and evaluate the first global_step or not.")},
)
optim: Optional[str] = field(default="adamw_torch")
@dataclass
class DPOConfig(transformers.TrainingArguments):
"""
Arguments related to the DPO training process itself. For all parameters, see: https://huggingface.co/docs/transformers/v4.26.1/en/main_classes/trainer#transformers.TrainingArguments
"""
beta: Optional[float] = field(
default=0.1,
metadata={"help": "The beta factor in DPO loss. Higher beta means less divergence from the initial policy."},
)
hub_model_revision: Optional[str] = field(
default="main",
metadata={"help": ("The Hub model branch to push the model to.")},
)
logging_first_step: bool = field(
default=True,
metadata={"help": ("Whether to log and evaluate the first global_step or not.")},
)
max_prompt_length: Optional[int] = field(
default=None,
metadata={"help": ("For DPO, the maximum length of the prompt to use for conditioning the model.")},
)
max_length: Optional[int] = field(
default=None,
metadata={"help": ("Used by TRL for reward model training, which tries to read this parameter in init.")},
)
optim: Optional[str] = field(default="rmsprop")
remove_unused_columns: bool = field(default=False)
+171
View File
@@ -0,0 +1,171 @@
import re
from typing import List, Literal, Optional, Union
from datasets import DatasetDict, concatenate_datasets, load_dataset
from .configs import DataArguments
DEFAULT_CHAT_TEMPLATE = "{% for message in messages %}\n{% if message['role'] == 'user' %}\n{{ '<|user|>\n' + message['content'] + eos_token }}\n{% elif message['role'] == 'system' %}\n{{ '<|system|>\n' + message['content'] + eos_token }}\n{% elif message['role'] == 'assistant' %}\n{{ '<|assistant|>\n' + message['content'] + eos_token }}\n{% endif %}\n{% if loop.last and add_generation_prompt %}\n{{ '<|assistant|>' }}\n{% endif %}\n{% endfor %}"
def apply_chat_template(
example, tokenizer, task: Literal["sft", "generation", "rm", "dpo"] = "sft", assistant_prefix="<|assistant|>\n"
):
def _strip_prefix(s, pattern):
# Use re.escape to escape any special characters in the pattern
return re.sub(f"^{re.escape(pattern)}", "", s)
if task in ["sft", "generation"]:
messages = example["messages"]
# We add an empty system message if there is none
if messages[0]["role"] != "system":
messages.insert(0, {"role": "system", "content": ""})
example["text"] = tokenizer.apply_chat_template(
messages, tokenize=False, add_generation_prompt=True if task == "generation" else False
)
elif task == "rm":
if all(k in example.keys() for k in ("chosen", "rejected")):
chosen_messages = example["chosen"]
rejected_messages = example["rejected"]
# We add an empty system message if there is none
if chosen_messages[0]["role"] != "system":
chosen_messages.insert(0, {"role": "system", "content": ""})
if rejected_messages[0]["role"] != "system":
rejected_messages.insert(0, {"role": "system", "content": ""})
example["text_chosen"] = tokenizer.apply_chat_template(chosen_messages, tokenize=False)
example["text_rejected"] = tokenizer.apply_chat_template(rejected_messages, tokenize=False)
else:
raise ValueError(
f"Could not format example as dialogue for `rm` task! Require `[chosen, rejected]` keys but found {list(example.keys())}"
)
elif task == "dpo":
if all(k in example.keys() for k in ("chosen", "rejected")):
# Compared to reward modeling, we filter out the prompt, so the text is everything after the last assistant token
prompt_messages = [[msg for msg in example["chosen"] if msg["role"] == "user"][0]]
# Insert system message
if example["chosen"][0]["role"] != "system":
prompt_messages.insert(0, {"role": "system", "content": ""})
else:
prompt_messages.insert(0, example["chosen"][0])
# TODO: handle case where chosen/rejected also have system messages
chosen_messages = example["chosen"][1:]
rejected_messages = example["rejected"][1:]
example["text_chosen"] = tokenizer.apply_chat_template(chosen_messages, tokenize=False)
example["text_rejected"] = tokenizer.apply_chat_template(rejected_messages, tokenize=False)
example["text_prompt"] = tokenizer.apply_chat_template(
prompt_messages, tokenize=False, add_generation_prompt=True
)
example["text_chosen"] = _strip_prefix(example["text_chosen"], assistant_prefix)
example["text_rejected"] = _strip_prefix(example["text_rejected"], assistant_prefix)
else:
raise ValueError(
f"Could not format example as dialogue for `dpo` task! Require `[chosen, rejected]` keys but found {list(example.keys())}"
)
return example
def get_datasets(
data_config: Union[DataArguments, dict],
splits: List[str] = ["train", "test"],
shuffle: bool = True,
) -> DatasetDict:
"""
Loads one or more datasets with varying training set proportions.
Args:
data_config (`DataArguments` or `dict`):
Dataset configuration and split proportions.
splits (`List[str]`, *optional*, defaults to `['train', 'test']`):
Dataset splits to load and mix. Assumes the splits exist in all datasets and have a `train_` or `test_` prefix.
shuffle (`bool`, *optional*, defaults to `True`):
Whether to shuffle the training data.
Returns
[`DatasetDict`]: The dataset dictionary containing the loaded datasets.
"""
if type(data_config) is DataArguments:
# Structure of the config to read the datasets and their mix
# datasets_mixer:
# - 'dataset1': 0.5
# - 'dataset2': 0.3
# - 'dataset3': 0.2
dataset_mixer = data_config.dataset_mixer
elif type(data_config) is dict:
# Structure of the input is:
# dataset_mixer = {
# "dataset1": 0.5,
# "dataset1": 0.3,
# "dataset1": 0.2,
# }
dataset_mixer = data_config
else:
raise ValueError(f"Data config {data_config} not recognized.")
raw_datasets = mix_datasets(dataset_mixer, splits=splits, shuffle=shuffle)
return raw_datasets
def mix_datasets(dataset_mixer: dict, splits: Optional[List[str]] = None, shuffle=True) -> DatasetDict:
"""
Loads and mixes datasets according to proportions specified in `dataset_mixer`.
Args:
dataset_mixer (`dict`):
Dictionary containing the dataset names and their training proportions. By default, all test proportions are 1.
splits (Optional[List[str]], *optional*, defaults to `None`):
Dataset splits to load and mix. Assumes the splits exist in all datasets and have a `train_` or `test_` prefix.
shuffle (`bool`, *optional*, defaults to `True`):
Whether to shuffle the training data.
"""
raw_datasets = DatasetDict()
raw_train_datasets = []
raw_val_datasets = []
fracs = []
for ds, frac in dataset_mixer.items():
fracs.append(frac)
for split in splits:
if "train" in split:
raw_train_datasets.append(
load_dataset(
ds,
split=split,
)
)
elif "test" in split:
raw_val_datasets.append(
load_dataset(
ds,
split=split,
)
)
else:
raise ValueError(f"Split type {split} not recognized as one of test or train.")
if any(frac < 0 for frac in fracs):
raise ValueError("Dataset fractions cannot be negative.")
if len(raw_train_datasets) > 0:
train_subsets = []
for dataset, frac in zip(raw_train_datasets, fracs):
train_subset = dataset.select(range(int(frac * len(dataset))))
train_subsets.append(train_subset)
if shuffle:
raw_datasets["train"] = concatenate_datasets(train_subsets).shuffle(seed=42)
else:
raw_datasets["train"] = concatenate_datasets(train_subsets)
# No subsampling for test datasets to enable fair comparison across models
if len(raw_val_datasets) > 0:
if shuffle:
raw_datasets["test"] = concatenate_datasets(raw_val_datasets).shuffle(seed=42)
else:
raw_datasets["test"] = concatenate_datasets(raw_val_datasets)
if len(raw_datasets) == 0:
raise ValueError(
f"Dataset {dataset_mixer} not recognized with split {split}. Check the dataset has been correctly formatted."
)
return raw_datasets
+79
View File
@@ -0,0 +1,79 @@
from typing import Dict, Union
import torch
from transformers import AutoTokenizer, BitsAndBytesConfig, PreTrainedTokenizer
from accelerate import Accelerator
from peft import LoraConfig, PeftConfig
from .configs import DataArguments, ModelArguments
from .data import DEFAULT_CHAT_TEMPLATE
def get_current_device() -> int:
"""Get the current device. For GPU we return the local process index to enable multiple GPU training."""
return Accelerator().local_process_index if torch.cuda.is_available() else "cpu"
def get_kbit_device_map() -> Dict[str, int] | None:
"""Useful for running inference with quantized models by setting `device_map=get_peft_device_map()`"""
return {"": get_current_device()} if torch.cuda.is_available() else None
def get_quantization_config(model_args) -> BitsAndBytesConfig | None:
if model_args.load_in_4bit:
quantization_config = BitsAndBytesConfig(
load_in_4bit=True,
bnb_4bit_compute_dtype=torch.float16, # For consistency with model weights, we use the same value as `torch_dtype` which is float16 for PEFT models
bnb_4bit_quant_type=model_args.bnb_4bit_quant_type,
bnb_4bit_use_double_quant=model_args.use_bnb_nested_quant,
)
elif model_args.load_in_8bit:
quantization_config = BitsAndBytesConfig(
load_in_8bit=True,
)
else:
quantization_config = None
return quantization_config
def get_tokenizer(model_args: ModelArguments, data_args: DataArguments) -> PreTrainedTokenizer:
"""Get the tokenizer for the model."""
tokenizer = AutoTokenizer.from_pretrained(
model_args.model_name_or_path,
revision=model_args.model_revision,
)
if tokenizer.pad_token_id is None:
tokenizer.pad_token_id = tokenizer.eos_token_id
if data_args.truncation_side is not None:
tokenizer.truncation_side = data_args.truncation_side
# Set reasonable default for models without max length
if tokenizer.model_max_length > 100_000:
tokenizer.model_max_length = 2048
if data_args.chat_template is not None:
tokenizer.chat_template = data_args.chat_template
elif tokenizer.chat_template is None:
tokenizer.chat_template = DEFAULT_CHAT_TEMPLATE
return tokenizer
def get_peft_config(model_args: ModelArguments) -> Union[PeftConfig, None]:
if model_args.use_peft is False:
return None
peft_config = LoraConfig(
r=model_args.lora_r,
lora_alpha=model_args.lora_alpha,
lora_dropout=model_args.lora_dropout,
bias="none",
task_type="CAUSAL_LM",
target_modules=model_args.lora_target_modules,
modules_to_save=model_args.lora_modules_to_save,
)
return peft_config