Merge pull request #313 from bth5032/bth5032/78-blackcat-trainer

Bth5032/78 blackcat trainer
This commit is contained in:
theblackcat102
2023-01-05 10:50:04 +08:00
committed by GitHub
7 changed files with 162 additions and 32 deletions
@@ -0,0 +1,16 @@
model_name: kalpeshk2011/rankgen-t5-base-all
tokenizer_name: google/t5-v1_1-base
learning_rate: 6e-6
gradient_checkpointing: false
fp16: true
gradient_accumulation_steps: 16
per_device_train_batch_size: 2
warmup_steps: 600
freeze_layer: 20
eval_steps: 200
save_steps: 500
max_length: 400
num_train_epochs: 2
datasets:
- webgpt
- hfsummary
@@ -0,0 +1,19 @@
model_name: kalpeshk2011/rankgen-t5-base-all
# model_name: kalpeshk2011/rankgen-t5-xl-all
# model_name: kalpeshk2011/rankgen-t5-xl-pg19
# model_name: kalpeshk2011/rankgen-t5-large-all
tokenizer_name: google/t5-v1_1-base
learning_rate: 6e-6
gradient_checkpointing: false
fp16: false
gradient_accumulation_steps: 16
per_device_train_batch_size: 2
warmup_steps: 600
freeze_layer: 20
eval_steps: 200
save_steps: 500
max_length: 400
num_train_epochs: 2
datasets:
- webgpt
- hfsummary
+27
View File
@@ -0,0 +1,27 @@
import torch
from transformers import AutoModel
class RankGenModel(torch.nn.Module):
def __init__(self, model_name):
super().__init__()
self.rankgen_hf_hub = model_name
assert model_name in [
"kalpeshk2011/rankgen-t5-xl-all",
"kalpeshk2011/rankgen-t5-xl-pg19",
"kalpeshk2011/rankgen-t5-base-all",
"kalpeshk2011/rankgen-t5-large-all",
]
self.model = AutoModel.from_pretrained(self.rankgen_hf_hub, trust_remote_code=True)
def forward(self, prefixes, suffixes):
# print(list(self.model.parameters()))
# raise Exception("stop")
embedded_prefixes = self.model(**prefixes)
embedded_suffixes = self.model(**suffixes)
# take dot product of each row independently
dot_products = torch.sum(embedded_prefixes * embedded_suffixes, dim=1)
# print(f"{embedded_prefixes.shape=}, {embedded_suffixes.shape=}, {prefixes['input_ids'].shape=}, {suffixes['input_ids'].shape=}, {embedded_prefixes=}, {embedded_suffixes=}, {dot_products=}")
# raise Exception("stop")
return dot_products
+30
View File
@@ -22,11 +22,41 @@ from dataclasses import dataclass
from typing import Optional, Union
import numpy as np
import torch
from datasets import load_dataset
from torch.utils.data import Dataset
from transformers.tokenization_utils_base import PaddingStrategy, PreTrainedTokenizerBase
@dataclass
class RankGenCollator:
tokenizer: PreTrainedTokenizerBase
padding: Union[bool, str, PaddingStrategy] = True
max_length: Optional[int] = None
max_examples: Optional[int] = None
def __call__(self, batch: list[dict[str, str]]) -> dict[str, torch.Tensor]:
prefixes = []
better_answers = []
worse_answers = []
for question, pairs in batch:
for (pos, neg) in pairs:
prefixes.append("pre " + question)
better_answers.append("suffi " + pos)
worse_answers.append("suffi " + neg)
tokenized_prefixes = self.tokenizer(
prefixes, return_tensors="pt", padding=self.padding, max_length=self.max_length, truncation=True
)
tokenized_pos = self.tokenizer(
better_answers, return_tensors="pt", padding=self.padding, max_length=self.max_length, truncation=True
)
tokenized_neg = self.tokenizer(
worse_answers, return_tensors="pt", padding=self.padding, max_length=self.max_length, truncation=True
)
return {"prefix": tokenized_prefixes, "positive": tokenized_pos, "negative": tokenized_neg}
@dataclass
class DataCollatorForPairRank:
"""
+2 -1
View File
@@ -1,6 +1,7 @@
datasets==2.8.0
evaluate==0.4.0
scikit-learn==1.2.0
torch==1.12.1+cu116
sentencepiece==0.1.97
torch>=1.12.1
transformers==4.25.1
wandb==0.13.7
+57 -26
View File
@@ -6,7 +6,8 @@ from typing import Any, Callable, Dict, List, Optional, Tuple, Union
import evaluate
import numpy as np
import torch
from rank_datasets import DataCollatorForPairRank, HFSummary, WebGPT
from models import RankGenModel
from rank_datasets import DataCollatorForPairRank, HFSummary, RankGenCollator, WebGPT
from torch import nn
from torch.utils.data import ConcatDataset, Dataset
from transformers import (
@@ -46,14 +47,16 @@ class RankLoss(nn.Module):
self.log_sigmoid = nn.LogSigmoid()
def forward(self, pos, neg):
return -self.log_sigmoid(pos - neg + self.eps).mean()
loss = -self.log_sigmoid(pos - neg + self.eps).mean()
return loss
class RankTrainer(Trainer):
def __init__(
self,
model: Union[PreTrainedModel, nn.Module] = None,
args: TrainingArguments = None,
model_name: str = None,
args: Optional[TrainingArguments] = None,
data_collator: Optional[DataCollator] = None,
train_dataset: Optional[Dataset] = None,
eval_dataset: Optional[Dataset] = None,
@@ -79,15 +82,25 @@ class RankTrainer(Trainer):
)
self.loss_fct = RankLoss() if args.loss_function == "rank" else nn.CrossEntropyLoss()
self.loss_function = args.loss_function
self.model_name = model_name
def compute_loss(self, model, inputs, return_outputs=False):
# forward pass
outputs = model(**inputs)
logits = outputs.get("logits").view(-1, 2)
if self.loss_function == "rank":
loss = self.loss_fct(logits[:, 0], logits[:, 1])
if "rankgen" in self.model_name:
positive_outputs = model(inputs["prefix"], inputs["positive"])
negative_outputs = model(inputs["prefix"], inputs["negative"])
if self.loss_function == "rank":
loss = self.loss_fct(positive_outputs, negative_outputs)
else:
raise NotImplementedError("Only ranking loss has been implemented for rankgen model")
outputs = torch.hstack((positive_outputs, negative_outputs)) # logits
else:
loss = self.loss_fct(logits, torch.zeros(logits.shape[0], device=logits.device, dtype=torch.long))
outputs = model(**inputs)
logits = outputs.get("logits").view(-1, 2)
if self.loss_function == "rank":
loss = self.loss_fct(logits[:, 0], logits[:, 1])
else:
loss = self.loss_fct(logits, torch.zeros(logits.shape[0], device=logits.device, dtype=torch.long))
return (loss, outputs) if return_outputs else loss
@@ -109,24 +122,37 @@ class RankTrainer(Trainer):
prediction_loss_only: bool,
ignore_keys: Optional[List[str]] = None,
) -> Tuple[Optional[torch.Tensor], Optional[torch.Tensor], Optional[torch.Tensor]]:
with torch.inference_mode():
if "rankgen" in self.model_name:
inputs = self._prepare_inputs(inputs)
positive_outputs = model(inputs["prefix"], inputs["positive"])
negative_outputs = model(inputs["prefix"], inputs["negative"])
if self.loss_function == "rank":
loss = self.loss_fct(positive_outputs, negative_outputs)
else:
raise NotImplementedError("Only ranking loss has been implemented for rankgen model")
outputs = torch.hstack((positive_outputs, negative_outputs)) # logits
return (loss, outputs, None)
else:
# compute loss on predict data
loss, logits = self._compute_loss(model, inputs)
with torch.no_grad():
# compute loss on predict data
loss, logits = self._compute_loss(model, inputs)
loss = loss.mean().detach()
labels = torch.zeros(logits.shape[0], device=logits.device, dtype=torch.long)
if self.args.prediction_loss_only:
return (loss, None, None)
loss = loss.mean().detach()
labels = torch.zeros(logits.shape[0], device=logits.device, dtype=torch.long)
if self.args.prediction_loss_only:
return (loss, None, None)
return (loss, logits, labels)
return (loss, logits, labels)
if __name__ == "__main__":
training_conf = argument_parsing(parser)
model_name = training_conf["model_name"]
model = AutoModelForSequenceClassification.from_pretrained(model_name, num_labels=1, problem_type="regression")
if "rankgen-t5" in model_name:
model = RankGenModel(model_name)
else:
model = AutoModelForSequenceClassification.from_pretrained(model_name, num_labels=1, problem_type="regression")
if "freeze_layer" in training_conf:
num_layer = training_conf["freeze_layer"]
model = freeze_top_n_layers(model, num_layer)
@@ -134,7 +160,6 @@ if __name__ == "__main__":
params = sum([np.prod(p.size()) for p in model_parameters])
print("Number of trainable : {}M".format(int(params / 1e6)))
tokenizer = get_tokenizer(model_name)
args = CustomTrainingArguments(
output_dir=f"{model_name}-finetuned",
num_train_epochs=training_conf["num_train_epochs"],
@@ -142,7 +167,7 @@ if __name__ == "__main__":
loss_function=training_conf["loss"],
learning_rate=training_conf["learning_rate"],
# half_precision_backend="apex",
fp16=True,
fp16=training_conf["fp16"],
gradient_checkpointing=training_conf["gradient_checkpointing"],
gradient_accumulation_steps=training_conf["gradient_accumulation_steps"],
per_device_train_batch_size=training_conf["per_device_train_batch_size"],
@@ -154,7 +179,7 @@ if __name__ == "__main__":
evaluation_strategy="steps",
eval_steps=training_conf["eval_steps"],
save_steps=1000,
report_to="wandb",
report_to="local",
)
train_datasets, evals = [], {}
if "webgpt" in training_conf["datasets"]:
@@ -169,17 +194,23 @@ if __name__ == "__main__":
assert len(sum_eval) > 0
evals["hfsummary"] = sum_eval
train = ConcatDataset(train_datasets)
collate_fn = DataCollatorForPairRank(
tokenizer, max_length=training_conf["max_length"], drop_token_type="galactica" in model_name
)
tokenizer = get_tokenizer(training_conf["tokenizer_name"])
if "rankgen" in model_name:
collate_fn = RankGenCollator(tokenizer, max_length=training_conf["max_length"])
else:
collate_fn = DataCollatorForPairRank(tokenizer, max_length=training_conf["max_length"])
assert len(evals) > 0
trainer = RankTrainer(
model,
args,
model=model,
model_name=model_name,
args=args,
train_dataset=train,
eval_dataset=eval,
data_collator=collate_fn,
tokenizer=tokenizer,
compute_metrics=compute_metrics,
)
# trainer.evaluate()
trainer.train()
+11 -5
View File
@@ -3,7 +3,7 @@ import re
import yaml
from sklearn.model_selection import train_test_split
from torch.utils.data import Subset
from transformers import AutoTokenizer
from transformers import AutoTokenizer, T5Tokenizer
re_reference_remove = re.compile(r"\[([0-9])+\]|\[([0-9])+,([0-9])+\]")
@@ -25,7 +25,10 @@ def webgpt_return_format(row):
def get_tokenizer(tokenizer_name):
tokenizer = AutoTokenizer.from_pretrained(tokenizer_name)
if "t5" in tokenizer_name: # rankgen
tokenizer = T5Tokenizer.from_pretrained(tokenizer_name, truncation_side="left")
else:
tokenizer = AutoTokenizer.from_pretrained(tokenizer_name)
if "galactica" in tokenizer_name:
tokenizer.add_special_tokens({"pad_token": "<pad>", "eos_token": "</s>"})
@@ -67,6 +70,10 @@ def freeze_top_n_layers(model, target_layers):
def argument_parsing(parser):
args = parser.parse_args()
with open(args.config, "r", encoding="utf-8") as f:
training_conf = yaml.safe_load(f.read())
default_params = {
"num_train_epochs": 4,
"learning_rate": 3e-5,
@@ -78,10 +85,9 @@ def argument_parsing(parser):
"gradient_accumulation_steps": 8,
"gradient_checkpointing": False,
"datasets": ["webgpt"],
"fp16": True,
"tokenizer_name": training_conf["model_name"],
}
args = parser.parse_args()
with open(args.config, "r", encoding="utf-8") as f:
training_conf = yaml.safe_load(f.read())
params = {**default_params, **training_conf}
params["gradient_accumulation_steps"] = int(params["gradient_accumulation_steps"])