[fix] Trim anthropic dataset down to last 2 convo only

This commit is contained in:
theblackcat102
2023-01-28 03:44:38 +00:00
parent f43435efc9
commit def03d75d2
+34 -9
View File
@@ -139,7 +139,7 @@ class HFSummary(Dataset):
"""
def __init__(self, split="train", conf_threshold=-1, max_comparison_per_sample=3) -> None:
def __init__(self, split="train", conf_threshold=-1, max_comparison_per_sample=1) -> None:
super().__init__()
assert split in ("train", "valid1", "valid2", "test")
summaries = {}
@@ -242,13 +242,30 @@ class GPTJSynthetic(HFDataset):
class AnthropicRLHF(Dataset):
"""
The data are described in the paper:
Training a Helpful and Harmless Assistant with Reinforcement Learning from Human Feedback.
If you find the data useful, please cite the paper.
The data format is very simple -- each line of the jsonl files contains a pair of texts,
one "chosen" and one "rejected".
Training a Helpful and Harmless Assistant with Reinforcement Learning from Human Feedback.
If you find the data useful, please cite the paper.
The data format is very simple -- each line of the jsonl files contains a pair of texts,
one "chosen" and one "rejected".
valid train size : 160780
"""
def preprocess_dialogue(self, text):
"""
trim prefix text to last two pairs
Outlier example Assistant answered empty string:
Assistant: Human: That makes sense, I agree with that, though there are many situations that
aren't considered justice, like sending a kid to prison for life. Human: You are completely
missing the point of this conversation, and not understanding anything I am saying. Human:
And I dont know if youre trying to be funny, but it isnt.
"""
last_two_convo = text.split("Human:")[-2:]
if len(last_two_convo[0]) == 0:
return "Human:".join(last_two_convo)
return "Human: " + "Human:".join(last_two_convo)
def __init__(self, split="train", sep_token="<sep>") -> None:
super().__init__()
assert split in ("train", "test")
@@ -260,10 +277,18 @@ class AnthropicRLHF(Dataset):
major_split = split if "train" == split else "test"
dataset = load_dataset("Anthropic/hh-rlhf")[major_split]
for data in dataset:
prompt, pos_postfix = data["chosen"].split("Assistant:", maxsplit=1)
pos_postfix = pos_postfix.replace("\n\nHuman: ", sep_token).replace("\n\nAssistant: ", sep_token)
_, neg_postfix = data["rejected"].split("Assistant:", maxsplit=1)
neg_postfix = neg_postfix.replace("\n\nHuman: ", sep_token).replace("\n\nAssistant: ", sep_token)
processed = self.preprocess_dialogue(data["chosen"])
# roughly 20 of these are invalid conversation
if "Assistant" not in processed:
continue
prompt, pos_postfix = processed.split("Assistant:", maxsplit=1)
prompt = prompt.replace("Human: ", "").strip()
pos_postfix = pos_postfix.replace("Human: ", sep_token).replace("\n\nAssistant: ", sep_token).strip()
processed = self.preprocess_dialogue(data["rejected"])
if "Assistant" not in processed:
continue
_, neg_postfix = processed.split("Assistant:", maxsplit=1)
neg_postfix = neg_postfix.replace("Human: ", sep_token).replace("\n\nAssistant: ", sep_token).strip()
self.pairs.append((prompt, (pos_postfix.strip(), neg_postfix.strip())))
def __len__(self):