[feature] add initial version of anthropic dataset

This commit is contained in:
theblackcat102
2023-01-25 05:03:43 +00:00
parent 4146930eb9
commit 3215a7bbf8
4 changed files with 55 additions and 4 deletions
+36
View File
@@ -237,3 +237,39 @@ class HFDataset(Dataset):
class GPTJSynthetic(HFDataset):
def __init__(self) -> None:
super().__init__("Dahoas/synthetic-instruct-gptj-pairwise", "prompt", "chosen", "rejected", None, "train")
class AnthropicRLHF(Dataset):
"""
The data are described in the paper:
Training a Helpful and Harmless Assistant with Reinforcement Learning from Human Feedback.
If you find the data useful, please cite the paper.
The data format is very simple -- each line of the jsonl files contains a pair of texts,
one "chosen" and one "rejected".
"""
def __init__(self, split="train", sep_token="<sep>") -> None:
super().__init__()
assert split in ("train", "test")
if sep_token is None:
sep_token = " . "
self.pairs = []
# using prompt as our index will allows us
# to add additional generated prompt later
major_split = split if "train" == split else "test"
dataset = load_dataset("Anthropic/hh-rlhf")[major_split]
for data in dataset:
prompt, pos_postfix = data["chosen"].split("Assistant:", maxsplit=1)
pos_postfix = pos_postfix.replace("\n\nHuman: ", sep_token).replace("\n\nAssistant: ", sep_token)
_, neg_postfix = data["rejected"].split("Assistant:", maxsplit=1)
neg_postfix = neg_postfix.replace("\n\nHuman: ", sep_token).replace("\n\nAssistant: ", sep_token)
self.pairs.append((prompt, (pos_postfix.strip(), neg_postfix.strip())))
def __len__(self):
return len(self.pairs)
def __getitem__(self, index):
context, pair = self.pairs[index]
return context, [pair]
+11 -1
View File
@@ -1,5 +1,5 @@
from experimental_dataset import DataCollatorForSummaryScore, HFSummaryQuality
from rank_datasets import DataCollatorForPairRank, GPTJSynthetic, HFSummary, WebGPT
from rank_datasets import AnthropicRLHF, DataCollatorForPairRank, GPTJSynthetic, HFSummary, WebGPT
from torch.utils.data import DataLoader
from transformers import AutoTokenizer
@@ -25,6 +25,16 @@ def test_webgpt():
print(batch["input_ids"].shape)
def test_anthropic_rlhf():
tokenizer = AutoTokenizer.from_pretrained("bigscience/mt0-large")
collate_fn = DataCollatorForPairRank(tokenizer, max_length=200)
dataset = AnthropicRLHF("test", sep_token=tokenizer.sep_token)
dataloader = DataLoader(dataset, collate_fn=collate_fn, batch_size=32)
for batch in dataloader:
print(batch["input_ids"].shape)
def test_hf_summary_quality():
tokenizer = AutoTokenizer.from_pretrained("bigscience/mt0-large")
+1 -1
View File
@@ -155,7 +155,7 @@ if __name__ == "__main__":
)
tokenizer = get_tokenizer(training_conf["tokenizer_name"])
train, evals = get_datasets(training_conf["datasets"])
train, evals = get_datasets(training_conf["datasets"], tokenizer)
if "rankgen" in model_name:
collate_fn = RankGenCollator(tokenizer, max_length=training_conf["max_length"])
else:
+7 -2
View File
@@ -99,8 +99,8 @@ def argument_parsing(parser):
return params
def get_datasets(dataset_list: List[AnyStr]):
from rank_datasets import GPTJSynthetic, HFSummary, WebGPT
def get_datasets(dataset_list: List[AnyStr], tokenizer):
from rank_datasets import AnthropicRLHF, GPTJSynthetic, HFSummary, WebGPT
from torch.utils.data import ConcatDataset
train_datasets, evals = [], {}
@@ -121,6 +121,11 @@ def get_datasets(dataset_list: List[AnyStr]):
train, eval = train_val_dataset(dataset, 0.1)
train_datasets.append(train)
evals["gptsynthetic"] = eval
elif "anthropic_rlhf" == dataset_name:
train = AnthropicRLHF("train", tokenizer.sep_token)
eval = AnthropicRLHF("test", tokenizer.sep_token)
train_datasets.append(train)
evals["anthropic_rlhf"] = eval
train = ConcatDataset(train_datasets)
return train, evals