mirror of
https://github.com/wassname/SimPO.git
synced 2026-06-28 01:43:29 +08:00
257 lines
11 KiB
Python
257 lines
11 KiB
Python
# coding=utf-8
|
|
# Copyright 2023 The HuggingFace Team. All rights reserved.
|
|
#
|
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
# you may not use this file except in compliance with the License.
|
|
# You may obtain a copy of the License at
|
|
#
|
|
# http://www.apache.org/licenses/LICENSE-2.0
|
|
#
|
|
# Unless required by applicable law or agreed to in writing, software
|
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
# See the License for the specific language governing permissions and
|
|
# limitations under the License.
|
|
|
|
import os
|
|
from typing import Any, List, Literal, Optional
|
|
|
|
from datasets import DatasetDict, concatenate_datasets, load_dataset, load_from_disk
|
|
from datasets.builder import DatasetGenerationError
|
|
|
|
from .configs import DataArguments
|
|
|
|
|
|
DEFAULT_CHAT_TEMPLATE = "{% for message in messages %}\n{% if message['role'] == 'user' %}\n{{ '<|user|>\n' + message['content'] + eos_token }}\n{% elif message['role'] == 'system' %}\n{{ '<|system|>\n' + message['content'] + eos_token }}\n{% elif message['role'] == 'assistant' %}\n{{ '<|assistant|>\n' + message['content'] + eos_token }}\n{% endif %}\n{% if loop.last and add_generation_prompt %}\n{{ '<|assistant|>' }}\n{% endif %}\n{% endfor %}"
|
|
|
|
|
|
def maybe_insert_system_message(messages, tokenizer):
|
|
if messages[0]["role"] == "system":
|
|
return
|
|
|
|
# chat template can be one of two attributes, we check in order
|
|
chat_template = tokenizer.chat_template
|
|
if chat_template is None:
|
|
chat_template = tokenizer.default_chat_template
|
|
|
|
# confirm the jinja template refers to a system message before inserting
|
|
if "system" in chat_template or "<|im_start|>" in chat_template:
|
|
messages.insert(0, {"role": "system", "content": ""})
|
|
|
|
|
|
def apply_chat_template(
|
|
example,
|
|
tokenizer,
|
|
task: Literal["sft", "generation", "rm", "dpo"],
|
|
auto_insert_empty_system_msg: bool = True,
|
|
):
|
|
if task in ["sft", "generation"]:
|
|
messages = example["messages"]
|
|
# We add an empty system message if there is none
|
|
if auto_insert_empty_system_msg:
|
|
maybe_insert_system_message(messages, tokenizer)
|
|
example["text"] = tokenizer.apply_chat_template(
|
|
messages,
|
|
tokenize=False,
|
|
add_generation_prompt=True if task == "generation" else False,
|
|
)
|
|
elif task == "rm":
|
|
if all(k in example.keys() for k in ("chosen", "rejected")):
|
|
chosen_messages = example["chosen"]
|
|
rejected_messages = example["rejected"]
|
|
# We add an empty system message if there is none
|
|
if auto_insert_empty_system_msg:
|
|
maybe_insert_system_message(chosen_messages, tokenizer)
|
|
maybe_insert_system_message(rejected_messages, tokenizer)
|
|
|
|
example["text_chosen"] = tokenizer.apply_chat_template(chosen_messages, tokenize=False)
|
|
example["text_rejected"] = tokenizer.apply_chat_template(rejected_messages, tokenize=False)
|
|
else:
|
|
raise ValueError(
|
|
f"Could not format example as dialogue for `rm` task! Require `[chosen, rejected]` keys but found {list(example.keys())}"
|
|
)
|
|
elif task in ["dpo", "orpo"]:
|
|
if all(k in example.keys() for k in ("chosen", "rejected")):
|
|
if not is_openai_format(example["chosen"]) or not is_openai_format(example["rejected"]):
|
|
raise ValueError(
|
|
f"Could not format example as dialogue for `{task}` task! Require OpenAI format for all messages"
|
|
)
|
|
|
|
# For DPO/ORPO, the inputs are triples of (prompt, chosen, rejected), where `chosen` and `rejected` are the final turn of a dialogue
|
|
# We therefore need to extract the N-1 turns to form the prompt
|
|
if "prompt" in example and is_openai_format(example["prompt"]):
|
|
prompt_messages = example["prompt"]
|
|
chosen_messages = example["chosen"]
|
|
rejected_messages = example["rejected"]
|
|
else:
|
|
prompt_messages = example["chosen"][:-1]
|
|
# Now we extract the final turn to define chosen/rejected responses
|
|
chosen_messages = example["chosen"][-1:]
|
|
rejected_messages = example["rejected"][-1:]
|
|
|
|
# Prepend a system message if the first message is not a system message
|
|
if auto_insert_empty_system_msg:
|
|
maybe_insert_system_message(prompt_messages, tokenizer)
|
|
|
|
example["text_prompt"] = tokenizer.apply_chat_template(prompt_messages, tokenize=False)
|
|
example["text_chosen"] = tokenizer.apply_chat_template(chosen_messages, tokenize=False)
|
|
example["text_rejected"] = tokenizer.apply_chat_template(rejected_messages, tokenize=False)
|
|
else:
|
|
raise ValueError(
|
|
f"Could not format example as dialogue for `{task}` task! Require either the "
|
|
f"`[chosen, rejected]` or `[prompt, chosen, rejected]` keys but found {list(example.keys())}"
|
|
)
|
|
else:
|
|
raise ValueError(
|
|
f"Task {task} not supported, please ensure that the provided task is one of ['sft', 'generation', 'rm', 'dpo', 'orpo']"
|
|
)
|
|
return example
|
|
|
|
|
|
def is_openai_format(messages: Any) -> bool:
|
|
"""
|
|
Check if the input messages are in OpenAI format.
|
|
Args:
|
|
messages (`Any`):
|
|
Messages to check.
|
|
Returns:
|
|
`bool`: Whether the messages are in OpenAI format.
|
|
"""
|
|
if isinstance(messages, list) and all(isinstance(message, dict) for message in messages):
|
|
return all("role" in message and "content" in message for message in messages)
|
|
return False
|
|
|
|
|
|
def get_datasets(
|
|
data_config: DataArguments | dict,
|
|
splits: Optional[List[str]] = None,
|
|
configs: Optional[List[str]] = None,
|
|
columns_to_keep: Optional[List[str]] = None,
|
|
shuffle: bool = True,
|
|
) -> DatasetDict:
|
|
"""
|
|
Loads one or more datasets with varying training set proportions.
|
|
|
|
Args:
|
|
data_config (`DataArguments` or `dict`):
|
|
Dataset configuration and split proportions.
|
|
splits (`List[str]`, *optional*, defaults to `['train', 'test']`):
|
|
Dataset splits to load and mix. Assumes the splits exist in all datasets and have a `train_` or `test_` prefix.
|
|
configs (Optional[List[str]], *optional*, defaults to `None`):
|
|
List of dataset config names. If given must be the same length as 'data_config' keys.
|
|
columns_to_keep (Optional[List[str]], *optional*, defaults to `None`):
|
|
Column names to keep in the dataset. Useful in the datamixer to avoid schema conflicts,
|
|
and for cpt this should be (at least) the text column.
|
|
shuffle (`bool`, *optional*, defaults to `True`):
|
|
Whether to shuffle the training and testing/validation data.
|
|
|
|
Returns
|
|
[`DatasetDict`]: The dataset dictionary containing the loaded datasets.
|
|
"""
|
|
if type(data_config) is DataArguments:
|
|
# Structure of the config to read the datasets and their mix
|
|
# datasets_mixer:
|
|
# - 'dataset1': 0.5
|
|
# - 'dataset2': 0.3
|
|
# - 'dataset3': 0.2
|
|
dataset_mixer = data_config.dataset_mixer
|
|
elif isinstance(data_config, dict):
|
|
# Structure of the input is:
|
|
# dataset_mixer = {
|
|
# "dataset1": 0.5,
|
|
# "dataset1": 0.3,
|
|
# "dataset1": 0.2,
|
|
# }
|
|
dataset_mixer = data_config
|
|
else:
|
|
raise ValueError(f"Data config {data_config} not recognized.")
|
|
|
|
raw_datasets = mix_datasets(
|
|
dataset_mixer,
|
|
splits=splits,
|
|
configs=configs,
|
|
columns_to_keep=columns_to_keep,
|
|
shuffle=shuffle,
|
|
)
|
|
return raw_datasets
|
|
|
|
|
|
def mix_datasets(
|
|
dataset_mixer: dict,
|
|
splits: Optional[List[str]] = None,
|
|
configs: Optional[List[str]] = None,
|
|
columns_to_keep: Optional[List[str]] = None,
|
|
shuffle=True,
|
|
) -> DatasetDict:
|
|
"""
|
|
Loads and mixes datasets according to proportions specified in `dataset_mixer`.
|
|
|
|
Args:
|
|
dataset_mixer (`dict`):
|
|
Dictionary containing the dataset names and their training proportions. By default, all test proportions are 1.
|
|
splits (Optional[List[str]], *optional*, defaults to `None`):
|
|
Dataset splits to load and mix. Assumes the splits exist in all datasets and have a `train_` or `test_` prefix.
|
|
configs (Optional[List[str]], *optional*, defaults to `None`):
|
|
List of dataset config names. If given must be the same length as 'dataset_mixer' keys.
|
|
columns_to_keep (Optional[List[str]], *optional*, defaults to `None`):
|
|
Column names to keep in the dataset. Useful in the datamixer to avoid schema conflicts,
|
|
and for cpt this should be (at least) the text column.
|
|
shuffle (`bool`, *optional*, defaults to `True`):
|
|
Whether to shuffle the training and testing/validation data.
|
|
"""
|
|
splits = ["train", "test"] if splits is None else splits
|
|
configs = [None] * len(dataset_mixer) if not configs else configs
|
|
columns_to_keep = [] if columns_to_keep is None else columns_to_keep
|
|
|
|
if configs is not None and len(configs) != len(dataset_mixer):
|
|
raise ValueError("The number of given dataset config names must be the same as the given number of datasets.")
|
|
|
|
raw_datasets = DatasetDict()
|
|
raw_train_datasets = []
|
|
raw_val_datasets = []
|
|
fracs = []
|
|
for (ds, frac), ds_config in zip(dataset_mixer.items(), configs):
|
|
fracs.append(frac)
|
|
for split in splits:
|
|
try:
|
|
# Try first if dataset on a Hub repo
|
|
dataset = load_dataset(ds, ds_config, split=split)
|
|
except DatasetGenerationError:
|
|
# If not, check local dataset
|
|
dataset = load_from_disk(os.path.join(ds, split))
|
|
|
|
# Remove redundant columns to avoid schema conflicts on load
|
|
dataset = dataset.remove_columns([col for col in dataset.column_names if col not in columns_to_keep])
|
|
if "train" in split:
|
|
raw_train_datasets.append(dataset)
|
|
elif "test" in split:
|
|
raw_val_datasets.append(dataset)
|
|
else:
|
|
raise ValueError(f"Split type {split} not recognized as one of test or train.")
|
|
|
|
if any(frac < 0 for frac in fracs):
|
|
raise ValueError("Dataset fractions cannot be negative.")
|
|
|
|
if len(raw_train_datasets) > 0:
|
|
train_subsets = []
|
|
for dataset, frac in zip(raw_train_datasets, fracs):
|
|
train_subset = dataset.select(range(int(frac * len(dataset))))
|
|
train_subsets.append(train_subset)
|
|
if shuffle:
|
|
raw_datasets["train"] = concatenate_datasets(train_subsets).shuffle(seed=42)
|
|
else:
|
|
raw_datasets["train"] = concatenate_datasets(train_subsets)
|
|
# No subsampling for test datasets to enable fair comparison across models
|
|
if len(raw_val_datasets) > 0:
|
|
if shuffle:
|
|
raw_datasets["test"] = concatenate_datasets(raw_val_datasets).shuffle(seed=42)
|
|
else:
|
|
raw_datasets["test"] = concatenate_datasets(raw_val_datasets)
|
|
|
|
if len(raw_datasets) == 0:
|
|
raise ValueError(
|
|
f"Dataset {dataset_mixer} not recognized with splits {splits}. Check the dataset has been correctly formatted."
|
|
)
|
|
|
|
return raw_datasets
|