diff --git a/model/supervised_finetuning/configs/config.yaml b/model/supervised_finetuning/configs/config.yaml index 616aa828..59912d09 100644 --- a/model/supervised_finetuning/configs/config.yaml +++ b/model/supervised_finetuning/configs/config.yaml @@ -23,8 +23,10 @@ defaults: eval_size: log_dir: "base" quantization: false + seq2seqmodel: false + poly_eps: 1.0 -galactica-125: +galactica-125m: learning_rate: 5e-5 model_name: facebook/galactica-125m weight_decay: 0.01 @@ -58,8 +60,8 @@ codegen: debug: eval_steps: 20 - eval_size: 100 + eval_size: 20 gradient_accumulation_steps: 1 per_device_train_batch_size: 1 per_device_eval_batch_size: 1 - quantization: false + quantization: false \ No newline at end of file diff --git a/model/supervised_finetuning/custom_datasets/__init__.py b/model/supervised_finetuning/custom_datasets/__init__.py index c6fd3ada..0bc62bc8 100644 --- a/model/supervised_finetuning/custom_datasets/__init__.py +++ b/model/supervised_finetuning/custom_datasets/__init__.py @@ -27,6 +27,8 @@ summarization_config_mapping = { "reddit": (), } +QA_DATASETS = ["squad_v2", "adversarial_qa", "trivia_qa_context", "trivia_qa_noconext"] +SUMMARIZATION_DATASETS = ["xsum", "cnn_dailymail", "samsum", "multi_news"] def index_squad_v2(example): return example["title"] + ". " + example["context"] + " " + example["question"], example["answers"]["text"][0] @@ -157,4 +159,4 @@ def get_one_dataset(conf, dataset_name): else: raise ValueError(f"Unknown dataset {dataset_name}") - return train, eval + return train, eval \ No newline at end of file diff --git a/model/supervised_finetuning/requirements.txt b/model/supervised_finetuning/requirements.txt index c47a1218..d79adf92 100644 --- a/model/supervised_finetuning/requirements.txt +++ b/model/supervised_finetuning/requirements.txt @@ -8,3 +8,5 @@ PyYAML==6.0 scikit_learn==1.2.0 torch==1.13.1 transformers==4.25.1 +evaluate==0.4.0 +nltk==3.8.1 \ No newline at end of file diff --git a/model/supervised_finetuning/trainer.py b/model/supervised_finetuning/trainer.py index 54a45408..88a725b3 100644 --- a/model/supervised_finetuning/trainer.py +++ b/model/supervised_finetuning/trainer.py @@ -7,15 +7,15 @@ import torch from torch import nn from transformers import PreTrainedModel, Trainer, TrainingArguments from utils import get_dataset, get_loss, get_metrics, get_model, get_tokenizer, read_yamls +from functools import partial os.environ["WANDB_PROJECT"] = "supervised-finetuning" -def compute_metrics(eval_pred, preprocess_fn, metrics): - preds, labels = preprocess_fn(eval_pred) - +def compute_metrics(eval_pred, preprocess_fns, metrics): out = {} - for metric in metrics: + for metric, preprocess_fn in zip(metrics, preprocess_fns): + preds, labels = preprocess_fn(eval_pred) out = dict(**out, **metric.compute(predictions=preds, references=labels)) return out @@ -44,7 +44,10 @@ class SFTTrainer(Trainer): labels_mask = inputs.pop("label_masks") targets = inputs.pop("targets") - outputs = model(input_ids=inputs["input_ids"], attention_mask=inputs.get("attention_mask", None)) + outputs = model( + input_ids=inputs["input_ids"], + attention_mask=inputs.get("attention_mask", None), + ) loss = self.loss_fct(outputs.get("logits"), targets, mask=labels_mask) @@ -56,7 +59,10 @@ class SFTTrainer(Trainer): labels_mask = inputs.pop("label_masks") targets = inputs.pop("targets") - outputs = model(input_ids=inputs["input_ids"], attention_mask=inputs.get("attention_mask", None)) + outputs = model( + input_ids=inputs["input_ids"], + attention_mask=inputs.get("attention_mask", None), + ) logits = outputs.get("logits") @@ -94,8 +100,6 @@ def argument_parsing(notebook=False, notebook_args=None): parser.add_argument("--local_rank", type=int, default=-1) parser.add_argument("--deepspeed", action="store_true") parser.add_argument("--no-deepspeed", dest="deepspeed", action="store_false") - parser.add_argument("--poly_eps", type=float, default=1.0) - parser.add_argument("--seq2seq_model", action="store_true") parser.set_defaults(deepspeed=False) if notebook: @@ -135,7 +139,7 @@ if __name__ == "__main__": model = get_model(training_conf, tokenizer) train, evals, collate_fn = get_dataset(training_conf, tokenizer) - metrics, preprocess_fn = get_metrics(training_conf) + metrics, preprocess_fns = get_metrics(training_conf, tokenizer) args = TrainingArguments( output_dir=f"{training_conf.model_name}-{training_conf.log_dir}-finetuned", @@ -161,15 +165,17 @@ if __name__ == "__main__": ) assert len(evals) > 0 + trainer = SFTTrainer( model, args, loss_function=training_conf.loss_fn, + poly_eps=training_conf.poly_eps, train_dataset=train, eval_dataset=evals, data_collator=collate_fn, tokenizer=tokenizer, - compute_metrics=compute_metrics, + compute_metrics=partial(compute_metrics, metrics=metrics, preprocess_fns=preprocess_fns), preprocess_logits_for_metrics=preprocess_logits_for_metrics, ) trainer.train() diff --git a/model/supervised_finetuning/utils.py b/model/supervised_finetuning/utils.py index 8dd21773..368cd188 100644 --- a/model/supervised_finetuning/utils.py +++ b/model/supervised_finetuning/utils.py @@ -6,7 +6,7 @@ import nltk import numpy as np import transformers import yaml -from custom_datasets import QA_SPECIAL_TOKENS, get_one_dataset +from custom_datasets import QA_SPECIAL_TOKENS, QA_DATASETS, SUMMARIZATION_DATASETS, get_one_dataset from custom_datasets.dialogue_collator import DialogueDataCollator from losses import CrossEntropyLoss, PolyLoss from models import freeze_top_n_layers, get_specific_model @@ -36,6 +36,16 @@ def get_tokenizer(conf): return tokenizer +def default_preprocess(eval_pred, ignote_negative_labels=True): + preds, labels = eval_pred.predictions, eval_pred.label_ids + + if not ignote_negative_labels: + return preds, labels + + mask = labels > 0 + return preds[mask], labels[mask] + + # placeholder for now def preprocess_qa(eval_pred): return (eval_pred.predictions, eval_pred.label_ids) @@ -63,20 +73,22 @@ def preprocess_summarization(eval_pred, tokenizer, ignore_pad_token_for_loss=Tru def get_metrics(conf, tokenizer): - qa_datasets = ["squad_v2", "adversarial_qa", "trivia_qa_context", "trivia_qa_noconext"] - summarization_datasets = ["xsum", "cnn_dailymail", "samsum", "multi_news"] - # the reason behind using a list is that we might want to extend the list of our # metrics in the future for more thorough evaluation - if any(dataset in qa_datasets for dataset in conf.datasets): - metrics, preprocess_fn = [evaluate.load("squad_v2")], preprocess_qa - elif any(dataset in summarization_datasets for dataset in conf.datasets): - metrics, preprocess_fn = [evaluate.load("rouge")], partial( - preprocess_summarization, tokenizer, ignore_pad_token_for_loss=conf.ignore_pad_token_for_loss + metrics, preprocess_fns = [evaluate.load("accuracy")], [default_preprocess] + + if any(dataset in QA_DATASETS for dataset in conf.datasets): + raise ValueError("TODO") + metrics.append(evaluate.load("squad_v2")) + preprocess_fns.append(preprocess_qa) + if any(dataset in SUMMARIZATION_DATASETS for dataset in conf.datasets): + raise ValueError("TODO") + metrics.append(evaluate.load("rouge")) + preprocess_fns.append( + partial(preprocess_summarization, tokenizer, ignore_pad_token_for_loss=conf.ignore_pad_token_for_loss) ) - else: - raise ValueError("Unknown dataset / task") - return metrics, preprocess_fn + + return metrics, preprocess_fns def get_model(conf, tokenizer):