Merge pull request #513 from LAION-AI/sft-hub-dataset

[feature] Add GPTJ synthetic dataset, fix reference removal regex for…
This commit is contained in:
sanagnos
2023-01-08 10:47:48 +02:00
committed by GitHub
4 changed files with 98 additions and 63 deletions
+46 -2
View File
@@ -19,7 +19,7 @@
"""
from dataclasses import dataclass
from typing import Optional, Union
from typing import Dict, List, Optional, Union
import numpy as np
import torch
@@ -35,7 +35,7 @@ class RankGenCollator:
max_length: Optional[int] = None
max_examples: Optional[int] = None
def __call__(self, batch: list[dict[str, str]]) -> dict[str, torch.Tensor]:
def __call__(self, batch: List[Dict[str, str]]) -> Dict[str, torch.Tensor]:
prefixes = []
better_answers = []
worse_answers = []
@@ -193,3 +193,47 @@ class HFSummary(Dataset):
valid_idx = np.random.choice(len(rows), self.max_comparison_per_sample)
# optimize the format later
return context + self.postfix_prompt, [r for idx, r in enumerate(rows) if idx in valid_idx]
class HFDataset(Dataset):
"""
This is a base huggingface dataset which written to support the
simplest pos-neg pair format
we should do something like this for supervised datasets
"""
def __init__(
self, dataset_name, question_field, pos_answer_field, neg_answer_field, subset=None, split=None
) -> None:
super().__init__()
dataset = load_dataset(dataset_name, subset)
if split is not None:
dataset = dataset[split]
self.questions = {}
self.index2question = {}
for row in dataset:
question = row[question_field].strip()
pos = row[pos_answer_field]
neg = row[neg_answer_field]
if question not in self.index2question:
self.index2question[len(self.index2question)] = question
if question not in self.questions:
self.questions[question] = []
self.questions[question].append((pos.strip(), neg.strip()))
def __len__(self):
return len(self.index2question)
def __getitem__(self, index):
question = self.index2question[index]
rows = self.questions[question]
# optimize the format later
return question, rows
class GPTJSynthetic(HFDataset):
def __init__(self) -> None:
super().__init__("Dahoas/synthetic-instruct-gptj-pairwise", "prompt", "chosen", "rejected", None, "train")
+11 -5
View File
@@ -1,5 +1,5 @@
from experimental_dataset import DataCollatorForSummaryScore, HFSummaryQuality
from rank_datasets import DataCollatorForPairRank, HFSummary, WebGPT
from rank_datasets import DataCollatorForPairRank, GPTJSynthetic, HFSummary, WebGPT
from torch.utils.data import DataLoader
from transformers import AutoTokenizer
@@ -25,7 +25,7 @@ def test_webgpt():
print(batch["input_ids"].shape)
def test_hf_quality():
def test_hf_summary_quality():
tokenizer = AutoTokenizer.from_pretrained("bigscience/mt0-large")
collate_fn = DataCollatorForSummaryScore(tokenizer, max_length=200)
@@ -35,6 +35,12 @@ def test_hf_quality():
print(batch["input_ids"].shape)
if __name__ == "__main__":
test_hf_quality()
# test_webgpt()
def test_gptj_dataset():
dataset = GPTJSynthetic()
tokenizer = AutoTokenizer.from_pretrained("bigscience/mt0-large")
collate_fn = DataCollatorForPairRank(tokenizer, max_length=1024)
print(len(dataset))
dataloader = DataLoader(dataset, collate_fn=collate_fn, batch_size=32)
for batch in dataloader:
batch["input_ids"].shape
+12 -55
View File
@@ -1,29 +1,23 @@
import os
from argparse import ArgumentParser
from dataclasses import dataclass
from typing import Any, Callable, Dict, List, Optional, Tuple, Union
from typing import Any, Dict, List, Optional, Tuple, Union
import evaluate
import numpy as np
import torch
from models import RankGenModel
from rank_datasets import DataCollatorForPairRank, HFSummary, RankGenCollator, WebGPT
from rank_datasets import DataCollatorForPairRank, RankGenCollator
from torch import nn
from torch.utils.data import ConcatDataset, Dataset
from transformers import (
AdamW,
AutoModelForSequenceClassification,
DataCollator,
EvalPrediction,
PreTrainedModel,
PreTrainedTokenizerBase,
Trainer,
TrainerCallback,
TrainingArguments,
get_cosine_schedule_with_warmup,
get_linear_schedule_with_warmup,
)
from utils import argument_parsing, freeze_top_n_layers, get_tokenizer, train_val_dataset
from utils import argument_parsing, freeze_top_n_layers, get_datasets, get_tokenizer
os.environ["WANDB_PROJECT"] = "reward-model"
@@ -32,11 +26,6 @@ parser = ArgumentParser()
parser.add_argument("config", type=str)
@dataclass
class CustomTrainingArguments(TrainingArguments):
loss_function: str = "rank"
def compute_metrics(eval_pred):
predictions, _ = eval_pred
predictions = np.argmax(predictions, axis=1)
@@ -60,31 +49,12 @@ class RankTrainer(Trainer):
model: Union[PreTrainedModel, nn.Module] = None,
model_name: str = None,
args: Optional[TrainingArguments] = None,
data_collator: Optional[DataCollator] = None,
train_dataset: Optional[Dataset] = None,
eval_dataset: Optional[Dataset] = None,
tokenizer: Optional[PreTrainedTokenizerBase] = None,
model_init: Callable[[], PreTrainedModel] = None,
compute_metrics: Optional[Callable[[EvalPrediction], Dict]] = None,
callbacks: Optional[List[TrainerCallback]] = None,
optimizers: Tuple[torch.optim.Optimizer, torch.optim.lr_scheduler.LambdaLR] = (None, None),
preprocess_logits_for_metrics: Callable[[torch.Tensor, torch.Tensor], torch.Tensor] = None,
loss_function: str = "rank",
**kwargs,
):
super().__init__(
model,
args,
data_collator,
train_dataset,
eval_dataset,
tokenizer,
model_init,
compute_metrics,
callbacks,
optimizers,
preprocess_logits_for_metrics,
)
self.loss_fct = RankLoss() if args.loss_function == "rank" else nn.CrossEntropyLoss()
self.loss_function = args.loss_function
super().__init__(model, args, **kwargs)
self.loss_fct = RankLoss() if loss_function == "rank" else nn.CrossEntropyLoss()
self.loss_function = loss_function
self.model_name = model_name
def compute_loss(self, model, inputs, return_outputs=False):
@@ -163,11 +133,10 @@ if __name__ == "__main__":
params = sum([np.prod(p.size()) for p in model_parameters])
print("Number of trainable : {}M".format(int(params / 1e6)))
args = CustomTrainingArguments(
args = TrainingArguments(
output_dir=f"{model_name}-finetuned",
num_train_epochs=training_conf["num_train_epochs"],
warmup_steps=500,
loss_function=training_conf["loss"],
learning_rate=training_conf["learning_rate"],
# half_precision_backend="apex",
fp16=training_conf["fp16"],
@@ -184,22 +153,9 @@ if __name__ == "__main__":
save_steps=1000,
report_to="wandb",
)
train_datasets, evals = [], {}
if "webgpt" in training_conf["datasets"]:
web_dataset = WebGPT()
train, eval = train_val_dataset(web_dataset)
train_datasets.append(train)
evals["webgpt"] = eval
if "hfsummary" in training_conf["datasets"]:
sum_train = HFSummary(split="train")
train_datasets.append(sum_train)
sum_eval = HFSummary(split="valid1")
assert len(sum_eval) > 0
evals["hfsummary"] = sum_eval
train = ConcatDataset(train_datasets)
tokenizer = get_tokenizer(training_conf["tokenizer_name"])
train, evals = get_datasets(training_conf["datasets"])
if "rankgen" in model_name:
collate_fn = RankGenCollator(tokenizer, max_length=training_conf["max_length"])
else:
@@ -224,8 +180,9 @@ if __name__ == "__main__":
model=model,
model_name=model_name,
args=args,
loss_function=training_conf["loss"],
train_dataset=train,
eval_dataset=eval,
eval_dataset=evals,
data_collator=collate_fn,
tokenizer=tokenizer,
compute_metrics=compute_metrics,
+29 -1
View File
@@ -1,11 +1,13 @@
import re
from typing import AnyStr, List
import yaml
from sklearn.model_selection import train_test_split
from torch.utils.data import Subset
from transformers import AutoTokenizer, T5Tokenizer
re_reference_remove = re.compile(r"\[([0-9])+\]|\[([0-9])+,([0-9])+\]")
# @agoryuno contributed this
re_reference_remove = re.compile(r"\[\d+(?:,\s*\d+)*?\]")
def webgpt_return_format(row):
@@ -97,6 +99,32 @@ def argument_parsing(parser):
return params
def get_datasets(dataset_list: List[AnyStr]):
from rank_datasets import GPTJSynthetic, HFSummary, WebGPT
from torch.utils.data import ConcatDataset
train_datasets, evals = [], {}
for dataset_name in dataset_list:
if "webgpt" == dataset_name:
web_dataset = WebGPT()
train, eval = train_val_dataset(web_dataset, 0.2)
train_datasets.append(train)
evals["webgpt"] = eval
elif "hfsummary" == dataset_name:
sum_train = HFSummary(split="train")
train_datasets.append(sum_train)
sum_eval = HFSummary(split="valid1")
assert len(sum_eval) > 0
evals["hfsummary"] = sum_eval
elif "gptsynthetic" == dataset_name:
dataset = GPTJSynthetic()
train, eval = train_val_dataset(dataset, 0.1)
train_datasets.append(train)
evals["gptsynthetic"] = eval
train = ConcatDataset(train_datasets)
return train, evals
if __name__ == "__main__":
from transformers import AutoModelForSequenceClassification