mirror of
https://github.com/wassname/Open-Assistant.git
synced 2026-07-01 16:50:12 +08:00
Merge pull request #513 from LAION-AI/sft-hub-dataset
[feature] Add GPTJ synthetic dataset, fix reference removal regex for…
This commit is contained in:
@@ -19,7 +19,7 @@
|
||||
|
||||
"""
|
||||
from dataclasses import dataclass
|
||||
from typing import Optional, Union
|
||||
from typing import Dict, List, Optional, Union
|
||||
|
||||
import numpy as np
|
||||
import torch
|
||||
@@ -35,7 +35,7 @@ class RankGenCollator:
|
||||
max_length: Optional[int] = None
|
||||
max_examples: Optional[int] = None
|
||||
|
||||
def __call__(self, batch: list[dict[str, str]]) -> dict[str, torch.Tensor]:
|
||||
def __call__(self, batch: List[Dict[str, str]]) -> Dict[str, torch.Tensor]:
|
||||
prefixes = []
|
||||
better_answers = []
|
||||
worse_answers = []
|
||||
@@ -193,3 +193,47 @@ class HFSummary(Dataset):
|
||||
valid_idx = np.random.choice(len(rows), self.max_comparison_per_sample)
|
||||
# optimize the format later
|
||||
return context + self.postfix_prompt, [r for idx, r in enumerate(rows) if idx in valid_idx]
|
||||
|
||||
|
||||
class HFDataset(Dataset):
|
||||
"""
|
||||
This is a base huggingface dataset which written to support the
|
||||
simplest pos-neg pair format
|
||||
|
||||
we should do something like this for supervised datasets
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self, dataset_name, question_field, pos_answer_field, neg_answer_field, subset=None, split=None
|
||||
) -> None:
|
||||
super().__init__()
|
||||
dataset = load_dataset(dataset_name, subset)
|
||||
if split is not None:
|
||||
dataset = dataset[split]
|
||||
|
||||
self.questions = {}
|
||||
self.index2question = {}
|
||||
for row in dataset:
|
||||
question = row[question_field].strip()
|
||||
pos = row[pos_answer_field]
|
||||
neg = row[neg_answer_field]
|
||||
if question not in self.index2question:
|
||||
self.index2question[len(self.index2question)] = question
|
||||
|
||||
if question not in self.questions:
|
||||
self.questions[question] = []
|
||||
self.questions[question].append((pos.strip(), neg.strip()))
|
||||
|
||||
def __len__(self):
|
||||
return len(self.index2question)
|
||||
|
||||
def __getitem__(self, index):
|
||||
question = self.index2question[index]
|
||||
rows = self.questions[question]
|
||||
# optimize the format later
|
||||
return question, rows
|
||||
|
||||
|
||||
class GPTJSynthetic(HFDataset):
|
||||
def __init__(self) -> None:
|
||||
super().__init__("Dahoas/synthetic-instruct-gptj-pairwise", "prompt", "chosen", "rejected", None, "train")
|
||||
|
||||
@@ -1,5 +1,5 @@
|
||||
from experimental_dataset import DataCollatorForSummaryScore, HFSummaryQuality
|
||||
from rank_datasets import DataCollatorForPairRank, HFSummary, WebGPT
|
||||
from rank_datasets import DataCollatorForPairRank, GPTJSynthetic, HFSummary, WebGPT
|
||||
from torch.utils.data import DataLoader
|
||||
from transformers import AutoTokenizer
|
||||
|
||||
@@ -25,7 +25,7 @@ def test_webgpt():
|
||||
print(batch["input_ids"].shape)
|
||||
|
||||
|
||||
def test_hf_quality():
|
||||
def test_hf_summary_quality():
|
||||
|
||||
tokenizer = AutoTokenizer.from_pretrained("bigscience/mt0-large")
|
||||
collate_fn = DataCollatorForSummaryScore(tokenizer, max_length=200)
|
||||
@@ -35,6 +35,12 @@ def test_hf_quality():
|
||||
print(batch["input_ids"].shape)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
test_hf_quality()
|
||||
# test_webgpt()
|
||||
def test_gptj_dataset():
|
||||
dataset = GPTJSynthetic()
|
||||
tokenizer = AutoTokenizer.from_pretrained("bigscience/mt0-large")
|
||||
collate_fn = DataCollatorForPairRank(tokenizer, max_length=1024)
|
||||
|
||||
print(len(dataset))
|
||||
dataloader = DataLoader(dataset, collate_fn=collate_fn, batch_size=32)
|
||||
for batch in dataloader:
|
||||
batch["input_ids"].shape
|
||||
|
||||
@@ -1,29 +1,23 @@
|
||||
import os
|
||||
from argparse import ArgumentParser
|
||||
from dataclasses import dataclass
|
||||
from typing import Any, Callable, Dict, List, Optional, Tuple, Union
|
||||
from typing import Any, Dict, List, Optional, Tuple, Union
|
||||
|
||||
import evaluate
|
||||
import numpy as np
|
||||
import torch
|
||||
from models import RankGenModel
|
||||
from rank_datasets import DataCollatorForPairRank, HFSummary, RankGenCollator, WebGPT
|
||||
from rank_datasets import DataCollatorForPairRank, RankGenCollator
|
||||
from torch import nn
|
||||
from torch.utils.data import ConcatDataset, Dataset
|
||||
from transformers import (
|
||||
AdamW,
|
||||
AutoModelForSequenceClassification,
|
||||
DataCollator,
|
||||
EvalPrediction,
|
||||
PreTrainedModel,
|
||||
PreTrainedTokenizerBase,
|
||||
Trainer,
|
||||
TrainerCallback,
|
||||
TrainingArguments,
|
||||
get_cosine_schedule_with_warmup,
|
||||
get_linear_schedule_with_warmup,
|
||||
)
|
||||
from utils import argument_parsing, freeze_top_n_layers, get_tokenizer, train_val_dataset
|
||||
from utils import argument_parsing, freeze_top_n_layers, get_datasets, get_tokenizer
|
||||
|
||||
os.environ["WANDB_PROJECT"] = "reward-model"
|
||||
|
||||
@@ -32,11 +26,6 @@ parser = ArgumentParser()
|
||||
parser.add_argument("config", type=str)
|
||||
|
||||
|
||||
@dataclass
|
||||
class CustomTrainingArguments(TrainingArguments):
|
||||
loss_function: str = "rank"
|
||||
|
||||
|
||||
def compute_metrics(eval_pred):
|
||||
predictions, _ = eval_pred
|
||||
predictions = np.argmax(predictions, axis=1)
|
||||
@@ -60,31 +49,12 @@ class RankTrainer(Trainer):
|
||||
model: Union[PreTrainedModel, nn.Module] = None,
|
||||
model_name: str = None,
|
||||
args: Optional[TrainingArguments] = None,
|
||||
data_collator: Optional[DataCollator] = None,
|
||||
train_dataset: Optional[Dataset] = None,
|
||||
eval_dataset: Optional[Dataset] = None,
|
||||
tokenizer: Optional[PreTrainedTokenizerBase] = None,
|
||||
model_init: Callable[[], PreTrainedModel] = None,
|
||||
compute_metrics: Optional[Callable[[EvalPrediction], Dict]] = None,
|
||||
callbacks: Optional[List[TrainerCallback]] = None,
|
||||
optimizers: Tuple[torch.optim.Optimizer, torch.optim.lr_scheduler.LambdaLR] = (None, None),
|
||||
preprocess_logits_for_metrics: Callable[[torch.Tensor, torch.Tensor], torch.Tensor] = None,
|
||||
loss_function: str = "rank",
|
||||
**kwargs,
|
||||
):
|
||||
super().__init__(
|
||||
model,
|
||||
args,
|
||||
data_collator,
|
||||
train_dataset,
|
||||
eval_dataset,
|
||||
tokenizer,
|
||||
model_init,
|
||||
compute_metrics,
|
||||
callbacks,
|
||||
optimizers,
|
||||
preprocess_logits_for_metrics,
|
||||
)
|
||||
self.loss_fct = RankLoss() if args.loss_function == "rank" else nn.CrossEntropyLoss()
|
||||
self.loss_function = args.loss_function
|
||||
super().__init__(model, args, **kwargs)
|
||||
self.loss_fct = RankLoss() if loss_function == "rank" else nn.CrossEntropyLoss()
|
||||
self.loss_function = loss_function
|
||||
self.model_name = model_name
|
||||
|
||||
def compute_loss(self, model, inputs, return_outputs=False):
|
||||
@@ -163,11 +133,10 @@ if __name__ == "__main__":
|
||||
params = sum([np.prod(p.size()) for p in model_parameters])
|
||||
print("Number of trainable : {}M".format(int(params / 1e6)))
|
||||
|
||||
args = CustomTrainingArguments(
|
||||
args = TrainingArguments(
|
||||
output_dir=f"{model_name}-finetuned",
|
||||
num_train_epochs=training_conf["num_train_epochs"],
|
||||
warmup_steps=500,
|
||||
loss_function=training_conf["loss"],
|
||||
learning_rate=training_conf["learning_rate"],
|
||||
# half_precision_backend="apex",
|
||||
fp16=training_conf["fp16"],
|
||||
@@ -184,22 +153,9 @@ if __name__ == "__main__":
|
||||
save_steps=1000,
|
||||
report_to="wandb",
|
||||
)
|
||||
train_datasets, evals = [], {}
|
||||
if "webgpt" in training_conf["datasets"]:
|
||||
web_dataset = WebGPT()
|
||||
train, eval = train_val_dataset(web_dataset)
|
||||
train_datasets.append(train)
|
||||
evals["webgpt"] = eval
|
||||
if "hfsummary" in training_conf["datasets"]:
|
||||
sum_train = HFSummary(split="train")
|
||||
train_datasets.append(sum_train)
|
||||
sum_eval = HFSummary(split="valid1")
|
||||
assert len(sum_eval) > 0
|
||||
evals["hfsummary"] = sum_eval
|
||||
train = ConcatDataset(train_datasets)
|
||||
|
||||
tokenizer = get_tokenizer(training_conf["tokenizer_name"])
|
||||
|
||||
train, evals = get_datasets(training_conf["datasets"])
|
||||
if "rankgen" in model_name:
|
||||
collate_fn = RankGenCollator(tokenizer, max_length=training_conf["max_length"])
|
||||
else:
|
||||
@@ -224,8 +180,9 @@ if __name__ == "__main__":
|
||||
model=model,
|
||||
model_name=model_name,
|
||||
args=args,
|
||||
loss_function=training_conf["loss"],
|
||||
train_dataset=train,
|
||||
eval_dataset=eval,
|
||||
eval_dataset=evals,
|
||||
data_collator=collate_fn,
|
||||
tokenizer=tokenizer,
|
||||
compute_metrics=compute_metrics,
|
||||
|
||||
@@ -1,11 +1,13 @@
|
||||
import re
|
||||
from typing import AnyStr, List
|
||||
|
||||
import yaml
|
||||
from sklearn.model_selection import train_test_split
|
||||
from torch.utils.data import Subset
|
||||
from transformers import AutoTokenizer, T5Tokenizer
|
||||
|
||||
re_reference_remove = re.compile(r"\[([0-9])+\]|\[([0-9])+,([0-9])+\]")
|
||||
# @agoryuno contributed this
|
||||
re_reference_remove = re.compile(r"\[\d+(?:,\s*\d+)*?\]")
|
||||
|
||||
|
||||
def webgpt_return_format(row):
|
||||
@@ -97,6 +99,32 @@ def argument_parsing(parser):
|
||||
return params
|
||||
|
||||
|
||||
def get_datasets(dataset_list: List[AnyStr]):
|
||||
from rank_datasets import GPTJSynthetic, HFSummary, WebGPT
|
||||
from torch.utils.data import ConcatDataset
|
||||
|
||||
train_datasets, evals = [], {}
|
||||
for dataset_name in dataset_list:
|
||||
if "webgpt" == dataset_name:
|
||||
web_dataset = WebGPT()
|
||||
train, eval = train_val_dataset(web_dataset, 0.2)
|
||||
train_datasets.append(train)
|
||||
evals["webgpt"] = eval
|
||||
elif "hfsummary" == dataset_name:
|
||||
sum_train = HFSummary(split="train")
|
||||
train_datasets.append(sum_train)
|
||||
sum_eval = HFSummary(split="valid1")
|
||||
assert len(sum_eval) > 0
|
||||
evals["hfsummary"] = sum_eval
|
||||
elif "gptsynthetic" == dataset_name:
|
||||
dataset = GPTJSynthetic()
|
||||
train, eval = train_val_dataset(dataset, 0.1)
|
||||
train_datasets.append(train)
|
||||
evals["gptsynthetic"] = eval
|
||||
train = ConcatDataset(train_datasets)
|
||||
return train, evals
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
from transformers import AutoModelForSequenceClassification
|
||||
|
||||
|
||||
Reference in New Issue
Block a user