mirror of
https://github.com/wassname/Open-Assistant.git
synced 2026-06-27 16:10:30 +08:00
[feature] add initial version of anthropic dataset
This commit is contained in:
@@ -237,3 +237,39 @@ class HFDataset(Dataset):
|
||||
class GPTJSynthetic(HFDataset):
|
||||
def __init__(self) -> None:
|
||||
super().__init__("Dahoas/synthetic-instruct-gptj-pairwise", "prompt", "chosen", "rejected", None, "train")
|
||||
|
||||
|
||||
class AnthropicRLHF(Dataset):
|
||||
"""
|
||||
The data are described in the paper:
|
||||
Training a Helpful and Harmless Assistant with Reinforcement Learning from Human Feedback.
|
||||
If you find the data useful, please cite the paper.
|
||||
The data format is very simple -- each line of the jsonl files contains a pair of texts,
|
||||
one "chosen" and one "rejected".
|
||||
|
||||
"""
|
||||
|
||||
def __init__(self, split="train", sep_token="<sep>") -> None:
|
||||
super().__init__()
|
||||
assert split in ("train", "test")
|
||||
if sep_token is None:
|
||||
sep_token = " . "
|
||||
self.pairs = []
|
||||
# using prompt as our index will allows us
|
||||
# to add additional generated prompt later
|
||||
major_split = split if "train" == split else "test"
|
||||
dataset = load_dataset("Anthropic/hh-rlhf")[major_split]
|
||||
for data in dataset:
|
||||
prompt, pos_postfix = data["chosen"].split("Assistant:", maxsplit=1)
|
||||
pos_postfix = pos_postfix.replace("\n\nHuman: ", sep_token).replace("\n\nAssistant: ", sep_token)
|
||||
_, neg_postfix = data["rejected"].split("Assistant:", maxsplit=1)
|
||||
neg_postfix = neg_postfix.replace("\n\nHuman: ", sep_token).replace("\n\nAssistant: ", sep_token)
|
||||
self.pairs.append((prompt, (pos_postfix.strip(), neg_postfix.strip())))
|
||||
|
||||
def __len__(self):
|
||||
return len(self.pairs)
|
||||
|
||||
def __getitem__(self, index):
|
||||
context, pair = self.pairs[index]
|
||||
|
||||
return context, [pair]
|
||||
|
||||
@@ -1,5 +1,5 @@
|
||||
from experimental_dataset import DataCollatorForSummaryScore, HFSummaryQuality
|
||||
from rank_datasets import DataCollatorForPairRank, GPTJSynthetic, HFSummary, WebGPT
|
||||
from rank_datasets import AnthropicRLHF, DataCollatorForPairRank, GPTJSynthetic, HFSummary, WebGPT
|
||||
from torch.utils.data import DataLoader
|
||||
from transformers import AutoTokenizer
|
||||
|
||||
@@ -25,6 +25,16 @@ def test_webgpt():
|
||||
print(batch["input_ids"].shape)
|
||||
|
||||
|
||||
def test_anthropic_rlhf():
|
||||
|
||||
tokenizer = AutoTokenizer.from_pretrained("bigscience/mt0-large")
|
||||
collate_fn = DataCollatorForPairRank(tokenizer, max_length=200)
|
||||
dataset = AnthropicRLHF("test", sep_token=tokenizer.sep_token)
|
||||
dataloader = DataLoader(dataset, collate_fn=collate_fn, batch_size=32)
|
||||
for batch in dataloader:
|
||||
print(batch["input_ids"].shape)
|
||||
|
||||
|
||||
def test_hf_summary_quality():
|
||||
|
||||
tokenizer = AutoTokenizer.from_pretrained("bigscience/mt0-large")
|
||||
|
||||
@@ -155,7 +155,7 @@ if __name__ == "__main__":
|
||||
)
|
||||
|
||||
tokenizer = get_tokenizer(training_conf["tokenizer_name"])
|
||||
train, evals = get_datasets(training_conf["datasets"])
|
||||
train, evals = get_datasets(training_conf["datasets"], tokenizer)
|
||||
if "rankgen" in model_name:
|
||||
collate_fn = RankGenCollator(tokenizer, max_length=training_conf["max_length"])
|
||||
else:
|
||||
|
||||
@@ -99,8 +99,8 @@ def argument_parsing(parser):
|
||||
return params
|
||||
|
||||
|
||||
def get_datasets(dataset_list: List[AnyStr]):
|
||||
from rank_datasets import GPTJSynthetic, HFSummary, WebGPT
|
||||
def get_datasets(dataset_list: List[AnyStr], tokenizer):
|
||||
from rank_datasets import AnthropicRLHF, GPTJSynthetic, HFSummary, WebGPT
|
||||
from torch.utils.data import ConcatDataset
|
||||
|
||||
train_datasets, evals = [], {}
|
||||
@@ -121,6 +121,11 @@ def get_datasets(dataset_list: List[AnyStr]):
|
||||
train, eval = train_val_dataset(dataset, 0.1)
|
||||
train_datasets.append(train)
|
||||
evals["gptsynthetic"] = eval
|
||||
elif "anthropic_rlhf" == dataset_name:
|
||||
train = AnthropicRLHF("train", tokenizer.sep_token)
|
||||
eval = AnthropicRLHF("test", tokenizer.sep_token)
|
||||
train_datasets.append(train)
|
||||
evals["anthropic_rlhf"] = eval
|
||||
train = ConcatDataset(train_datasets)
|
||||
return train, evals
|
||||
|
||||
|
||||
Reference in New Issue
Block a user