fixed linting

This commit is contained in:
Bobak Hashemi
2023-01-03 20:47:33 -05:00
parent 45c147362e
commit 4569bcf354
5 changed files with 51 additions and 46 deletions
+24 -20
View File
@@ -1,24 +1,28 @@
# -*- coding: utf-8 -*-
import torch
from transformers import AutoModel
class RankGenModel(torch.nn.Module):
def __init__(self, model_name):
super().__init__()
self.rankgen_hf_hub = model_name
assert model_name in ["kalpeshk2011/rankgen-t5-xl-all",
"kalpeshk2011/rankgen-t5-xl-pg19",
"kalpeshk2011/rankgen-t5-base-all",
"kalpeshk2011/rankgen-t5-large-all"]
self.model = AutoModel.from_pretrained(self.rankgen_hf_hub, trust_remote_code=True)
def forward(self, prefixes, suffixes):
# print(list(self.model.parameters()))
# raise Exception("stop")
embedded_prefixes = self.model(**prefixes)
embedded_suffixes = self.model(**suffixes)
# take dot product of each row independently
dot_products = torch.sum(embedded_prefixes * embedded_suffixes, dim=1)
# print(f"{embedded_prefixes.shape=}, {embedded_suffixes.shape=}, {prefixes['input_ids'].shape=}, {suffixes['input_ids'].shape=}, {embedded_prefixes=}, {embedded_suffixes=}, {dot_products=}")
# raise Exception("stop")
return dot_products
class RankGenModel(torch.nn.Module):
def __init__(self, model_name):
super().__init__()
self.rankgen_hf_hub = model_name
assert model_name in [
"kalpeshk2011/rankgen-t5-xl-all",
"kalpeshk2011/rankgen-t5-xl-pg19",
"kalpeshk2011/rankgen-t5-base-all",
"kalpeshk2011/rankgen-t5-large-all",
]
self.model = AutoModel.from_pretrained(self.rankgen_hf_hub, trust_remote_code=True)
def forward(self, prefixes, suffixes):
# print(list(self.model.parameters()))
# raise Exception("stop")
embedded_prefixes = self.model(**prefixes)
embedded_suffixes = self.model(**suffixes)
# take dot product of each row independently
dot_products = torch.sum(embedded_prefixes * embedded_suffixes, dim=1)
# print(f"{embedded_prefixes.shape=}, {embedded_suffixes.shape=}, {prefixes['input_ids'].shape=}, {suffixes['input_ids'].shape=}, {embedded_prefixes=}, {embedded_suffixes=}, {dot_products=}")
# raise Exception("stop")
return dot_products
+16 -10
View File
@@ -23,19 +23,20 @@ from dataclasses import dataclass
from typing import Optional, Union
import numpy as np
from datasets import load_dataset
import torch
from datasets import load_dataset
from torch.utils.data import Dataset
from transformers.tokenization_utils_base import PaddingStrategy, PreTrainedTokenizerBase
@dataclass
class RankGenCollator():
class RankGenCollator:
tokenizer: PreTrainedTokenizerBase
padding: Union[bool, str, PaddingStrategy] = True
max_length: Optional[int] = None
max_examples: Optional[int] = None
def __call__(self, batch : list[dict[str, str]]) -> dict[str, torch.Tensor]:
def __call__(self, batch: list[dict[str, str]]) -> dict[str, torch.Tensor]:
prefixes = []
better_answers = []
worse_answers = []
@@ -44,13 +45,18 @@ class RankGenCollator():
prefixes.append("pre " + question)
better_answers.append("suffi " + pos)
worse_answers.append("suffi " + neg)
tokenized_prefixes = self.tokenizer(prefixes, return_tensors="pt", padding=self.padding, max_length=self.max_length, truncation=True)
tokenized_pos = self.tokenizer(better_answers, return_tensors="pt", padding=self.padding, max_length=self.max_length, truncation=True)
tokenized_neg = self.tokenizer(worse_answers, return_tensors="pt", padding=self.padding, max_length=self.max_length, truncation=True)
return {"prefix" : tokenized_prefixes,
"positive": tokenized_pos,
"negative": tokenized_neg}
tokenized_prefixes = self.tokenizer(
prefixes, return_tensors="pt", padding=self.padding, max_length=self.max_length, truncation=True
)
tokenized_pos = self.tokenizer(
better_answers, return_tensors="pt", padding=self.padding, max_length=self.max_length, truncation=True
)
tokenized_neg = self.tokenizer(
worse_answers, return_tensors="pt", padding=self.padding, max_length=self.max_length, truncation=True
)
return {"prefix": tokenized_prefixes, "positive": tokenized_pos, "negative": tokenized_neg}
@dataclass
class DataCollatorForPairRank:
+1 -1
View File
@@ -1,7 +1,7 @@
datasets==2.8.0
evaluate==0.4.0
scikit-learn==1.2.0
sentencepiece==0.1.97
torch>=1.12.1
transformers==4.25.1
wandb==0.13.7
sentencepiece==0.1.97
+9 -14
View File
@@ -7,11 +7,11 @@ from typing import Any, Callable, Dict, List, Optional, Tuple, Union
import evaluate
import numpy as np
import torch
from models import RankGenModel
from rank_datasets import DataCollatorForPairRank, HFSummary, RankGenCollator, WebGPT
from torch import nn
from torch.utils.data import ConcatDataset, Dataset
from transformers import (
AutoModel,
AutoModelForSequenceClassification,
DataCollator,
EvalPrediction,
@@ -21,7 +21,6 @@ from transformers import (
TrainerCallback,
TrainingArguments,
)
from models import RankGenModel
from utils import argument_parsing, freeze_top_n_layers, get_tokenizer, train_val_dataset
os.environ["WANDB_PROJECT"] = "reward-model"
@@ -95,7 +94,7 @@ class RankTrainer(Trainer):
loss = self.loss_fct(positive_outputs, negative_outputs)
else:
raise NotImplementedError("Only ranking loss has been implemented for rankgen model")
outputs = torch.hstack((positive_outputs, negative_outputs)) #logits
outputs = torch.hstack((positive_outputs, negative_outputs)) # logits
else:
outputs = model(**inputs)
logits = outputs.get("logits").view(-1, 2)
@@ -133,7 +132,7 @@ class RankTrainer(Trainer):
loss = self.loss_fct(positive_outputs, negative_outputs)
else:
raise NotImplementedError("Only ranking loss has been implemented for rankgen model")
outputs = torch.hstack((positive_outputs, negative_outputs)) # logits
outputs = torch.hstack((positive_outputs, negative_outputs)) # logits
return (loss, outputs, None)
else:
# compute loss on predict data
@@ -161,7 +160,7 @@ if __name__ == "__main__":
model_parameters = filter(lambda p: p.requires_grad, model.parameters())
params = sum([np.prod(p.size()) for p in model_parameters])
print("Number of trainable : {}M".format(int(params / 1e6)))
args = CustomTrainingArguments(
output_dir=f"{model_name}-finetuned",
num_train_epochs=training_conf["num_train_epochs"],
@@ -196,20 +195,16 @@ if __name__ == "__main__":
assert len(sum_eval) > 0
evals["hfsummary"] = sum_eval
train = ConcatDataset(train_datasets)
if "tokenizer_name" in training_conf:
tokenizer=get_tokenizer(training_conf["tokenizer_name"])
tokenizer = get_tokenizer(training_conf["tokenizer_name"])
else:
tokenizer = get_tokenizer(model_name)
if "rankgen" in model_name:
collate_fn = RankGenCollator(
tokenizer, max_length=training_conf["max_length"]
)
collate_fn = RankGenCollator(tokenizer, max_length=training_conf["max_length"])
else:
collate_fn = DataCollatorForPairRank(
tokenizer, max_length=training_conf["max_length"]
)
collate_fn = DataCollatorForPairRank(tokenizer, max_length=training_conf["max_length"])
assert len(evals) > 0
trainer = RankTrainer(
model=model,
+1 -1
View File
@@ -26,7 +26,7 @@ def webgpt_return_format(row):
def get_tokenizer(tokenizer_name):
if "t5" in tokenizer_name: #rankgen
if "t5" in tokenizer_name: # rankgen
tokenizer = T5Tokenizer.from_pretrained(tokenizer_name, truncation_side="left")
else:
tokenizer = AutoTokenizer.from_pretrained(tokenizer_name)