mirror of
https://github.com/wassname/Open-Assistant.git
synced 2026-06-27 16:10:30 +08:00
refactoring, now running
This commit is contained in:
@@ -23,8 +23,10 @@ defaults:
|
||||
eval_size:
|
||||
log_dir: "base"
|
||||
quantization: false
|
||||
seq2seqmodel: false
|
||||
poly_eps: 1.0
|
||||
|
||||
galactica-125:
|
||||
galactica-125m:
|
||||
learning_rate: 5e-5
|
||||
model_name: facebook/galactica-125m
|
||||
weight_decay: 0.01
|
||||
@@ -58,8 +60,8 @@ codegen:
|
||||
|
||||
debug:
|
||||
eval_steps: 20
|
||||
eval_size: 100
|
||||
eval_size: 20
|
||||
gradient_accumulation_steps: 1
|
||||
per_device_train_batch_size: 1
|
||||
per_device_eval_batch_size: 1
|
||||
quantization: false
|
||||
quantization: false
|
||||
@@ -27,6 +27,8 @@ summarization_config_mapping = {
|
||||
"reddit": (),
|
||||
}
|
||||
|
||||
QA_DATASETS = ["squad_v2", "adversarial_qa", "trivia_qa_context", "trivia_qa_noconext"]
|
||||
SUMMARIZATION_DATASETS = ["xsum", "cnn_dailymail", "samsum", "multi_news"]
|
||||
|
||||
def index_squad_v2(example):
|
||||
return example["title"] + ". " + example["context"] + " " + example["question"], example["answers"]["text"][0]
|
||||
@@ -157,4 +159,4 @@ def get_one_dataset(conf, dataset_name):
|
||||
else:
|
||||
raise ValueError(f"Unknown dataset {dataset_name}")
|
||||
|
||||
return train, eval
|
||||
return train, eval
|
||||
@@ -8,3 +8,5 @@ PyYAML==6.0
|
||||
scikit_learn==1.2.0
|
||||
torch==1.13.1
|
||||
transformers==4.25.1
|
||||
evaluate==0.4.0
|
||||
nltk==3.8.1
|
||||
@@ -7,15 +7,15 @@ import torch
|
||||
from torch import nn
|
||||
from transformers import PreTrainedModel, Trainer, TrainingArguments
|
||||
from utils import get_dataset, get_loss, get_metrics, get_model, get_tokenizer, read_yamls
|
||||
from functools import partial
|
||||
|
||||
os.environ["WANDB_PROJECT"] = "supervised-finetuning"
|
||||
|
||||
|
||||
def compute_metrics(eval_pred, preprocess_fn, metrics):
|
||||
preds, labels = preprocess_fn(eval_pred)
|
||||
|
||||
def compute_metrics(eval_pred, preprocess_fns, metrics):
|
||||
out = {}
|
||||
for metric in metrics:
|
||||
for metric, preprocess_fn in zip(metrics, preprocess_fns):
|
||||
preds, labels = preprocess_fn(eval_pred)
|
||||
out = dict(**out, **metric.compute(predictions=preds, references=labels))
|
||||
|
||||
return out
|
||||
@@ -44,7 +44,10 @@ class SFTTrainer(Trainer):
|
||||
labels_mask = inputs.pop("label_masks")
|
||||
targets = inputs.pop("targets")
|
||||
|
||||
outputs = model(input_ids=inputs["input_ids"], attention_mask=inputs.get("attention_mask", None))
|
||||
outputs = model(
|
||||
input_ids=inputs["input_ids"],
|
||||
attention_mask=inputs.get("attention_mask", None),
|
||||
)
|
||||
|
||||
loss = self.loss_fct(outputs.get("logits"), targets, mask=labels_mask)
|
||||
|
||||
@@ -56,7 +59,10 @@ class SFTTrainer(Trainer):
|
||||
labels_mask = inputs.pop("label_masks")
|
||||
targets = inputs.pop("targets")
|
||||
|
||||
outputs = model(input_ids=inputs["input_ids"], attention_mask=inputs.get("attention_mask", None))
|
||||
outputs = model(
|
||||
input_ids=inputs["input_ids"],
|
||||
attention_mask=inputs.get("attention_mask", None),
|
||||
)
|
||||
|
||||
logits = outputs.get("logits")
|
||||
|
||||
@@ -94,8 +100,6 @@ def argument_parsing(notebook=False, notebook_args=None):
|
||||
parser.add_argument("--local_rank", type=int, default=-1)
|
||||
parser.add_argument("--deepspeed", action="store_true")
|
||||
parser.add_argument("--no-deepspeed", dest="deepspeed", action="store_false")
|
||||
parser.add_argument("--poly_eps", type=float, default=1.0)
|
||||
parser.add_argument("--seq2seq_model", action="store_true")
|
||||
parser.set_defaults(deepspeed=False)
|
||||
|
||||
if notebook:
|
||||
@@ -135,7 +139,7 @@ if __name__ == "__main__":
|
||||
model = get_model(training_conf, tokenizer)
|
||||
|
||||
train, evals, collate_fn = get_dataset(training_conf, tokenizer)
|
||||
metrics, preprocess_fn = get_metrics(training_conf)
|
||||
metrics, preprocess_fns = get_metrics(training_conf, tokenizer)
|
||||
|
||||
args = TrainingArguments(
|
||||
output_dir=f"{training_conf.model_name}-{training_conf.log_dir}-finetuned",
|
||||
@@ -161,15 +165,17 @@ if __name__ == "__main__":
|
||||
)
|
||||
|
||||
assert len(evals) > 0
|
||||
|
||||
trainer = SFTTrainer(
|
||||
model,
|
||||
args,
|
||||
loss_function=training_conf.loss_fn,
|
||||
poly_eps=training_conf.poly_eps,
|
||||
train_dataset=train,
|
||||
eval_dataset=evals,
|
||||
data_collator=collate_fn,
|
||||
tokenizer=tokenizer,
|
||||
compute_metrics=compute_metrics,
|
||||
compute_metrics=partial(compute_metrics, metrics=metrics, preprocess_fns=preprocess_fns),
|
||||
preprocess_logits_for_metrics=preprocess_logits_for_metrics,
|
||||
)
|
||||
trainer.train()
|
||||
|
||||
@@ -6,7 +6,7 @@ import nltk
|
||||
import numpy as np
|
||||
import transformers
|
||||
import yaml
|
||||
from custom_datasets import QA_SPECIAL_TOKENS, get_one_dataset
|
||||
from custom_datasets import QA_SPECIAL_TOKENS, QA_DATASETS, SUMMARIZATION_DATASETS, get_one_dataset
|
||||
from custom_datasets.dialogue_collator import DialogueDataCollator
|
||||
from losses import CrossEntropyLoss, PolyLoss
|
||||
from models import freeze_top_n_layers, get_specific_model
|
||||
@@ -36,6 +36,16 @@ def get_tokenizer(conf):
|
||||
return tokenizer
|
||||
|
||||
|
||||
def default_preprocess(eval_pred, ignote_negative_labels=True):
|
||||
preds, labels = eval_pred.predictions, eval_pred.label_ids
|
||||
|
||||
if not ignote_negative_labels:
|
||||
return preds, labels
|
||||
|
||||
mask = labels > 0
|
||||
return preds[mask], labels[mask]
|
||||
|
||||
|
||||
# placeholder for now
|
||||
def preprocess_qa(eval_pred):
|
||||
return (eval_pred.predictions, eval_pred.label_ids)
|
||||
@@ -63,20 +73,22 @@ def preprocess_summarization(eval_pred, tokenizer, ignore_pad_token_for_loss=Tru
|
||||
|
||||
|
||||
def get_metrics(conf, tokenizer):
|
||||
qa_datasets = ["squad_v2", "adversarial_qa", "trivia_qa_context", "trivia_qa_noconext"]
|
||||
summarization_datasets = ["xsum", "cnn_dailymail", "samsum", "multi_news"]
|
||||
|
||||
# the reason behind using a list is that we might want to extend the list of our
|
||||
# metrics in the future for more thorough evaluation
|
||||
if any(dataset in qa_datasets for dataset in conf.datasets):
|
||||
metrics, preprocess_fn = [evaluate.load("squad_v2")], preprocess_qa
|
||||
elif any(dataset in summarization_datasets for dataset in conf.datasets):
|
||||
metrics, preprocess_fn = [evaluate.load("rouge")], partial(
|
||||
preprocess_summarization, tokenizer, ignore_pad_token_for_loss=conf.ignore_pad_token_for_loss
|
||||
metrics, preprocess_fns = [evaluate.load("accuracy")], [default_preprocess]
|
||||
|
||||
if any(dataset in QA_DATASETS for dataset in conf.datasets):
|
||||
raise ValueError("TODO")
|
||||
metrics.append(evaluate.load("squad_v2"))
|
||||
preprocess_fns.append(preprocess_qa)
|
||||
if any(dataset in SUMMARIZATION_DATASETS for dataset in conf.datasets):
|
||||
raise ValueError("TODO")
|
||||
metrics.append(evaluate.load("rouge"))
|
||||
preprocess_fns.append(
|
||||
partial(preprocess_summarization, tokenizer, ignore_pad_token_for_loss=conf.ignore_pad_token_for_loss)
|
||||
)
|
||||
else:
|
||||
raise ValueError("Unknown dataset / task")
|
||||
return metrics, preprocess_fn
|
||||
|
||||
return metrics, preprocess_fns
|
||||
|
||||
|
||||
def get_model(conf, tokenizer):
|
||||
|
||||
Reference in New Issue
Block a user