mirror of
https://github.com/wassname/Open-Assistant.git
synced 2026-07-01 16:50:12 +08:00
[feature] added configs argument for parameters training and recording
This commit is contained in:
@@ -1,5 +1,8 @@
|
||||
# Sections to train Reward Model (RM)
|
||||
|
||||
|
||||
Currently we format
|
||||
|
||||
|
||||
```bash
|
||||
|
||||
|
||||
@@ -9,13 +9,11 @@
|
||||
|
||||
'''
|
||||
from typing import Optional, Union
|
||||
import os
|
||||
import glob
|
||||
import json
|
||||
from dataclasses import dataclass
|
||||
import numpy as np
|
||||
from torch.utils.data import Dataset
|
||||
import torch
|
||||
from datasets import load_dataset
|
||||
from transformers.tokenization_utils_base import PreTrainedTokenizerBase, PaddingStrategy
|
||||
|
||||
|
||||
@@ -1,18 +1,22 @@
|
||||
import os
|
||||
os.environ['WANDB_PROJECT'] = 'reward-model'
|
||||
from typing import Any, Callable, List, Optional, Tuple, Union, Dict
|
||||
import torch
|
||||
from torch import nn
|
||||
import numpy as np
|
||||
import yaml
|
||||
import evaluate
|
||||
from typing import Any, Callable, List, Optional, Tuple, Union, Dict
|
||||
from torch import nn
|
||||
from argparse import ArgumentParser
|
||||
import numpy as np
|
||||
from dataclasses import dataclass
|
||||
from torch.utils.data import Dataset, ConcatDataset
|
||||
from transformers import AutoModelForSequenceClassification, AutoModelForMultipleChoice
|
||||
from transformers import AutoModelForSequenceClassification
|
||||
from transformers import Trainer, PreTrainedModel, TrainingArguments, DataCollator, EvalPrediction, TrainerCallback, PreTrainedTokenizerBase
|
||||
from rank_datasets import DataCollatorForPairRank, WebGPT, HFSummary
|
||||
from utils import get_tokenizer, train_val_dataset
|
||||
from utils import get_tokenizer, train_val_dataset, freeze_top_n_layers, argument_parsing
|
||||
|
||||
accuracy = evaluate.load("accuracy")
|
||||
parser = ArgumentParser()
|
||||
parser.add_argument('config', type=str)
|
||||
|
||||
@dataclass
|
||||
class CustomTrainingArguments(TrainingArguments):
|
||||
@@ -87,21 +91,26 @@ class RankTrainer(Trainer):
|
||||
return (loss, logits, labels)
|
||||
|
||||
if __name__ == "__main__":
|
||||
model_name = 'bigscience/bloomz-560m'
|
||||
model_name = 'google/electra-large-discriminator'
|
||||
training_conf = argument_parsing(parser)
|
||||
|
||||
model_name = training_conf['model_name']
|
||||
model = AutoModelForSequenceClassification.from_pretrained(model_name, num_labels=1, problem_type='regression')
|
||||
if 'freeze_layer' in training_conf:
|
||||
num_layer = training_conf['freeze_layer']
|
||||
model = freeze_top_n_layers(model, num_layer)
|
||||
|
||||
tokenizer = get_tokenizer(model_name)
|
||||
args = CustomTrainingArguments(
|
||||
output_dir=f"{model_name}-finetuned",
|
||||
num_train_epochs=4,
|
||||
num_train_epochs=training_conf['num_train_epochs'],
|
||||
warmup_steps=500,
|
||||
loss_function='rank',
|
||||
learning_rate=3e-5,
|
||||
loss_function=training_conf['loss'],
|
||||
learning_rate=training_conf['learning_rate'],
|
||||
# half_precision_backend="apex",
|
||||
fp16=True,
|
||||
gradient_checkpointing=True,
|
||||
gradient_accumulation_steps=8,
|
||||
per_device_train_batch_size=8,
|
||||
gradient_checkpointing=training_conf['gradient_checkpointing'],
|
||||
gradient_accumulation_steps=training_conf['gradient_checkpointing'],
|
||||
per_device_train_batch_size=training_conf['per_device_train_batch_size'],
|
||||
per_device_eval_batch_size=5,
|
||||
weight_decay=0.01,
|
||||
max_grad_norm=2.0,
|
||||
@@ -112,10 +121,19 @@ if __name__ == "__main__":
|
||||
save_steps=1000,
|
||||
report_to='wandb'
|
||||
)
|
||||
dataset = WebGPT()
|
||||
train, eval = train_val_dataset(dataset)
|
||||
train = ConcatDataset([train, HFSummary()])
|
||||
collate_fn = DataCollatorForPairRank(tokenizer, max_length=440)
|
||||
train_datasets, evals = [], {}
|
||||
if 'webgpt' in training_conf['datasets']:
|
||||
web_dataset = WebGPT()
|
||||
train, eval = train_val_dataset(web_dataset)
|
||||
train_datasets.append(train)
|
||||
evals['webgpt'] = eval
|
||||
if 'hfsummary' in training_conf['datasets']:
|
||||
summary_dataset = HFSummary()
|
||||
sum_train, sum_eval = train_val_dataset(summary_dataset)
|
||||
train_datasets.append(sum_train)
|
||||
evals['hfsummary'] = sum_eval
|
||||
|
||||
collate_fn = DataCollatorForPairRank(tokenizer, max_length=training_conf['max_length'])
|
||||
trainer = RankTrainer(
|
||||
model,
|
||||
args,
|
||||
|
||||
@@ -1,4 +1,5 @@
|
||||
import re
|
||||
import yaml
|
||||
from torch.utils.data import Subset
|
||||
from sklearn.model_selection import train_test_split
|
||||
from transformers import AutoTokenizer
|
||||
@@ -39,3 +40,40 @@ def train_val_dataset(dataset, val_split=0.2):
|
||||
print(train_idx[:10])
|
||||
return Subset(dataset, train_idx), Subset(dataset, val_idx)
|
||||
|
||||
def freeze_top_n_layers(model, target_layers):
|
||||
for name, param in model.name_parameters():
|
||||
if 'embed' in name:
|
||||
param.requires_grad = False
|
||||
elif 'layer' in name:
|
||||
tokens = name.split('.')
|
||||
idx = 0
|
||||
for token in tokens:
|
||||
if 'layer' in token:
|
||||
break
|
||||
idx += 1
|
||||
|
||||
layer_ = int(tokens[idx+1])
|
||||
if layer_ < target_layers:
|
||||
param.requires_grad = False
|
||||
return model
|
||||
|
||||
|
||||
def argument_parsing(parser):
|
||||
default_params = {
|
||||
'num_train_epochs': 4,
|
||||
'learning_rate': 3e-5,
|
||||
'eval_steps': 500,
|
||||
'loss': 'rank',
|
||||
'max_length': 440,
|
||||
'per_device_train_batch_size': 8,
|
||||
'gradient_accumulation_steps': 8,
|
||||
'gradient_checkpointing': False,
|
||||
'datasets': ['webgpt']
|
||||
}
|
||||
args = parser.parse_args()
|
||||
with open(args.config, 'r', encoding='utf-8') as f:
|
||||
training_conf = yaml.safe_load(f.read())
|
||||
|
||||
return { **default_params, **training_conf }
|
||||
|
||||
|
||||
|
||||
Reference in New Issue
Block a user