From 0be4d88605a007b5eabc0cf58220599b9211108b Mon Sep 17 00:00:00 2001 From: theblackcat102 Date: Fri, 3 Feb 2023 15:07:05 +0000 Subject: [PATCH] [feature] Add OA private RM dataset --- model/reward/instructor/rank_datasets.py | 42 ++++++++++++++++++++++++ model/reward/instructor/utils.py | 8 ++++- 2 files changed, 49 insertions(+), 1 deletion(-) diff --git a/model/reward/instructor/rank_datasets.py b/model/reward/instructor/rank_datasets.py index 330c0a9f..bf572f4e 100644 --- a/model/reward/instructor/rank_datasets.py +++ b/model/reward/instructor/rank_datasets.py @@ -18,6 +18,7 @@ """ +import os from dataclasses import dataclass from typing import Dict, List, Optional, Union @@ -298,3 +299,44 @@ class AnthropicRLHF(Dataset): context, pair = self.pairs[index] return context, [pair] + + +class OAPrivate(Dataset): + """ + { + "prompt": , + "history": [("prompt1", "answer2"), ("prompt2", "answer2")], + "pos": , + "neg_replies": [list of bad answers] + } + """ + + split_name_mapping = { + "train": "rm_train.jsonl", + "test": "rm_test.jsonl", + "val": "rm_val.jsonl", + } + + def __init__(self, split="train", sep_token="", data_path=".cache") -> None: + super().__init__() + import json + + jsonl_file = os.path.join(data_path, self.split_name_mapping[split]) + self.pairs = [] + with open(jsonl_file, "r", encoding="utf-8") as f: + for line in f: + data = json.loads(line) + prefix = sep_token.join([sep_token.join(p) for p in data["history"][-2:]]) + prefix += sep_token + data["prompt"] + pair = [] + for neg_text in data["neg_replies"]: + pair.append((data["pos"], neg_text)) + self.pairs.append((prefix, pair)) + + def __len__(self): + return len(self.pairs) + + def __getitem__(self, index): + context, pair = self.pairs[index] + + return context, pair diff --git a/model/reward/instructor/utils.py b/model/reward/instructor/utils.py index 94a256c2..2e235213 100644 --- a/model/reward/instructor/utils.py +++ b/model/reward/instructor/utils.py @@ -115,7 +115,7 @@ def argument_parsing(parser): def get_datasets(dataset_list: List[AnyStr], tokenizer): - from rank_datasets import AnthropicRLHF, GPTJSynthetic, HFSummary, WebGPT + from rank_datasets import AnthropicRLHF, GPTJSynthetic, HFSummary, OAPrivate, WebGPT from torch.utils.data import ConcatDataset train_datasets, evals = [], {} @@ -141,5 +141,11 @@ def get_datasets(dataset_list: List[AnyStr], tokenizer): eval = AnthropicRLHF("test", tokenizer.sep_token) train_datasets.append(train) evals["anthropic_rlhf"] = eval + elif "oa_private" == dataset_name: + train = OAPrivate(split="train", sep_token=tokenizer.sep_token) + eval = OAPrivate(split="val", sep_token=tokenizer.sep_token) + train_datasets.append(train) + evals["oa_private"] = eval + train = ConcatDataset(train_datasets) return train, evals