diff --git a/model/reward/instructor/README.md b/model/reward/instructor/README.md index 7dbfefbc..a8b5ef33 100644 --- a/model/reward/instructor/README.md +++ b/model/reward/instructor/README.md @@ -1,5 +1,8 @@ +# Sections to train Reward Model (RM) +Currently we format + ```bash diff --git a/model/reward/instructor/rank_datasets.py b/model/reward/instructor/rank_datasets.py index 41740dcf..aa77089c 100644 --- a/model/reward/instructor/rank_datasets.py +++ b/model/reward/instructor/rank_datasets.py @@ -9,13 +9,11 @@ ''' from typing import Optional, Union -import os import glob import json from dataclasses import dataclass import numpy as np from torch.utils.data import Dataset -import torch from datasets import load_dataset from transformers.tokenization_utils_base import PreTrainedTokenizerBase, PaddingStrategy diff --git a/model/reward/instructor/trainer.py b/model/reward/instructor/trainer.py index 586c8d47..06bb8098 100644 --- a/model/reward/instructor/trainer.py +++ b/model/reward/instructor/trainer.py @@ -1,18 +1,22 @@ import os os.environ['WANDB_PROJECT'] = 'reward-model' -from typing import Any, Callable, List, Optional, Tuple, Union, Dict import torch -from torch import nn -import numpy as np +import yaml import evaluate +from typing import Any, Callable, List, Optional, Tuple, Union, Dict +from torch import nn +from argparse import ArgumentParser +import numpy as np from dataclasses import dataclass from torch.utils.data import Dataset, ConcatDataset -from transformers import AutoModelForSequenceClassification, AutoModelForMultipleChoice +from transformers import AutoModelForSequenceClassification from transformers import Trainer, PreTrainedModel, TrainingArguments, DataCollator, EvalPrediction, TrainerCallback, PreTrainedTokenizerBase from rank_datasets import DataCollatorForPairRank, WebGPT, HFSummary -from utils import get_tokenizer, train_val_dataset +from utils import get_tokenizer, train_val_dataset, freeze_top_n_layers, argument_parsing accuracy = evaluate.load("accuracy") +parser = ArgumentParser() +parser.add_argument('config', type=str) @dataclass class CustomTrainingArguments(TrainingArguments): @@ -87,21 +91,26 @@ class RankTrainer(Trainer): return (loss, logits, labels) if __name__ == "__main__": - model_name = 'bigscience/bloomz-560m' - model_name = 'google/electra-large-discriminator' + training_conf = argument_parsing(parser) + + model_name = training_conf['model_name'] model = AutoModelForSequenceClassification.from_pretrained(model_name, num_labels=1, problem_type='regression') + if 'freeze_layer' in training_conf: + num_layer = training_conf['freeze_layer'] + model = freeze_top_n_layers(model, num_layer) + tokenizer = get_tokenizer(model_name) args = CustomTrainingArguments( output_dir=f"{model_name}-finetuned", - num_train_epochs=4, + num_train_epochs=training_conf['num_train_epochs'], warmup_steps=500, - loss_function='rank', - learning_rate=3e-5, + loss_function=training_conf['loss'], + learning_rate=training_conf['learning_rate'], # half_precision_backend="apex", fp16=True, - gradient_checkpointing=True, - gradient_accumulation_steps=8, - per_device_train_batch_size=8, + gradient_checkpointing=training_conf['gradient_checkpointing'], + gradient_accumulation_steps=training_conf['gradient_checkpointing'], + per_device_train_batch_size=training_conf['per_device_train_batch_size'], per_device_eval_batch_size=5, weight_decay=0.01, max_grad_norm=2.0, @@ -112,10 +121,19 @@ if __name__ == "__main__": save_steps=1000, report_to='wandb' ) - dataset = WebGPT() - train, eval = train_val_dataset(dataset) - train = ConcatDataset([train, HFSummary()]) - collate_fn = DataCollatorForPairRank(tokenizer, max_length=440) + train_datasets, evals = [], {} + if 'webgpt' in training_conf['datasets']: + web_dataset = WebGPT() + train, eval = train_val_dataset(web_dataset) + train_datasets.append(train) + evals['webgpt'] = eval + if 'hfsummary' in training_conf['datasets']: + summary_dataset = HFSummary() + sum_train, sum_eval = train_val_dataset(summary_dataset) + train_datasets.append(sum_train) + evals['hfsummary'] = sum_eval + + collate_fn = DataCollatorForPairRank(tokenizer, max_length=training_conf['max_length']) trainer = RankTrainer( model, args, diff --git a/model/reward/instructor/utils.py b/model/reward/instructor/utils.py index 10f84193..4867087c 100644 --- a/model/reward/instructor/utils.py +++ b/model/reward/instructor/utils.py @@ -1,4 +1,5 @@ import re +import yaml from torch.utils.data import Subset from sklearn.model_selection import train_test_split from transformers import AutoTokenizer @@ -39,3 +40,40 @@ def train_val_dataset(dataset, val_split=0.2): print(train_idx[:10]) return Subset(dataset, train_idx), Subset(dataset, val_idx) +def freeze_top_n_layers(model, target_layers): + for name, param in model.name_parameters(): + if 'embed' in name: + param.requires_grad = False + elif 'layer' in name: + tokens = name.split('.') + idx = 0 + for token in tokens: + if 'layer' in token: + break + idx += 1 + + layer_ = int(tokens[idx+1]) + if layer_ < target_layers: + param.requires_grad = False + return model + + +def argument_parsing(parser): + default_params = { + 'num_train_epochs': 4, + 'learning_rate': 3e-5, + 'eval_steps': 500, + 'loss': 'rank', + 'max_length': 440, + 'per_device_train_batch_size': 8, + 'gradient_accumulation_steps': 8, + 'gradient_checkpointing': False, + 'datasets': ['webgpt'] + } + args = parser.parse_args() + with open(args.config, 'r', encoding='utf-8') as f: + training_conf = yaml.safe_load(f.read()) + + return { **default_params, **training_conf } + +