mirror of
https://github.com/wassname/Open-Assistant.git
synced 2026-06-27 16:10:30 +08:00
[fix] pre-commit update
This commit is contained in:
@@ -2,7 +2,6 @@
|
||||
|
||||
Trainer code based on huggingface. Compatible with deepspeed or accelerate
|
||||
|
||||
|
||||
Requirements
|
||||
|
||||
```
|
||||
@@ -15,12 +14,10 @@ torch==1.12
|
||||
|
||||
Start training reward model
|
||||
|
||||
|
||||
```bash
|
||||
python trainer.py configs/electra-base-dis-webgpt.yml
|
||||
```
|
||||
|
||||
|
||||
Additional axis labeling, this outputs a 4 summary quality evaluation metrics (score are normalized to 0-1 )
|
||||
|
||||
```bash
|
||||
@@ -29,13 +26,13 @@ python summary_quality_trainer.py configs/test-bloomz-560m-quality.yml
|
||||
|
||||
The four summary are :
|
||||
|
||||
* overall
|
||||
- overall
|
||||
|
||||
* accuracy
|
||||
- accuracy
|
||||
|
||||
* coverage
|
||||
- coverage
|
||||
|
||||
* coherence
|
||||
- coherence
|
||||
|
||||
## Dataset
|
||||
|
||||
|
||||
@@ -1,23 +1,19 @@
|
||||
|
||||
Some other reward features we can use
|
||||
|
||||
0. Finish classifcation feature
|
||||
0. Finish classifcation feature
|
||||
|
||||
1. Summaries from human feedback
|
||||
|
||||
* use `confidence` score into the RM learning, ensure the output rank score correlates with confidence
|
||||
- use `confidence` score into the RM learning, ensure the output rank score correlates with confidence
|
||||
|
||||
* each labeling has a labeling `note`, basically comments by labeler, not sure what else we can use
|
||||
- each labeling has a labeling `note`, basically comments by labeler, not sure what else we can use
|
||||
|
||||
* ~~Use the score for "overall", "accuracy", "coverage", "coherence" from axis/evals to train an addition model (rank additional aspect of the policy model)~~
|
||||
|
||||
* this should be placed under experimental_dataset.py
|
||||
- ~~Use the score for "overall", "accuracy", "coverage", "coherence" from axis/evals to train an addition model (rank additional aspect of the policy model)~~
|
||||
|
||||
- this should be placed under experimental_dataset.py
|
||||
|
||||
2. Add support for anthropic dataset
|
||||
|
||||
* anthropic dataset is more like a conversation tree which is much complex than simply question-answer schema
|
||||
|
||||
* this is basically a MCTS from alphazero.
|
||||
|
||||
- anthropic dataset is more like a conversation tree which is much complex than simply question-answer schema
|
||||
|
||||
- this is basically a MCTS from alphazero.
|
||||
|
||||
@@ -1,32 +1,34 @@
|
||||
'''
|
||||
# -*- coding: utf-8 -*-
|
||||
"""
|
||||
|
||||
classification based ranking
|
||||
|
||||
'''
|
||||
import os
|
||||
"""
|
||||
import json
|
||||
import os
|
||||
import random
|
||||
import torch
|
||||
import numpy as np
|
||||
|
||||
from dataset import load_dataset
|
||||
from torch.utils.data import Dataset
|
||||
|
||||
from .utils import webgpt_return_format
|
||||
|
||||
|
||||
class WebGPTDataset(Dataset):
|
||||
def __init__(self, mode='train', index_cache='dataset/webgpt_train_idx.pt', additional_dataset=None) -> None:
|
||||
def __init__(self, mode="train", index_cache="dataset/webgpt_train_idx.pt", additional_dataset=None) -> None:
|
||||
super().__init__()
|
||||
'''
|
||||
"""
|
||||
mode : train or val, used for validation purpose, has nothing to do with original split
|
||||
additional_dataset : a list of jsonline format with idx, question and texts (generate candidates)
|
||||
idx : must match the index you iterate from comparison enumerate order
|
||||
question : for validation purpose
|
||||
texts : list of K generate results from the question prompt
|
||||
'''
|
||||
os.makedirs('dataset', exist_ok=True)
|
||||
"""
|
||||
os.makedirs("dataset", exist_ok=True)
|
||||
dataset = load_dataset("openai/webgpt_comparisons")
|
||||
self.dataset = []
|
||||
self.dataset_index = []
|
||||
for idx, row in enumerate(dataset['train']):
|
||||
for idx, row in enumerate(dataset["train"]):
|
||||
self.dataset.append(webgpt_return_format(row))
|
||||
|
||||
# since this dataset was generated from 176B GPT-3
|
||||
@@ -36,17 +38,17 @@ class WebGPTDataset(Dataset):
|
||||
if additional_dataset is not None:
|
||||
self.sample_additional = True
|
||||
self.additional = {}
|
||||
with open(additional_dataset, 'r') as f:
|
||||
with open(additional_dataset, "r") as f:
|
||||
for line in f:
|
||||
row = json.loads(line)
|
||||
if row['idx'] in self.dataset_index:
|
||||
self.additional[row['idx']] = row['negatives']
|
||||
if row["idx"] in self.dataset_index:
|
||||
self.additional[row["idx"]] = row["negatives"]
|
||||
if len(self.additional) != len(self.dataset_index):
|
||||
for match_idx in self.dataset_index:
|
||||
if match_idx in self.additional:
|
||||
continue
|
||||
|
||||
idx = match_idx-900
|
||||
idx = match_idx - 900
|
||||
while idx not in self.additional:
|
||||
idx -= 1
|
||||
self.additional[match_idx] = self.additional[idx]
|
||||
@@ -57,10 +59,7 @@ class WebGPTDataset(Dataset):
|
||||
def __getitem__(self, index):
|
||||
row = self.dataset[index]
|
||||
if not self.sample_additional:
|
||||
return row['question'], row['pos'], row['neg']
|
||||
return row["question"], row["pos"], row["neg"]
|
||||
|
||||
gen_neg = random.choice(self.additional[self.dataset_index[index]])
|
||||
return row['question'], row['pos'], row['neg'], gen_neg
|
||||
|
||||
|
||||
|
||||
return row["question"], row["pos"], row["neg"], gen_neg
|
||||
|
||||
@@ -6,4 +6,4 @@ max_length: 600
|
||||
freeze_layer: 12
|
||||
num_train_epochs: 2
|
||||
datasets:
|
||||
- hfsummary
|
||||
- hfsummary
|
||||
|
||||
@@ -7,4 +7,4 @@ freeze_layer: 12
|
||||
num_train_epochs: 2
|
||||
datasets:
|
||||
- webgpt
|
||||
- hfsummary
|
||||
- hfsummary
|
||||
|
||||
@@ -1,3 +1,3 @@
|
||||
model_name: google/electra-large-discriminator
|
||||
learning_rate: 3e-5
|
||||
max_length: 300
|
||||
max_length: 300
|
||||
|
||||
@@ -10,4 +10,4 @@ max_length: 512
|
||||
num_train_epochs: 2
|
||||
datasets:
|
||||
- webgpt
|
||||
- hfsummary
|
||||
- hfsummary
|
||||
|
||||
@@ -11,4 +11,4 @@ max_length: 400
|
||||
num_train_epochs: 2
|
||||
datasets:
|
||||
- webgpt
|
||||
- hfsummary
|
||||
- hfsummary
|
||||
|
||||
@@ -11,4 +11,4 @@ max_length: 128
|
||||
num_train_epochs: 2
|
||||
datasets:
|
||||
- webgpt
|
||||
- hfsummary
|
||||
- hfsummary
|
||||
|
||||
@@ -1,4 +1,5 @@
|
||||
'''
|
||||
# -*- coding: utf-8 -*-
|
||||
"""
|
||||
HFSummary
|
||||
|
||||
I want to train a multi regression model on axis_evals dataset mainly we can estimate the score of these score
|
||||
@@ -7,15 +8,16 @@
|
||||
|
||||
Should be better than just a preference score
|
||||
|
||||
'''
|
||||
import torch
|
||||
from typing import Optional, Union
|
||||
import numpy as np
|
||||
"""
|
||||
from collections import defaultdict
|
||||
from datasets import load_dataset
|
||||
from dataclasses import dataclass
|
||||
from typing import Optional, Union
|
||||
|
||||
import numpy as np
|
||||
import torch
|
||||
from datasets import load_dataset
|
||||
from torch.utils.data import Dataset
|
||||
from transformers.tokenization_utils_base import PreTrainedTokenizerBase, PaddingStrategy
|
||||
from transformers.tokenization_utils_base import PaddingStrategy, PreTrainedTokenizerBase
|
||||
|
||||
|
||||
@dataclass
|
||||
@@ -25,12 +27,13 @@ class DataCollatorForSummaryScore:
|
||||
Data collator that will dynamically pad the inputs for multiple choice received.
|
||||
|
||||
"""
|
||||
|
||||
tokenizer: PreTrainedTokenizerBase
|
||||
num_choices: int = 2
|
||||
padding: Union[bool, str, PaddingStrategy] = True
|
||||
max_length: Optional[int] = None
|
||||
pad_to_multiple_of: Optional[int] = None
|
||||
drop_token_type: bool = False # galactica
|
||||
drop_token_type: bool = False # galactica
|
||||
|
||||
def __call__(self, batch):
|
||||
|
||||
@@ -48,17 +51,17 @@ class DataCollatorForSummaryScore:
|
||||
return_tensors="pt",
|
||||
)
|
||||
if self.drop_token_type:
|
||||
batch_feature.pop('token_type_ids')
|
||||
batch_feature.pop("token_type_ids")
|
||||
# batch = {k: v.view(batch_size, self.num_choices, -1) for k, v in batch.items()}
|
||||
batch_feature['labels'] = torch.from_numpy(np.array(labels)).float()
|
||||
batch_feature["labels"] = torch.from_numpy(np.array(labels)).float()
|
||||
return batch_feature
|
||||
|
||||
|
||||
class HFSummaryQuality(Dataset):
|
||||
def __init__(self, split, tokenizer, max_length=300) -> None:
|
||||
super().__init__()
|
||||
assert split in ('validation', 'test')
|
||||
dataset = load_dataset('Tristan/summarize_from_feedback', 'axis')[split]
|
||||
assert split in ("validation", "test")
|
||||
dataset = load_dataset("Tristan/summarize_from_feedback", "axis")[split]
|
||||
self.max_length = max_length
|
||||
mean_scores = defaultdict(list)
|
||||
self.contexts = []
|
||||
@@ -66,22 +69,21 @@ class HFSummaryQuality(Dataset):
|
||||
self.labels = []
|
||||
for data in dataset:
|
||||
|
||||
if 'article' in data['info'] and \
|
||||
data['info']['article'] is not None:
|
||||
context = data['info']['article']
|
||||
elif 'post' in data['info']:
|
||||
context = data['info']['post']
|
||||
if "article" in data["info"] and data["info"]["article"] is not None:
|
||||
context = data["info"]["article"]
|
||||
elif "post" in data["info"]:
|
||||
context = data["info"]["post"]
|
||||
self.contexts.append(context)
|
||||
|
||||
response = data['summary']['text']
|
||||
response = data["summary"]["text"]
|
||||
self.responses.append(response)
|
||||
self.labels.append(data['summary']['axes'])
|
||||
for axis, score in data['summary']['axes'].items():
|
||||
self.labels.append(data["summary"]["axes"])
|
||||
for axis, score in data["summary"]["axes"].items():
|
||||
if score is not None:
|
||||
mean_scores[axis].append(score)
|
||||
|
||||
self.label2idx = { key: idx for idx, key in enumerate(mean_scores.keys()) }
|
||||
self.label2mean = { key: np.mean(scores) for key, scores in mean_scores.items() }
|
||||
self.label2idx = {key: idx for idx, key in enumerate(mean_scores.keys())}
|
||||
self.label2mean = {key: np.mean(scores) for key, scores in mean_scores.items()}
|
||||
self.tokenizer = tokenizer
|
||||
print(self.label2idx)
|
||||
|
||||
@@ -94,7 +96,5 @@ class HFSummaryQuality(Dataset):
|
||||
response = self.responses[index]
|
||||
labels = np.zeros(len(self.label2idx))
|
||||
for key, score in self.labels[index].items():
|
||||
labels[self.label2idx[key]] = (self.label2mean[key] if score is None else score)/10
|
||||
labels[self.label2idx[key]] = (self.label2mean[key] if score is None else score) / 10
|
||||
return self.tokenizer(context, response, truncation=True, max_length=self.max_length), labels
|
||||
|
||||
|
||||
|
||||
@@ -1,4 +1,5 @@
|
||||
'''
|
||||
# -*- coding: utf-8 -*-
|
||||
"""
|
||||
author: theblackcat102
|
||||
|
||||
Dataset output format from __getitem__
|
||||
@@ -17,13 +18,15 @@
|
||||
inferior than the human perference one
|
||||
|
||||
|
||||
'''
|
||||
from typing import Optional, Union
|
||||
"""
|
||||
from dataclasses import dataclass
|
||||
from typing import Optional, Union
|
||||
|
||||
import numpy as np
|
||||
from torch.utils.data import Dataset
|
||||
from datasets import load_dataset
|
||||
from transformers.tokenization_utils_base import PreTrainedTokenizerBase, PaddingStrategy
|
||||
from torch.utils.data import Dataset
|
||||
from transformers.tokenization_utils_base import PaddingStrategy, PreTrainedTokenizerBase
|
||||
|
||||
|
||||
@dataclass
|
||||
class DataCollatorForPairRank:
|
||||
@@ -32,12 +35,13 @@ class DataCollatorForPairRank:
|
||||
Data collator that will dynamically pad the inputs for multiple choice received.
|
||||
|
||||
"""
|
||||
|
||||
tokenizer: PreTrainedTokenizerBase
|
||||
num_choices: int = 2
|
||||
padding: Union[bool, str, PaddingStrategy] = True
|
||||
max_length: Optional[int] = None
|
||||
pad_to_multiple_of: Optional[int] = None
|
||||
drop_token_type: bool = False # galactica
|
||||
drop_token_type: bool = False # galactica
|
||||
|
||||
def __call__(self, features):
|
||||
|
||||
@@ -45,12 +49,10 @@ class DataCollatorForPairRank:
|
||||
batch_size = 0
|
||||
for question, pairs in features:
|
||||
for (pos, neg) in pairs:
|
||||
flatten_features.append(self.tokenizer(question, pos,
|
||||
truncation=True, max_length=self.max_length))
|
||||
flatten_features.append(self.tokenizer(question, neg,
|
||||
truncation=True, max_length=self.max_length))
|
||||
flatten_features.append(self.tokenizer(question, pos, truncation=True, max_length=self.max_length))
|
||||
flatten_features.append(self.tokenizer(question, neg, truncation=True, max_length=self.max_length))
|
||||
batch_size += 1
|
||||
|
||||
|
||||
batch = self.tokenizer.pad(
|
||||
flatten_features,
|
||||
padding=self.padding,
|
||||
@@ -59,13 +61,12 @@ class DataCollatorForPairRank:
|
||||
return_tensors="pt",
|
||||
)
|
||||
if self.drop_token_type:
|
||||
batch.pop('token_type_ids')
|
||||
batch.pop("token_type_ids")
|
||||
# batch = {k: v.view(batch_size, self.num_choices, -1) for k, v in batch.items()}
|
||||
return batch
|
||||
|
||||
|
||||
class WebGPT(Dataset):
|
||||
|
||||
def __init__(self) -> None:
|
||||
super().__init__()
|
||||
|
||||
@@ -74,23 +75,19 @@ class WebGPT(Dataset):
|
||||
# using prompt as our index will allows us
|
||||
# to add additional generated prompt later
|
||||
self.index2question = {}
|
||||
for row in dataset['train']:
|
||||
question = row['question']['full_text']
|
||||
for row in dataset["train"]:
|
||||
question = row["question"]["full_text"]
|
||||
if question not in self.index2question:
|
||||
self.index2question[len(self.index2question)] = question
|
||||
|
||||
if question not in questions:
|
||||
questions[question] = []
|
||||
|
||||
if row['score_0'] > row['score_1']:
|
||||
if row["score_0"] > row["score_1"]:
|
||||
# not going to risk it
|
||||
questions[question].append((
|
||||
row['answer_0'], row['answer_1']
|
||||
))
|
||||
questions[question].append((row["answer_0"], row["answer_1"]))
|
||||
else:
|
||||
questions[question].append((
|
||||
row['answer_1'], row['answer_0']
|
||||
))
|
||||
questions[question].append((row["answer_1"], row["answer_0"]))
|
||||
|
||||
self.questions = questions
|
||||
|
||||
@@ -104,61 +101,55 @@ class WebGPT(Dataset):
|
||||
return question, rows
|
||||
|
||||
|
||||
|
||||
|
||||
class HFSummary(Dataset):
|
||||
'''
|
||||
Human feedback data from OpenAI
|
||||
https://github.com/openai/summarize-from-feedback
|
||||
|
||||
labeling method : pair comparison, 0 or 1
|
||||
"""
|
||||
Human feedback data from OpenAI
|
||||
https://github.com/openai/summarize-from-feedback
|
||||
|
||||
'''
|
||||
def __init__(self, split='train',
|
||||
conf_threshold=-1,
|
||||
max_comparison_per_sample=3) -> None:
|
||||
labeling method : pair comparison, 0 or 1
|
||||
|
||||
"""
|
||||
|
||||
def __init__(self, split="train", conf_threshold=-1, max_comparison_per_sample=3) -> None:
|
||||
super().__init__()
|
||||
assert split in ('train', 'valid1', 'valid2', 'test')
|
||||
assert split in ("train", "valid1", "valid2", "test")
|
||||
summaries = {}
|
||||
# using prompt as our index will allows us
|
||||
# to add additional generated prompt later
|
||||
self.index2summary = {}
|
||||
self.max_comparison_per_sample = max_comparison_per_sample
|
||||
major_split = split if 'train' == split else 'validation'
|
||||
dataset = load_dataset('Tristan/summarize_from_feedback', 'comparisons')[major_split]
|
||||
major_split = split if "train" == split else "validation"
|
||||
dataset = load_dataset("Tristan/summarize_from_feedback", "comparisons")[major_split]
|
||||
for data in dataset:
|
||||
if 'extra' in data and \
|
||||
'confidence' in data['extra'] and \
|
||||
data['extra']['confidence'] is not None and \
|
||||
conf_threshold > data['extra']['confidence']:
|
||||
print('skipping {}'.format(data['info']['id']))
|
||||
if (
|
||||
"extra" in data
|
||||
and "confidence" in data["extra"]
|
||||
and data["extra"]["confidence"] is not None
|
||||
and conf_threshold > data["extra"]["confidence"]
|
||||
):
|
||||
print("skipping {}".format(data["info"]["id"]))
|
||||
continue
|
||||
|
||||
if split != 'train' and split != data['split']:
|
||||
if split != "train" and split != data["split"]:
|
||||
continue
|
||||
|
||||
if 'article' in data['info'] and \
|
||||
data['info']['article'] is not None:
|
||||
context = data['info']['article']
|
||||
elif 'post' in data['info']:
|
||||
context = data['info']['post']
|
||||
|
||||
if "article" in data["info"] and data["info"]["article"] is not None:
|
||||
context = data["info"]["article"]
|
||||
elif "post" in data["info"]:
|
||||
context = data["info"]["post"]
|
||||
|
||||
if context not in self.index2summary:
|
||||
self.index2summary[len(self.index2summary)] = context
|
||||
|
||||
|
||||
if context not in summaries:
|
||||
summaries[context] = []
|
||||
|
||||
pos, neg = (0, 1) if data['choice'] == 0 else (1, 0)
|
||||
summaries[context].append((
|
||||
data['summaries'][pos]['text'],
|
||||
data['summaries'][neg]['text']
|
||||
))
|
||||
pos, neg = (0, 1) if data["choice"] == 0 else (1, 0)
|
||||
summaries[context].append((data["summaries"][pos]["text"], data["summaries"][neg]["text"]))
|
||||
|
||||
self.summaries = summaries
|
||||
|
||||
self.postfix_prompt = ' TLDR;'
|
||||
self.postfix_prompt = " TLDR;"
|
||||
|
||||
def __len__(self):
|
||||
return len(self.index2summary)
|
||||
@@ -172,5 +163,4 @@ class HFSummary(Dataset):
|
||||
# not optimal but good for now
|
||||
valid_idx = np.random.choice(len(rows), self.max_comparison_per_sample)
|
||||
# optimize the format later
|
||||
return context+self.postfix_prompt, [ r for idx, r in enumerate(rows) if idx in valid_idx ]
|
||||
|
||||
return context + self.postfix_prompt, [r for idx, r in enumerate(rows) if idx in valid_idx]
|
||||
|
||||
@@ -1,46 +1,72 @@
|
||||
# -*- coding: utf-8 -*-
|
||||
import os
|
||||
os.environ['WANDB_PROJECT'] = 'quality-scoring'
|
||||
import torch
|
||||
import yaml
|
||||
import evaluate
|
||||
from typing import Any, Callable, List, Optional, Tuple, Union, Dict
|
||||
from torch import nn
|
||||
from argparse import ArgumentParser
|
||||
from typing import Any, Callable, Dict, List, Optional, Tuple, Union
|
||||
|
||||
import evaluate
|
||||
import numpy as np
|
||||
import torch
|
||||
from experimental_dataset import DataCollatorForSummaryScore, HFSummaryQuality
|
||||
from torch import nn
|
||||
from torch.utils.data import Dataset
|
||||
from transformers import AutoModelForSequenceClassification
|
||||
from transformers import Trainer, PreTrainedModel, TrainingArguments, DataCollator, EvalPrediction, TrainerCallback, PreTrainedTokenizerBase
|
||||
from experimental_dataset import HFSummaryQuality, DataCollatorForSummaryScore
|
||||
from utils import get_tokenizer, train_val_dataset, freeze_top_n_layers, argument_parsing
|
||||
from transformers import (
|
||||
AutoModelForSequenceClassification,
|
||||
DataCollator,
|
||||
EvalPrediction,
|
||||
PreTrainedModel,
|
||||
PreTrainedTokenizerBase,
|
||||
Trainer,
|
||||
TrainerCallback,
|
||||
TrainingArguments,
|
||||
)
|
||||
from utils import argument_parsing, freeze_top_n_layers, get_tokenizer
|
||||
|
||||
os.environ["WANDB_PROJECT"] = "quality-scoring"
|
||||
|
||||
parser = ArgumentParser()
|
||||
parser.add_argument('config', type=str)
|
||||
parser.add_argument("config", type=str)
|
||||
|
||||
accuracy = evaluate.load("mse")
|
||||
|
||||
|
||||
def compute_metrics(eval_pred):
|
||||
predictions, labels = eval_pred
|
||||
return accuracy.compute(predictions=predictions.flatten(), references=labels.flatten())
|
||||
|
||||
|
||||
class QualityTrainer(Trainer):
|
||||
def __init__(self, model: Union[PreTrainedModel, nn.Module] = None,
|
||||
args: TrainingArguments = None,
|
||||
data_collator: Optional[DataCollator] = None,
|
||||
train_dataset: Optional[Dataset] = None,
|
||||
eval_dataset: Optional[Dataset] = None,
|
||||
tokenizer: Optional[PreTrainedTokenizerBase] = None,
|
||||
model_init: Callable[[], PreTrainedModel] = None,
|
||||
compute_metrics: Optional[Callable[[EvalPrediction], Dict]] = None,
|
||||
callbacks: Optional[List[TrainerCallback]] = None,
|
||||
optimizers: Tuple[torch.optim.Optimizer, torch.optim.lr_scheduler.LambdaLR] = (None, None),
|
||||
preprocess_logits_for_metrics: Callable[[torch.Tensor, torch.Tensor], torch.Tensor] = None):
|
||||
super().__init__(model, args, data_collator, train_dataset, eval_dataset, tokenizer,
|
||||
model_init, compute_metrics, callbacks, optimizers, preprocess_logits_for_metrics)
|
||||
def __init__(
|
||||
self,
|
||||
model: Union[PreTrainedModel, nn.Module] = None,
|
||||
args: TrainingArguments = None,
|
||||
data_collator: Optional[DataCollator] = None,
|
||||
train_dataset: Optional[Dataset] = None,
|
||||
eval_dataset: Optional[Dataset] = None,
|
||||
tokenizer: Optional[PreTrainedTokenizerBase] = None,
|
||||
model_init: Callable[[], PreTrainedModel] = None,
|
||||
compute_metrics: Optional[Callable[[EvalPrediction], Dict]] = None,
|
||||
callbacks: Optional[List[TrainerCallback]] = None,
|
||||
optimizers: Tuple[torch.optim.Optimizer, torch.optim.lr_scheduler.LambdaLR] = (None, None),
|
||||
preprocess_logits_for_metrics: Callable[[torch.Tensor, torch.Tensor], torch.Tensor] = None,
|
||||
):
|
||||
super().__init__(
|
||||
model,
|
||||
args,
|
||||
data_collator,
|
||||
train_dataset,
|
||||
eval_dataset,
|
||||
tokenizer,
|
||||
model_init,
|
||||
compute_metrics,
|
||||
callbacks,
|
||||
optimizers,
|
||||
preprocess_logits_for_metrics,
|
||||
)
|
||||
self.loss_fct = nn.L1Loss()
|
||||
self.sigmoid = nn.Sigmoid()
|
||||
|
||||
def compute_loss(self, model, inputs, return_outputs=False):
|
||||
labels = inputs.pop('labels')
|
||||
labels = inputs.pop("labels")
|
||||
# forward pass
|
||||
outputs = model(**inputs)
|
||||
logits = self.sigmoid(outputs.get("logits"))
|
||||
@@ -50,75 +76,73 @@ class QualityTrainer(Trainer):
|
||||
|
||||
def _compute_loss(self, model, inputs):
|
||||
inputs = self._prepare_inputs(inputs)
|
||||
labels = inputs.pop('labels')
|
||||
labels = inputs.pop("labels")
|
||||
outputs = model(**inputs)
|
||||
logits = self.sigmoid(outputs.get("logits"))
|
||||
loss = self.loss_fct(logits, labels)
|
||||
|
||||
return loss, logits
|
||||
|
||||
def prediction_step(self, model: nn.Module,
|
||||
inputs: Dict[str, Union[torch.Tensor, Any]],
|
||||
prediction_loss_only: bool,
|
||||
ignore_keys: Optional[List[str]] = None) -> Tuple[Optional[torch.Tensor], Optional[torch.Tensor], Optional[torch.Tensor]]:
|
||||
def prediction_step(
|
||||
self,
|
||||
model: nn.Module,
|
||||
inputs: Dict[str, Union[torch.Tensor, Any]],
|
||||
prediction_loss_only: bool,
|
||||
ignore_keys: Optional[List[str]] = None,
|
||||
) -> Tuple[Optional[torch.Tensor], Optional[torch.Tensor], Optional[torch.Tensor]]:
|
||||
|
||||
with torch.no_grad():
|
||||
# compute loss on predict data
|
||||
loss, logits = self._compute_loss(model, inputs)
|
||||
|
||||
loss = loss.mean().detach()
|
||||
labels = inputs['labels']
|
||||
labels = inputs["labels"]
|
||||
if self.args.prediction_loss_only:
|
||||
return (loss, None, None)
|
||||
|
||||
return (loss, logits, labels)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
training_conf = argument_parsing(parser)
|
||||
|
||||
model_name = training_conf['model_name']
|
||||
model_name = training_conf["model_name"]
|
||||
tokenizer = get_tokenizer(model_name)
|
||||
collate_fn = DataCollatorForSummaryScore(tokenizer,
|
||||
max_length=training_conf['max_length'],
|
||||
drop_token_type= 'galactica' in model_name
|
||||
collate_fn = DataCollatorForSummaryScore(
|
||||
tokenizer, max_length=training_conf["max_length"], drop_token_type="galactica" in model_name
|
||||
)
|
||||
train = HFSummaryQuality(split="validation", tokenizer=tokenizer, max_length=training_conf["max_length"])
|
||||
eval = HFSummaryQuality(split="test", tokenizer=tokenizer, max_length=training_conf["max_length"])
|
||||
model = AutoModelForSequenceClassification.from_pretrained(
|
||||
model_name, num_labels=len(train.label2idx), problem_type="regression"
|
||||
)
|
||||
train = HFSummaryQuality(split='validation',
|
||||
tokenizer=tokenizer,
|
||||
max_length=training_conf['max_length']
|
||||
)
|
||||
eval = HFSummaryQuality(split='test',
|
||||
tokenizer=tokenizer,
|
||||
max_length=training_conf['max_length']
|
||||
)
|
||||
model = AutoModelForSequenceClassification.from_pretrained(model_name,
|
||||
num_labels=len(train.label2idx), problem_type='regression')
|
||||
|
||||
if 'freeze_layer' in training_conf:
|
||||
num_layer = training_conf['freeze_layer']
|
||||
if "freeze_layer" in training_conf:
|
||||
num_layer = training_conf["freeze_layer"]
|
||||
model = freeze_top_n_layers(model, num_layer)
|
||||
model_parameters = filter(lambda p: p.requires_grad, model.parameters())
|
||||
params = sum([np.prod(p.size()) for p in model_parameters])
|
||||
print('Number of trainable : {}M'.format(int(params/1e6)))
|
||||
print("Number of trainable : {}M".format(int(params / 1e6)))
|
||||
|
||||
args = TrainingArguments(
|
||||
output_dir=f"{model_name}-finetuned",
|
||||
num_train_epochs=training_conf['num_train_epochs'],
|
||||
num_train_epochs=training_conf["num_train_epochs"],
|
||||
warmup_steps=500,
|
||||
learning_rate=training_conf['learning_rate'],
|
||||
learning_rate=training_conf["learning_rate"],
|
||||
# half_precision_backend="apex",
|
||||
fp16=True,
|
||||
gradient_checkpointing=training_conf['gradient_checkpointing'],
|
||||
gradient_accumulation_steps=training_conf['gradient_accumulation_steps'],
|
||||
per_device_train_batch_size=training_conf['per_device_train_batch_size'],
|
||||
per_device_eval_batch_size=training_conf['per_device_eval_batch_size'],
|
||||
gradient_checkpointing=training_conf["gradient_checkpointing"],
|
||||
gradient_accumulation_steps=training_conf["gradient_accumulation_steps"],
|
||||
per_device_train_batch_size=training_conf["per_device_train_batch_size"],
|
||||
per_device_eval_batch_size=training_conf["per_device_eval_batch_size"],
|
||||
weight_decay=0.01,
|
||||
max_grad_norm=2.0,
|
||||
logging_steps=10,
|
||||
save_total_limit=4,
|
||||
evaluation_strategy='steps',
|
||||
eval_steps=training_conf['eval_steps'],
|
||||
evaluation_strategy="steps",
|
||||
eval_steps=training_conf["eval_steps"],
|
||||
save_steps=1000,
|
||||
report_to='wandb'
|
||||
report_to="wandb",
|
||||
)
|
||||
trainer = QualityTrainer(
|
||||
model,
|
||||
@@ -127,6 +151,6 @@ if __name__ == "__main__":
|
||||
eval_dataset=eval,
|
||||
data_collator=collate_fn,
|
||||
tokenizer=tokenizer,
|
||||
compute_metrics=compute_metrics
|
||||
compute_metrics=compute_metrics,
|
||||
)
|
||||
trainer.train()
|
||||
|
||||
@@ -1,40 +1,41 @@
|
||||
from transformers import AutoTokenizer
|
||||
# -*- coding: utf-8 -*-
|
||||
from experimental_dataset import DataCollatorForSummaryScore, HFSummaryQuality
|
||||
from rank_datasets import DataCollatorForPairRank, HFSummary, WebGPT
|
||||
from torch.utils.data import DataLoader
|
||||
from rank_datasets import WebGPT, HFSummary, DataCollatorForPairRank
|
||||
from experimental_dataset import HFSummaryQuality, DataCollatorForSummaryScore
|
||||
from transformers import AutoTokenizer
|
||||
|
||||
|
||||
def test_hfsummary():
|
||||
|
||||
|
||||
tokenizer = AutoTokenizer.from_pretrained("bigscience/mt0-large")
|
||||
collate_fn = DataCollatorForPairRank(tokenizer, max_length=200)
|
||||
dataset = HFSummary('train')
|
||||
dataset = HFSummary("train")
|
||||
print(len(dataset))
|
||||
dataloader = DataLoader(dataset, collate_fn=collate_fn, batch_size=8)
|
||||
for batch in dataloader:
|
||||
batch['input_ids'].shape
|
||||
|
||||
batch["input_ids"].shape
|
||||
|
||||
|
||||
def test_webgpt():
|
||||
|
||||
|
||||
tokenizer = AutoTokenizer.from_pretrained("bigscience/mt0-large")
|
||||
collate_fn = DataCollatorForPairRank(tokenizer, max_length=200)
|
||||
dataset = WebGPT()
|
||||
dataloader = DataLoader(dataset, collate_fn=collate_fn, batch_size=32)
|
||||
for batch in dataloader:
|
||||
print(batch['input_ids'].shape)
|
||||
print(batch["input_ids"].shape)
|
||||
|
||||
|
||||
def test_hf_quality():
|
||||
|
||||
tokenizer = AutoTokenizer.from_pretrained("bigscience/mt0-large")
|
||||
collate_fn = DataCollatorForSummaryScore(tokenizer, max_length=200)
|
||||
dataset = HFSummaryQuality('validation', tokenizer)
|
||||
dataset = HFSummaryQuality("validation", tokenizer)
|
||||
dataloader = DataLoader(dataset, collate_fn=collate_fn, batch_size=32)
|
||||
for batch in dataloader:
|
||||
print(batch['input_ids'].shape)
|
||||
|
||||
print(batch["input_ids"].shape)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
test_hf_quality()
|
||||
# test_webgpt()
|
||||
# test_webgpt()
|
||||
|
||||
@@ -1,32 +1,44 @@
|
||||
# -*- coding: utf-8 -*-
|
||||
import os
|
||||
os.environ['WANDB_PROJECT'] = 'reward-model'
|
||||
import torch
|
||||
import yaml
|
||||
import evaluate
|
||||
from typing import Any, Callable, List, Optional, Tuple, Union, Dict
|
||||
from torch import nn
|
||||
from argparse import ArgumentParser
|
||||
import numpy as np
|
||||
from dataclasses import dataclass
|
||||
from torch.utils.data import Dataset, ConcatDataset
|
||||
from transformers import AutoModelForSequenceClassification
|
||||
from transformers import Trainer, PreTrainedModel, TrainingArguments, DataCollator, EvalPrediction, TrainerCallback, PreTrainedTokenizerBase
|
||||
from rank_datasets import DataCollatorForPairRank, WebGPT, HFSummary
|
||||
from utils import get_tokenizer, train_val_dataset, freeze_top_n_layers, argument_parsing
|
||||
from typing import Any, Callable, Dict, List, Optional, Tuple, Union
|
||||
|
||||
import evaluate
|
||||
import numpy as np
|
||||
import torch
|
||||
from rank_datasets import DataCollatorForPairRank, HFSummary, WebGPT
|
||||
from torch import nn
|
||||
from torch.utils.data import ConcatDataset, Dataset
|
||||
from transformers import (
|
||||
AutoModelForSequenceClassification,
|
||||
DataCollator,
|
||||
EvalPrediction,
|
||||
PreTrainedModel,
|
||||
PreTrainedTokenizerBase,
|
||||
Trainer,
|
||||
TrainerCallback,
|
||||
TrainingArguments,
|
||||
)
|
||||
from utils import argument_parsing, freeze_top_n_layers, get_tokenizer, train_val_dataset
|
||||
|
||||
os.environ["WANDB_PROJECT"] = "reward-model"
|
||||
|
||||
accuracy = evaluate.load("accuracy")
|
||||
parser = ArgumentParser()
|
||||
parser.add_argument('config', type=str)
|
||||
parser.add_argument("config", type=str)
|
||||
|
||||
|
||||
@dataclass
|
||||
class CustomTrainingArguments(TrainingArguments):
|
||||
loss_function: str='rank'
|
||||
loss_function: str = "rank"
|
||||
|
||||
|
||||
def compute_metrics(eval_pred):
|
||||
predictions, _ = eval_pred
|
||||
predictions = np.argmax(predictions, axis=1)
|
||||
return accuracy.compute(predictions=predictions, references=[0]*predictions.shape[0])
|
||||
return accuracy.compute(predictions=predictions, references=[0] * predictions.shape[0])
|
||||
|
||||
|
||||
class RankLoss(nn.Module):
|
||||
def __init__(self, eps=1e-8) -> None:
|
||||
@@ -39,27 +51,41 @@ class RankLoss(nn.Module):
|
||||
|
||||
|
||||
class RankTrainer(Trainer):
|
||||
def __init__(self, model: Union[PreTrainedModel, nn.Module] = None,
|
||||
args: TrainingArguments = None,
|
||||
data_collator: Optional[DataCollator] = None,
|
||||
train_dataset: Optional[Dataset] = None,
|
||||
eval_dataset: Optional[Dataset] = None,
|
||||
tokenizer: Optional[PreTrainedTokenizerBase] = None,
|
||||
model_init: Callable[[], PreTrainedModel] = None,
|
||||
compute_metrics: Optional[Callable[[EvalPrediction], Dict]] = None,
|
||||
callbacks: Optional[List[TrainerCallback]] = None,
|
||||
optimizers: Tuple[torch.optim.Optimizer, torch.optim.lr_scheduler.LambdaLR] = (None, None),
|
||||
preprocess_logits_for_metrics: Callable[[torch.Tensor, torch.Tensor], torch.Tensor] = None):
|
||||
super().__init__(model, args, data_collator, train_dataset, eval_dataset, tokenizer,
|
||||
model_init, compute_metrics, callbacks, optimizers, preprocess_logits_for_metrics)
|
||||
self.loss_fct = RankLoss() if args.loss_function == 'rank' else nn.CrossEntropyLoss()
|
||||
def __init__(
|
||||
self,
|
||||
model: Union[PreTrainedModel, nn.Module] = None,
|
||||
args: TrainingArguments = None,
|
||||
data_collator: Optional[DataCollator] = None,
|
||||
train_dataset: Optional[Dataset] = None,
|
||||
eval_dataset: Optional[Dataset] = None,
|
||||
tokenizer: Optional[PreTrainedTokenizerBase] = None,
|
||||
model_init: Callable[[], PreTrainedModel] = None,
|
||||
compute_metrics: Optional[Callable[[EvalPrediction], Dict]] = None,
|
||||
callbacks: Optional[List[TrainerCallback]] = None,
|
||||
optimizers: Tuple[torch.optim.Optimizer, torch.optim.lr_scheduler.LambdaLR] = (None, None),
|
||||
preprocess_logits_for_metrics: Callable[[torch.Tensor, torch.Tensor], torch.Tensor] = None,
|
||||
):
|
||||
super().__init__(
|
||||
model,
|
||||
args,
|
||||
data_collator,
|
||||
train_dataset,
|
||||
eval_dataset,
|
||||
tokenizer,
|
||||
model_init,
|
||||
compute_metrics,
|
||||
callbacks,
|
||||
optimizers,
|
||||
preprocess_logits_for_metrics,
|
||||
)
|
||||
self.loss_fct = RankLoss() if args.loss_function == "rank" else nn.CrossEntropyLoss()
|
||||
self.loss_function = args.loss_function
|
||||
|
||||
def compute_loss(self, model, inputs, return_outputs=False):
|
||||
# forward pass
|
||||
outputs = model(**inputs)
|
||||
logits = outputs.get("logits").view(-1, 2)
|
||||
if self.loss_function == 'rank':
|
||||
if self.loss_function == "rank":
|
||||
loss = self.loss_fct(logits[:, 0], logits[:, 1])
|
||||
else:
|
||||
loss = self.loss_fct(logits, torch.zeros(logits.shape[0], device=logits.device, dtype=torch.long))
|
||||
@@ -70,17 +96,20 @@ class RankTrainer(Trainer):
|
||||
inputs = self._prepare_inputs(inputs)
|
||||
outputs = model(**inputs)
|
||||
logits = outputs.get("logits").view(-1, 2)
|
||||
if self.loss_function == 'rank':
|
||||
if self.loss_function == "rank":
|
||||
loss = self.loss_fct(logits[:, 0], logits[:, 1])
|
||||
else:
|
||||
loss = self.loss_fct(logits, torch.zeros(logits.shape[0], device=logits.device, dtype=torch.long))
|
||||
|
||||
return loss, logits
|
||||
|
||||
def prediction_step(self, model: nn.Module,
|
||||
inputs: Dict[str, Union[torch.Tensor, Any]],
|
||||
prediction_loss_only: bool,
|
||||
ignore_keys: Optional[List[str]] = None) -> Tuple[Optional[torch.Tensor], Optional[torch.Tensor], Optional[torch.Tensor]]:
|
||||
def prediction_step(
|
||||
self,
|
||||
model: nn.Module,
|
||||
inputs: Dict[str, Union[torch.Tensor, Any]],
|
||||
prediction_loss_only: bool,
|
||||
ignore_keys: Optional[List[str]] = None,
|
||||
) -> Tuple[Optional[torch.Tensor], Optional[torch.Tensor], Optional[torch.Tensor]]:
|
||||
|
||||
with torch.no_grad():
|
||||
# compute loss on predict data
|
||||
@@ -93,54 +122,57 @@ class RankTrainer(Trainer):
|
||||
|
||||
return (loss, logits, labels)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
training_conf = argument_parsing(parser)
|
||||
|
||||
model_name = training_conf['model_name']
|
||||
model = AutoModelForSequenceClassification.from_pretrained(model_name, num_labels=1, problem_type='regression')
|
||||
if 'freeze_layer' in training_conf:
|
||||
num_layer = training_conf['freeze_layer']
|
||||
model_name = training_conf["model_name"]
|
||||
model = AutoModelForSequenceClassification.from_pretrained(model_name, num_labels=1, problem_type="regression")
|
||||
if "freeze_layer" in training_conf:
|
||||
num_layer = training_conf["freeze_layer"]
|
||||
model = freeze_top_n_layers(model, num_layer)
|
||||
model_parameters = filter(lambda p: p.requires_grad, model.parameters())
|
||||
params = sum([np.prod(p.size()) for p in model_parameters])
|
||||
print('Number of trainable : {}M'.format(int(params/1e6)))
|
||||
print("Number of trainable : {}M".format(int(params / 1e6)))
|
||||
|
||||
tokenizer = get_tokenizer(model_name)
|
||||
args = CustomTrainingArguments(
|
||||
output_dir=f"{model_name}-finetuned",
|
||||
num_train_epochs=training_conf['num_train_epochs'],
|
||||
num_train_epochs=training_conf["num_train_epochs"],
|
||||
warmup_steps=500,
|
||||
loss_function=training_conf['loss'],
|
||||
learning_rate=training_conf['learning_rate'],
|
||||
loss_function=training_conf["loss"],
|
||||
learning_rate=training_conf["learning_rate"],
|
||||
# half_precision_backend="apex",
|
||||
fp16=True,
|
||||
gradient_checkpointing=training_conf['gradient_checkpointing'],
|
||||
gradient_accumulation_steps=training_conf['gradient_accumulation_steps'],
|
||||
per_device_train_batch_size=training_conf['per_device_train_batch_size'],
|
||||
per_device_eval_batch_size=training_conf['per_device_eval_batch_size'],
|
||||
gradient_checkpointing=training_conf["gradient_checkpointing"],
|
||||
gradient_accumulation_steps=training_conf["gradient_accumulation_steps"],
|
||||
per_device_train_batch_size=training_conf["per_device_train_batch_size"],
|
||||
per_device_eval_batch_size=training_conf["per_device_eval_batch_size"],
|
||||
weight_decay=0.01,
|
||||
max_grad_norm=2.0,
|
||||
logging_steps=10,
|
||||
save_total_limit=4,
|
||||
evaluation_strategy='steps',
|
||||
eval_steps=training_conf['eval_steps'],
|
||||
evaluation_strategy="steps",
|
||||
eval_steps=training_conf["eval_steps"],
|
||||
save_steps=1000,
|
||||
report_to='wandb'
|
||||
report_to="wandb",
|
||||
)
|
||||
train_datasets, evals = [], {}
|
||||
if 'webgpt' in training_conf['datasets']:
|
||||
if "webgpt" in training_conf["datasets"]:
|
||||
web_dataset = WebGPT()
|
||||
train, eval = train_val_dataset(web_dataset)
|
||||
train_datasets.append(train)
|
||||
evals['webgpt'] = eval
|
||||
if 'hfsummary' in training_conf['datasets']:
|
||||
sum_train = HFSummary(split='train')
|
||||
evals["webgpt"] = eval
|
||||
if "hfsummary" in training_conf["datasets"]:
|
||||
sum_train = HFSummary(split="train")
|
||||
train_datasets.append(sum_train)
|
||||
sum_eval = HFSummary(split='valid1')
|
||||
sum_eval = HFSummary(split="valid1")
|
||||
assert len(sum_eval) > 0
|
||||
evals['hfsummary'] = sum_eval
|
||||
evals["hfsummary"] = sum_eval
|
||||
train = ConcatDataset(train_datasets)
|
||||
collate_fn = DataCollatorForPairRank(tokenizer, max_length=training_conf['max_length'], drop_token_type= 'galactica' in model_name)
|
||||
collate_fn = DataCollatorForPairRank(
|
||||
tokenizer, max_length=training_conf["max_length"], drop_token_type="galactica" in model_name
|
||||
)
|
||||
assert len(evals) > 0
|
||||
trainer = RankTrainer(
|
||||
model,
|
||||
@@ -149,6 +181,6 @@ if __name__ == "__main__":
|
||||
eval_dataset=eval,
|
||||
data_collator=collate_fn,
|
||||
tokenizer=tokenizer,
|
||||
compute_metrics=compute_metrics
|
||||
compute_metrics=compute_metrics,
|
||||
)
|
||||
trainer.train()
|
||||
|
||||
@@ -1,96 +1,100 @@
|
||||
# -*- coding: utf-8 -*-
|
||||
import re
|
||||
|
||||
import yaml
|
||||
from torch.utils.data import Subset
|
||||
from sklearn.model_selection import train_test_split
|
||||
from torch.utils.data import Subset
|
||||
from transformers import AutoTokenizer
|
||||
|
||||
re_reference_remove = re.compile(r'\[([0-9])+\]|\[([0-9])+,([0-9])+\]')
|
||||
re_reference_remove = re.compile(r"\[([0-9])+\]|\[([0-9])+,([0-9])+\]")
|
||||
|
||||
|
||||
def webgpt_return_format(row):
|
||||
if row['score_0'] >= row['score_1']:
|
||||
if row["score_0"] >= row["score_1"]:
|
||||
# remove this to prevent information leak, since we are not using reference
|
||||
return {
|
||||
'question': row['question']['full_text'],
|
||||
'pos': re_reference_remove.sub('', row['answer_0']),
|
||||
'neg': re_reference_remove.sub('', row['answer_1'])
|
||||
}
|
||||
"question": row["question"]["full_text"],
|
||||
"pos": re_reference_remove.sub("", row["answer_0"]),
|
||||
"neg": re_reference_remove.sub("", row["answer_1"]),
|
||||
}
|
||||
|
||||
return {
|
||||
'question': row['question']['full_text'],
|
||||
'pos': re_reference_remove.sub('', row['answer_1']),
|
||||
'neg': re_reference_remove.sub('', row['answer_0'])
|
||||
}
|
||||
"question": row["question"]["full_text"],
|
||||
"pos": re_reference_remove.sub("", row["answer_1"]),
|
||||
"neg": re_reference_remove.sub("", row["answer_0"]),
|
||||
}
|
||||
|
||||
|
||||
def get_tokenizer(tokenizer_name):
|
||||
tokenizer = AutoTokenizer.from_pretrained(tokenizer_name)
|
||||
if 'galactica' in tokenizer_name:
|
||||
tokenizer.add_special_tokens({'pad_token':'<pad>', 'eos_token': '</s>' })
|
||||
if "galactica" in tokenizer_name:
|
||||
tokenizer.add_special_tokens({"pad_token": "<pad>", "eos_token": "</s>"})
|
||||
|
||||
return tokenizer
|
||||
|
||||
|
||||
|
||||
def train_val_dataset(dataset, val_split=0.2):
|
||||
train_idx, val_idx = train_test_split(list(range(len(dataset))),
|
||||
test_size=val_split, random_state=666, shuffle=True)
|
||||
train_idx, val_idx = train_test_split(
|
||||
list(range(len(dataset))), test_size=val_split, random_state=666, shuffle=True
|
||||
)
|
||||
# [3879, 11479, 8341, 9177, 10798, 18177, 5735, 15669, 4837, 2760]
|
||||
print(val_idx[:10])
|
||||
# [13582, 5919, 11875, 7373, 19135, 13706, 8555, 15788, 15005, 15209]
|
||||
print(train_idx[:10])
|
||||
return Subset(dataset, train_idx), Subset(dataset, val_idx)
|
||||
|
||||
|
||||
def freeze_top_n_layers(model, target_layers):
|
||||
# its possible we can simply detect which module is a ModuleList
|
||||
# and simply freeze the module without doing string parsing
|
||||
for name, param in model.named_parameters():
|
||||
if 'embed' in name:
|
||||
if "embed" in name:
|
||||
param.requires_grad = False
|
||||
elif '.layer' in name or '.h.' in name:
|
||||
tokens = name.split('.')
|
||||
elif ".layer" in name or ".h." in name:
|
||||
tokens = name.split(".")
|
||||
idx = 0
|
||||
for token in tokens:
|
||||
if 'layer' in token or token == 'h':
|
||||
if "layer" in token or token == "h":
|
||||
break
|
||||
idx += 1
|
||||
if idx >= len(tokens):
|
||||
continue
|
||||
|
||||
layer_ = int(tokens[idx+1])
|
||||
layer_ = int(tokens[idx + 1])
|
||||
if layer_ < target_layers:
|
||||
# print('freeze ', layer_, name)
|
||||
param.requires_grad = False
|
||||
return model
|
||||
|
||||
|
||||
def argument_parsing(parser):
|
||||
default_params = {
|
||||
'num_train_epochs': 4,
|
||||
'learning_rate': 3e-5,
|
||||
'eval_steps': 500,
|
||||
'loss': 'rank',
|
||||
'max_length': 440,
|
||||
'per_device_eval_batch_size': 5,
|
||||
'per_device_train_batch_size': 8,
|
||||
'gradient_accumulation_steps': 8,
|
||||
'gradient_checkpointing': False,
|
||||
'datasets': ['webgpt']
|
||||
"num_train_epochs": 4,
|
||||
"learning_rate": 3e-5,
|
||||
"eval_steps": 500,
|
||||
"loss": "rank",
|
||||
"max_length": 440,
|
||||
"per_device_eval_batch_size": 5,
|
||||
"per_device_train_batch_size": 8,
|
||||
"gradient_accumulation_steps": 8,
|
||||
"gradient_checkpointing": False,
|
||||
"datasets": ["webgpt"],
|
||||
}
|
||||
args = parser.parse_args()
|
||||
with open(args.config, 'r', encoding='utf-8') as f:
|
||||
with open(args.config, "r", encoding="utf-8") as f:
|
||||
training_conf = yaml.safe_load(f.read())
|
||||
|
||||
params = { **default_params, **training_conf }
|
||||
params['gradient_accumulation_steps'] = int(params['gradient_accumulation_steps'])
|
||||
params['num_train_epochs'] = int(params['num_train_epochs'])
|
||||
params['per_device_train_batch_size'] = int(params['per_device_train_batch_size'])
|
||||
params['learning_rate'] = float(params['learning_rate'])
|
||||
params = {**default_params, **training_conf}
|
||||
params["gradient_accumulation_steps"] = int(params["gradient_accumulation_steps"])
|
||||
params["num_train_epochs"] = int(params["num_train_epochs"])
|
||||
params["per_device_train_batch_size"] = int(params["per_device_train_batch_size"])
|
||||
params["learning_rate"] = float(params["learning_rate"])
|
||||
return params
|
||||
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
from transformers import AutoModelForSequenceClassification
|
||||
|
||||
model = AutoModelForSequenceClassification.from_pretrained('bigscience/bloomz-560m')
|
||||
model = AutoModelForSequenceClassification.from_pretrained("bigscience/bloomz-560m")
|
||||
freeze_top_n_layers(model, 10)
|
||||
print(model.state_dict().keys())
|
||||
print(model.state_dict().keys())
|
||||
|
||||
Reference in New Issue
Block a user