refactoring, now running

This commit is contained in:
Sotirios Anagnostidis
2023-01-11 22:42:04 +01:00
parent 5b77dd2e9f
commit 4a3ea0b033
5 changed files with 50 additions and 26 deletions
@@ -23,8 +23,10 @@ defaults:
eval_size:
log_dir: "base"
quantization: false
seq2seqmodel: false
poly_eps: 1.0
galactica-125:
galactica-125m:
learning_rate: 5e-5
model_name: facebook/galactica-125m
weight_decay: 0.01
@@ -58,8 +60,8 @@ codegen:
debug:
eval_steps: 20
eval_size: 100
eval_size: 20
gradient_accumulation_steps: 1
per_device_train_batch_size: 1
per_device_eval_batch_size: 1
quantization: false
quantization: false
@@ -27,6 +27,8 @@ summarization_config_mapping = {
"reddit": (),
}
QA_DATASETS = ["squad_v2", "adversarial_qa", "trivia_qa_context", "trivia_qa_noconext"]
SUMMARIZATION_DATASETS = ["xsum", "cnn_dailymail", "samsum", "multi_news"]
def index_squad_v2(example):
return example["title"] + ". " + example["context"] + " " + example["question"], example["answers"]["text"][0]
@@ -157,4 +159,4 @@ def get_one_dataset(conf, dataset_name):
else:
raise ValueError(f"Unknown dataset {dataset_name}")
return train, eval
return train, eval
@@ -8,3 +8,5 @@ PyYAML==6.0
scikit_learn==1.2.0
torch==1.13.1
transformers==4.25.1
evaluate==0.4.0
nltk==3.8.1
+16 -10
View File
@@ -7,15 +7,15 @@ import torch
from torch import nn
from transformers import PreTrainedModel, Trainer, TrainingArguments
from utils import get_dataset, get_loss, get_metrics, get_model, get_tokenizer, read_yamls
from functools import partial
os.environ["WANDB_PROJECT"] = "supervised-finetuning"
def compute_metrics(eval_pred, preprocess_fn, metrics):
preds, labels = preprocess_fn(eval_pred)
def compute_metrics(eval_pred, preprocess_fns, metrics):
out = {}
for metric in metrics:
for metric, preprocess_fn in zip(metrics, preprocess_fns):
preds, labels = preprocess_fn(eval_pred)
out = dict(**out, **metric.compute(predictions=preds, references=labels))
return out
@@ -44,7 +44,10 @@ class SFTTrainer(Trainer):
labels_mask = inputs.pop("label_masks")
targets = inputs.pop("targets")
outputs = model(input_ids=inputs["input_ids"], attention_mask=inputs.get("attention_mask", None))
outputs = model(
input_ids=inputs["input_ids"],
attention_mask=inputs.get("attention_mask", None),
)
loss = self.loss_fct(outputs.get("logits"), targets, mask=labels_mask)
@@ -56,7 +59,10 @@ class SFTTrainer(Trainer):
labels_mask = inputs.pop("label_masks")
targets = inputs.pop("targets")
outputs = model(input_ids=inputs["input_ids"], attention_mask=inputs.get("attention_mask", None))
outputs = model(
input_ids=inputs["input_ids"],
attention_mask=inputs.get("attention_mask", None),
)
logits = outputs.get("logits")
@@ -94,8 +100,6 @@ def argument_parsing(notebook=False, notebook_args=None):
parser.add_argument("--local_rank", type=int, default=-1)
parser.add_argument("--deepspeed", action="store_true")
parser.add_argument("--no-deepspeed", dest="deepspeed", action="store_false")
parser.add_argument("--poly_eps", type=float, default=1.0)
parser.add_argument("--seq2seq_model", action="store_true")
parser.set_defaults(deepspeed=False)
if notebook:
@@ -135,7 +139,7 @@ if __name__ == "__main__":
model = get_model(training_conf, tokenizer)
train, evals, collate_fn = get_dataset(training_conf, tokenizer)
metrics, preprocess_fn = get_metrics(training_conf)
metrics, preprocess_fns = get_metrics(training_conf, tokenizer)
args = TrainingArguments(
output_dir=f"{training_conf.model_name}-{training_conf.log_dir}-finetuned",
@@ -161,15 +165,17 @@ if __name__ == "__main__":
)
assert len(evals) > 0
trainer = SFTTrainer(
model,
args,
loss_function=training_conf.loss_fn,
poly_eps=training_conf.poly_eps,
train_dataset=train,
eval_dataset=evals,
data_collator=collate_fn,
tokenizer=tokenizer,
compute_metrics=compute_metrics,
compute_metrics=partial(compute_metrics, metrics=metrics, preprocess_fns=preprocess_fns),
preprocess_logits_for_metrics=preprocess_logits_for_metrics,
)
trainer.train()
+24 -12
View File
@@ -6,7 +6,7 @@ import nltk
import numpy as np
import transformers
import yaml
from custom_datasets import QA_SPECIAL_TOKENS, get_one_dataset
from custom_datasets import QA_SPECIAL_TOKENS, QA_DATASETS, SUMMARIZATION_DATASETS, get_one_dataset
from custom_datasets.dialogue_collator import DialogueDataCollator
from losses import CrossEntropyLoss, PolyLoss
from models import freeze_top_n_layers, get_specific_model
@@ -36,6 +36,16 @@ def get_tokenizer(conf):
return tokenizer
def default_preprocess(eval_pred, ignote_negative_labels=True):
preds, labels = eval_pred.predictions, eval_pred.label_ids
if not ignote_negative_labels:
return preds, labels
mask = labels > 0
return preds[mask], labels[mask]
# placeholder for now
def preprocess_qa(eval_pred):
return (eval_pred.predictions, eval_pred.label_ids)
@@ -63,20 +73,22 @@ def preprocess_summarization(eval_pred, tokenizer, ignore_pad_token_for_loss=Tru
def get_metrics(conf, tokenizer):
qa_datasets = ["squad_v2", "adversarial_qa", "trivia_qa_context", "trivia_qa_noconext"]
summarization_datasets = ["xsum", "cnn_dailymail", "samsum", "multi_news"]
# the reason behind using a list is that we might want to extend the list of our
# metrics in the future for more thorough evaluation
if any(dataset in qa_datasets for dataset in conf.datasets):
metrics, preprocess_fn = [evaluate.load("squad_v2")], preprocess_qa
elif any(dataset in summarization_datasets for dataset in conf.datasets):
metrics, preprocess_fn = [evaluate.load("rouge")], partial(
preprocess_summarization, tokenizer, ignore_pad_token_for_loss=conf.ignore_pad_token_for_loss
metrics, preprocess_fns = [evaluate.load("accuracy")], [default_preprocess]
if any(dataset in QA_DATASETS for dataset in conf.datasets):
raise ValueError("TODO")
metrics.append(evaluate.load("squad_v2"))
preprocess_fns.append(preprocess_qa)
if any(dataset in SUMMARIZATION_DATASETS for dataset in conf.datasets):
raise ValueError("TODO")
metrics.append(evaluate.load("rouge"))
preprocess_fns.append(
partial(preprocess_summarization, tokenizer, ignore_pad_token_for_loss=conf.ignore_pad_token_for_loss)
)
else:
raise ValueError("Unknown dataset / task")
return metrics, preprocess_fn
return metrics, preprocess_fns
def get_model(conf, tokenizer):