From 3215a7bbf8d5efa06ec126a537c050794fde9bab Mon Sep 17 00:00:00 2001 From: theblackcat102 Date: Wed, 25 Jan 2023 05:03:43 +0000 Subject: [PATCH] [feature] add initial version of anthropic dataset --- model/reward/instructor/rank_datasets.py | 36 +++++++++++++++++++ model/reward/instructor/tests/test_dataset.py | 12 ++++++- model/reward/instructor/trainer.py | 2 +- model/reward/instructor/utils.py | 9 +++-- 4 files changed, 55 insertions(+), 4 deletions(-) diff --git a/model/reward/instructor/rank_datasets.py b/model/reward/instructor/rank_datasets.py index bd53eb02..5100d455 100644 --- a/model/reward/instructor/rank_datasets.py +++ b/model/reward/instructor/rank_datasets.py @@ -237,3 +237,39 @@ class HFDataset(Dataset): class GPTJSynthetic(HFDataset): def __init__(self) -> None: super().__init__("Dahoas/synthetic-instruct-gptj-pairwise", "prompt", "chosen", "rejected", None, "train") + + +class AnthropicRLHF(Dataset): + """ + The data are described in the paper: + Training a Helpful and Harmless Assistant with Reinforcement Learning from Human Feedback. + If you find the data useful, please cite the paper. + The data format is very simple -- each line of the jsonl files contains a pair of texts, + one "chosen" and one "rejected". + + """ + + def __init__(self, split="train", sep_token="") -> None: + super().__init__() + assert split in ("train", "test") + if sep_token is None: + sep_token = " . " + self.pairs = [] + # using prompt as our index will allows us + # to add additional generated prompt later + major_split = split if "train" == split else "test" + dataset = load_dataset("Anthropic/hh-rlhf")[major_split] + for data in dataset: + prompt, pos_postfix = data["chosen"].split("Assistant:", maxsplit=1) + pos_postfix = pos_postfix.replace("\n\nHuman: ", sep_token).replace("\n\nAssistant: ", sep_token) + _, neg_postfix = data["rejected"].split("Assistant:", maxsplit=1) + neg_postfix = neg_postfix.replace("\n\nHuman: ", sep_token).replace("\n\nAssistant: ", sep_token) + self.pairs.append((prompt, (pos_postfix.strip(), neg_postfix.strip()))) + + def __len__(self): + return len(self.pairs) + + def __getitem__(self, index): + context, pair = self.pairs[index] + + return context, [pair] diff --git a/model/reward/instructor/tests/test_dataset.py b/model/reward/instructor/tests/test_dataset.py index 832aace3..5cc9b7e8 100644 --- a/model/reward/instructor/tests/test_dataset.py +++ b/model/reward/instructor/tests/test_dataset.py @@ -1,5 +1,5 @@ from experimental_dataset import DataCollatorForSummaryScore, HFSummaryQuality -from rank_datasets import DataCollatorForPairRank, GPTJSynthetic, HFSummary, WebGPT +from rank_datasets import AnthropicRLHF, DataCollatorForPairRank, GPTJSynthetic, HFSummary, WebGPT from torch.utils.data import DataLoader from transformers import AutoTokenizer @@ -25,6 +25,16 @@ def test_webgpt(): print(batch["input_ids"].shape) +def test_anthropic_rlhf(): + + tokenizer = AutoTokenizer.from_pretrained("bigscience/mt0-large") + collate_fn = DataCollatorForPairRank(tokenizer, max_length=200) + dataset = AnthropicRLHF("test", sep_token=tokenizer.sep_token) + dataloader = DataLoader(dataset, collate_fn=collate_fn, batch_size=32) + for batch in dataloader: + print(batch["input_ids"].shape) + + def test_hf_summary_quality(): tokenizer = AutoTokenizer.from_pretrained("bigscience/mt0-large") diff --git a/model/reward/instructor/trainer.py b/model/reward/instructor/trainer.py index 940c0708..abb05a42 100644 --- a/model/reward/instructor/trainer.py +++ b/model/reward/instructor/trainer.py @@ -155,7 +155,7 @@ if __name__ == "__main__": ) tokenizer = get_tokenizer(training_conf["tokenizer_name"]) - train, evals = get_datasets(training_conf["datasets"]) + train, evals = get_datasets(training_conf["datasets"], tokenizer) if "rankgen" in model_name: collate_fn = RankGenCollator(tokenizer, max_length=training_conf["max_length"]) else: diff --git a/model/reward/instructor/utils.py b/model/reward/instructor/utils.py index a6f3da4e..e8c7160b 100644 --- a/model/reward/instructor/utils.py +++ b/model/reward/instructor/utils.py @@ -99,8 +99,8 @@ def argument_parsing(parser): return params -def get_datasets(dataset_list: List[AnyStr]): - from rank_datasets import GPTJSynthetic, HFSummary, WebGPT +def get_datasets(dataset_list: List[AnyStr], tokenizer): + from rank_datasets import AnthropicRLHF, GPTJSynthetic, HFSummary, WebGPT from torch.utils.data import ConcatDataset train_datasets, evals = [], {} @@ -121,6 +121,11 @@ def get_datasets(dataset_list: List[AnyStr]): train, eval = train_val_dataset(dataset, 0.1) train_datasets.append(train) evals["gptsynthetic"] = eval + elif "anthropic_rlhf" == dataset_name: + train = AnthropicRLHF("train", tokenizer.sep_token) + eval = AnthropicRLHF("test", tokenizer.sep_token) + train_datasets.append(train) + evals["anthropic_rlhf"] = eval train = ConcatDataset(train_datasets) return train, evals