Files
2023-01-02 00:01:45 +00:00

65 lines
2.4 KiB
Python

"""
classification based ranking
"""
import json
import os
import random
from datasets import load_dataset
from torch.utils.data import Dataset
from .utils import webgpt_return_format
class WebGPTDataset(Dataset):
def __init__(self, mode="train", index_cache="dataset/webgpt_train_idx.pt", additional_dataset=None) -> None:
super().__init__()
"""
mode : train or val, used for validation purpose, has nothing to do with original split
additional_dataset : a list of jsonline format with idx, question and texts (generate candidates)
idx : must match the index you iterate from comparison enumerate order
question : for validation purpose
texts : list of K generate results from the question prompt
"""
os.makedirs("dataset", exist_ok=True)
dataset = load_dataset("openai/webgpt_comparisons")
self.dataset = []
self.dataset_index = []
for idx, row in enumerate(dataset["train"]):
self.dataset.append(webgpt_return_format(row))
# since this dataset was generated from 176B GPT-3
# we needed some more sample generated from the starting model
# since this model must rank model generated by GPT-3 being better than your starting model
self.sample_additional = False
if additional_dataset is not None:
self.sample_additional = True
self.additional = {}
with open(additional_dataset, "r") as f:
for line in f:
row = json.loads(line)
if row["idx"] in self.dataset_index:
self.additional[row["idx"]] = row["negatives"]
if len(self.additional) != len(self.dataset_index):
for match_idx in self.dataset_index:
if match_idx in self.additional:
continue
idx = match_idx - 900
while idx not in self.additional:
idx -= 1
self.additional[match_idx] = self.additional[idx]
def __len__(self):
return len(self.dataset)
def __getitem__(self, index):
row = self.dataset[index]
if not self.sample_additional:
return row["question"], row["pos"], row["neg"]
gen_neg = random.choice(self.additional[self.dataset_index[index]])
return row["question"], row["pos"], row["neg"], gen_neg