mirror of
https://github.com/wassname/Open-Assistant.git
synced 2026-06-27 16:10:30 +08:00
65 lines
2.4 KiB
Python
65 lines
2.4 KiB
Python
"""
|
|
|
|
classification based ranking
|
|
|
|
"""
|
|
import json
|
|
import os
|
|
import random
|
|
|
|
from datasets import load_dataset
|
|
from torch.utils.data import Dataset
|
|
|
|
from .utils import webgpt_return_format
|
|
|
|
|
|
class WebGPTDataset(Dataset):
|
|
def __init__(self, mode="train", index_cache="dataset/webgpt_train_idx.pt", additional_dataset=None) -> None:
|
|
super().__init__()
|
|
"""
|
|
mode : train or val, used for validation purpose, has nothing to do with original split
|
|
additional_dataset : a list of jsonline format with idx, question and texts (generate candidates)
|
|
idx : must match the index you iterate from comparison enumerate order
|
|
question : for validation purpose
|
|
texts : list of K generate results from the question prompt
|
|
"""
|
|
os.makedirs("dataset", exist_ok=True)
|
|
dataset = load_dataset("openai/webgpt_comparisons")
|
|
self.dataset = []
|
|
self.dataset_index = []
|
|
for idx, row in enumerate(dataset["train"]):
|
|
self.dataset.append(webgpt_return_format(row))
|
|
|
|
# since this dataset was generated from 176B GPT-3
|
|
# we needed some more sample generated from the starting model
|
|
# since this model must rank model generated by GPT-3 being better than your starting model
|
|
self.sample_additional = False
|
|
if additional_dataset is not None:
|
|
self.sample_additional = True
|
|
self.additional = {}
|
|
with open(additional_dataset, "r") as f:
|
|
for line in f:
|
|
row = json.loads(line)
|
|
if row["idx"] in self.dataset_index:
|
|
self.additional[row["idx"]] = row["negatives"]
|
|
if len(self.additional) != len(self.dataset_index):
|
|
for match_idx in self.dataset_index:
|
|
if match_idx in self.additional:
|
|
continue
|
|
|
|
idx = match_idx - 900
|
|
while idx not in self.additional:
|
|
idx -= 1
|
|
self.additional[match_idx] = self.additional[idx]
|
|
|
|
def __len__(self):
|
|
return len(self.dataset)
|
|
|
|
def __getitem__(self, index):
|
|
row = self.dataset[index]
|
|
if not self.sample_additional:
|
|
return row["question"], row["pos"], row["neg"]
|
|
|
|
gen_neg = random.choice(self.additional[self.dataset_index[index]])
|
|
return row["question"], row["pos"], row["neg"], gen_neg
|