diff --git a/model/reward/instructor/rank_datasets.py b/model/reward/instructor/rank_datasets.py index 5e7da948..bd53eb02 100644 --- a/model/reward/instructor/rank_datasets.py +++ b/model/reward/instructor/rank_datasets.py @@ -19,7 +19,7 @@ """ from dataclasses import dataclass -from typing import Optional, Union +from typing import Dict, List, Optional, Union import numpy as np import torch @@ -35,7 +35,7 @@ class RankGenCollator: max_length: Optional[int] = None max_examples: Optional[int] = None - def __call__(self, batch: list[dict[str, str]]) -> dict[str, torch.Tensor]: + def __call__(self, batch: List[Dict[str, str]]) -> Dict[str, torch.Tensor]: prefixes = [] better_answers = [] worse_answers = [] @@ -193,3 +193,47 @@ class HFSummary(Dataset): valid_idx = np.random.choice(len(rows), self.max_comparison_per_sample) # optimize the format later return context + self.postfix_prompt, [r for idx, r in enumerate(rows) if idx in valid_idx] + + +class HFDataset(Dataset): + """ + This is a base huggingface dataset which written to support the + simplest pos-neg pair format + + we should do something like this for supervised datasets + """ + + def __init__( + self, dataset_name, question_field, pos_answer_field, neg_answer_field, subset=None, split=None + ) -> None: + super().__init__() + dataset = load_dataset(dataset_name, subset) + if split is not None: + dataset = dataset[split] + + self.questions = {} + self.index2question = {} + for row in dataset: + question = row[question_field].strip() + pos = row[pos_answer_field] + neg = row[neg_answer_field] + if question not in self.index2question: + self.index2question[len(self.index2question)] = question + + if question not in self.questions: + self.questions[question] = [] + self.questions[question].append((pos.strip(), neg.strip())) + + def __len__(self): + return len(self.index2question) + + def __getitem__(self, index): + question = self.index2question[index] + rows = self.questions[question] + # optimize the format later + return question, rows + + +class GPTJSynthetic(HFDataset): + def __init__(self) -> None: + super().__init__("Dahoas/synthetic-instruct-gptj-pairwise", "prompt", "chosen", "rejected", None, "train") diff --git a/model/reward/instructor/tests/test_dataset.py b/model/reward/instructor/tests/test_dataset.py index 746a3c1e..832aace3 100644 --- a/model/reward/instructor/tests/test_dataset.py +++ b/model/reward/instructor/tests/test_dataset.py @@ -1,5 +1,5 @@ from experimental_dataset import DataCollatorForSummaryScore, HFSummaryQuality -from rank_datasets import DataCollatorForPairRank, HFSummary, WebGPT +from rank_datasets import DataCollatorForPairRank, GPTJSynthetic, HFSummary, WebGPT from torch.utils.data import DataLoader from transformers import AutoTokenizer @@ -25,7 +25,7 @@ def test_webgpt(): print(batch["input_ids"].shape) -def test_hf_quality(): +def test_hf_summary_quality(): tokenizer = AutoTokenizer.from_pretrained("bigscience/mt0-large") collate_fn = DataCollatorForSummaryScore(tokenizer, max_length=200) @@ -35,6 +35,12 @@ def test_hf_quality(): print(batch["input_ids"].shape) -if __name__ == "__main__": - test_hf_quality() - # test_webgpt() +def test_gptj_dataset(): + dataset = GPTJSynthetic() + tokenizer = AutoTokenizer.from_pretrained("bigscience/mt0-large") + collate_fn = DataCollatorForPairRank(tokenizer, max_length=1024) + + print(len(dataset)) + dataloader = DataLoader(dataset, collate_fn=collate_fn, batch_size=32) + for batch in dataloader: + batch["input_ids"].shape diff --git a/model/reward/instructor/trainer.py b/model/reward/instructor/trainer.py index 68a58a38..940c0708 100644 --- a/model/reward/instructor/trainer.py +++ b/model/reward/instructor/trainer.py @@ -1,29 +1,23 @@ import os from argparse import ArgumentParser -from dataclasses import dataclass -from typing import Any, Callable, Dict, List, Optional, Tuple, Union +from typing import Any, Dict, List, Optional, Tuple, Union import evaluate import numpy as np import torch from models import RankGenModel -from rank_datasets import DataCollatorForPairRank, HFSummary, RankGenCollator, WebGPT +from rank_datasets import DataCollatorForPairRank, RankGenCollator from torch import nn -from torch.utils.data import ConcatDataset, Dataset from transformers import ( AdamW, AutoModelForSequenceClassification, - DataCollator, - EvalPrediction, PreTrainedModel, - PreTrainedTokenizerBase, Trainer, - TrainerCallback, TrainingArguments, get_cosine_schedule_with_warmup, get_linear_schedule_with_warmup, ) -from utils import argument_parsing, freeze_top_n_layers, get_tokenizer, train_val_dataset +from utils import argument_parsing, freeze_top_n_layers, get_datasets, get_tokenizer os.environ["WANDB_PROJECT"] = "reward-model" @@ -32,11 +26,6 @@ parser = ArgumentParser() parser.add_argument("config", type=str) -@dataclass -class CustomTrainingArguments(TrainingArguments): - loss_function: str = "rank" - - def compute_metrics(eval_pred): predictions, _ = eval_pred predictions = np.argmax(predictions, axis=1) @@ -60,31 +49,12 @@ class RankTrainer(Trainer): model: Union[PreTrainedModel, nn.Module] = None, model_name: str = None, args: Optional[TrainingArguments] = None, - data_collator: Optional[DataCollator] = None, - train_dataset: Optional[Dataset] = None, - eval_dataset: Optional[Dataset] = None, - tokenizer: Optional[PreTrainedTokenizerBase] = None, - model_init: Callable[[], PreTrainedModel] = None, - compute_metrics: Optional[Callable[[EvalPrediction], Dict]] = None, - callbacks: Optional[List[TrainerCallback]] = None, - optimizers: Tuple[torch.optim.Optimizer, torch.optim.lr_scheduler.LambdaLR] = (None, None), - preprocess_logits_for_metrics: Callable[[torch.Tensor, torch.Tensor], torch.Tensor] = None, + loss_function: str = "rank", + **kwargs, ): - super().__init__( - model, - args, - data_collator, - train_dataset, - eval_dataset, - tokenizer, - model_init, - compute_metrics, - callbacks, - optimizers, - preprocess_logits_for_metrics, - ) - self.loss_fct = RankLoss() if args.loss_function == "rank" else nn.CrossEntropyLoss() - self.loss_function = args.loss_function + super().__init__(model, args, **kwargs) + self.loss_fct = RankLoss() if loss_function == "rank" else nn.CrossEntropyLoss() + self.loss_function = loss_function self.model_name = model_name def compute_loss(self, model, inputs, return_outputs=False): @@ -163,11 +133,10 @@ if __name__ == "__main__": params = sum([np.prod(p.size()) for p in model_parameters]) print("Number of trainable : {}M".format(int(params / 1e6))) - args = CustomTrainingArguments( + args = TrainingArguments( output_dir=f"{model_name}-finetuned", num_train_epochs=training_conf["num_train_epochs"], warmup_steps=500, - loss_function=training_conf["loss"], learning_rate=training_conf["learning_rate"], # half_precision_backend="apex", fp16=training_conf["fp16"], @@ -184,22 +153,9 @@ if __name__ == "__main__": save_steps=1000, report_to="wandb", ) - train_datasets, evals = [], {} - if "webgpt" in training_conf["datasets"]: - web_dataset = WebGPT() - train, eval = train_val_dataset(web_dataset) - train_datasets.append(train) - evals["webgpt"] = eval - if "hfsummary" in training_conf["datasets"]: - sum_train = HFSummary(split="train") - train_datasets.append(sum_train) - sum_eval = HFSummary(split="valid1") - assert len(sum_eval) > 0 - evals["hfsummary"] = sum_eval - train = ConcatDataset(train_datasets) tokenizer = get_tokenizer(training_conf["tokenizer_name"]) - + train, evals = get_datasets(training_conf["datasets"]) if "rankgen" in model_name: collate_fn = RankGenCollator(tokenizer, max_length=training_conf["max_length"]) else: @@ -224,8 +180,9 @@ if __name__ == "__main__": model=model, model_name=model_name, args=args, + loss_function=training_conf["loss"], train_dataset=train, - eval_dataset=eval, + eval_dataset=evals, data_collator=collate_fn, tokenizer=tokenizer, compute_metrics=compute_metrics, diff --git a/model/reward/instructor/utils.py b/model/reward/instructor/utils.py index fe52c2ef..a6f3da4e 100644 --- a/model/reward/instructor/utils.py +++ b/model/reward/instructor/utils.py @@ -1,11 +1,13 @@ import re +from typing import AnyStr, List import yaml from sklearn.model_selection import train_test_split from torch.utils.data import Subset from transformers import AutoTokenizer, T5Tokenizer -re_reference_remove = re.compile(r"\[([0-9])+\]|\[([0-9])+,([0-9])+\]") +# @agoryuno contributed this +re_reference_remove = re.compile(r"\[\d+(?:,\s*\d+)*?\]") def webgpt_return_format(row): @@ -97,6 +99,32 @@ def argument_parsing(parser): return params +def get_datasets(dataset_list: List[AnyStr]): + from rank_datasets import GPTJSynthetic, HFSummary, WebGPT + from torch.utils.data import ConcatDataset + + train_datasets, evals = [], {} + for dataset_name in dataset_list: + if "webgpt" == dataset_name: + web_dataset = WebGPT() + train, eval = train_val_dataset(web_dataset, 0.2) + train_datasets.append(train) + evals["webgpt"] = eval + elif "hfsummary" == dataset_name: + sum_train = HFSummary(split="train") + train_datasets.append(sum_train) + sum_eval = HFSummary(split="valid1") + assert len(sum_eval) > 0 + evals["hfsummary"] = sum_eval + elif "gptsynthetic" == dataset_name: + dataset = GPTJSynthetic() + train, eval = train_val_dataset(dataset, 0.1) + train_datasets.append(train) + evals["gptsynthetic"] = eval + train = ConcatDataset(train_datasets) + return train, evals + + if __name__ == "__main__": from transformers import AutoModelForSequenceClassification