mirror of
https://github.com/wassname/Open-Assistant.git
synced 2026-06-27 16:10:30 +08:00
[feature] Add OA private RM dataset
This commit is contained in:
@@ -18,6 +18,7 @@
|
||||
|
||||
|
||||
"""
|
||||
import os
|
||||
from dataclasses import dataclass
|
||||
from typing import Dict, List, Optional, Union
|
||||
|
||||
@@ -298,3 +299,44 @@ class AnthropicRLHF(Dataset):
|
||||
context, pair = self.pairs[index]
|
||||
|
||||
return context, [pair]
|
||||
|
||||
|
||||
class OAPrivate(Dataset):
|
||||
"""
|
||||
{
|
||||
"prompt": <prompt string>,
|
||||
"history": [("prompt1", "answer2"), ("prompt2", "answer2")],
|
||||
"pos": <pos answer string>,
|
||||
"neg_replies": [list of bad answers]
|
||||
}
|
||||
"""
|
||||
|
||||
split_name_mapping = {
|
||||
"train": "rm_train.jsonl",
|
||||
"test": "rm_test.jsonl",
|
||||
"val": "rm_val.jsonl",
|
||||
}
|
||||
|
||||
def __init__(self, split="train", sep_token="<sep>", data_path=".cache") -> None:
|
||||
super().__init__()
|
||||
import json
|
||||
|
||||
jsonl_file = os.path.join(data_path, self.split_name_mapping[split])
|
||||
self.pairs = []
|
||||
with open(jsonl_file, "r", encoding="utf-8") as f:
|
||||
for line in f:
|
||||
data = json.loads(line)
|
||||
prefix = sep_token.join([sep_token.join(p) for p in data["history"][-2:]])
|
||||
prefix += sep_token + data["prompt"]
|
||||
pair = []
|
||||
for neg_text in data["neg_replies"]:
|
||||
pair.append((data["pos"], neg_text))
|
||||
self.pairs.append((prefix, pair))
|
||||
|
||||
def __len__(self):
|
||||
return len(self.pairs)
|
||||
|
||||
def __getitem__(self, index):
|
||||
context, pair = self.pairs[index]
|
||||
|
||||
return context, pair
|
||||
|
||||
@@ -115,7 +115,7 @@ def argument_parsing(parser):
|
||||
|
||||
|
||||
def get_datasets(dataset_list: List[AnyStr], tokenizer):
|
||||
from rank_datasets import AnthropicRLHF, GPTJSynthetic, HFSummary, WebGPT
|
||||
from rank_datasets import AnthropicRLHF, GPTJSynthetic, HFSummary, OAPrivate, WebGPT
|
||||
from torch.utils.data import ConcatDataset
|
||||
|
||||
train_datasets, evals = [], {}
|
||||
@@ -141,5 +141,11 @@ def get_datasets(dataset_list: List[AnyStr], tokenizer):
|
||||
eval = AnthropicRLHF("test", tokenizer.sep_token)
|
||||
train_datasets.append(train)
|
||||
evals["anthropic_rlhf"] = eval
|
||||
elif "oa_private" == dataset_name:
|
||||
train = OAPrivate(split="train", sep_token=tokenizer.sep_token)
|
||||
eval = OAPrivate(split="val", sep_token=tokenizer.sep_token)
|
||||
train_datasets.append(train)
|
||||
evals["oa_private"] = eval
|
||||
|
||||
train = ConcatDataset(train_datasets)
|
||||
return train, evals
|
||||
|
||||
Reference in New Issue
Block a user