mirror of
https://github.com/wassname/Open-Assistant.git
synced 2026-06-30 16:40:05 +08:00
testing rankgen integration into instructor trainer
This commit is contained in:
@@ -0,0 +1,15 @@
|
||||
model_name: kalpeshk2011/rankgen-t5-base-all
|
||||
tokenizer_name: google/t5-v1_1-base
|
||||
learning_rate: 6e-6
|
||||
gradient_checkpointing: false
|
||||
gradient_accumulation_steps: 16
|
||||
per_device_train_batch_size: 3
|
||||
warmup_steps: 600
|
||||
freeze_layer: 20
|
||||
eval_steps: 200
|
||||
save_steps: 500
|
||||
max_length: 400
|
||||
num_train_epochs: 2
|
||||
datasets:
|
||||
- webgpt
|
||||
- hfsummary
|
||||
@@ -0,0 +1,22 @@
|
||||
import torch
|
||||
from transformers import AutoModel
|
||||
|
||||
class RankGenModel(torch.nn.Module):
|
||||
def __init__(self, model_name):
|
||||
super().__init__()
|
||||
self.rankgen_hf_hub = model_name
|
||||
assert model_name in ["kalpeshk2011/rankgen-t5-xl-all",
|
||||
"kalpeshk2011/rankgen-t5-xl-pg19",
|
||||
"kalpeshk2011/rankgen-t5-base-all",
|
||||
"kalpeshk2011/rankgen-t5-large-all"]
|
||||
self.model = AutoModel.from_pretrained(self.rankgen_hf_hub, trust_remote_code=True)
|
||||
|
||||
def forward(self, prefixes, suffixes):
|
||||
embedded_prefixes = self.model(**prefixes)
|
||||
embedded_suffixes = self.model(**suffixes)
|
||||
# take dot product of each row independently
|
||||
dot_products = torch.sum(embedded_prefixes * embedded_suffixes, dim=1)
|
||||
|
||||
print(f"{prefixes=}, {suffixes=}, {embedded_prefixes=}, {embedded_suffixes=}, {dot_products=}")
|
||||
|
||||
return dot_products
|
||||
@@ -24,9 +24,32 @@ from typing import Optional, Union
|
||||
|
||||
import numpy as np
|
||||
from datasets import load_dataset
|
||||
import torch
|
||||
from torch.utils.data import Dataset
|
||||
from transformers.tokenization_utils_base import PaddingStrategy, PreTrainedTokenizerBase
|
||||
|
||||
@dataclass
|
||||
class RankGenCollator():
|
||||
tokenizer: PreTrainedTokenizerBase
|
||||
padding: Union[bool, str, PaddingStrategy] = True
|
||||
max_length: Optional[int] = None
|
||||
|
||||
def __call__(self, batch : list[dict[str, str]]) -> dict[str, torch.Tensor]:
|
||||
prefixes = []
|
||||
better_answers = []
|
||||
worse_answers = []
|
||||
for question, pairs in batch:
|
||||
for (pos, neg) in pairs:
|
||||
prefixes.append("pre " + question)
|
||||
better_answers.append("suffi " + pos)
|
||||
worse_answers.append("suffi " + neg)
|
||||
|
||||
tokenized_prefixes = self.tokenizer(prefixes, return_tensors="pt", padding=self.padding, max_length=self.max_length, truncation=True)
|
||||
tokenized_pos = self.tokenizer(better_answers, return_tensors="pt", padding=self.padding, max_length=self.max_length, truncation=True)
|
||||
tokenized_neg = self.tokenizer(worse_answers, return_tensors="pt", padding=self.padding, max_length=self.max_length, truncation=True)
|
||||
return {"prefix" : tokenized_prefixes,
|
||||
"positive": tokenized_pos,
|
||||
"negative": tokenized_neg}
|
||||
|
||||
@dataclass
|
||||
class DataCollatorForPairRank:
|
||||
|
||||
@@ -1,6 +1,7 @@
|
||||
datasets==2.8.0
|
||||
evaluate==0.4.0
|
||||
scikit-learn==1.2.0
|
||||
torch==1.12.1+cu116
|
||||
torch>=1.12.1
|
||||
transformers==4.25.1
|
||||
wandb==0.13.7
|
||||
sentencepiece==0.1.97
|
||||
|
||||
@@ -7,10 +7,11 @@ from typing import Any, Callable, Dict, List, Optional, Tuple, Union
|
||||
import evaluate
|
||||
import numpy as np
|
||||
import torch
|
||||
from rank_datasets import DataCollatorForPairRank, HFSummary, WebGPT
|
||||
from rank_datasets import DataCollatorForPairRank, HFSummary, RankGenCollator, WebGPT
|
||||
from torch import nn
|
||||
from torch.utils.data import ConcatDataset, Dataset
|
||||
from transformers import (
|
||||
AutoModel,
|
||||
AutoModelForSequenceClassification,
|
||||
DataCollator,
|
||||
EvalPrediction,
|
||||
@@ -20,6 +21,7 @@ from transformers import (
|
||||
TrainerCallback,
|
||||
TrainingArguments,
|
||||
)
|
||||
from models import RankGenModel
|
||||
from utils import argument_parsing, freeze_top_n_layers, get_tokenizer, train_val_dataset
|
||||
|
||||
os.environ["WANDB_PROJECT"] = "reward-model"
|
||||
@@ -47,14 +49,17 @@ class RankLoss(nn.Module):
|
||||
self.log_sigmoid = nn.LogSigmoid()
|
||||
|
||||
def forward(self, pos, neg):
|
||||
return -self.log_sigmoid(pos - neg + self.eps).mean()
|
||||
loss = -self.log_sigmoid(pos - neg + self.eps).mean()
|
||||
print(f"in loss {pos=}, {neg=}, {loss=}")
|
||||
return loss
|
||||
|
||||
|
||||
class RankTrainer(Trainer):
|
||||
def __init__(
|
||||
self,
|
||||
model: Union[PreTrainedModel, nn.Module] = None,
|
||||
args: TrainingArguments = None,
|
||||
model_name: str = None,
|
||||
args: Optional[TrainingArguments] = None,
|
||||
data_collator: Optional[DataCollator] = None,
|
||||
train_dataset: Optional[Dataset] = None,
|
||||
eval_dataset: Optional[Dataset] = None,
|
||||
@@ -80,15 +85,26 @@ class RankTrainer(Trainer):
|
||||
)
|
||||
self.loss_fct = RankLoss() if args.loss_function == "rank" else nn.CrossEntropyLoss()
|
||||
self.loss_function = args.loss_function
|
||||
self.model_name = model_name
|
||||
|
||||
def compute_loss(self, model, inputs, return_outputs=False):
|
||||
# forward pass
|
||||
outputs = model(**inputs)
|
||||
logits = outputs.get("logits").view(-1, 2)
|
||||
if self.loss_function == "rank":
|
||||
loss = self.loss_fct(logits[:, 0], logits[:, 1])
|
||||
if "rankgen" in self.model_name:
|
||||
print(f"{inputs=}")
|
||||
positive_outputs = model(inputs["prefix"], inputs["positive"])
|
||||
negative_outputs = model(inputs["prefix"], inputs["negative"])
|
||||
if self.loss_function == "rank":
|
||||
loss = self.loss_fct(positive_outputs, negative_outputs)
|
||||
else:
|
||||
raise NotImplementedError("Only ranking loss has been implemented for rankgen model")
|
||||
outputs = torch.hstack((positive_outputs, negative_outputs)) #logits
|
||||
else:
|
||||
loss = self.loss_fct(logits, torch.zeros(logits.shape[0], device=logits.device, dtype=torch.long))
|
||||
outputs = model(**inputs)
|
||||
logits = outputs.get("logits").view(-1, 2)
|
||||
if self.loss_function == "rank":
|
||||
loss = self.loss_fct(logits[:, 0], logits[:, 1])
|
||||
else:
|
||||
loss = self.loss_fct(logits, torch.zeros(logits.shape[0], device=logits.device, dtype=torch.long))
|
||||
|
||||
return (loss, outputs) if return_outputs else loss
|
||||
|
||||
@@ -110,32 +126,44 @@ class RankTrainer(Trainer):
|
||||
prediction_loss_only: bool,
|
||||
ignore_keys: Optional[List[str]] = None,
|
||||
) -> Tuple[Optional[torch.Tensor], Optional[torch.Tensor], Optional[torch.Tensor]]:
|
||||
with torch.inference_mode():
|
||||
if "rankgen" in self.model_name:
|
||||
inputs = self._prepare_inputs(inputs)
|
||||
positive_outputs = model(inputs["prefix"], inputs["positive"])
|
||||
negative_outputs = model(inputs["prefix"], inputs["negative"])
|
||||
if self.loss_function == "rank":
|
||||
loss = self.loss_fct(positive_outputs, negative_outputs)
|
||||
else:
|
||||
raise NotImplementedError("Only ranking loss has been implemented for rankgen model")
|
||||
outputs = torch.hstack((positive_outputs, negative_outputs)) # logits
|
||||
return (loss, outputs, None)
|
||||
else:
|
||||
# compute loss on predict data
|
||||
loss, logits = self._compute_loss(model, inputs)
|
||||
|
||||
with torch.no_grad():
|
||||
# compute loss on predict data
|
||||
loss, logits = self._compute_loss(model, inputs)
|
||||
loss = loss.mean().detach()
|
||||
labels = torch.zeros(logits.shape[0], device=logits.device, dtype=torch.long)
|
||||
if self.args.prediction_loss_only:
|
||||
return (loss, None, None)
|
||||
|
||||
loss = loss.mean().detach()
|
||||
labels = torch.zeros(logits.shape[0], device=logits.device, dtype=torch.long)
|
||||
if self.args.prediction_loss_only:
|
||||
return (loss, None, None)
|
||||
|
||||
return (loss, logits, labels)
|
||||
return (loss, logits, labels)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
training_conf = argument_parsing(parser)
|
||||
|
||||
model_name = training_conf["model_name"]
|
||||
model = AutoModelForSequenceClassification.from_pretrained(model_name, num_labels=1, problem_type="regression")
|
||||
if "rankgen-t5" in model_name:
|
||||
model = RankGenModel(model_name)
|
||||
else:
|
||||
model = AutoModelForSequenceClassification.from_pretrained(model_name, num_labels=1, problem_type="regression")
|
||||
if "freeze_layer" in training_conf:
|
||||
num_layer = training_conf["freeze_layer"]
|
||||
model = freeze_top_n_layers(model, num_layer)
|
||||
model_parameters = filter(lambda p: p.requires_grad, model.parameters())
|
||||
params = sum([np.prod(p.size()) for p in model_parameters])
|
||||
print("Number of trainable : {}M".format(int(params / 1e6)))
|
||||
|
||||
tokenizer = get_tokenizer(model_name)
|
||||
|
||||
args = CustomTrainingArguments(
|
||||
output_dir=f"{model_name}-finetuned",
|
||||
num_train_epochs=training_conf["num_train_epochs"],
|
||||
@@ -170,17 +198,30 @@ if __name__ == "__main__":
|
||||
assert len(sum_eval) > 0
|
||||
evals["hfsummary"] = sum_eval
|
||||
train = ConcatDataset(train_datasets)
|
||||
collate_fn = DataCollatorForPairRank(
|
||||
tokenizer, max_length=training_conf["max_length"], drop_token_type="galactica" in model_name
|
||||
)
|
||||
|
||||
if "tokenizer_name" in training_conf:
|
||||
tokenizer=get_tokenizer(training_conf["tokenizer_name"])
|
||||
else:
|
||||
tokenizer = get_tokenizer(model_name)
|
||||
|
||||
if "rankgen" in model_name:
|
||||
collate_fn = RankGenCollator(
|
||||
tokenizer, max_length=training_conf["max_length"]
|
||||
)
|
||||
else:
|
||||
collate_fn = DataCollatorForPairRank(
|
||||
tokenizer, max_length=training_conf["max_length"]
|
||||
)
|
||||
assert len(evals) > 0
|
||||
trainer = RankTrainer(
|
||||
model,
|
||||
args,
|
||||
model=model,
|
||||
model_name=model_name,
|
||||
args=args,
|
||||
train_dataset=train,
|
||||
eval_dataset=eval,
|
||||
data_collator=collate_fn,
|
||||
tokenizer=tokenizer,
|
||||
compute_metrics=compute_metrics,
|
||||
)
|
||||
# trainer.evaluate()
|
||||
trainer.train()
|
||||
|
||||
@@ -4,7 +4,7 @@ import re
|
||||
import yaml
|
||||
from sklearn.model_selection import train_test_split
|
||||
from torch.utils.data import Subset
|
||||
from transformers import AutoTokenizer
|
||||
from transformers import AutoTokenizer, T5Tokenizer
|
||||
|
||||
re_reference_remove = re.compile(r"\[([0-9])+\]|\[([0-9])+,([0-9])+\]")
|
||||
|
||||
@@ -26,7 +26,10 @@ def webgpt_return_format(row):
|
||||
|
||||
|
||||
def get_tokenizer(tokenizer_name):
|
||||
tokenizer = AutoTokenizer.from_pretrained(tokenizer_name)
|
||||
if "t5" in tokenizer_name: #rankgen
|
||||
tokenizer = T5Tokenizer.from_pretrained(tokenizer_name, truncation_side="left")
|
||||
else:
|
||||
tokenizer = AutoTokenizer.from_pretrained(tokenizer_name)
|
||||
if "galactica" in tokenizer_name:
|
||||
tokenizer.add_special_tokens({"pad_token": "<pad>", "eos_token": "</s>"})
|
||||
|
||||
|
||||
Reference in New Issue
Block a user