diff --git a/model/reward/instructor/models.py b/model/reward/instructor/models.py index dc7692bf..084cfa51 100644 --- a/model/reward/instructor/models.py +++ b/model/reward/instructor/models.py @@ -1,24 +1,28 @@ +# -*- coding: utf-8 -*- import torch from transformers import AutoModel -class RankGenModel(torch.nn.Module): - def __init__(self, model_name): - super().__init__() - self.rankgen_hf_hub = model_name - assert model_name in ["kalpeshk2011/rankgen-t5-xl-all", - "kalpeshk2011/rankgen-t5-xl-pg19", - "kalpeshk2011/rankgen-t5-base-all", - "kalpeshk2011/rankgen-t5-large-all"] - self.model = AutoModel.from_pretrained(self.rankgen_hf_hub, trust_remote_code=True) - def forward(self, prefixes, suffixes): - # print(list(self.model.parameters())) - # raise Exception("stop") - embedded_prefixes = self.model(**prefixes) - embedded_suffixes = self.model(**suffixes) - # take dot product of each row independently - dot_products = torch.sum(embedded_prefixes * embedded_suffixes, dim=1) - - # print(f"{embedded_prefixes.shape=}, {embedded_suffixes.shape=}, {prefixes['input_ids'].shape=}, {suffixes['input_ids'].shape=}, {embedded_prefixes=}, {embedded_suffixes=}, {dot_products=}") - # raise Exception("stop") - return dot_products \ No newline at end of file +class RankGenModel(torch.nn.Module): + def __init__(self, model_name): + super().__init__() + self.rankgen_hf_hub = model_name + assert model_name in [ + "kalpeshk2011/rankgen-t5-xl-all", + "kalpeshk2011/rankgen-t5-xl-pg19", + "kalpeshk2011/rankgen-t5-base-all", + "kalpeshk2011/rankgen-t5-large-all", + ] + self.model = AutoModel.from_pretrained(self.rankgen_hf_hub, trust_remote_code=True) + + def forward(self, prefixes, suffixes): + # print(list(self.model.parameters())) + # raise Exception("stop") + embedded_prefixes = self.model(**prefixes) + embedded_suffixes = self.model(**suffixes) + # take dot product of each row independently + dot_products = torch.sum(embedded_prefixes * embedded_suffixes, dim=1) + + # print(f"{embedded_prefixes.shape=}, {embedded_suffixes.shape=}, {prefixes['input_ids'].shape=}, {suffixes['input_ids'].shape=}, {embedded_prefixes=}, {embedded_suffixes=}, {dot_products=}") + # raise Exception("stop") + return dot_products diff --git a/model/reward/instructor/rank_datasets.py b/model/reward/instructor/rank_datasets.py index 965893ce..a63c9e02 100644 --- a/model/reward/instructor/rank_datasets.py +++ b/model/reward/instructor/rank_datasets.py @@ -23,19 +23,20 @@ from dataclasses import dataclass from typing import Optional, Union import numpy as np -from datasets import load_dataset import torch +from datasets import load_dataset from torch.utils.data import Dataset from transformers.tokenization_utils_base import PaddingStrategy, PreTrainedTokenizerBase + @dataclass -class RankGenCollator(): +class RankGenCollator: tokenizer: PreTrainedTokenizerBase padding: Union[bool, str, PaddingStrategy] = True max_length: Optional[int] = None max_examples: Optional[int] = None - def __call__(self, batch : list[dict[str, str]]) -> dict[str, torch.Tensor]: + def __call__(self, batch: list[dict[str, str]]) -> dict[str, torch.Tensor]: prefixes = [] better_answers = [] worse_answers = [] @@ -44,13 +45,18 @@ class RankGenCollator(): prefixes.append("pre " + question) better_answers.append("suffi " + pos) worse_answers.append("suffi " + neg) - - tokenized_prefixes = self.tokenizer(prefixes, return_tensors="pt", padding=self.padding, max_length=self.max_length, truncation=True) - tokenized_pos = self.tokenizer(better_answers, return_tensors="pt", padding=self.padding, max_length=self.max_length, truncation=True) - tokenized_neg = self.tokenizer(worse_answers, return_tensors="pt", padding=self.padding, max_length=self.max_length, truncation=True) - return {"prefix" : tokenized_prefixes, - "positive": tokenized_pos, - "negative": tokenized_neg} + + tokenized_prefixes = self.tokenizer( + prefixes, return_tensors="pt", padding=self.padding, max_length=self.max_length, truncation=True + ) + tokenized_pos = self.tokenizer( + better_answers, return_tensors="pt", padding=self.padding, max_length=self.max_length, truncation=True + ) + tokenized_neg = self.tokenizer( + worse_answers, return_tensors="pt", padding=self.padding, max_length=self.max_length, truncation=True + ) + return {"prefix": tokenized_prefixes, "positive": tokenized_pos, "negative": tokenized_neg} + @dataclass class DataCollatorForPairRank: diff --git a/model/reward/instructor/requirements.txt b/model/reward/instructor/requirements.txt index eaaf36e6..ca3935e4 100644 --- a/model/reward/instructor/requirements.txt +++ b/model/reward/instructor/requirements.txt @@ -1,7 +1,7 @@ datasets==2.8.0 evaluate==0.4.0 scikit-learn==1.2.0 +sentencepiece==0.1.97 torch>=1.12.1 transformers==4.25.1 wandb==0.13.7 -sentencepiece==0.1.97 diff --git a/model/reward/instructor/trainer.py b/model/reward/instructor/trainer.py index c6f58f66..124c28f8 100644 --- a/model/reward/instructor/trainer.py +++ b/model/reward/instructor/trainer.py @@ -7,11 +7,11 @@ from typing import Any, Callable, Dict, List, Optional, Tuple, Union import evaluate import numpy as np import torch +from models import RankGenModel from rank_datasets import DataCollatorForPairRank, HFSummary, RankGenCollator, WebGPT from torch import nn from torch.utils.data import ConcatDataset, Dataset from transformers import ( - AutoModel, AutoModelForSequenceClassification, DataCollator, EvalPrediction, @@ -21,7 +21,6 @@ from transformers import ( TrainerCallback, TrainingArguments, ) -from models import RankGenModel from utils import argument_parsing, freeze_top_n_layers, get_tokenizer, train_val_dataset os.environ["WANDB_PROJECT"] = "reward-model" @@ -95,7 +94,7 @@ class RankTrainer(Trainer): loss = self.loss_fct(positive_outputs, negative_outputs) else: raise NotImplementedError("Only ranking loss has been implemented for rankgen model") - outputs = torch.hstack((positive_outputs, negative_outputs)) #logits + outputs = torch.hstack((positive_outputs, negative_outputs)) # logits else: outputs = model(**inputs) logits = outputs.get("logits").view(-1, 2) @@ -133,7 +132,7 @@ class RankTrainer(Trainer): loss = self.loss_fct(positive_outputs, negative_outputs) else: raise NotImplementedError("Only ranking loss has been implemented for rankgen model") - outputs = torch.hstack((positive_outputs, negative_outputs)) # logits + outputs = torch.hstack((positive_outputs, negative_outputs)) # logits return (loss, outputs, None) else: # compute loss on predict data @@ -161,7 +160,7 @@ if __name__ == "__main__": model_parameters = filter(lambda p: p.requires_grad, model.parameters()) params = sum([np.prod(p.size()) for p in model_parameters]) print("Number of trainable : {}M".format(int(params / 1e6))) - + args = CustomTrainingArguments( output_dir=f"{model_name}-finetuned", num_train_epochs=training_conf["num_train_epochs"], @@ -196,20 +195,16 @@ if __name__ == "__main__": assert len(sum_eval) > 0 evals["hfsummary"] = sum_eval train = ConcatDataset(train_datasets) - + if "tokenizer_name" in training_conf: - tokenizer=get_tokenizer(training_conf["tokenizer_name"]) + tokenizer = get_tokenizer(training_conf["tokenizer_name"]) else: tokenizer = get_tokenizer(model_name) - + if "rankgen" in model_name: - collate_fn = RankGenCollator( - tokenizer, max_length=training_conf["max_length"] - ) + collate_fn = RankGenCollator(tokenizer, max_length=training_conf["max_length"]) else: - collate_fn = DataCollatorForPairRank( - tokenizer, max_length=training_conf["max_length"] - ) + collate_fn = DataCollatorForPairRank(tokenizer, max_length=training_conf["max_length"]) assert len(evals) > 0 trainer = RankTrainer( model=model, diff --git a/model/reward/instructor/utils.py b/model/reward/instructor/utils.py index 59165598..780ac9c8 100644 --- a/model/reward/instructor/utils.py +++ b/model/reward/instructor/utils.py @@ -26,7 +26,7 @@ def webgpt_return_format(row): def get_tokenizer(tokenizer_name): - if "t5" in tokenizer_name: #rankgen + if "t5" in tokenizer_name: # rankgen tokenizer = T5Tokenizer.from_pretrained(tokenizer_name, truncation_side="left") else: tokenizer = AutoTokenizer.from_pretrained(tokenizer_name)