From 967eab4cfb71ac41e51f77b30123d6858c880328 Mon Sep 17 00:00:00 2001 From: Lewis Tunstall Date: Wed, 8 Nov 2023 13:21:57 +0000 Subject: [PATCH] Add skeleton --- Makefile | 10 +- recipes/dpo/.gitkeep | 0 recipes/ppo/.gitkeep | 0 recipes/reward_modeling/.gitkeep | 0 recipes/sft/.gitkeep | 0 scripts/run_dpo.py | 235 +++++++++++++++++++++++ scripts/run_sft.py | 198 +++++++++++++++++++ setup.py | 14 +- src/alignment/__init__.py | 4 + src/alignment/configs.py | 272 +++++++++++++++++++++++++++ src/alignment/data.py | 171 +++++++++++++++++ src/alignment/model_utils.py | 79 ++++++++ src/alignment/{utils => }/release.py | 0 13 files changed, 971 insertions(+), 12 deletions(-) delete mode 100644 recipes/dpo/.gitkeep delete mode 100644 recipes/ppo/.gitkeep delete mode 100644 recipes/reward_modeling/.gitkeep delete mode 100644 recipes/sft/.gitkeep create mode 100644 scripts/run_dpo.py create mode 100644 scripts/run_sft.py create mode 100644 src/alignment/configs.py create mode 100644 src/alignment/data.py create mode 100644 src/alignment/model_utils.py rename src/alignment/{utils => }/release.py (100%) diff --git a/Makefile b/Makefile index 553edb1..2d82400 100644 --- a/Makefile +++ b/Makefile @@ -3,7 +3,7 @@ # make sure to test the local checkout in scripts and not the pre-installed one (don't use quotes!) export PYTHONPATH = src -check_dirs := src tests +check_dirs := src tests scripts style: python -m black --line-length 119 --target-version py310 $(check_dirs) setup.py @@ -18,16 +18,16 @@ quality: # Release stuff pre-release: - python src/alignment/utils/release.py + python src/alignment/release.py pre-patch: - python src/alignment/utils/release.py --patch + python src/alignment/release.py --patch post-release: - python src/alignment/utils/release.py --post_release + python src/alignment/release.py --post_release post-patch: - python src/alignment/utils/release.py --post_release --patch + python src/alignment/release.py --post_release --patch wheels: python setup.py bdist_wheel && python setup.py sdist diff --git a/recipes/dpo/.gitkeep b/recipes/dpo/.gitkeep deleted file mode 100644 index e69de29..0000000 diff --git a/recipes/ppo/.gitkeep b/recipes/ppo/.gitkeep deleted file mode 100644 index e69de29..0000000 diff --git a/recipes/reward_modeling/.gitkeep b/recipes/reward_modeling/.gitkeep deleted file mode 100644 index e69de29..0000000 diff --git a/recipes/sft/.gitkeep b/recipes/sft/.gitkeep deleted file mode 100644 index e69de29..0000000 diff --git a/scripts/run_dpo.py b/scripts/run_dpo.py new file mode 100644 index 0000000..b6f1cba --- /dev/null +++ b/scripts/run_dpo.py @@ -0,0 +1,235 @@ +#!/usr/bin/env python +# coding=utf-8 +# Copyright 2023 The HuggingFace Inc. team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +import logging +import random +import subprocess +import sys +from datetime import timedelta + +import torch +import transformers +from transformers import set_seed + +import wandb +from accelerate import Accelerator, InitProcessGroupKwargs +from h4.data import get_datasets +from h4.training import DataArguments, DPOTrainingArguments, ModelArguments, init_wandb_training +from h4.utils import ( + H4ArgumentParser, + apply_chat_template, + convert_to_safetensors, + get_kbit_device_map, + get_peft_config, + get_quantization_config, + get_tokenizer, + hf_login, + is_slurm_available, + push_to_hub_revision, + run_mt_bench_job, +) +from trl import DPOTrainer + + +logger = logging.getLogger(__name__) + + +def main(): + parser = H4ArgumentParser((ModelArguments, DataArguments, DPOTrainingArguments)) + model_args, data_args, training_args = parser.parse() + + ####### + # Setup + ####### + logging.basicConfig( + format="%(asctime)s - %(levelname)s - %(name)s - %(message)s", + datefmt="%Y-%m-%d %H:%M:%S", + handlers=[logging.StreamHandler(sys.stdout)], + ) + log_level = training_args.get_process_log_level() + logger.setLevel(log_level) + transformers.utils.logging.set_verbosity(log_level) + transformers.utils.logging.enable_default_handler() + transformers.utils.logging.enable_explicit_format() + + # Log on each process the small summary: + logger.info(f"Model parameters {model_args}") + logger.info(f"Data parameters {data_args}") + logger.info(f"Training/evaluation parameters {training_args}") + + # Setup WandB + if training_args.wandb_enabled: + init_wandb_training(training_args) + + # Login to HuggingFace Hub if needed + hf_login() + + # Set seed for reproducibility + set_seed(training_args.seed) + + # Increase distributed timeout to 3h to enable push to Hub to complete + accelerator = Accelerator(kwargs_handlers=[InitProcessGroupKwargs(timeout=timedelta(seconds=6 * 1800))]) + + ############### + # Load datasets + ############### + raw_datasets = get_datasets(data_args, splits=data_args.dataset_splits) + logger.info( + f"Training on the following splits: {[split + ' : ' + str(dset.num_rows) for split, dset in raw_datasets.items()]}" + ) + column_names = list(raw_datasets["train"].features) + + ##################################### + # Load tokenizer and process datasets + ##################################### + data_args.truncation_side = "left" # Truncate from left to ensure we don't lose labels in final turn + tokenizer = get_tokenizer(model_args, data_args) + + ##################### + # Apply chat template + ##################### + raw_datasets = raw_datasets.map( + apply_chat_template, + fn_kwargs={"tokenizer": tokenizer, "task": "dpo"}, + num_proc=data_args.preprocessing_num_workers, + remove_columns=column_names, + desc="Formatting comparisons with prompt template", + ) + + # Replace column names with what TRL needs, text_chosen -> chosen and text_rejected -> rejected + for split in ["train", "test"]: + raw_datasets[split] = raw_datasets[split].rename_columns( + {"text_prompt": "prompt", "text_chosen": "chosen", "text_rejected": "rejected"} + ) + + # Log a few random samples from the training set: + for index in random.sample(range(len(raw_datasets["train"])), 3): + logger.info(f"Prompt sample {index} of the raw training set:\n\n{raw_datasets['train'][index]['prompt']}") + logger.info(f"Chosen sample {index} of the raw training set:\n\n{raw_datasets['train'][index]['chosen']}") + logger.info(f"Rejected sample {index} of the raw training set:\n\n{raw_datasets['train'][index]['rejected']}") + + torch_dtype = ( + model_args.torch_dtype if model_args.torch_dtype in ["auto", None] else getattr(torch, model_args.torch_dtype) + ) + model_kwargs = dict( + revision=model_args.model_revision, + trust_remote_code=model_args.trust_remote_code, + use_flash_attention_2=model_args.use_flash_attention_2, + torch_dtype=torch_dtype, + use_cache=False if training_args.gradient_checkpointing else True, + device_map=get_kbit_device_map(), + quantization_config=get_quantization_config(model_args), + ) + + ref_model = model_args.model_name_or_path + ref_model_kwargs = model_kwargs + + if model_args.use_peft: + ref_model = None + ref_model_kwargs = None + + ######################### + # Instantiate DPO trainer + ######################### + dpo_trainer = DPOTrainer( + model_args.model_name_or_path, + ref_model, + model_init_kwargs=model_kwargs, + ref_model_init_kwargs=ref_model_kwargs, + args=training_args, + beta=training_args.beta, + train_dataset=raw_datasets["train"], + eval_dataset=raw_datasets["test"], + tokenizer=tokenizer, + max_length=training_args.max_seq_length, + max_prompt_length=training_args.max_prompt_length, + peft_config=get_peft_config(model_args), + ) + + ############### + # Training loop + ############### + train_result = dpo_trainer.train() + metrics = train_result.metrics + max_train_samples = ( + data_args.max_train_samples if data_args.max_train_samples is not None else len(raw_datasets["train"]) + ) + metrics["train_samples"] = min(max_train_samples, len(raw_datasets["train"])) + dpo_trainer.log_metrics("train", metrics) + dpo_trainer.save_metrics("train", metrics) + dpo_trainer.save_state() + + logger.info("*** Training complete ***") + + ########## + # Evaluate + ########## + if training_args.do_eval: + logger.info("*** Evaluate ***") + metrics = dpo_trainer.evaluate(eval_dataset=raw_datasets["test"]) + max_eval_samples = ( + data_args.max_eval_samples if data_args.max_eval_samples is not None else len(raw_datasets["test"]) + ) + metrics["eval_samples"] = min(max_eval_samples, len(raw_datasets["test"])) + dpo_trainer.log_metrics("eval", metrics) + dpo_trainer.save_metrics("eval", metrics) + + ################################## + # Save model and create model card + ################################## + dpo_trainer.save_model(training_args.output_dir) + + # Save everything else on main process + if accelerator.is_main_process: + kwargs = {"finetuned_from": model_args.model_name_or_path, "tasks": "text-generation"} + kwargs["dataset"] = list(data_args.dataset_mixer.keys()) + dpo_trainer.create_model_card(**kwargs) + # Restore k,v cache for fast inference + dpo_trainer.model.config.use_cache = True + # Fix custom code paths + if model_args.trust_remote_code is True: + auto_map = dpo_trainer.model.config.auto_map + dpo_trainer.model.config.auto_map = {k: v.split("--")[-1] for k, v in auto_map.items()} + dpo_trainer.model.config.save_pretrained(training_args.output_dir) + # FSDP/DeepSpeed save the model as a single `pytorch_model.bin` file, so we need to shard it. + # We run this in a subprocess to avoid interference from the accelerators. + subprocess.run( + [ + "python", + "scripts/training/shard_checkpoint.py", + f"--output_dir={training_args.output_dir}", + f"--trust_remote_code={model_args.trust_remote_code}", + ], + check=True, + ) + # Convert torch weights to safetensors for deployment with TGI + convert_to_safetensors(training_args.output_dir) + if training_args.push_to_hub_revision: + is_model_on_hub = push_to_hub_revision(training_args, model_args) + # Run automatic evaluation once the model is pushed to the Hub + if is_slurm_available() and is_model_on_hub is True and training_args.do_eval is True: + logger.info("*** Launching MT Bench ***") + run_mt_bench_job(training_args, model_args) + + # Ensure we don't timeout on model save / push to Hub + logger.info("*** Waiting for all processes to finish ***") + accelerator.wait_for_everyone() + wandb.finish() + + logger.info("*** Run complete! ***") + + +if __name__ == "__main__": + main() diff --git a/scripts/run_sft.py b/scripts/run_sft.py new file mode 100644 index 0000000..580916b --- /dev/null +++ b/scripts/run_sft.py @@ -0,0 +1,198 @@ +#!/usr/bin/env python +# coding=utf-8 +# Copyright 2023 The HuggingFace Inc. team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +""" +Supervised fine-tuning script for decoder language models. +""" + +import logging +import math +import random +import sys + +import datasets +import torch +import transformers +from transformers import set_seed + +from accelerate import Accelerator +from alignment import ( + DataArguments, + H4ArgumentParser, + ModelArguments, + SFTConfig, + apply_chat_template, + get_datasets, + get_kbit_device_map, + get_peft_config, + get_quantization_config, + get_tokenizer, +) +from trl import SFTTrainer + + +logger = logging.getLogger(__name__) + + +def main(): + parser = H4ArgumentParser((ModelArguments, DataArguments, SFTConfig)) + model_args, data_args, training_args = parser.parse() + + # Set seed for reproducibility + set_seed(training_args.seed) + accelerator = Accelerator() + + ############### + # Setup logging + ############### + logging.basicConfig( + format="%(asctime)s - %(levelname)s - %(name)s - %(message)s", + datefmt="%Y-%m-%d %H:%M:%S", + handlers=[logging.StreamHandler(sys.stdout)], + ) + log_level = training_args.get_process_log_level() + logger.setLevel(log_level) + datasets.utils.logging.set_verbosity(log_level) + transformers.utils.logging.set_verbosity(log_level) + transformers.utils.logging.enable_default_handler() + transformers.utils.logging.enable_explicit_format() + + # Log on each process a small summary + logger.warning( + f"Process rank: {training_args.local_rank}, device: {training_args.device}, n_gpu: {training_args.n_gpu}" + + f" distributed training: {bool(training_args.local_rank != -1)}, 16-bits training: {training_args.bf16}" + ) + logger.info(f"Model parameters {model_args}") + logger.info(f"Data parameters {data_args}") + logger.info(f"Training/evaluation parameters {training_args}") + + ############### + # Load datasets + ############### + raw_datasets = get_datasets(data_args, splits=data_args.dataset_splits) + + logger.info( + f"Training on the following datasets and their proportions: {[split + ' : ' + str(dset.num_rows) for split, dset in raw_datasets.items()]}" + ) + with training_args.main_process_first(desc="Log a few random samples from the raw training set"): + for index in random.sample(range(len(raw_datasets["train"])), 3): + logger.info(f"Sample {index} of the raw training set:\n\n{raw_datasets['train'][index]['messages']}") + + ##################################### + # Load tokenizer and process datasets + ##################################### + tokenizer = get_tokenizer(model_args, data_args) + + ##################### + # Apply chat template + ##################### + raw_datasets = raw_datasets.map(apply_chat_template, fn_kwargs={"tokenizer": tokenizer, "task": "sft"}) + train_dataset = raw_datasets["train"] + eval_dataset = raw_datasets["test"] + + with training_args.main_process_first(desc="Log a few random samples from the processed training set"): + for index in random.sample(range(len(raw_datasets["train"])), 3): + logger.info(f"Sample {index} of the processed training set:\n\n{raw_datasets['train'][index]['text']}") + + ####################### + # Load pretrained model + ####################### + logger.info("*** Load pretrained model ***") + torch_dtype = ( + model_args.torch_dtype if model_args.torch_dtype in ["auto", None] else getattr(torch, model_args.torch_dtype) + ) + + model_kwargs = dict( + revision=model_args.model_revision, + trust_remote_code=model_args.trust_remote_code, + use_flash_attention_2=model_args.use_flash_attention_2, + torch_dtype=torch_dtype, + use_cache=False if training_args.gradient_checkpointing else True, + device_map=get_kbit_device_map(), + quantization_config=get_quantization_config(model_args), + ) + logger.info("*** Model loaded! ***") + + ######################## + # Initialize the Trainer + ######################## + trainer = SFTTrainer( + model=model_args.model_name_or_path, + model_init_kwargs=model_kwargs, + args=training_args, + train_dataset=raw_datasets["train"] if training_args.do_train else None, + eval_dataset=raw_datasets["test"] if training_args.do_eval else None, + dataset_text_field="text", + max_seq_length=training_args.max_seq_length, + tokenizer=tokenizer, + packing=True, + peft_config=get_peft_config(model_args), + ) + + ############### + # Training loop + ############### + if training_args.do_train: + logger.info("*** Train ***") + train_result = trainer.train() + metrics = train_result.metrics + max_train_samples = ( + data_args.max_train_samples if data_args.max_train_samples is not None else len(train_dataset) + ) + metrics["train_samples"] = min(max_train_samples, len(train_dataset)) + trainer.log_metrics("train", metrics) + trainer.save_metrics("train", metrics) + trainer.save_state() + + ########## + # Evaluate + ########## + if training_args.do_eval: + logger.info("*** Evaluate ***") + metrics = trainer.evaluate() + max_eval_samples = data_args.max_eval_samples if data_args.max_eval_samples is not None else len(eval_dataset) + metrics["eval_samples"] = min(max_eval_samples, len(eval_dataset)) + try: + perplexity = math.exp(metrics["eval_loss"]) + except OverflowError: + perplexity = float("inf") + metrics["perplexity"] = perplexity + trainer.log_metrics("eval", metrics) + trainer.save_metrics("eval", metrics) + + ################################## + # Save model and create model card + ################################## + logger.info("*** Save model ***") + trainer.save_model(training_args.output_dir) + logger.info(f"Model saved to {training_args.output_dir}") + + # Save everything else on main process + if accelerator.is_main_process: + kwargs = {"finetuned_from": model_args.model_name_or_path, "tasks": "text-generation"} + kwargs["dataset"] = list(data_args.dataset_mixer.keys()) + trainer.create_model_card(**kwargs) + # Restore k,v cache for fast inference + trainer.model.config.use_cache = True + trainer.model.config.save_pretrained(training_args.output_dir) + + if training_args.push_to_hub: + trainer.push_to_hub() + + accelerator.wait_for_everyone() + + +if __name__ == "__main__": + main() diff --git a/setup.py b/setup.py index 3ef9c89..dad5141 100644 --- a/setup.py +++ b/setup.py @@ -45,25 +45,25 @@ _deps = [ "bitsandbytes==0.41.1", "black==23.1.0", "datasets==2.12.0", - "deepspeed==0.9.5", - "einops==0.6.1", + "deepspeed==0.12.2", + "einops>=0.6.1", "evaluate==0.4.0", "flake8>=6.0.0", "hf-doc-builder>=0.4.0", "huggingface-hub>=0.14.1,<1.0", "isort>=5.12.0", - "ninja==1.11.1", + "ninja>=1.11.1", "numpy>=1.24.2", "packaging>=23.0", "parameterized>=0.9.0", - "peft==0.5.0", + "peft==0.6.0", "protobuf<=3.20.2", # Needed to avoid conflicts with `transformers` "pytest", - "safetensors==0.3.3", + "safetensors>=0.3.3", "tensorboard", "torch==2.0.1", - "transformers @ git+https://github.com/huggingface/transformers.git@b3961f7291307ee877ef1a4d057949597d805220", - "trl @ git+https://github.com/huggingface/trl.git@1e56ff0f166888973d69cd9d56be60a9f8edfedb", # TODO bump to next release, added for NEFTune + "transformers==4.35.0", + "trl==0.7.4", # TODO bump to next release, added for NEFTune "tqdm>=4.64.1", ] diff --git a/src/alignment/__init__.py b/src/alignment/__init__.py index b9d465b..3080b6a 100644 --- a/src/alignment/__init__.py +++ b/src/alignment/__init__.py @@ -1 +1,5 @@ __version__ = "0.2.0.dev0" + +from .configs import DataArguments, DPOConfig, H4ArgumentParser, ModelArguments, SFTConfig +from .data import apply_chat_template, get_datasets +from .model_utils import get_kbit_device_map, get_peft_config, get_quantization_config, get_tokenizer diff --git a/src/alignment/configs.py b/src/alignment/configs.py new file mode 100644 index 0000000..890790e --- /dev/null +++ b/src/alignment/configs.py @@ -0,0 +1,272 @@ +# coding=utf-8 +# coding=utf-8 +# Copyright 2023 The HuggingFace Team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +import dataclasses +import os +import sys +from dataclasses import dataclass, field +from typing import Any, Dict, List, NewType, Optional, Tuple, Union + +import transformers +from transformers import MODEL_FOR_CAUSAL_LM_MAPPING, HfArgumentParser + + +MODEL_CONFIG_CLASSES = list(MODEL_FOR_CAUSAL_LM_MAPPING.keys()) +MODEL_TYPES = tuple(conf.model_type for conf in MODEL_CONFIG_CLASSES) + + +DataClassType = NewType("DataClassType", Any) + + +class H4ArgumentParser(HfArgumentParser): + def parse_yaml_and_args(self, yaml_arg: str, other_args: Optional[List[str]] = None) -> List[dataclass]: + """ + Parse a YAML file and overwrite the default/loaded values with the values provided to the command line. + + Args: + yaml_arg (`str`): + The path to the config file used + other_args (`List[str]`, *optional`): + A list of strings to parse as command line arguments, e.g. ['--arg=val', '--arg2=val2']. + + Returns: + [`List[dataclass]`]: a list of dataclasses with the values from the YAML file and the command line + """ + arg_list = self.parse_yaml_file(os.path.abspath(yaml_arg)) + + outputs = [] + # strip other args list into dict of key-value pairs + other_args = {arg.split("=")[0].strip("-"): arg.split("=")[1] for arg in other_args} + used_args = {} + + # overwrite the default/loaded value with the value provided to the command line + # adapted from https://github.com/huggingface/transformers/blob/d0b5002378daabf62769159add3e7d66d3f83c3b/src/transformers/hf_argparser.py#L327 + for data_yaml, data_class in zip(arg_list, self.dataclass_types): + keys = {f.name for f in dataclasses.fields(data_yaml) if f.init} + inputs = {k: v for k, v in vars(data_yaml).items() if k in keys} + for arg, val in other_args.items(): + # add only if in keys + if arg in keys: + base_type = data_yaml.__dataclass_fields__[arg].type + inputs[arg] = val + + # cast type for ints, floats (default to strings) + if base_type in [int, float]: + inputs[arg] = base_type(val) + + if base_type == List[str]: + inputs[arg] = [str(v) for v in val.split(",")] + + # bool of a non-empty string is True, so we manually check for bools + if base_type == bool: + if val in ["true", "True"]: + inputs[arg] = True + else: + inputs[arg] = False + + # add to used-args so we can check if double add + if arg not in used_args: + used_args[arg] = val + else: + raise ValueError(f"Duplicate argument provided: {arg}, may cause unexpected behavior") + + obj = data_class(**inputs) + outputs.append(obj) + + return outputs + + def parse(self) -> Union[DataClassType, Tuple[DataClassType]]: + if len(sys.argv) == 2 and sys.argv[1].endswith(".yaml"): + # If we pass only one argument to the script and it's the path to a YAML file, + # let's parse it to get our arguments. + output = self.parse_yaml_file(os.path.abspath(sys.argv[1])) + # parse command line args and yaml file + elif len(sys.argv) > 2 and sys.argv[1].endswith(".yaml"): + output = self.parse_yaml_and_args(os.path.abspath(sys.argv[1]), sys.argv[2:]) + # parse command line args only + else: + output = self.parse_args_into_dataclasses() + + if len(output) == 1: + output = output[0] + return output + + +@dataclass +class ModelArguments: + """ + Arguments pertaining to which model/config/tokenizer we are going to fine-tune. + """ + + base_model_revision: Optional[str] = field( + default=None, + metadata={"help": ("The base model checkpoint for weights initialization with PEFT adatpers.")}, + ) + model_name_or_path: Optional[str] = field( + default=None, + metadata={ + "help": ( + "The model checkpoint for weights initialization. Don't set if you want to train a model from scratch." + ) + }, + ) + model_revision: str = field( + default="main", + metadata={"help": "The specific model version to use (can be a branch name, tag name or commit id)."}, + ) + model_code_revision: str = field(default=None, metadata={"help": "The branch of the IFT model"}) + torch_dtype: Optional[str] = field( + default=None, + metadata={ + "help": ( + "Override the default `torch.dtype` and load the model under this dtype. If `auto` is passed, the " + "dtype will be automatically derived from the model's weights." + ), + "choices": ["auto", "bfloat16", "float16", "float32"], + }, + ) + trust_remote_code: bool = field(default=False, metadata={"help": "Trust remote code when loading a model."}) + use_flash_attention_2: bool = field( + default=False, + metadata={ + "help": ( + "Whether to use flash attention 2. You must install this manually by running `pip install flash-attn --no-build-isolation`" + ) + }, + ) + use_peft: bool = field( + default=False, + metadata={"help": ("Whether to use PEFT or not for training.")}, + ) + lora_r: Optional[int] = field( + default=16, + metadata={"help": ("LoRA R value.")}, + ) + lora_alpha: Optional[int] = field( + default=32, + metadata={"help": ("LoRA alpha.")}, + ) + lora_dropout: Optional[float] = field( + default=0.05, + metadata={"help": ("LoRA dropout.")}, + ) + lora_target_modules: Optional[List[str]] = field( + default=None, + metadata={"help": ("LoRA target modules.")}, + ) + lora_modules_to_save: Optional[List[str]] = field( + default=None, + metadata={"help": ("Model layers to unfreeze & train")}, + ) + load_in_8bit: bool = field(default=False, metadata={"help": "use 8 bit precision"}) + load_in_4bit: bool = field(default=False, metadata={"help": "use 4 bit precision"}) + + bnb_4bit_quant_type: Optional[str] = field( + default="nf4", metadata={"help": "precise the quantization type (fp4 or nf4)"} + ) + use_bnb_nested_quant: bool = field(default=False, metadata={"help": "use nested quantization"}) + + def __post_init__(self): + if self.load_in_8bit and self.load_in_4bit: + raise ValueError("You can't use 8 bit and 4 bit precision at the same time") + + +@dataclass +class DataArguments: + """ + Arguments pertaining to what data we are going to input our model for training and eval. + """ + + chat_template: Optional[str] = field(default=None, metadata={"help": "The chat template to use."}) + dataset_mixer: Optional[Dict[str, float]] = field( + default=None, + metadata={"help": ("Datasets and their proportions to be used for training ift/rl.")}, + ) + dataset_splits: Optional[List[str]] = field( + default_factory=lambda: ["train", "test"], + metadata={"help": ("List of train test splits to use in the dataset")}, + ) + max_train_samples: Optional[int] = field( + default=None, + metadata={ + "help": ( + "For debugging purposes or quicker training, truncate the number of training examples to this " + "value if set." + ) + }, + ) + max_eval_samples: Optional[int] = field( + default=None, + metadata={ + "help": ( + "For debugging purposes or quicker training, truncate the number of evaluation examples to this " + "value if set." + ) + }, + ) + preprocessing_num_workers: Optional[int] = field( + default=None, + metadata={"help": "The number of processes to use for the preprocessing."}, + ) + truncation_side: Optional[str] = field( + default=None, metadata={"help": "Truncation side to use for the tokenizer."} + ) + + +@dataclass +class SFTConfig(transformers.TrainingArguments): + """ + Arguments related to the training process itself. For all parameters, see: https://huggingface.co/docs/transformers/v4.26.1/en/main_classes/trainer#transformers.TrainingArguments + """ + + max_seq_length: Optional[int] = field( + default=None, + metadata={"help": ("Used by TRL for reward model training, which tries to read this parameter in init.")}, + ) + logging_first_step: bool = field( + default=True, + metadata={"help": ("Whether to log and evaluate the first global_step or not.")}, + ) + optim: Optional[str] = field(default="adamw_torch") + + +@dataclass +class DPOConfig(transformers.TrainingArguments): + """ + Arguments related to the DPO training process itself. For all parameters, see: https://huggingface.co/docs/transformers/v4.26.1/en/main_classes/trainer#transformers.TrainingArguments + """ + + beta: Optional[float] = field( + default=0.1, + metadata={"help": "The beta factor in DPO loss. Higher beta means less divergence from the initial policy."}, + ) + hub_model_revision: Optional[str] = field( + default="main", + metadata={"help": ("The Hub model branch to push the model to.")}, + ) + logging_first_step: bool = field( + default=True, + metadata={"help": ("Whether to log and evaluate the first global_step or not.")}, + ) + max_prompt_length: Optional[int] = field( + default=None, + metadata={"help": ("For DPO, the maximum length of the prompt to use for conditioning the model.")}, + ) + max_length: Optional[int] = field( + default=None, + metadata={"help": ("Used by TRL for reward model training, which tries to read this parameter in init.")}, + ) + optim: Optional[str] = field(default="rmsprop") + remove_unused_columns: bool = field(default=False) diff --git a/src/alignment/data.py b/src/alignment/data.py new file mode 100644 index 0000000..de25e61 --- /dev/null +++ b/src/alignment/data.py @@ -0,0 +1,171 @@ +import re +from typing import List, Literal, Optional, Union + +from datasets import DatasetDict, concatenate_datasets, load_dataset + +from .configs import DataArguments + + +DEFAULT_CHAT_TEMPLATE = "{% for message in messages %}\n{% if message['role'] == 'user' %}\n{{ '<|user|>\n' + message['content'] + eos_token }}\n{% elif message['role'] == 'system' %}\n{{ '<|system|>\n' + message['content'] + eos_token }}\n{% elif message['role'] == 'assistant' %}\n{{ '<|assistant|>\n' + message['content'] + eos_token }}\n{% endif %}\n{% if loop.last and add_generation_prompt %}\n{{ '<|assistant|>' }}\n{% endif %}\n{% endfor %}" + + +def apply_chat_template( + example, tokenizer, task: Literal["sft", "generation", "rm", "dpo"] = "sft", assistant_prefix="<|assistant|>\n" +): + def _strip_prefix(s, pattern): + # Use re.escape to escape any special characters in the pattern + return re.sub(f"^{re.escape(pattern)}", "", s) + + if task in ["sft", "generation"]: + messages = example["messages"] + # We add an empty system message if there is none + if messages[0]["role"] != "system": + messages.insert(0, {"role": "system", "content": ""}) + example["text"] = tokenizer.apply_chat_template( + messages, tokenize=False, add_generation_prompt=True if task == "generation" else False + ) + elif task == "rm": + if all(k in example.keys() for k in ("chosen", "rejected")): + chosen_messages = example["chosen"] + rejected_messages = example["rejected"] + # We add an empty system message if there is none + if chosen_messages[0]["role"] != "system": + chosen_messages.insert(0, {"role": "system", "content": ""}) + if rejected_messages[0]["role"] != "system": + rejected_messages.insert(0, {"role": "system", "content": ""}) + example["text_chosen"] = tokenizer.apply_chat_template(chosen_messages, tokenize=False) + example["text_rejected"] = tokenizer.apply_chat_template(rejected_messages, tokenize=False) + else: + raise ValueError( + f"Could not format example as dialogue for `rm` task! Require `[chosen, rejected]` keys but found {list(example.keys())}" + ) + elif task == "dpo": + if all(k in example.keys() for k in ("chosen", "rejected")): + # Compared to reward modeling, we filter out the prompt, so the text is everything after the last assistant token + prompt_messages = [[msg for msg in example["chosen"] if msg["role"] == "user"][0]] + # Insert system message + if example["chosen"][0]["role"] != "system": + prompt_messages.insert(0, {"role": "system", "content": ""}) + else: + prompt_messages.insert(0, example["chosen"][0]) + # TODO: handle case where chosen/rejected also have system messages + chosen_messages = example["chosen"][1:] + rejected_messages = example["rejected"][1:] + example["text_chosen"] = tokenizer.apply_chat_template(chosen_messages, tokenize=False) + example["text_rejected"] = tokenizer.apply_chat_template(rejected_messages, tokenize=False) + example["text_prompt"] = tokenizer.apply_chat_template( + prompt_messages, tokenize=False, add_generation_prompt=True + ) + + example["text_chosen"] = _strip_prefix(example["text_chosen"], assistant_prefix) + example["text_rejected"] = _strip_prefix(example["text_rejected"], assistant_prefix) + else: + raise ValueError( + f"Could not format example as dialogue for `dpo` task! Require `[chosen, rejected]` keys but found {list(example.keys())}" + ) + return example + + +def get_datasets( + data_config: Union[DataArguments, dict], + splits: List[str] = ["train", "test"], + shuffle: bool = True, +) -> DatasetDict: + """ + Loads one or more datasets with varying training set proportions. + + Args: + data_config (`DataArguments` or `dict`): + Dataset configuration and split proportions. + splits (`List[str]`, *optional*, defaults to `['train', 'test']`): + Dataset splits to load and mix. Assumes the splits exist in all datasets and have a `train_` or `test_` prefix. + shuffle (`bool`, *optional*, defaults to `True`): + Whether to shuffle the training data. + + Returns + [`DatasetDict`]: The dataset dictionary containing the loaded datasets. + """ + + if type(data_config) is DataArguments: + # Structure of the config to read the datasets and their mix + # datasets_mixer: + # - 'dataset1': 0.5 + # - 'dataset2': 0.3 + # - 'dataset3': 0.2 + dataset_mixer = data_config.dataset_mixer + elif type(data_config) is dict: + # Structure of the input is: + # dataset_mixer = { + # "dataset1": 0.5, + # "dataset1": 0.3, + # "dataset1": 0.2, + # } + dataset_mixer = data_config + else: + raise ValueError(f"Data config {data_config} not recognized.") + + raw_datasets = mix_datasets(dataset_mixer, splits=splits, shuffle=shuffle) + return raw_datasets + + +def mix_datasets(dataset_mixer: dict, splits: Optional[List[str]] = None, shuffle=True) -> DatasetDict: + """ + Loads and mixes datasets according to proportions specified in `dataset_mixer`. + + Args: + dataset_mixer (`dict`): + Dictionary containing the dataset names and their training proportions. By default, all test proportions are 1. + splits (Optional[List[str]], *optional*, defaults to `None`): + Dataset splits to load and mix. Assumes the splits exist in all datasets and have a `train_` or `test_` prefix. + shuffle (`bool`, *optional*, defaults to `True`): + Whether to shuffle the training data. + """ + raw_datasets = DatasetDict() + raw_train_datasets = [] + raw_val_datasets = [] + fracs = [] + for ds, frac in dataset_mixer.items(): + fracs.append(frac) + for split in splits: + if "train" in split: + raw_train_datasets.append( + load_dataset( + ds, + split=split, + ) + ) + elif "test" in split: + raw_val_datasets.append( + load_dataset( + ds, + split=split, + ) + ) + else: + raise ValueError(f"Split type {split} not recognized as one of test or train.") + + if any(frac < 0 for frac in fracs): + raise ValueError("Dataset fractions cannot be negative.") + + if len(raw_train_datasets) > 0: + train_subsets = [] + for dataset, frac in zip(raw_train_datasets, fracs): + train_subset = dataset.select(range(int(frac * len(dataset)))) + train_subsets.append(train_subset) + if shuffle: + raw_datasets["train"] = concatenate_datasets(train_subsets).shuffle(seed=42) + else: + raw_datasets["train"] = concatenate_datasets(train_subsets) + # No subsampling for test datasets to enable fair comparison across models + if len(raw_val_datasets) > 0: + if shuffle: + raw_datasets["test"] = concatenate_datasets(raw_val_datasets).shuffle(seed=42) + else: + raw_datasets["test"] = concatenate_datasets(raw_val_datasets) + + if len(raw_datasets) == 0: + raise ValueError( + f"Dataset {dataset_mixer} not recognized with split {split}. Check the dataset has been correctly formatted." + ) + + return raw_datasets diff --git a/src/alignment/model_utils.py b/src/alignment/model_utils.py new file mode 100644 index 0000000..a01f1b4 --- /dev/null +++ b/src/alignment/model_utils.py @@ -0,0 +1,79 @@ +from typing import Dict, Union + +import torch +from transformers import AutoTokenizer, BitsAndBytesConfig, PreTrainedTokenizer + +from accelerate import Accelerator +from peft import LoraConfig, PeftConfig + +from .configs import DataArguments, ModelArguments +from .data import DEFAULT_CHAT_TEMPLATE + + +def get_current_device() -> int: + """Get the current device. For GPU we return the local process index to enable multiple GPU training.""" + return Accelerator().local_process_index if torch.cuda.is_available() else "cpu" + + +def get_kbit_device_map() -> Dict[str, int] | None: + """Useful for running inference with quantized models by setting `device_map=get_peft_device_map()`""" + return {"": get_current_device()} if torch.cuda.is_available() else None + + +def get_quantization_config(model_args) -> BitsAndBytesConfig | None: + if model_args.load_in_4bit: + quantization_config = BitsAndBytesConfig( + load_in_4bit=True, + bnb_4bit_compute_dtype=torch.float16, # For consistency with model weights, we use the same value as `torch_dtype` which is float16 for PEFT models + bnb_4bit_quant_type=model_args.bnb_4bit_quant_type, + bnb_4bit_use_double_quant=model_args.use_bnb_nested_quant, + ) + elif model_args.load_in_8bit: + quantization_config = BitsAndBytesConfig( + load_in_8bit=True, + ) + else: + quantization_config = None + + return quantization_config + + +def get_tokenizer(model_args: ModelArguments, data_args: DataArguments) -> PreTrainedTokenizer: + """Get the tokenizer for the model.""" + tokenizer = AutoTokenizer.from_pretrained( + model_args.model_name_or_path, + revision=model_args.model_revision, + ) + if tokenizer.pad_token_id is None: + tokenizer.pad_token_id = tokenizer.eos_token_id + + if data_args.truncation_side is not None: + tokenizer.truncation_side = data_args.truncation_side + + # Set reasonable default for models without max length + if tokenizer.model_max_length > 100_000: + tokenizer.model_max_length = 2048 + + if data_args.chat_template is not None: + tokenizer.chat_template = data_args.chat_template + elif tokenizer.chat_template is None: + tokenizer.chat_template = DEFAULT_CHAT_TEMPLATE + + return tokenizer + + +def get_peft_config(model_args: ModelArguments) -> Union[PeftConfig, None]: + if model_args.use_peft is False: + return None + + peft_config = LoraConfig( + r=model_args.lora_r, + lora_alpha=model_args.lora_alpha, + lora_dropout=model_args.lora_dropout, + bias="none", + task_type="CAUSAL_LM", + target_modules=model_args.lora_target_modules, + modules_to_save=model_args.lora_modules_to_save, + ) + + return peft_config diff --git a/src/alignment/utils/release.py b/src/alignment/release.py similarity index 100% rename from src/alignment/utils/release.py rename to src/alignment/release.py