mirror of
https://github.com/wassname/Open-Assistant.git
synced 2026-06-30 16:40:05 +08:00
[feature] added summary quality rater
This commit is contained in:
@@ -13,7 +13,7 @@ transformers
|
||||
torch==1.12
|
||||
```
|
||||
|
||||
Start training
|
||||
Start training reward model
|
||||
|
||||
|
||||
```bash
|
||||
@@ -21,6 +21,22 @@ python trainer.py configs/electra-base-dis-webgpt.yml
|
||||
```
|
||||
|
||||
|
||||
Additional axis labeling, this outputs a 4 summary quality evaluation metrics (score are normalized to 0-1 )
|
||||
|
||||
```bash
|
||||
python summary_quality_trainer.py configs/test-bloomz-560m-quality.yml
|
||||
```
|
||||
|
||||
The four summary are :
|
||||
|
||||
* overall
|
||||
|
||||
* accuracy
|
||||
|
||||
* coverage
|
||||
|
||||
* coherence
|
||||
|
||||
## Dataset
|
||||
|
||||
For now we only supports webgpt and summary dataset from OpenAI. Once open-asisstant dataset are available it will be added here.
|
||||
|
||||
@@ -9,7 +9,7 @@ Some other reward features we can use
|
||||
|
||||
* each labeling has a labeling `note`, basically comments by labeler, not sure what else we can use
|
||||
|
||||
* Use the score for "overall", "accuracy", "coverage", "coherence" from axis/evals to train an addition model (rank additional aspect of the policy model)
|
||||
* ~~Use the score for "overall", "accuracy", "coverage", "coherence" from axis/evals to train an addition model (rank additional aspect of the policy model)~~
|
||||
|
||||
* this should be placed under experimental_dataset.py
|
||||
|
||||
|
||||
@@ -8,13 +8,93 @@
|
||||
Should be better than just a preference score
|
||||
|
||||
'''
|
||||
import os
|
||||
import json
|
||||
import random
|
||||
import torch
|
||||
from typing import Optional, Union
|
||||
import numpy as np
|
||||
from dataset import load_dataset
|
||||
from collections import defaultdict
|
||||
from datasets import load_dataset
|
||||
from dataclasses import dataclass
|
||||
from torch.utils.data import Dataset
|
||||
from transformers.tokenization_utils_base import PreTrainedTokenizerBase, PaddingStrategy
|
||||
|
||||
|
||||
@dataclass
|
||||
class DataCollatorForSummaryScore:
|
||||
"""
|
||||
|
||||
Data collator that will dynamically pad the inputs for multiple choice received.
|
||||
|
||||
"""
|
||||
tokenizer: PreTrainedTokenizerBase
|
||||
num_choices: int = 2
|
||||
padding: Union[bool, str, PaddingStrategy] = True
|
||||
max_length: Optional[int] = None
|
||||
pad_to_multiple_of: Optional[int] = None
|
||||
drop_token_type: bool = False # galactica
|
||||
|
||||
def __call__(self, batch):
|
||||
|
||||
features = []
|
||||
labels = []
|
||||
for feature, label in batch:
|
||||
features.append(feature)
|
||||
labels.append(label)
|
||||
|
||||
batch_feature = self.tokenizer.pad(
|
||||
features,
|
||||
padding=self.padding,
|
||||
max_length=self.max_length,
|
||||
pad_to_multiple_of=self.pad_to_multiple_of,
|
||||
return_tensors="pt",
|
||||
)
|
||||
if self.drop_token_type:
|
||||
batch_feature.pop('token_type_ids')
|
||||
# batch = {k: v.view(batch_size, self.num_choices, -1) for k, v in batch.items()}
|
||||
batch_feature['labels'] = torch.from_numpy(np.array(labels)).float()
|
||||
return batch_feature
|
||||
|
||||
|
||||
class HFSummaryQuality(Dataset):
|
||||
def __init__(self, split, tokenizer, max_length=300) -> None:
|
||||
super().__init__()
|
||||
assert split in ('validation', 'test')
|
||||
dataset = load_dataset('Tristan/summarize_from_feedback', 'axis')[split]
|
||||
self.max_length = max_length
|
||||
mean_scores = defaultdict(list)
|
||||
self.contexts = []
|
||||
self.responses = []
|
||||
self.labels = []
|
||||
for data in dataset:
|
||||
|
||||
if 'article' in data['info'] and \
|
||||
data['info']['article'] is not None:
|
||||
context = data['info']['article']
|
||||
elif 'post' in data['info']:
|
||||
context = data['info']['post']
|
||||
self.contexts.append(context)
|
||||
|
||||
response = data['summary']['text']
|
||||
self.responses.append(response)
|
||||
self.labels.append(data['summary']['axes'])
|
||||
for axis, score in data['summary']['axes'].items():
|
||||
if score is not None:
|
||||
mean_scores[axis].append(score)
|
||||
|
||||
self.label2idx = { key: idx for idx, key in enumerate(mean_scores.keys()) }
|
||||
self.label2mean = { key: np.mean(scores) for key, scores in mean_scores.items() }
|
||||
self.tokenizer = tokenizer
|
||||
print(self.label2idx)
|
||||
|
||||
def __len__(self):
|
||||
return len(self.responses)
|
||||
|
||||
def __getitem__(self, index):
|
||||
context = self.contexts[index]
|
||||
# return pairs of comparison
|
||||
response = self.responses[index]
|
||||
labels = np.zeros(len(self.label2idx))
|
||||
for key, score in self.labels[index].items():
|
||||
labels[self.label2idx[key]] = (self.label2mean[key] if score is None else score)/10
|
||||
return self.tokenizer(context, response, truncation=True, max_length=self.max_length), labels
|
||||
|
||||
|
||||
class
|
||||
@@ -19,8 +19,6 @@
|
||||
|
||||
'''
|
||||
from typing import Optional, Union
|
||||
import glob
|
||||
import json
|
||||
from dataclasses import dataclass
|
||||
import numpy as np
|
||||
from torch.utils.data import Dataset
|
||||
@@ -145,8 +143,6 @@ class HFSummary(Dataset):
|
||||
elif 'post' in data['info']:
|
||||
context = data['info']['post']
|
||||
|
||||
if context is None:
|
||||
continue
|
||||
|
||||
if context not in self.index2summary:
|
||||
self.index2summary[len(self.index2summary)] = context
|
||||
|
||||
@@ -0,0 +1,132 @@
|
||||
import os
|
||||
os.environ['WANDB_PROJECT'] = 'quality-scoring'
|
||||
import torch
|
||||
import yaml
|
||||
import evaluate
|
||||
from typing import Any, Callable, List, Optional, Tuple, Union, Dict
|
||||
from torch import nn
|
||||
from argparse import ArgumentParser
|
||||
import numpy as np
|
||||
from torch.utils.data import Dataset
|
||||
from transformers import AutoModelForSequenceClassification
|
||||
from transformers import Trainer, PreTrainedModel, TrainingArguments, DataCollator, EvalPrediction, TrainerCallback, PreTrainedTokenizerBase
|
||||
from experimental_dataset import HFSummaryQuality, DataCollatorForSummaryScore
|
||||
from utils import get_tokenizer, train_val_dataset, freeze_top_n_layers, argument_parsing
|
||||
|
||||
parser = ArgumentParser()
|
||||
parser.add_argument('config', type=str)
|
||||
|
||||
accuracy = evaluate.load("mse")
|
||||
def compute_metrics(eval_pred):
|
||||
predictions, labels = eval_pred
|
||||
return accuracy.compute(predictions=predictions.flatten(), references=labels.flatten())
|
||||
|
||||
|
||||
class QualityTrainer(Trainer):
|
||||
def __init__(self, model: Union[PreTrainedModel, nn.Module] = None,
|
||||
args: TrainingArguments = None,
|
||||
data_collator: Optional[DataCollator] = None,
|
||||
train_dataset: Optional[Dataset] = None,
|
||||
eval_dataset: Optional[Dataset] = None,
|
||||
tokenizer: Optional[PreTrainedTokenizerBase] = None,
|
||||
model_init: Callable[[], PreTrainedModel] = None,
|
||||
compute_metrics: Optional[Callable[[EvalPrediction], Dict]] = None,
|
||||
callbacks: Optional[List[TrainerCallback]] = None,
|
||||
optimizers: Tuple[torch.optim.Optimizer, torch.optim.lr_scheduler.LambdaLR] = (None, None),
|
||||
preprocess_logits_for_metrics: Callable[[torch.Tensor, torch.Tensor], torch.Tensor] = None):
|
||||
super().__init__(model, args, data_collator, train_dataset, eval_dataset, tokenizer,
|
||||
model_init, compute_metrics, callbacks, optimizers, preprocess_logits_for_metrics)
|
||||
self.loss_fct = nn.L1Loss()
|
||||
self.sigmoid = nn.Sigmoid()
|
||||
|
||||
def compute_loss(self, model, inputs, return_outputs=False):
|
||||
labels = inputs.pop('labels')
|
||||
# forward pass
|
||||
outputs = model(**inputs)
|
||||
logits = self.sigmoid(outputs.get("logits"))
|
||||
loss = self.loss_fct(logits, labels)
|
||||
|
||||
return (loss, outputs) if return_outputs else loss
|
||||
|
||||
def _compute_loss(self, model, inputs):
|
||||
inputs = self._prepare_inputs(inputs)
|
||||
labels = inputs.pop('labels')
|
||||
outputs = model(**inputs)
|
||||
logits = self.sigmoid(outputs.get("logits"))
|
||||
loss = self.loss_fct(logits, labels)
|
||||
|
||||
return loss, logits
|
||||
|
||||
def prediction_step(self, model: nn.Module,
|
||||
inputs: Dict[str, Union[torch.Tensor, Any]],
|
||||
prediction_loss_only: bool,
|
||||
ignore_keys: Optional[List[str]] = None) -> Tuple[Optional[torch.Tensor], Optional[torch.Tensor], Optional[torch.Tensor]]:
|
||||
|
||||
with torch.no_grad():
|
||||
# compute loss on predict data
|
||||
loss, logits = self._compute_loss(model, inputs)
|
||||
|
||||
loss = loss.mean().detach()
|
||||
labels = inputs['labels']
|
||||
if self.args.prediction_loss_only:
|
||||
return (loss, None, None)
|
||||
|
||||
return (loss, logits, labels)
|
||||
|
||||
if __name__ == "__main__":
|
||||
training_conf = argument_parsing(parser)
|
||||
|
||||
model_name = training_conf['model_name']
|
||||
tokenizer = get_tokenizer(model_name)
|
||||
collate_fn = DataCollatorForSummaryScore(tokenizer,
|
||||
max_length=training_conf['max_length'],
|
||||
drop_token_type= 'galactica' in model_name
|
||||
)
|
||||
train = HFSummaryQuality(split='validation',
|
||||
tokenizer=tokenizer,
|
||||
max_length=training_conf['max_length']
|
||||
)
|
||||
eval = HFSummaryQuality(split='test',
|
||||
tokenizer=tokenizer,
|
||||
max_length=training_conf['max_length']
|
||||
)
|
||||
model = AutoModelForSequenceClassification.from_pretrained(model_name,
|
||||
num_labels=len(train.label2idx), problem_type='regression')
|
||||
|
||||
if 'freeze_layer' in training_conf:
|
||||
num_layer = training_conf['freeze_layer']
|
||||
model = freeze_top_n_layers(model, num_layer)
|
||||
model_parameters = filter(lambda p: p.requires_grad, model.parameters())
|
||||
params = sum([np.prod(p.size()) for p in model_parameters])
|
||||
print('Number of trainable : {}M'.format(int(params/1e6)))
|
||||
|
||||
args = TrainingArguments(
|
||||
output_dir=f"{model_name}-finetuned",
|
||||
num_train_epochs=training_conf['num_train_epochs'],
|
||||
warmup_steps=500,
|
||||
learning_rate=training_conf['learning_rate'],
|
||||
# half_precision_backend="apex",
|
||||
fp16=True,
|
||||
gradient_checkpointing=training_conf['gradient_checkpointing'],
|
||||
gradient_accumulation_steps=training_conf['gradient_accumulation_steps'],
|
||||
per_device_train_batch_size=training_conf['per_device_train_batch_size'],
|
||||
per_device_eval_batch_size=training_conf['per_device_eval_batch_size'],
|
||||
weight_decay=0.01,
|
||||
max_grad_norm=2.0,
|
||||
logging_steps=10,
|
||||
save_total_limit=4,
|
||||
evaluation_strategy='steps',
|
||||
eval_steps=training_conf['eval_steps'],
|
||||
save_steps=1000,
|
||||
report_to='wandb'
|
||||
)
|
||||
trainer = QualityTrainer(
|
||||
model,
|
||||
args,
|
||||
train_dataset=train,
|
||||
eval_dataset=eval,
|
||||
data_collator=collate_fn,
|
||||
tokenizer=tokenizer,
|
||||
compute_metrics=compute_metrics
|
||||
)
|
||||
trainer.train()
|
||||
@@ -1,7 +1,7 @@
|
||||
from transformers import AutoTokenizer
|
||||
from torch.utils.data import DataLoader
|
||||
from rank_datasets import WebGPT, HFSummary, DataCollatorForPairRank
|
||||
|
||||
from experimental_dataset import HFSummaryQuality, DataCollatorForSummaryScore
|
||||
|
||||
def test_hfsummary():
|
||||
|
||||
@@ -24,6 +24,17 @@ def test_webgpt():
|
||||
print(batch['input_ids'].shape)
|
||||
|
||||
|
||||
def test_hf_quality():
|
||||
|
||||
tokenizer = AutoTokenizer.from_pretrained("bigscience/mt0-large")
|
||||
collate_fn = DataCollatorForSummaryScore(tokenizer, max_length=200)
|
||||
dataset = HFSummaryQuality('validation', tokenizer)
|
||||
dataloader = DataLoader(dataset, collate_fn=collate_fn, batch_size=32)
|
||||
for batch in dataloader:
|
||||
print(batch['input_ids'].shape)
|
||||
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
test_hfsummary()
|
||||
test_hf_quality()
|
||||
# test_webgpt()
|
||||
@@ -117,13 +117,13 @@ if __name__ == "__main__":
|
||||
gradient_checkpointing=training_conf['gradient_checkpointing'],
|
||||
gradient_accumulation_steps=training_conf['gradient_accumulation_steps'],
|
||||
per_device_train_batch_size=training_conf['per_device_train_batch_size'],
|
||||
per_device_eval_batch_size=5,
|
||||
per_device_eval_batch_size=training_conf['per_device_eval_batch_size'],
|
||||
weight_decay=0.01,
|
||||
max_grad_norm=2.0,
|
||||
logging_steps=10,
|
||||
save_total_limit=4,
|
||||
evaluation_strategy='steps',
|
||||
eval_steps=500,
|
||||
eval_steps=training_conf['eval_steps'],
|
||||
save_steps=1000,
|
||||
report_to='wandb'
|
||||
)
|
||||
|
||||
@@ -69,6 +69,7 @@ def argument_parsing(parser):
|
||||
'eval_steps': 500,
|
||||
'loss': 'rank',
|
||||
'max_length': 440,
|
||||
'per_device_eval_batch_size': 5,
|
||||
'per_device_train_batch_size': 8,
|
||||
'gradient_accumulation_steps': 8,
|
||||
'gradient_checkpointing': False,
|
||||
|
||||
Reference in New Issue
Block a user