mirror of
https://github.com/wassname/Open-Assistant.git
synced 2026-07-03 17:10:10 +08:00
240 lines
8.1 KiB
Python
240 lines
8.1 KiB
Python
"""
|
|
author: theblackcat102
|
|
|
|
Dataset output format from __getitem__
|
|
|
|
- question / prompt : string
|
|
|
|
- answers / rows : list of tuple pair. The first element in the tuple pair must be the positive pair (rank higher than the second element)
|
|
|
|
A list of rank based dataset for training using rank loss
|
|
|
|
Some nice features to have
|
|
|
|
[] support additional negative samples generated from other models.
|
|
|
|
For example we can use galactica-125m to generate a TLDR and assume it was
|
|
inferior than the human perference one
|
|
|
|
|
|
"""
|
|
from dataclasses import dataclass
|
|
from typing import Dict, List, Optional, Union
|
|
|
|
import numpy as np
|
|
import torch
|
|
from datasets import load_dataset
|
|
from torch.utils.data import Dataset
|
|
from transformers.tokenization_utils_base import PaddingStrategy, PreTrainedTokenizerBase
|
|
|
|
|
|
@dataclass
|
|
class RankGenCollator:
|
|
tokenizer: PreTrainedTokenizerBase
|
|
padding: Union[bool, str, PaddingStrategy] = True
|
|
max_length: Optional[int] = None
|
|
max_examples: Optional[int] = None
|
|
|
|
def __call__(self, batch: List[Dict[str, str]]) -> Dict[str, torch.Tensor]:
|
|
prefixes = []
|
|
better_answers = []
|
|
worse_answers = []
|
|
for question, pairs in batch:
|
|
for (pos, neg) in pairs:
|
|
prefixes.append("pre " + question)
|
|
better_answers.append("suffi " + pos)
|
|
worse_answers.append("suffi " + neg)
|
|
|
|
tokenized_prefixes = self.tokenizer(
|
|
prefixes, return_tensors="pt", padding=self.padding, max_length=self.max_length, truncation=True
|
|
)
|
|
tokenized_pos = self.tokenizer(
|
|
better_answers, return_tensors="pt", padding=self.padding, max_length=self.max_length, truncation=True
|
|
)
|
|
tokenized_neg = self.tokenizer(
|
|
worse_answers, return_tensors="pt", padding=self.padding, max_length=self.max_length, truncation=True
|
|
)
|
|
return {"prefix": tokenized_prefixes, "positive": tokenized_pos, "negative": tokenized_neg}
|
|
|
|
|
|
@dataclass
|
|
class DataCollatorForPairRank:
|
|
"""
|
|
|
|
Data collator that will dynamically pad the inputs for multiple choice received.
|
|
|
|
"""
|
|
|
|
tokenizer: PreTrainedTokenizerBase
|
|
num_choices: int = 2
|
|
padding: Union[bool, str, PaddingStrategy] = True
|
|
max_length: Optional[int] = None
|
|
pad_to_multiple_of: Optional[int] = None
|
|
drop_token_type: bool = False # galactica
|
|
|
|
def __call__(self, features):
|
|
|
|
flatten_features = []
|
|
batch_size = 0
|
|
for question, pairs in features:
|
|
for (pos, neg) in pairs:
|
|
flatten_features.append(self.tokenizer(question, pos, truncation=True, max_length=self.max_length))
|
|
flatten_features.append(self.tokenizer(question, neg, truncation=True, max_length=self.max_length))
|
|
batch_size += 1
|
|
|
|
batch = self.tokenizer.pad(
|
|
flatten_features,
|
|
padding=self.padding,
|
|
max_length=self.max_length,
|
|
pad_to_multiple_of=self.pad_to_multiple_of,
|
|
return_tensors="pt",
|
|
)
|
|
if self.drop_token_type:
|
|
batch.pop("token_type_ids")
|
|
# batch = {k: v.view(batch_size, self.num_choices, -1) for k, v in batch.items()}
|
|
return batch
|
|
|
|
|
|
class WebGPT(Dataset):
|
|
def __init__(self) -> None:
|
|
super().__init__()
|
|
|
|
dataset = load_dataset("openai/webgpt_comparisons")
|
|
questions = {}
|
|
# using prompt as our index will allows us
|
|
# to add additional generated prompt later
|
|
self.index2question = {}
|
|
for row in dataset["train"]:
|
|
question = row["question"]["full_text"]
|
|
if question not in self.index2question:
|
|
self.index2question[len(self.index2question)] = question
|
|
|
|
if question not in questions:
|
|
questions[question] = []
|
|
|
|
if row["score_0"] > row["score_1"]:
|
|
# not going to risk it
|
|
questions[question].append((row["answer_0"], row["answer_1"]))
|
|
else:
|
|
questions[question].append((row["answer_1"], row["answer_0"]))
|
|
|
|
self.questions = questions
|
|
|
|
def __len__(self):
|
|
return len(self.index2question)
|
|
|
|
def __getitem__(self, index):
|
|
question = self.index2question[index]
|
|
rows = self.questions[question]
|
|
# optimize the format later
|
|
return question, rows
|
|
|
|
|
|
class HFSummary(Dataset):
|
|
"""
|
|
Human feedback data from OpenAI
|
|
https://github.com/openai/summarize-from-feedback
|
|
|
|
labeling method : pair comparison, 0 or 1
|
|
|
|
"""
|
|
|
|
def __init__(self, split="train", conf_threshold=-1, max_comparison_per_sample=3) -> None:
|
|
super().__init__()
|
|
assert split in ("train", "valid1", "valid2", "test")
|
|
summaries = {}
|
|
# using prompt as our index will allows us
|
|
# to add additional generated prompt later
|
|
self.index2summary = {}
|
|
self.max_comparison_per_sample = max_comparison_per_sample
|
|
major_split = split if "train" == split else "validation"
|
|
dataset = load_dataset("openai/summarize_from_feedback", "comparisons")[major_split]
|
|
for data in dataset:
|
|
if (
|
|
"extra" in data
|
|
and "confidence" in data["extra"]
|
|
and data["extra"]["confidence"] is not None
|
|
and conf_threshold > data["extra"]["confidence"]
|
|
):
|
|
print("skipping {}".format(data["info"]["id"]))
|
|
continue
|
|
|
|
if split != "train" and split != data["split"]:
|
|
continue
|
|
|
|
if "article" in data["info"] and data["info"]["article"] is not None:
|
|
context = data["info"]["article"]
|
|
elif "post" in data["info"]:
|
|
context = data["info"]["post"]
|
|
|
|
if context not in self.index2summary:
|
|
self.index2summary[len(self.index2summary)] = context
|
|
|
|
if context not in summaries:
|
|
summaries[context] = []
|
|
|
|
pos, neg = (0, 1) if data["choice"] == 0 else (1, 0)
|
|
summaries[context].append((data["summaries"][pos]["text"], data["summaries"][neg]["text"]))
|
|
|
|
self.summaries = summaries
|
|
|
|
self.postfix_prompt = " TLDR;"
|
|
|
|
def __len__(self):
|
|
return len(self.index2summary)
|
|
|
|
def __getitem__(self, index):
|
|
context = self.index2summary[index]
|
|
# return pairs of comparison
|
|
rows = self.summaries[context]
|
|
# pair very big
|
|
# we are going to do some sampling
|
|
# not optimal but good for now
|
|
valid_idx = np.random.choice(len(rows), self.max_comparison_per_sample)
|
|
# optimize the format later
|
|
return context + self.postfix_prompt, [r for idx, r in enumerate(rows) if idx in valid_idx]
|
|
|
|
|
|
class HFDataset(Dataset):
|
|
"""
|
|
This is a base huggingface dataset which written to support the
|
|
simplest pos-neg pair format
|
|
|
|
we should do something like this for supervised datasets
|
|
"""
|
|
|
|
def __init__(
|
|
self, dataset_name, question_field, pos_answer_field, neg_answer_field, subset=None, split=None
|
|
) -> None:
|
|
super().__init__()
|
|
dataset = load_dataset(dataset_name, subset)
|
|
if split is not None:
|
|
dataset = dataset[split]
|
|
|
|
self.questions = {}
|
|
self.index2question = {}
|
|
for row in dataset:
|
|
question = row[question_field].strip()
|
|
pos = row[pos_answer_field]
|
|
neg = row[neg_answer_field]
|
|
if question not in self.index2question:
|
|
self.index2question[len(self.index2question)] = question
|
|
|
|
if question not in self.questions:
|
|
self.questions[question] = []
|
|
self.questions[question].append((pos.strip(), neg.strip()))
|
|
|
|
def __len__(self):
|
|
return len(self.index2question)
|
|
|
|
def __getitem__(self, index):
|
|
question = self.index2question[index]
|
|
rows = self.questions[question]
|
|
# optimize the format later
|
|
return question, rows
|
|
|
|
|
|
class GPTJSynthetic(HFDataset):
|
|
def __init__(self) -> None:
|
|
super().__init__("Dahoas/synthetic-instruct-gptj-pairwise", "prompt", "chosen", "rejected", None, "train")
|