From 1aa921d74898d772ce80384a3eaa0dd3d7c82643 Mon Sep 17 00:00:00 2001 From: Yu Meng Date: Sun, 4 Aug 2024 23:12:21 -0400 Subject: [PATCH] fix alignment-handbook version --- alignment/__init__.py | 31 +++++ alignment/configs.py | 271 +++++++++++++++++++++++++++++++++++++ alignment/data.py | 256 +++++++++++++++++++++++++++++++++++ alignment/decontaminate.py | 91 +++++++++++++ alignment/model_utils.py | 128 ++++++++++++++++++ alignment/release.py | 125 +++++++++++++++++ scripts/run_simpo.py | 2 +- scripts/simpo_config.py | 1 - 8 files changed, 903 insertions(+), 2 deletions(-) create mode 100644 alignment/__init__.py create mode 100644 alignment/configs.py create mode 100644 alignment/data.py create mode 100644 alignment/decontaminate.py create mode 100644 alignment/model_utils.py create mode 100644 alignment/release.py diff --git a/alignment/__init__.py b/alignment/__init__.py new file mode 100644 index 0000000..f4fa850 --- /dev/null +++ b/alignment/__init__.py @@ -0,0 +1,31 @@ +__version__ = "0.3.0.dev0" + +from .configs import DataArguments, DPOConfig, H4ArgumentParser, ModelArguments, SFTConfig +from .data import apply_chat_template, get_datasets +# from .decontaminate import decontaminate_humaneval +from .model_utils import ( + get_checkpoint, + get_kbit_device_map, + get_peft_config, + get_quantization_config, + get_tokenizer, + is_adapter_model, +) + + +__all__ = [ + "DataArguments", + "DPOConfig", + "H4ArgumentParser", + "ModelArguments", + "SFTConfig", + "apply_chat_template", + "get_datasets", + "decontaminate_humaneval", + "get_checkpoint", + "get_kbit_device_map", + "get_peft_config", + "get_quantization_config", + "get_tokenizer", + "is_adapter_model", +] diff --git a/alignment/configs.py b/alignment/configs.py new file mode 100644 index 0000000..aff0792 --- /dev/null +++ b/alignment/configs.py @@ -0,0 +1,271 @@ +# coding=utf-8 +# Copyright 2023 The HuggingFace Team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +import dataclasses +import os +import sys +from dataclasses import dataclass, field +from typing import Any, Dict, List, NewType, Optional, Tuple + +from transformers import MODEL_FOR_CAUSAL_LM_MAPPING, HfArgumentParser + +import trl + + +MODEL_CONFIG_CLASSES = list(MODEL_FOR_CAUSAL_LM_MAPPING.keys()) +MODEL_TYPES = tuple(conf.model_type for conf in MODEL_CONFIG_CLASSES) + + +DataClassType = NewType("DataClassType", Any) + + +class H4ArgumentParser(HfArgumentParser): + def parse_yaml_and_args(self, yaml_arg: str, other_args: Optional[List[str]] = None) -> List[dataclass]: + """ + Parse a YAML file and overwrite the default/loaded values with the values provided to the command line. + + Args: + yaml_arg (`str`): + The path to the config file used + other_args (`List[str]`, *optional`): + A list of strings to parse as command line arguments, e.g. ['--arg=val', '--arg2=val2']. + + Returns: + [`List[dataclass]`]: a list of dataclasses with the values from the YAML file and the command line + """ + arg_list = self.parse_yaml_file(os.path.abspath(yaml_arg)) + + outputs = [] + # strip other args list into dict of key-value pairs + other_args = {arg.split("=")[0].strip("-"): arg.split("=")[1] for arg in other_args} + used_args = {} + + # overwrite the default/loaded value with the value provided to the command line + # adapted from https://github.com/huggingface/transformers/blob/d0b5002378daabf62769159add3e7d66d3f83c3b/src/transformers/hf_argparser.py#L327 + for data_yaml, data_class in zip(arg_list, self.dataclass_types): + keys = {f.name for f in dataclasses.fields(data_yaml) if f.init} + inputs = {k: v for k, v in vars(data_yaml).items() if k in keys} + for arg, val in other_args.items(): + # add only if in keys + + if arg in keys: + base_type = data_yaml.__dataclass_fields__[arg].type + inputs[arg] = val + + # cast type for ints, floats (default to strings) + if base_type in [int, float]: + inputs[arg] = base_type(val) + + if base_type == List[str]: + inputs[arg] = [str(v) for v in val.split(",")] + + # bool of a non-empty string is True, so we manually check for bools + if base_type is bool: + if val in ["true", "True"]: + inputs[arg] = True + else: + inputs[arg] = False + + # add to used-args so we can check if double add + if arg not in used_args: + used_args[arg] = val + else: + raise ValueError(f"Duplicate argument provided: {arg}, may cause unexpected behavior") + + obj = data_class(**inputs) + outputs.append(obj) + + return outputs + + def parse(self) -> DataClassType | Tuple[DataClassType]: + if len(sys.argv) == 2 and sys.argv[1].endswith(".yaml"): + # If we pass only one argument to the script and it's the path to a YAML file, + # let's parse it to get our arguments. + output = self.parse_yaml_file(os.path.abspath(sys.argv[1])) + # parse command line args and yaml file + elif len(sys.argv) > 2 and sys.argv[1].endswith(".yaml"): + output = self.parse_yaml_and_args(os.path.abspath(sys.argv[1]), sys.argv[2:]) + # parse command line args only + else: + output = self.parse_args_into_dataclasses() + + if len(output) == 1: + output = output[0] + return output + + +@dataclass +class ModelArguments: + """ + Arguments pertaining to which model/config/tokenizer we are going to fine-tune. + """ + + base_model_revision: Optional[str] = field( + default=None, + metadata={"help": ("The base model checkpoint for weights initialization with PEFT adapters.")}, + ) + model_name_or_path: Optional[str] = field( + default=None, + metadata={ + "help": ( + "The model checkpoint for weights initialization. Don't set if you want to train a model from scratch." + ) + }, + ) + model_revision: str = field( + default="main", + metadata={"help": "The specific model version to use (can be a branch name, tag name or commit id)."}, + ) + model_code_revision: str = field(default=None, metadata={"help": "The branch of the IFT model"}) + torch_dtype: Optional[str] = field( + default=None, + metadata={ + "help": ( + "Override the default `torch.dtype` and load the model under this dtype. If `auto` is passed, the " + "dtype will be automatically derived from the model's weights." + ), + "choices": ["auto", "bfloat16", "float16", "float32"], + }, + ) + tokenizer_name_or_path: Optional[str] = field( + default=None, + metadata={ + "help": ( + "The path to the tokenizer. Useful if you want to use a different tokenizer to the one stored in `model_name_or_path`." + ) + }, + ) + trust_remote_code: bool = field(default=False, metadata={"help": "Trust remote code when loading a model."}) + attn_implementation: Optional[str] = field( + default=None, + metadata={ + "help": ( + "Which attention implementation to use; you can use --attn_implementation=flash_attention_2, in which case you must install this manually by running `pip install flash-attn --no-build-isolation`" + ) + }, + ) + use_peft: bool = field( + default=False, + metadata={"help": ("Whether to use PEFT or not for training.")}, + ) + lora_r: Optional[int] = field( + default=16, + metadata={"help": ("LoRA R value.")}, + ) + lora_alpha: Optional[int] = field( + default=32, + metadata={"help": ("LoRA alpha.")}, + ) + lora_dropout: Optional[float] = field( + default=0.05, + metadata={"help": ("LoRA dropout.")}, + ) + lora_target_modules: Optional[List[str]] = field( + default=None, + metadata={"help": ("LoRA target modules.")}, + ) + lora_modules_to_save: Optional[List[str]] = field( + default=None, + metadata={"help": ("Model layers to unfreeze & train")}, + ) + load_in_8bit: bool = field(default=False, metadata={"help": "use 8 bit precision"}) + load_in_4bit: bool = field(default=False, metadata={"help": "use 4 bit precision"}) + + bnb_4bit_quant_type: Optional[str] = field( + default="nf4", metadata={"help": "precise the quantization type (fp4 or nf4)"} + ) + use_bnb_nested_quant: bool = field(default=False, metadata={"help": "use nested quantization"}) + bnb_4bit_quant_storage: Optional[str] = field( + default="uint8", + metadata={"help": "storage type to pack the quanitzed 4-bit prarams."}, + ) + + def __post_init__(self): + if self.load_in_8bit and self.load_in_4bit: + raise ValueError("You can't use 8 bit and 4 bit precision at the same time") + + +@dataclass +class DataArguments: + """ + Arguments pertaining to what data we are going to input our model for training and eval. + """ + + chat_template: Optional[str] = field(default=None, metadata={"help": "The chat template to use."}) + dataset_mixer: Optional[Dict[str, float]] = field( + default=None, + metadata={"help": ("Datasets and their proportions to be used for training ift/rl.")}, + ) + text_column: Optional[str] = field( + default="text", + metadata={"help": "The column name to use for the text in the dataset (only used for continued pretraining)."}, + ) + dataset_splits: Optional[List[str]] = field( + default_factory=lambda: ["train", "test"], + metadata={"help": ("List of train test splits to use in the dataset")}, + ) + dataset_configs: Optional[List[str]] = field( + default=None, + metadata={"help": "List of dataset config names. If given must be the same length as 'dataset_mixer' keys."}, + ) + preprocessing_num_workers: Optional[int] = field( + default=None, + metadata={"help": "The number of processes to use for the preprocessing."}, + ) + truncation_side: Optional[str] = field( + default=None, metadata={"help": "Truncation side to use for the tokenizer."} + ) + auto_insert_empty_system_msg: bool = field( + default=True, + metadata={ + "help": ( + "Whether to automatically insert an empty system message as the first message if `system` is mentioned in the chat template." + ) + }, + ) + + +@dataclass +class SFTConfig(trl.SFTConfig): + """ + Arguments related to the training process itself. For all parameters, see: https://huggingface.co/docs/transformers/v4.39.3/en/main_classes/trainer#transformers.TrainingArguments + Also used for the continued pretraining task. + """ + + hub_model_revision: Optional[str] = field( + default="main", + metadata={"help": ("The Hub model branch to push the model to.")}, + ) + logging_first_step: bool = field( + default=True, + metadata={"help": ("Whether to log and evaluate the first global_step or not.")}, + ) + + +@dataclass +class DPOConfig(trl.DPOConfig): + """ + Arguments related to the DPO training process itself. For all parameters, see: https://huggingface.co/docs/transformers/v4.39.3/en/main_classes/trainer#transformers.TrainingArguments + """ + + hub_model_revision: Optional[str] = field( + default="main", + metadata={"help": ("The Hub model branch to push the model to.")}, + ) + logging_first_step: bool = field( + default=True, + metadata={"help": ("Whether to log and evaluate the first global_step or not.")}, + ) + optim: Optional[str] = field(default="rmsprop") + remove_unused_columns: bool = field(default=False) diff --git a/alignment/data.py b/alignment/data.py new file mode 100644 index 0000000..84544c6 --- /dev/null +++ b/alignment/data.py @@ -0,0 +1,256 @@ +# coding=utf-8 +# Copyright 2023 The HuggingFace Team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import os +from typing import Any, List, Literal, Optional + +from datasets import DatasetDict, concatenate_datasets, load_dataset, load_from_disk +from datasets.builder import DatasetGenerationError + +from .configs import DataArguments + + +DEFAULT_CHAT_TEMPLATE = "{% for message in messages %}\n{% if message['role'] == 'user' %}\n{{ '<|user|>\n' + message['content'] + eos_token }}\n{% elif message['role'] == 'system' %}\n{{ '<|system|>\n' + message['content'] + eos_token }}\n{% elif message['role'] == 'assistant' %}\n{{ '<|assistant|>\n' + message['content'] + eos_token }}\n{% endif %}\n{% if loop.last and add_generation_prompt %}\n{{ '<|assistant|>' }}\n{% endif %}\n{% endfor %}" + + +def maybe_insert_system_message(messages, tokenizer): + if messages[0]["role"] == "system": + return + + # chat template can be one of two attributes, we check in order + chat_template = tokenizer.chat_template + if chat_template is None: + chat_template = tokenizer.default_chat_template + + # confirm the jinja template refers to a system message before inserting + if "system" in chat_template or "<|im_start|>" in chat_template: + messages.insert(0, {"role": "system", "content": ""}) + + +def apply_chat_template( + example, + tokenizer, + task: Literal["sft", "generation", "rm", "dpo"], + auto_insert_empty_system_msg: bool = True, +): + if task in ["sft", "generation"]: + messages = example["messages"] + # We add an empty system message if there is none + if auto_insert_empty_system_msg: + maybe_insert_system_message(messages, tokenizer) + example["text"] = tokenizer.apply_chat_template( + messages, + tokenize=False, + add_generation_prompt=True if task == "generation" else False, + ) + elif task == "rm": + if all(k in example.keys() for k in ("chosen", "rejected")): + chosen_messages = example["chosen"] + rejected_messages = example["rejected"] + # We add an empty system message if there is none + if auto_insert_empty_system_msg: + maybe_insert_system_message(chosen_messages, tokenizer) + maybe_insert_system_message(rejected_messages, tokenizer) + + example["text_chosen"] = tokenizer.apply_chat_template(chosen_messages, tokenize=False) + example["text_rejected"] = tokenizer.apply_chat_template(rejected_messages, tokenize=False) + else: + raise ValueError( + f"Could not format example as dialogue for `rm` task! Require `[chosen, rejected]` keys but found {list(example.keys())}" + ) + elif task in ["dpo", "orpo"]: + if all(k in example.keys() for k in ("chosen", "rejected")): + if not is_openai_format(example["chosen"]) or not is_openai_format(example["rejected"]): + raise ValueError( + f"Could not format example as dialogue for `{task}` task! Require OpenAI format for all messages" + ) + + # For DPO/ORPO, the inputs are triples of (prompt, chosen, rejected), where `chosen` and `rejected` are the final turn of a dialogue + # We therefore need to extract the N-1 turns to form the prompt + if "prompt" in example and is_openai_format(example["prompt"]): + prompt_messages = example["prompt"] + chosen_messages = example["chosen"] + rejected_messages = example["rejected"] + else: + prompt_messages = example["chosen"][:-1] + # Now we extract the final turn to define chosen/rejected responses + chosen_messages = example["chosen"][-1:] + rejected_messages = example["rejected"][-1:] + + # Prepend a system message if the first message is not a system message + if auto_insert_empty_system_msg: + maybe_insert_system_message(prompt_messages, tokenizer) + + example["text_prompt"] = tokenizer.apply_chat_template(prompt_messages, tokenize=False) + example["text_chosen"] = tokenizer.apply_chat_template(chosen_messages, tokenize=False) + example["text_rejected"] = tokenizer.apply_chat_template(rejected_messages, tokenize=False) + else: + raise ValueError( + f"Could not format example as dialogue for `{task}` task! Require either the " + f"`[chosen, rejected]` or `[prompt, chosen, rejected]` keys but found {list(example.keys())}" + ) + else: + raise ValueError( + f"Task {task} not supported, please ensure that the provided task is one of ['sft', 'generation', 'rm', 'dpo', 'orpo']" + ) + return example + + +def is_openai_format(messages: Any) -> bool: + """ + Check if the input messages are in OpenAI format. + Args: + messages (`Any`): + Messages to check. + Returns: + `bool`: Whether the messages are in OpenAI format. + """ + if isinstance(messages, list) and all(isinstance(message, dict) for message in messages): + return all("role" in message and "content" in message for message in messages) + return False + + +def get_datasets( + data_config: DataArguments | dict, + splits: Optional[List[str]] = None, + configs: Optional[List[str]] = None, + columns_to_keep: Optional[List[str]] = None, + shuffle: bool = True, +) -> DatasetDict: + """ + Loads one or more datasets with varying training set proportions. + + Args: + data_config (`DataArguments` or `dict`): + Dataset configuration and split proportions. + splits (`List[str]`, *optional*, defaults to `['train', 'test']`): + Dataset splits to load and mix. Assumes the splits exist in all datasets and have a `train_` or `test_` prefix. + configs (Optional[List[str]], *optional*, defaults to `None`): + List of dataset config names. If given must be the same length as 'data_config' keys. + columns_to_keep (Optional[List[str]], *optional*, defaults to `None`): + Column names to keep in the dataset. Useful in the datamixer to avoid schema conflicts, + and for cpt this should be (at least) the text column. + shuffle (`bool`, *optional*, defaults to `True`): + Whether to shuffle the training and testing/validation data. + + Returns + [`DatasetDict`]: The dataset dictionary containing the loaded datasets. + """ + if type(data_config) is DataArguments: + # Structure of the config to read the datasets and their mix + # datasets_mixer: + # - 'dataset1': 0.5 + # - 'dataset2': 0.3 + # - 'dataset3': 0.2 + dataset_mixer = data_config.dataset_mixer + elif isinstance(data_config, dict): + # Structure of the input is: + # dataset_mixer = { + # "dataset1": 0.5, + # "dataset1": 0.3, + # "dataset1": 0.2, + # } + dataset_mixer = data_config + else: + raise ValueError(f"Data config {data_config} not recognized.") + + raw_datasets = mix_datasets( + dataset_mixer, + splits=splits, + configs=configs, + columns_to_keep=columns_to_keep, + shuffle=shuffle, + ) + return raw_datasets + + +def mix_datasets( + dataset_mixer: dict, + splits: Optional[List[str]] = None, + configs: Optional[List[str]] = None, + columns_to_keep: Optional[List[str]] = None, + shuffle=True, +) -> DatasetDict: + """ + Loads and mixes datasets according to proportions specified in `dataset_mixer`. + + Args: + dataset_mixer (`dict`): + Dictionary containing the dataset names and their training proportions. By default, all test proportions are 1. + splits (Optional[List[str]], *optional*, defaults to `None`): + Dataset splits to load and mix. Assumes the splits exist in all datasets and have a `train_` or `test_` prefix. + configs (Optional[List[str]], *optional*, defaults to `None`): + List of dataset config names. If given must be the same length as 'dataset_mixer' keys. + columns_to_keep (Optional[List[str]], *optional*, defaults to `None`): + Column names to keep in the dataset. Useful in the datamixer to avoid schema conflicts, + and for cpt this should be (at least) the text column. + shuffle (`bool`, *optional*, defaults to `True`): + Whether to shuffle the training and testing/validation data. + """ + splits = ["train", "test"] if splits is None else splits + configs = [None] * len(dataset_mixer) if not configs else configs + columns_to_keep = [] if columns_to_keep is None else columns_to_keep + + if configs is not None and len(configs) != len(dataset_mixer): + raise ValueError("The number of given dataset config names must be the same as the given number of datasets.") + + raw_datasets = DatasetDict() + raw_train_datasets = [] + raw_val_datasets = [] + fracs = [] + for (ds, frac), ds_config in zip(dataset_mixer.items(), configs): + fracs.append(frac) + for split in splits: + try: + # Try first if dataset on a Hub repo + dataset = load_dataset(ds, ds_config, split=split) + except DatasetGenerationError: + # If not, check local dataset + dataset = load_from_disk(os.path.join(ds, split)) + + # Remove redundant columns to avoid schema conflicts on load + dataset = dataset.remove_columns([col for col in dataset.column_names if col not in columns_to_keep]) + if "train" in split: + raw_train_datasets.append(dataset) + elif "test" in split: + raw_val_datasets.append(dataset) + else: + raise ValueError(f"Split type {split} not recognized as one of test or train.") + + if any(frac < 0 for frac in fracs): + raise ValueError("Dataset fractions cannot be negative.") + + if len(raw_train_datasets) > 0: + train_subsets = [] + for dataset, frac in zip(raw_train_datasets, fracs): + train_subset = dataset.select(range(int(frac * len(dataset)))) + train_subsets.append(train_subset) + if shuffle: + raw_datasets["train"] = concatenate_datasets(train_subsets).shuffle(seed=42) + else: + raw_datasets["train"] = concatenate_datasets(train_subsets) + # No subsampling for test datasets to enable fair comparison across models + if len(raw_val_datasets) > 0: + if shuffle: + raw_datasets["test"] = concatenate_datasets(raw_val_datasets).shuffle(seed=42) + else: + raw_datasets["test"] = concatenate_datasets(raw_val_datasets) + + if len(raw_datasets) == 0: + raise ValueError( + f"Dataset {dataset_mixer} not recognized with splits {splits}. Check the dataset has been correctly formatted." + ) + + return raw_datasets diff --git a/alignment/decontaminate.py b/alignment/decontaminate.py new file mode 100644 index 0000000..45cba95 --- /dev/null +++ b/alignment/decontaminate.py @@ -0,0 +1,91 @@ +# coding=utf-8 +# Copyright 2024 The HuggingFace Inc. team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from typing import Any, Dict, List + +from datasets import load_dataset + + +# HumanEval solutions that are considered simple/generic enough to be kept in the training dataset +HUMAN_EVAL_STRINGS_OK = ["return x + y", "return len(string)", "return n**2", "return " ".join(strings)"] + + +def extract_docstring(prompt: str) -> str: + if '"""' in prompt: + if prompt.count('"""') == 2: + return prompt.split('"""')[1].strip() + elif prompt.count('"""') == 4: + return prompt.split('"""')[3].strip() + else: + raise ValueError() + elif "'''" in prompt: + assert prompt.count("'''") == 2 + return prompt.split("'''")[1].strip() + else: + raise ValueError() + + +def human_eval_docstrings() -> List[str]: + ds = load_dataset("openai_humaneval", split="test") + docstrings = [extract_docstring(v["prompt"]) for v in ds] + return docstrings + + +def load_dataset_column(dataset: str, column: str, split: str, name=None) -> List[str]: + ds = load_dataset(dataset, split=split, name=name) + res = [sample[column].strip() for sample in ds] + # Only return non-empty strings + return [sample for sample in res if len(sample) > 0] + + +FILTER_OUT = { + "human_eval_docstrings": human_eval_docstrings(), + "human_eval_solutions": [ + s + for s in load_dataset_column("openai_humaneval", "canonical_solution", "test") + if s not in HUMAN_EVAL_STRINGS_OK + ], +} + + +def normalize_whitespace(text: str) -> str: + return " ".join(text.split()) + + +def decontaminate_humaneval( + samples: List[Dict[str, Any]], text_column: str = "text", filter_out: Dict[str, List[str]] = FILTER_OUT +) -> List[Dict[str, Any]]: + """ + filter_out: Dict[str, List[str]] mapping from benchmark name to list of strings that need to be + filtered-out. + Return a list where each element is True if the corresponding file should be included in the dataset. + Otherwise, the element is False. + """ + output = [] + + for content in samples[text_column]: + content = normalize_whitespace(content.lower()) + matched = False + for _, substrings in filter_out.items(): + for substring in substrings: + if normalize_whitespace(substring.lower()) in content: + matched = True + break + if matched: + break + # we keep files that are not matched + output.append(not matched) + + return output diff --git a/alignment/model_utils.py b/alignment/model_utils.py new file mode 100644 index 0000000..122da3c --- /dev/null +++ b/alignment/model_utils.py @@ -0,0 +1,128 @@ +# coding=utf-8 +# Copyright 2023 The HuggingFace Team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +import os +from pathlib import Path +from typing import Dict + +import torch +from transformers import AutoTokenizer, BitsAndBytesConfig, PreTrainedTokenizer +from transformers.trainer_utils import get_last_checkpoint + +from accelerate import Accelerator +from huggingface_hub import list_repo_files +from huggingface_hub.utils._errors import RepositoryNotFoundError +from huggingface_hub.utils._validators import HFValidationError +from peft import LoraConfig, PeftConfig + +from .configs import DataArguments, DPOConfig, ModelArguments, SFTConfig +from .data import DEFAULT_CHAT_TEMPLATE + + +def get_current_device() -> int: + """Get the current device. For GPU we return the local process index to enable multiple GPU training.""" + return Accelerator().local_process_index if torch.cuda.is_available() else "cpu" + + +def get_kbit_device_map() -> Dict[str, int] | None: + """Useful for running inference with quantized models by setting `device_map=get_peft_device_map()`""" + return {"": get_current_device()} if torch.cuda.is_available() else None + + +def get_quantization_config(model_args: ModelArguments) -> BitsAndBytesConfig | None: + if model_args.load_in_4bit: + compute_dtype = torch.float16 + if model_args.torch_dtype not in {"auto", None}: + compute_dtype = getattr(torch, model_args.torch_dtype) + + quantization_config = BitsAndBytesConfig( + load_in_4bit=True, + bnb_4bit_compute_dtype=compute_dtype, + bnb_4bit_quant_type=model_args.bnb_4bit_quant_type, + bnb_4bit_use_double_quant=model_args.use_bnb_nested_quant, + bnb_4bit_quant_storage=model_args.bnb_4bit_quant_storage, + ) + elif model_args.load_in_8bit: + quantization_config = BitsAndBytesConfig( + load_in_8bit=True, + ) + else: + quantization_config = None + + return quantization_config + + +def get_tokenizer( + model_args: ModelArguments, data_args: DataArguments, auto_set_chat_template: bool = True +) -> PreTrainedTokenizer: + """Get the tokenizer for the model.""" + tokenizer = AutoTokenizer.from_pretrained( + ( + model_args.model_name_or_path + if model_args.tokenizer_name_or_path is None + else model_args.tokenizer_name_or_path + ), + revision=model_args.model_revision, + trust_remote_code=model_args.trust_remote_code, + ) + if tokenizer.pad_token_id is None: + tokenizer.pad_token_id = tokenizer.eos_token_id + + if data_args.truncation_side is not None: + tokenizer.truncation_side = data_args.truncation_side + + # Set reasonable default for models without max length + if tokenizer.model_max_length > 100_000: + tokenizer.model_max_length = 2048 + + if data_args.chat_template is not None: + tokenizer.chat_template = data_args.chat_template + elif auto_set_chat_template and tokenizer.chat_template is None and tokenizer.default_chat_template is None: + tokenizer.chat_template = DEFAULT_CHAT_TEMPLATE + + return tokenizer + + +def get_peft_config(model_args: ModelArguments) -> PeftConfig | None: + if model_args.use_peft is False: + return None + + peft_config = LoraConfig( + r=model_args.lora_r, + lora_alpha=model_args.lora_alpha, + lora_dropout=model_args.lora_dropout, + bias="none", + task_type="CAUSAL_LM", + target_modules=model_args.lora_target_modules, + modules_to_save=model_args.lora_modules_to_save, + ) + + return peft_config + + +def is_adapter_model(model_name_or_path: str, revision: str = "main") -> bool: + try: + # Try first if model on a Hub repo + repo_files = list_repo_files(model_name_or_path, revision=revision) + except (HFValidationError, RepositoryNotFoundError): + # If not, check local repo + repo_files = os.listdir(model_name_or_path) + return "adapter_model.safetensors" in repo_files or "adapter_model.bin" in repo_files + + +def get_checkpoint(training_args: SFTConfig | DPOConfig) -> Path | None: + last_checkpoint = None + if os.path.isdir(training_args.output_dir): + last_checkpoint = get_last_checkpoint(training_args.output_dir) + return last_checkpoint diff --git a/alignment/release.py b/alignment/release.py new file mode 100644 index 0000000..5a9f660 --- /dev/null +++ b/alignment/release.py @@ -0,0 +1,125 @@ +# coding=utf-8 +# Copyright 2023 The HuggingFace Team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import argparse +import re + +import packaging.version + + +REPLACE_PATTERNS = { + "init": ( + re.compile(r'^__version__\s+=\s+"([^"]+)"\s*$', re.MULTILINE), + '__version__ = "VERSION"\n', + ), + "setup": ( + re.compile(r'^(\s*)version\s*=\s*"[^"]+",', re.MULTILINE), + r'\1version="VERSION",', + ), + "citation": (re.compile(r"^version:\s+[^ ]+", re.MULTILINE), "version: VERSION"), + "readme": ( + re.compile(r"version\s+=\s+\{[^}]+\}", re.MULTILINE), + "version = {VERSION}", + ), +} + +README_FILE = "README.md" + +REPLACE_FILES = { + "init": "src/alignment/__init__.py", + "setup": "setup.py", + "citation": "CITATION.cff", + "readme": README_FILE, +} + + +def update_version_in_file(fname, version, pattern): + """Update the version in one file using a specific pattern.""" + with open(fname, "r", encoding="utf-8", newline="\n") as f: + code = f.read() + re_pattern, replace = REPLACE_PATTERNS[pattern] + replace = replace.replace("VERSION", version) + code = re_pattern.sub(replace, code) + with open(fname, "w", encoding="utf-8", newline="\n") as f: + f.write(code) + + +def global_version_update(version, patch=False): + """Update the version in all needed files.""" + for pattern, fname in REPLACE_FILES.items(): + update_version_in_file(fname, version, pattern) + + +def get_version(): + """Reads the current version in the __init__.""" + with open(REPLACE_FILES["init"], "r") as f: + code = f.read() + default_version = REPLACE_PATTERNS["init"][0].search(code).groups()[0] + return packaging.version.parse(default_version) + + +def pre_release_work(patch=False): + """Do all the necessary pre-release steps.""" + # First let's get the default version: base version if we are in dev, bump minor otherwise. + default_version = get_version() + if patch and default_version.is_devrelease: + raise ValueError("Can't create a patch version from the dev branch, checkout a released version!") + if default_version.is_devrelease: + default_version = default_version.base_version + elif patch: + default_version = f"{default_version.major}.{default_version.minor}.{default_version.micro + 1}" + else: + default_version = f"{default_version.major}.{default_version.minor + 1}.0" + + # Now let's ask nicely if that's the right one. + version = input(f"Which version are you releasing? [{default_version}]") + if len(version) == 0: + version = default_version + + print(f"Updating version to {version}.") + global_version_update(version, patch=patch) + + +def post_release_work(): + """Do all the necessary post-release steps.""" + # First let's get the current version + current_version = get_version() + dev_version = f"{current_version.major}.{current_version.minor + 1}.0.dev0" + current_version = current_version.base_version + + # Check with the user we got that right. + version = input(f"Which version are we developing now? [{dev_version}]") + if len(version) == 0: + version = dev_version + + print(f"Updating version to {version}.") + global_version_update(version) + + +if __name__ == "__main__": + parser = argparse.ArgumentParser() + parser.add_argument( + "--post_release", + action="store_true", + help="Whether this is pre or post release.", + ) + parser.add_argument("--patch", action="store_true", help="Whether or not this is a patch release.") + args = parser.parse_args() + if not args.post_release: + pre_release_work(patch=args.patch) + elif args.patch: + print("Nothing to do after a patch :-)") + else: + post_release_work() diff --git a/scripts/run_simpo.py b/scripts/run_simpo.py index 4ff7457..be69f3f 100644 --- a/scripts/run_simpo.py +++ b/scripts/run_simpo.py @@ -218,7 +218,7 @@ def main(): use_cache=False if training_args.gradient_checkpointing else True, device_map=get_kbit_device_map() if quantization_config is not None else None, quantization_config=quantization_config, - attn_implementation=training_args.attn_implementation, + attn_implementation=model_args.attn_implementation, ) model = model_args.model_name_or_path diff --git a/scripts/simpo_config.py b/scripts/simpo_config.py index 41247d5..8bdce08 100644 --- a/scripts/simpo_config.py +++ b/scripts/simpo_config.py @@ -69,4 +69,3 @@ class SimPOConfig(TrainingArguments): dataset_num_proc: Optional[int] = None - attn_implementation: str = None