mirror of
https://github.com/wassname/alignment-handbook.git
synced 2026-06-27 17:14:25 +08:00
Add run_orpo.py (#143)
* Add `ORPOConfig` * Add `task=orpo` and support `(prompt,chosen,rejected)` datasets * Add missing `model_init_kwargs` and `dataset_num_proc` * Add `run_orpo.py` (WIP) * Update `trl` dependency from source * Add `setup_chat_format` before `apply_chat_template` * Add `config_full.yaml` for `mistral-7b-orpo` * Fix comment indentation * Use `chat_template=chatml` instead * Add `kaist-ai/mistral-orpo-capybara-7k` recipe * Rename `DPOTrainer` to `ORPOTrainer` in `config_full.yaml` files * Run `black --line-length 119 src` * Add `is_openai_format` to fix `(prompt,chosen,rejected)` formatting * Run `black --line-length 119 src` * Fix `isort` in `run_orpo.py` * Update `mistral-capybara/orpo/config_full.yaml` * Check if `test` is available split * Pin `trl` to `alvarobartt/trl` fork (debugging) * Add `qwen-capybara` recipe * Update `mistral-capybara` recipe * Set `add_generation_prompt=True` if `task="orpo"` * Reduce `logging_steps` to 10 * Unset `add_generation_prompt` when `task=orpo` * Add filtering based on prompt length Done similarly to the original implementation, in order to better reproduce their results * Fix prompt length filtering * Update `trl` pinned version * Remove extra outdate config files * Update `recipes/mistral-capybara/orpo/config_full.yaml` * Run `make style` * Activate BEAST MODE * Pin deps * Add readme * Fix dep --------- Co-authored-by: Lewis Tunstall <lewis.c.tunstall@gmail.com>
This commit is contained in:
@@ -0,0 +1,26 @@
|
||||
compute_environment: LOCAL_MACHINE
|
||||
debug: false
|
||||
distributed_type: FSDP
|
||||
downcast_bf16: 'no'
|
||||
enable_cpu_affinity: false
|
||||
fsdp_config:
|
||||
fsdp_auto_wrap_policy: TRANSFORMER_BASED_WRAP
|
||||
fsdp_backward_prefetch: BACKWARD_PRE
|
||||
fsdp_cpu_ram_efficient_loading: true
|
||||
fsdp_forward_prefetch: true
|
||||
fsdp_offload_params: false
|
||||
fsdp_sharding_strategy: FULL_SHARD
|
||||
fsdp_state_dict_type: SHARDED_STATE_DICT
|
||||
fsdp_sync_module_states: true
|
||||
fsdp_use_orig_params: true
|
||||
machine_rank: 0
|
||||
main_training_function: main
|
||||
mixed_precision: bf16
|
||||
num_machines: 1
|
||||
num_processes: 8
|
||||
rdzv_backend: static
|
||||
same_network: true
|
||||
tpu_env: []
|
||||
tpu_use_cluster: false
|
||||
tpu_use_sudo: false
|
||||
use_cpu: false
|
||||
@@ -1,5 +1,5 @@
|
||||
|
||||
# Instructions to StarChat2
|
||||
# Instructions to train StarChat2
|
||||
|
||||
Similar to how we trained Zephyr 7B Beta in our [technical report](https://huggingface.co/papers/2310.16944), training this model proceeds in two steps:
|
||||
|
||||
|
||||
@@ -0,0 +1,23 @@
|
||||
|
||||
# Instructions to train Zephyr-141B-A35B with ORPO
|
||||
|
||||
This model is fine-tuned via a novel alignment algorithm called [Odds Ratio Preference Optimization (ORPO)](https://huggingface.co/papers/2403.07691). ORPO does not require an SFT step to achieve high performance and is thus much more computationally efficient than methods like DPO and PPO. To train Zephyr-141B-A35B, we used the [`argilla/distilabel-capybara-dpo-7k-binarized`](https://huggingface.co/datasets/argilla/distilabel-capybara-dpo-7k-binarized) preference dataset, which consists of synthetic, high-quality, multi-turn preferences that have been scored via LLMs.
|
||||
|
||||
See below for commands to train these models using FSDP. **Note:** we found it was not possible to train this large model with DeepSpeed ZeRO-3 due to unresolved NCCL errors which cause GPUs to hang.
|
||||
|
||||
## Full training examples
|
||||
|
||||
You will require 4 nodes of 8 GPUs (80GB of VRAM) to train the full model - alternatively, you may be able to train on fewer GPUs by adjusting `per_device_train_batch_size` and `gradient_accumulation_steps` and `num_train_epochs` to keep the global batch size constant. A recipe involving QLoRA will come later 🤗.
|
||||
|
||||
To run with Slurm, use:
|
||||
|
||||
```shell
|
||||
sbatch --job-name=handbook_sft --nodes=4 recipes/launch.slurm zephyr-141b-A35b orpo full fsdp
|
||||
```
|
||||
|
||||
Under the hood, this calls the following script which can be adapted to other models and datasets:
|
||||
|
||||
|
||||
```shell
|
||||
ACCELERATE_LOG_LEVEL=info TRANSFORMERS_VERBOSITY=info accelerate launch --config_file recipes/accelerate_configs/fsdp.yaml scripts/run_orpo.py recipes/zephyr-141b-A35b/orpo/config_full.yaml
|
||||
```
|
||||
@@ -0,0 +1,39 @@
|
||||
# Model arguments
|
||||
model_name_or_path: mistral-community/Mixtral-8x22B-v0.1
|
||||
model_revision: main
|
||||
torch_dtype: bfloat16
|
||||
use_flash_attention_2: true
|
||||
|
||||
# Data training arguments
|
||||
chat_template: "{% for message in messages %}\n{% if message['role'] == 'user' %}\n{{ '<|user|>\n' + message['content'] + eos_token }}\n{% elif message['role'] == 'system' %}\n{{ '<|system|>\n' + message['content'] + eos_token }}\n{% elif message['role'] == 'assistant' %}\n{{ '<|assistant|>\n' + message['content'] + eos_token }}\n{% endif %}\n{% if loop.last and add_generation_prompt %}\n{{ '<|assistant|>' }}\n{% endif %}\n{% endfor %}"
|
||||
dataset_mixer:
|
||||
argilla/distilabel-capybara-dpo-7k-binarized: 1.0
|
||||
dataset_splits:
|
||||
- train
|
||||
preprocessing_num_workers: 8
|
||||
|
||||
# ORPOTrainer arguments
|
||||
bf16: true
|
||||
beta: 0.05
|
||||
gradient_accumulation_steps: 1
|
||||
gradient_checkpointing: true
|
||||
gradient_checkpointing_kwargs:
|
||||
use_reentrant: true
|
||||
hub_model_id: zephyr-orpo-141b-A35b
|
||||
learning_rate: 5.0e-6
|
||||
log_level: info
|
||||
logging_steps: 10
|
||||
lr_scheduler_type: inverse_sqrt
|
||||
max_length: 2048
|
||||
max_prompt_length: 1792
|
||||
num_train_epochs: 3
|
||||
optim: adamw_bnb_8bit
|
||||
output_dir: data/zephyr-orpo-141b-A35b
|
||||
per_device_train_batch_size: 1
|
||||
push_to_hub: true
|
||||
report_to:
|
||||
- tensorboard
|
||||
- wandb
|
||||
save_strategy: "no"
|
||||
seed: 42
|
||||
warmup_steps: 100
|
||||
@@ -0,0 +1,271 @@
|
||||
#!/usr/bin/env python
|
||||
# coding=utf-8
|
||||
# Copyright 2023 The HuggingFace Inc. team. All rights reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
import logging
|
||||
import random
|
||||
import sys
|
||||
from typing import Any, Dict
|
||||
|
||||
import torch
|
||||
import transformers
|
||||
from transformers import AutoModelForCausalLM, set_seed
|
||||
|
||||
from alignment import (
|
||||
DataArguments,
|
||||
H4ArgumentParser,
|
||||
ModelArguments,
|
||||
apply_chat_template,
|
||||
decontaminate_humaneval,
|
||||
get_checkpoint,
|
||||
get_datasets,
|
||||
get_kbit_device_map,
|
||||
get_peft_config,
|
||||
get_quantization_config,
|
||||
get_tokenizer,
|
||||
)
|
||||
from alignment.configs import ORPOConfig
|
||||
from trl import ORPOTrainer, setup_chat_format
|
||||
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
def main():
|
||||
parser = H4ArgumentParser((ModelArguments, DataArguments, ORPOConfig))
|
||||
model_args, data_args, training_args = parser.parse()
|
||||
|
||||
#######
|
||||
# Setup
|
||||
#######
|
||||
logging.basicConfig(
|
||||
format="%(asctime)s - %(levelname)s - %(name)s - %(message)s",
|
||||
datefmt="%Y-%m-%d %H:%M:%S",
|
||||
handlers=[logging.StreamHandler(sys.stdout)],
|
||||
)
|
||||
log_level = training_args.get_process_log_level()
|
||||
logger.setLevel(log_level)
|
||||
transformers.utils.logging.set_verbosity(log_level)
|
||||
transformers.utils.logging.enable_default_handler()
|
||||
transformers.utils.logging.enable_explicit_format()
|
||||
|
||||
# Log on each process the small summary:
|
||||
logger.info(f"Model parameters {model_args}")
|
||||
logger.info(f"Data parameters {data_args}")
|
||||
logger.info(f"Training/evaluation parameters {training_args}")
|
||||
|
||||
# Check for last checkpoint
|
||||
last_checkpoint = get_checkpoint(training_args)
|
||||
if last_checkpoint is not None and training_args.resume_from_checkpoint is None:
|
||||
logger.info(f"Checkpoint detected, resuming training at {last_checkpoint=}.")
|
||||
|
||||
# Set seed for reproducibility
|
||||
set_seed(training_args.seed)
|
||||
|
||||
###############
|
||||
# Load datasets
|
||||
###############
|
||||
raw_datasets = get_datasets(
|
||||
data_args,
|
||||
splits=data_args.dataset_splits,
|
||||
configs=data_args.dataset_configs,
|
||||
columns_to_keep=[
|
||||
"prompt",
|
||||
"chosen",
|
||||
"rejected",
|
||||
],
|
||||
)
|
||||
logger.info(
|
||||
f"Training on the following splits: {[split + ' : ' + str(dset.num_rows) for split, dset in raw_datasets.items()]}"
|
||||
)
|
||||
column_names = list(raw_datasets["train"].features)
|
||||
|
||||
#####################################
|
||||
# Load tokenizer and process datasets
|
||||
#####################################
|
||||
data_args.truncation_side = "left" # Truncate from left to ensure we don't lose labels in final turn
|
||||
tokenizer = get_tokenizer(model_args, data_args)
|
||||
|
||||
torch_dtype = (
|
||||
model_args.torch_dtype if model_args.torch_dtype in ["auto", None] else getattr(torch, model_args.torch_dtype)
|
||||
)
|
||||
quantization_config = get_quantization_config(model_args)
|
||||
|
||||
model = AutoModelForCausalLM.from_pretrained(
|
||||
model_args.model_name_or_path,
|
||||
revision=model_args.model_revision,
|
||||
trust_remote_code=model_args.trust_remote_code,
|
||||
use_flash_attention_2=model_args.use_flash_attention_2,
|
||||
torch_dtype=torch_dtype,
|
||||
use_cache=False if training_args.gradient_checkpointing else True,
|
||||
device_map=get_kbit_device_map() if quantization_config is not None else None,
|
||||
quantization_config=quantization_config,
|
||||
)
|
||||
|
||||
# For ChatML we need to add special tokens and resize the embedding layer
|
||||
if "<|im_start|>" in tokenizer.chat_template:
|
||||
model, tokenizer = setup_chat_format(model, tokenizer)
|
||||
|
||||
#####################
|
||||
# Apply chat template
|
||||
#####################
|
||||
raw_datasets = raw_datasets.map(
|
||||
apply_chat_template,
|
||||
fn_kwargs={
|
||||
"tokenizer": tokenizer,
|
||||
"task": "orpo",
|
||||
"auto_insert_empty_system_msg": data_args.auto_insert_empty_system_msg,
|
||||
},
|
||||
num_proc=data_args.preprocessing_num_workers,
|
||||
remove_columns=column_names,
|
||||
desc="Formatting comparisons with prompt template",
|
||||
)
|
||||
|
||||
#############################
|
||||
# Filter out seq > max_length
|
||||
#############################
|
||||
if training_args.max_prompt_length is not None:
|
||||
unfiltered_train_samples = len(raw_datasets["train"])
|
||||
if "test" in raw_datasets:
|
||||
unfiltered_test_samples = len(raw_datasets["test"])
|
||||
|
||||
def filter_fn(sample: Dict[str, Any]) -> Dict[str, Any]:
|
||||
prompt_length = tokenizer(
|
||||
sample["text_prompt"],
|
||||
return_tensors="pt",
|
||||
)[
|
||||
"input_ids"
|
||||
].size(dim=-1)
|
||||
|
||||
return prompt_length < training_args.max_prompt_length
|
||||
|
||||
raw_datasets = raw_datasets.filter(
|
||||
filter_fn,
|
||||
desc="Filtering out the samples where len(text_prompt) > max_prompt_length",
|
||||
)
|
||||
|
||||
filtered_train_samples = unfiltered_train_samples - len(raw_datasets["train"])
|
||||
logger.info(
|
||||
f"Filtered out {filtered_train_samples} training samples out of the {unfiltered_train_samples} samples."
|
||||
)
|
||||
if "test" in raw_datasets:
|
||||
filtered_test_samples = unfiltered_test_samples - len(raw_datasets["test"])
|
||||
logger.info(
|
||||
f"Filtered out {filtered_test_samples} test samples out of the {unfiltered_test_samples} samples."
|
||||
)
|
||||
|
||||
##########################
|
||||
# Decontaminate benchmarks
|
||||
##########################
|
||||
num_raw_train_samples = len(raw_datasets["train"])
|
||||
raw_datasets = raw_datasets.filter(
|
||||
decontaminate_humaneval,
|
||||
fn_kwargs={"text_column": "text_chosen"},
|
||||
batched=True,
|
||||
batch_size=10_000,
|
||||
num_proc=1,
|
||||
desc="Decontaminating HumanEval samples",
|
||||
)
|
||||
num_filtered_train_samples = num_raw_train_samples - len(raw_datasets["train"])
|
||||
logger.info(
|
||||
f"Decontaminated {num_filtered_train_samples} ({num_filtered_train_samples/num_raw_train_samples * 100:.2f}%) samples from the training set."
|
||||
)
|
||||
|
||||
# Replace column names with what TRL needs, text_prompt -> prompt, text_chosen -> chosen and text_rejected -> rejected
|
||||
for split in raw_datasets.keys():
|
||||
raw_datasets[split] = raw_datasets[split].rename_columns(
|
||||
{
|
||||
"text_prompt": "prompt",
|
||||
"text_chosen": "chosen",
|
||||
"text_rejected": "rejected",
|
||||
}
|
||||
)
|
||||
|
||||
# Log a few random samples from the training set:
|
||||
for index in random.sample(range(len(raw_datasets["train"])), 3):
|
||||
logger.info(f"Prompt sample {index} of the raw training set:\n\n{raw_datasets['train'][index]['prompt']}")
|
||||
logger.info(f"Chosen sample {index} of the raw training set:\n\n{raw_datasets['train'][index]['chosen']}")
|
||||
logger.info(f"Rejected sample {index} of the raw training set:\n\n{raw_datasets['train'][index]['rejected']}")
|
||||
|
||||
##########################
|
||||
# Instantiate ORPO trainer
|
||||
##########################
|
||||
trainer = ORPOTrainer(
|
||||
model,
|
||||
args=training_args,
|
||||
train_dataset=raw_datasets["train"],
|
||||
eval_dataset=raw_datasets["test"] if "test" in raw_datasets else None,
|
||||
tokenizer=tokenizer,
|
||||
peft_config=get_peft_config(model_args), # type: ignore
|
||||
)
|
||||
|
||||
###############
|
||||
# Training loop
|
||||
###############
|
||||
checkpoint = None
|
||||
if training_args.resume_from_checkpoint is not None:
|
||||
checkpoint = training_args.resume_from_checkpoint
|
||||
elif last_checkpoint is not None:
|
||||
checkpoint = last_checkpoint
|
||||
train_result = trainer.train(resume_from_checkpoint=checkpoint)
|
||||
metrics = train_result.metrics
|
||||
metrics["train_samples"] = len(raw_datasets["train"])
|
||||
trainer.log_metrics("train", metrics)
|
||||
trainer.save_metrics("train", metrics)
|
||||
trainer.save_state()
|
||||
|
||||
logger.info("*** Training complete ***")
|
||||
|
||||
##################################
|
||||
# Save model and create model card
|
||||
##################################
|
||||
logger.info("*** Save model ***")
|
||||
if trainer.is_fsdp_enabled:
|
||||
trainer.accelerator.state.fsdp_plugin.set_state_dict_type("FULL_STATE_DICT")
|
||||
trainer.save_model(training_args.output_dir)
|
||||
logger.info(f"Model saved to {training_args.output_dir}")
|
||||
|
||||
# Save everything else on main process
|
||||
kwargs = {
|
||||
"finetuned_from": model_args.model_name_or_path,
|
||||
"dataset": list(data_args.dataset_mixer.keys()),
|
||||
"dataset_tags": list(data_args.dataset_mixer.keys()),
|
||||
"tags": ["alignment-handbook"],
|
||||
}
|
||||
if trainer.accelerator.is_main_process:
|
||||
trainer.create_model_card(**kwargs)
|
||||
# Restore k,v cache for fast inference
|
||||
trainer.model.config.use_cache = True
|
||||
trainer.model.config.save_pretrained(training_args.output_dir)
|
||||
|
||||
##########
|
||||
# Evaluate
|
||||
##########
|
||||
if training_args.do_eval and "test" in raw_datasets:
|
||||
logger.info("*** Evaluate ***")
|
||||
metrics = trainer.evaluate()
|
||||
metrics["eval_samples"] = len(raw_datasets["test"])
|
||||
trainer.log_metrics("eval", metrics)
|
||||
trainer.save_metrics("eval", metrics)
|
||||
|
||||
if training_args.push_to_hub is True:
|
||||
logger.info("Pushing to hub...")
|
||||
trainer.push_to_hub(**kwargs)
|
||||
|
||||
logger.info("*** Training complete! ***")
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
||||
@@ -41,10 +41,10 @@ if stale_egg_info.exists():
|
||||
# IMPORTANT: all dependencies should be listed here with their version requirements, if any.
|
||||
# * If a dependency is fast-moving (e.g. transformers), pin to the exact version
|
||||
_deps = [
|
||||
"accelerate==0.27.2",
|
||||
"accelerate>=0.29.2",
|
||||
"bitsandbytes==0.41.2.post2",
|
||||
"black==23.1.0",
|
||||
"datasets==2.14.6",
|
||||
"datasets>=2.18.0",
|
||||
"deepspeed==0.12.2",
|
||||
"einops>=0.6.1",
|
||||
"evaluate==0.4.0",
|
||||
@@ -65,8 +65,8 @@ _deps = [
|
||||
"scipy",
|
||||
"tensorboard",
|
||||
"torch==2.1.2",
|
||||
"transformers @ git+https://github.com/huggingface/transformers.git@831bc25d8fdb85768402f772cf65cc3d7872b211", # Enable StarCoder2
|
||||
"trl==0.7.10",
|
||||
"transformers>=4.39.3",
|
||||
"trl>=0.8.2",
|
||||
"jinja2>=3.0.0",
|
||||
"tqdm>=4.64.1",
|
||||
]
|
||||
|
||||
@@ -281,3 +281,62 @@ class DPOConfig(transformers.TrainingArguments):
|
||||
optim: Optional[str] = field(default="rmsprop")
|
||||
remove_unused_columns: bool = field(default=False)
|
||||
loss_type: Optional[str] = field(default="sigmoid", metadata={"help": ("The loss type for DPO.")})
|
||||
|
||||
|
||||
@dataclass
|
||||
class ORPOConfig(transformers.TrainingArguments):
|
||||
max_length: Optional[int] = field(
|
||||
default=None,
|
||||
metadata={"help": "The maximum length of the sequences in the batch."},
|
||||
)
|
||||
max_prompt_length: Optional[int] = field(
|
||||
default=None,
|
||||
metadata={"help": "The maximum length of the prompt."},
|
||||
)
|
||||
max_completion_length: Optional[int] = field(
|
||||
default=None,
|
||||
metadata={"help": "The maximum length of the completions."},
|
||||
)
|
||||
|
||||
beta: float = field(
|
||||
default=0.1,
|
||||
metadata={
|
||||
"help": "The beta factor in ORPO loss (lambda/alpha in paper/code) that is the weight of the relative loss ratio in the SFT loss."
|
||||
},
|
||||
)
|
||||
disable_dropout: bool = field(
|
||||
default=True,
|
||||
metadata={"help": "Whether or not to disable dropouts in `model`."},
|
||||
)
|
||||
|
||||
label_pad_token_id: int = field(
|
||||
default=-100,
|
||||
metadata={"help": "The label pad token id."},
|
||||
)
|
||||
padding_value: Optional[int] = field(
|
||||
default=None,
|
||||
metadata={"help": "The padding value if it is different to the tokenizer's pad_token_id."},
|
||||
)
|
||||
truncation_mode: str = field(
|
||||
default="keep_end",
|
||||
metadata={"help": "The truncation mode to use, either `keep_end` or `keep_start`."},
|
||||
)
|
||||
|
||||
generate_during_eval: bool = field(
|
||||
default=False,
|
||||
metadata={"help": "Whether to sample and log generations during evaluation step."},
|
||||
)
|
||||
is_encoder_decoder: Optional[bool] = field(
|
||||
default=None,
|
||||
metadata={"help": ("If no model is provided, we need to know if the model_init returns an encoder-decoder.")},
|
||||
)
|
||||
|
||||
model_init_kwargs: Optional[Dict] = field(
|
||||
default=None,
|
||||
metadata={"help": ("Dict of Optional kwargs to pass when instantiating the model from a string")},
|
||||
)
|
||||
|
||||
dataset_num_proc: Optional[int] = field(
|
||||
default=None,
|
||||
metadata={"help": ("The number of workers to use to tokenize the data.")},
|
||||
)
|
||||
|
||||
+45
-12
@@ -12,8 +12,9 @@
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
import os
|
||||
from typing import List, Literal, Optional
|
||||
from typing import Any, List, Literal, Optional
|
||||
|
||||
from datasets import DatasetDict, concatenate_datasets, load_dataset, load_from_disk
|
||||
from datasets.builder import DatasetGenerationError
|
||||
@@ -50,7 +51,9 @@ def apply_chat_template(
|
||||
if auto_insert_empty_system_msg:
|
||||
maybe_insert_system_message(messages, tokenizer)
|
||||
example["text"] = tokenizer.apply_chat_template(
|
||||
messages, tokenize=False, add_generation_prompt=True if task == "generation" else False
|
||||
messages,
|
||||
tokenize=False,
|
||||
add_generation_prompt=True if task == "generation" else False,
|
||||
)
|
||||
elif task == "rm":
|
||||
if all(k in example.keys() for k in ("chosen", "rejected")):
|
||||
@@ -67,32 +70,58 @@ def apply_chat_template(
|
||||
raise ValueError(
|
||||
f"Could not format example as dialogue for `rm` task! Require `[chosen, rejected]` keys but found {list(example.keys())}"
|
||||
)
|
||||
elif task == "dpo":
|
||||
elif task in ["dpo", "orpo"]:
|
||||
if all(k in example.keys() for k in ("chosen", "rejected")):
|
||||
# For DPO, the inputs are triples of (prompt, chosen, rejected), where `chosen` and `rejected` are the final turn of a dialogue
|
||||
if not is_openai_format(example["chosen"]) or not is_openai_format(example["rejected"]):
|
||||
raise ValueError(
|
||||
f"Could not format example as dialogue for `{task}` task! Require OpenAI format for all messages"
|
||||
)
|
||||
|
||||
# For DPO/ORPO, the inputs are triples of (prompt, chosen, rejected), where `chosen` and `rejected` are the final turn of a dialogue
|
||||
# We therefore need to extract the N-1 turns to form the prompt
|
||||
prompt_messages = example["chosen"][:-1]
|
||||
if "prompt" in example and is_openai_format(example["prompt"]):
|
||||
prompt_messages = example["prompt"]
|
||||
chosen_messages = example["chosen"]
|
||||
rejected_messages = example["rejected"]
|
||||
else:
|
||||
prompt_messages = example["chosen"][:-1]
|
||||
# Now we extract the final turn to define chosen/rejected responses
|
||||
chosen_messages = example["chosen"][-1:]
|
||||
rejected_messages = example["rejected"][-1:]
|
||||
|
||||
# Prepend a system message if the first message is not a system message
|
||||
if auto_insert_empty_system_msg:
|
||||
maybe_insert_system_message(prompt_messages, tokenizer)
|
||||
|
||||
# Now we extract the final turn to define chosen/rejected responses
|
||||
chosen_messages = example["chosen"][-1:]
|
||||
rejected_messages = example["rejected"][-1:]
|
||||
example["text_prompt"] = tokenizer.apply_chat_template(prompt_messages, tokenize=False)
|
||||
example["text_chosen"] = tokenizer.apply_chat_template(chosen_messages, tokenize=False)
|
||||
example["text_rejected"] = tokenizer.apply_chat_template(rejected_messages, tokenize=False)
|
||||
example["text_prompt"] = tokenizer.apply_chat_template(prompt_messages, tokenize=False)
|
||||
else:
|
||||
raise ValueError(
|
||||
f"Could not format example as dialogue for `dpo` task! Require `[chosen, rejected]` keys but found {list(example.keys())}"
|
||||
f"Could not format example as dialogue for `{task}` task! Require either the "
|
||||
f"`[chosen, rejected]` or `[prompt, chosen, rejected]` keys but found {list(example.keys())}"
|
||||
)
|
||||
else:
|
||||
raise ValueError(
|
||||
f"Task {task} not supported, please ensure that the provided task is one of {['sft', 'generation', 'rm', 'dpo']}"
|
||||
f"Task {task} not supported, please ensure that the provided task is one of ['sft', 'generation', 'rm', 'dpo', 'orpo']"
|
||||
)
|
||||
return example
|
||||
|
||||
|
||||
def is_openai_format(messages: Any) -> bool:
|
||||
"""
|
||||
Check if the input messages are in OpenAI format.
|
||||
Args:
|
||||
messages (`Any`):
|
||||
Messages to check.
|
||||
Returns:
|
||||
`bool`: Whether the messages are in OpenAI format.
|
||||
"""
|
||||
if isinstance(messages, list) and all(isinstance(message, dict) for message in messages):
|
||||
return all("role" in message and "content" in message for message in messages)
|
||||
return False
|
||||
|
||||
|
||||
def get_datasets(
|
||||
data_config: DataArguments | dict,
|
||||
splits: Optional[List[str]] = None,
|
||||
@@ -138,7 +167,11 @@ def get_datasets(
|
||||
raise ValueError(f"Data config {data_config} not recognized.")
|
||||
|
||||
raw_datasets = mix_datasets(
|
||||
dataset_mixer, splits=splits, configs=configs, columns_to_keep=columns_to_keep, shuffle=shuffle
|
||||
dataset_mixer,
|
||||
splits=splits,
|
||||
configs=configs,
|
||||
columns_to_keep=columns_to_keep,
|
||||
shuffle=shuffle,
|
||||
)
|
||||
return raw_datasets
|
||||
|
||||
|
||||
Reference in New Issue
Block a user