From def03d75d2444f2b74108923af0c9d91cc67d88f Mon Sep 17 00:00:00 2001 From: theblackcat102 Date: Sat, 28 Jan 2023 03:44:38 +0000 Subject: [PATCH] [fix] Trim anthropic dataset down to last 2 convo only --- model/reward/instructor/rank_datasets.py | 43 +++++++++++++++++++----- 1 file changed, 34 insertions(+), 9 deletions(-) diff --git a/model/reward/instructor/rank_datasets.py b/model/reward/instructor/rank_datasets.py index 5100d455..330c0a9f 100644 --- a/model/reward/instructor/rank_datasets.py +++ b/model/reward/instructor/rank_datasets.py @@ -139,7 +139,7 @@ class HFSummary(Dataset): """ - def __init__(self, split="train", conf_threshold=-1, max_comparison_per_sample=3) -> None: + def __init__(self, split="train", conf_threshold=-1, max_comparison_per_sample=1) -> None: super().__init__() assert split in ("train", "valid1", "valid2", "test") summaries = {} @@ -242,13 +242,30 @@ class GPTJSynthetic(HFDataset): class AnthropicRLHF(Dataset): """ The data are described in the paper: - Training a Helpful and Harmless Assistant with Reinforcement Learning from Human Feedback. - If you find the data useful, please cite the paper. - The data format is very simple -- each line of the jsonl files contains a pair of texts, - one "chosen" and one "rejected". + Training a Helpful and Harmless Assistant with Reinforcement Learning from Human Feedback. + If you find the data useful, please cite the paper. + The data format is very simple -- each line of the jsonl files contains a pair of texts, + one "chosen" and one "rejected". + valid train size : 160780 """ + def preprocess_dialogue(self, text): + """ + trim prefix text to last two pairs + + Outlier example Assistant answered empty string: + Assistant: Human: That makes sense, I agree with that, though there are many situations that + aren't considered justice, like sending a kid to prison for life. Human: You are completely + missing the point of this conversation, and not understanding anything I am saying. Human: + And I don’t know if you’re trying to be funny, but it isn’t. + + """ + last_two_convo = text.split("Human:")[-2:] + if len(last_two_convo[0]) == 0: + return "Human:".join(last_two_convo) + return "Human: " + "Human:".join(last_two_convo) + def __init__(self, split="train", sep_token="") -> None: super().__init__() assert split in ("train", "test") @@ -260,10 +277,18 @@ class AnthropicRLHF(Dataset): major_split = split if "train" == split else "test" dataset = load_dataset("Anthropic/hh-rlhf")[major_split] for data in dataset: - prompt, pos_postfix = data["chosen"].split("Assistant:", maxsplit=1) - pos_postfix = pos_postfix.replace("\n\nHuman: ", sep_token).replace("\n\nAssistant: ", sep_token) - _, neg_postfix = data["rejected"].split("Assistant:", maxsplit=1) - neg_postfix = neg_postfix.replace("\n\nHuman: ", sep_token).replace("\n\nAssistant: ", sep_token) + processed = self.preprocess_dialogue(data["chosen"]) + # roughly 20 of these are invalid conversation + if "Assistant" not in processed: + continue + prompt, pos_postfix = processed.split("Assistant:", maxsplit=1) + prompt = prompt.replace("Human: ", "").strip() + pos_postfix = pos_postfix.replace("Human: ", sep_token).replace("\n\nAssistant: ", sep_token).strip() + processed = self.preprocess_dialogue(data["rejected"]) + if "Assistant" not in processed: + continue + _, neg_postfix = processed.split("Assistant:", maxsplit=1) + neg_postfix = neg_postfix.replace("Human: ", sep_token).replace("\n\nAssistant: ", sep_token).strip() self.pairs.append((prompt, (pos_postfix.strip(), neg_postfix.strip()))) def __len__(self):