From 5b77dd2e9f21e24553173bf5c7f5221ea0bbd6f8 Mon Sep 17 00:00:00 2001 From: ekurtulus Date: Wed, 11 Jan 2023 11:37:27 +0300 Subject: [PATCH 1/5] better --- .../custom_datasets/__init__.py | 103 ++++++++++++++++-- model/supervised_finetuning/losses.py | 30 +++++ .../supervised_finetuning/models/__init__.py | 17 +-- model/supervised_finetuning/trainer.py | 31 +++--- model/supervised_finetuning/utils.py | 59 +++++++++- 5 files changed, 202 insertions(+), 38 deletions(-) diff --git a/model/supervised_finetuning/custom_datasets/__init__.py b/model/supervised_finetuning/custom_datasets/__init__.py index 5706bfa7..c6fd3ada 100644 --- a/model/supervised_finetuning/custom_datasets/__init__.py +++ b/model/supervised_finetuning/custom_datasets/__init__.py @@ -1,3 +1,4 @@ +import numpy as np from datasets import load_dataset from sklearn.model_selection import train_test_split from torch.utils.data import Dataset, Subset @@ -5,19 +6,100 @@ from torch.utils.data import Dataset, Subset from .prompt_dialogue import PromptGeneratedDataset QA_SPECIAL_TOKENS = {"Question": "", "Answer": ""} +SUMMARIZATION_SPECIAL_TOKENS = {"Text": "", "Summary": "TL;DR:"} + +summarization_name_mapping = { + "cnn_dailymail": ("article", "highlights"), + "samsum": ("dialogue", "summary"), + "xsum": ("document", "summary"), + "multi_news": ("document", "summary"), + "scitldr": ("source", "target"), + "billsum": ("text", "summary"), + "reddit": ("content", "summary"), +} +summarization_config_mapping = { + "cnn_dailymail": ("3.0.0",), + "samsum": (), + "xsum": (), + "multi_news": (), + "scitldr": ("AIC",), + "billsum": (), + "reddit": (), +} -class SquadV2Dataset(Dataset): - def __init__(self, cache_dir, split): - self.dataset = load_dataset("squad_v2", cache_dir=cache_dir, split=split) +def index_squad_v2(example): + return example["title"] + ". " + example["context"] + " " + example["question"], example["answers"]["text"][0] + + +def index_trivia_qa_nocontext(example): + # dummy return one randomly + return example["question"], example["answer"]["aliases"][np.random.randint(len(example["answer"]["aliases"]))] + + +def index_trivia_qa_context(example): + question = example["question"] + title = example["title"][np.random.randint(len(example["title"]))] + context = example["search_context"][np.random.randint(len(example["search_context"]))] + answer = example["answer"]["aliases"][np.random.randint(len(example["answer"]["aliases"]))] + + return title + ". " + context + " " + question, answer + + +def index_adversarial_qa(example): + return example["title"] + ". " + example["context"] + " " + example["question"], example["answers"]["text"][0] + + +class QADataset(Dataset): + def __init__(self, dataset, cache_dir, split): + if dataset == "squad_v2": + self.index_fn = index_squad_v2 + self.dataset = load_dataset("squad_v2", cache_dir=cache_dir, split=split) + elif dataset == "trivia_qa_nocontext": + self.index_fn = index_trivia_qa_nocontext + self.dataset = load_dataset("trivia_qa", "rc.nocontext") + elif dataset == "trivia_qa_context": + self.index_fn = index_trivia_qa_context + self.dataset = load_dataset("trivia_qa", "rc") + elif dataset == "adversarial_qa": + self.index_fn = index_adversarial_qa + self.dataset = load_dataset("adversarial_qa", "adversarialQA") + else: + raise ValueError("Unknown dataset : " + dataset) def __len__(self): return len(self.dataset) def __getitem__(self, idx): data = self.dataset[idx] - # return first answer form list of possible answers - return data["title"] + ". " + data["context"] + " " + data["question"], data["answers"]["text"][0] + return self.index_fn(data) + + +def index_summary_default(text, summary): + return text, summary + + +def index_summary_merge(text, summary): + return " ".join(text), " ".join(summary) + + +class SummarizationDataset(Dataset): + def __init__(self, dataset, cache_dir, split): + self.dataset = load_dataset(dataset, *summarization_config_mapping[dataset], cache_dir=cache_dir, split=split) + self.summary_column, self.text_column = summarization_name_mapping[dataset] + self.preprocess_fn = index_summary_merge if dataset == "scitdlr" else index_summary_merge + + def __len__(self): + return len(self.dataset) + + def __getitem__(self, idx): + data = self.dataset[idx] + text, summary = data[self.text_column], data[self.summary_column] + text, summary = self.preprocess_fn(text, summary) + + return "".join( + SUMMARIZATION_SPECIAL_TOKENS["Text"], text, " ", SUMMARIZATION_SPECIAL_TOKENS["Summary"], summary + ) class WebGPT(Dataset): @@ -58,9 +140,14 @@ def train_val_dataset(dataset, val_split=0.2): def get_one_dataset(conf, dataset_name): dataset_name = dataset_name.lower() - if dataset_name == "squadv2": - train = SquadV2Dataset(conf.cache_dir, "train") - eval = SquadV2Dataset(conf.cache_dir, "validation") + if dataset_name in ["squad_v2", "adversarial_qa", "trivia_qa_context", "trivia_qa_noconext"]: + train = QADataset(dataset_name, conf.cache_dir, "train") + eval = QADataset(dataset_name, conf.cache_dir, "validation") + + elif dataset_name in ["xsum", "cnn_dailymail", "samsum", "multi_news", "scitldr", "billsum", "reddit"]: + train = SummarizationDataset(dataset_name, conf.cache_dir, "train") + eval = SummarizationDataset(dataset_name, conf.cache_dir, "validation") + elif dataset_name == "webgpt": dataset = WebGPT() train, eval = train_val_dataset(dataset, val_split=0.2) diff --git a/model/supervised_finetuning/losses.py b/model/supervised_finetuning/losses.py index 0cc639cf..88902ed7 100644 --- a/model/supervised_finetuning/losses.py +++ b/model/supervised_finetuning/losses.py @@ -1,3 +1,5 @@ +import torch +import torch.nn.functional as F from torch import nn @@ -13,3 +15,31 @@ class CrossEntropyLoss(nn.CrossEntropyLoss): input = input[mask] target = target[mask] return super(CrossEntropyLoss, self).forward(input, target) + + +class PolyLoss(nn.Module): + def __init__(self, weight=None, size_average=None, ignore_index=-100, reduce=None, reduction="mean", epsilon=1.0): + super(PolyLoss, self).__init__() + self.weight = torch.tensor(weight) + self.ignore_index = ignore_index + self.reduction = reduction + self.cross_entropy = CrossEntropyLoss(weight, size_average, ignore_index, reduce, "none") + self.epsilon = epsilon + + def forward(self, input, target, mask=None): + if mask is not None: + mask = mask.view(-1).bool() + input = input.view(-1, input.size(-1)) + target = target.view(-1) + input = input[mask] + target = target[mask] + + onehot_target = F.one_hot(target, num_classes=input.size(-1)).to(device=input.device, dtype=input.dtype) + pt = torch.sum(onehot_target * F.softmax(input, -1), -1) + CE = self.cross_entropy(input, target) + poly1 = CE + self.epsilon * (1 - pt) + if self.reduction == "mean": + poly1 = poly1.mean() + elif self.reduction == "sum": + poly1 = poly1.sum() + return poly1 diff --git a/model/supervised_finetuning/models/__init__.py b/model/supervised_finetuning/models/__init__.py index 21ddab9d..510a2738 100644 --- a/model/supervised_finetuning/models/__init__.py +++ b/model/supervised_finetuning/models/__init__.py @@ -1,4 +1,4 @@ -from transformers import AutoModelForCausalLM +import transformers # from .gptj import get_model as get_gptj_model @@ -25,9 +25,12 @@ def freeze_top_n_layers(model, target_layers): return model -def get_specific_model(model_name, cache_dir, quantization): - return AutoModelForCausalLM.from_pretrained(model_name, cache_dir=cache_dir) - # if "gpt-j" in model_name.lower(): - # return get_gptj_model(model_name, cache_dir, quantization) - # else: - # return AutoModelForCausalLM.from_pretrained(model_name, cache_dir=cache_dir) +def get_specific_model(model_name, cache_dir, quantization, seq2seqmodel): + # encoder-decoder support for Flan-T5 like models + # for now, we can use an argument but in the future, + # we can automate this + if seq2seqmodel: + model = transformers.AutoModelForSeq2SeqLM.from_pretrained(model_name, cache_dir=cache_dir) + else: + model = transformers.AutoModelForCausalLM.from_pretrained(model_name, cache_dir=cache_dir) + return model diff --git a/model/supervised_finetuning/trainer.py b/model/supervised_finetuning/trainer.py index 450854f1..54a45408 100644 --- a/model/supervised_finetuning/trainer.py +++ b/model/supervised_finetuning/trainer.py @@ -3,21 +3,22 @@ import os from distutils.util import strtobool from typing import Any, Dict, List, Optional, Tuple, Union -import bitsandbytes import torch from torch import nn from transformers import PreTrainedModel, Trainer, TrainingArguments -from transformers.training_args import OptimizerNames -from utils import get_dataset, get_loss, get_model, get_tokenizer, read_yamls +from utils import get_dataset, get_loss, get_metrics, get_model, get_tokenizer, read_yamls os.environ["WANDB_PROJECT"] = "supervised-finetuning" -def compute_metrics(eval_pred): - pred_ids = eval_pred.predictions - labels = eval_pred.label_ids +def compute_metrics(eval_pred, preprocess_fn, metrics): + preds, labels = preprocess_fn(eval_pred) - return {"accuracy": (pred_ids[labels > 0] == labels[labels > 0]).mean()} + out = {} + for metric in metrics: + out = dict(**out, **metric.compute(predictions=preds, references=labels)) + + return out def preprocess_logits_for_metrics(logits, labels): @@ -31,12 +32,13 @@ class SFTTrainer(Trainer): model: Union[PreTrainedModel, nn.Module] = None, args: TrainingArguments = None, loss_function: str = "CrossEntropyLoss", + poly_eps: float = 1.0, **kwargs, ): super().__init__(model, args, **kwargs) # By default CrossEntropyLoss ignores padding_index -100, but just in case use our own loss_fct - self.loss_fct = get_loss(loss_function) + self.loss_fct = get_loss(loss_function, poly_eps) def compute_loss(self, model, inputs, return_outputs=False): labels_mask = inputs.pop("label_masks") @@ -92,6 +94,8 @@ def argument_parsing(notebook=False, notebook_args=None): parser.add_argument("--local_rank", type=int, default=-1) parser.add_argument("--deepspeed", action="store_true") parser.add_argument("--no-deepspeed", dest="deepspeed", action="store_false") + parser.add_argument("--poly_eps", type=float, default=1.0) + parser.add_argument("--seq2seq_model", action="store_true") parser.set_defaults(deepspeed=False) if notebook: @@ -131,15 +135,7 @@ if __name__ == "__main__": model = get_model(training_conf, tokenizer) train, evals, collate_fn = get_dataset(training_conf, tokenizer) - - optimizer = OptimizerNames.ADAMW_BNB if training_conf.quantization else None - - if training_conf.quantization: - for module in model.modules(): - if isinstance(module, torch.nn.Embedding): - bitsandbytes.optim.GlobalOptimManager.get_instance().register_module_override( - module, "weight", {"optim_bits": 32} - ) + metrics, preprocess_fn = get_metrics(training_conf) args = TrainingArguments( output_dir=f"{training_conf.model_name}-{training_conf.log_dir}-finetuned", @@ -147,7 +143,6 @@ if __name__ == "__main__": warmup_steps=training_conf.warmup_steps, learning_rate=float(training_conf.learning_rate), deepspeed="configs/zero_config.json" if training_conf.deepspeed else None, - optim=optimizer, fp16=True, local_rank=training_conf.local_rank, gradient_checkpointing=training_conf.gradient_checkpointing, diff --git a/model/supervised_finetuning/utils.py b/model/supervised_finetuning/utils.py index d6abcff2..8dd21773 100644 --- a/model/supervised_finetuning/utils.py +++ b/model/supervised_finetuning/utils.py @@ -1,17 +1,21 @@ +from functools import partial from pathlib import Path +import evaluate +import nltk +import numpy as np +import transformers import yaml from custom_datasets import QA_SPECIAL_TOKENS, get_one_dataset from custom_datasets.dialogue_collator import DialogueDataCollator -from losses import CrossEntropyLoss +from losses import CrossEntropyLoss, PolyLoss from models import freeze_top_n_layers, get_specific_model from sklearn.model_selection import train_test_split from torch.utils.data import ConcatDataset, Subset -from transformers import AutoTokenizer def get_tokenizer(conf): - tokenizer = AutoTokenizer.from_pretrained(conf.model_name, cache_dir=conf.cache_dir) + tokenizer = transformers.AutoTokenizer.from_pretrained(conf.model_name, cache_dir=conf.cache_dir) if "galactica" in conf.model_name: tokenizer.add_special_tokens({"pad_token": "", "eos_token": ""}) @@ -32,8 +36,51 @@ def get_tokenizer(conf): return tokenizer +# placeholder for now +def preprocess_qa(eval_pred): + return (eval_pred.predictions, eval_pred.label_ids) + + +def postprocess_summarization(preds, labels): + preds = [pred.strip() for pred in preds] + labels = [label.strip() for label in labels] + + preds = ["\n".join(nltk.sent_tokenize(pred)) for pred in preds] + labels = ["\n".join(nltk.sent_tokenize(label)) for label in labels] + + return preds, labels + + +def preprocess_summarization(eval_pred, tokenizer, ignore_pad_token_for_loss=True): + preds, labels = eval_pred + decoded_preds = tokenizer.batch_decode(preds, skip_special_tokens=True) + if ignore_pad_token_for_loss: + labels = np.where(labels != -100, labels, tokenizer.pad_token_id) + decoded_labels = tokenizer.batch_decode(labels, skip_special_tokens=True) + + decoded_preds, decoded_labels = postprocess_summarization(decoded_preds, decoded_labels) + return decoded_preds, decoded_labels + + +def get_metrics(conf, tokenizer): + qa_datasets = ["squad_v2", "adversarial_qa", "trivia_qa_context", "trivia_qa_noconext"] + summarization_datasets = ["xsum", "cnn_dailymail", "samsum", "multi_news"] + + # the reason behind using a list is that we might want to extend the list of our + # metrics in the future for more thorough evaluation + if any(dataset in qa_datasets for dataset in conf.datasets): + metrics, preprocess_fn = [evaluate.load("squad_v2")], preprocess_qa + elif any(dataset in summarization_datasets for dataset in conf.datasets): + metrics, preprocess_fn = [evaluate.load("rouge")], partial( + preprocess_summarization, tokenizer, ignore_pad_token_for_loss=conf.ignore_pad_token_for_loss + ) + else: + raise ValueError("Unknown dataset / task") + return metrics, preprocess_fn + + def get_model(conf, tokenizer): - model = get_specific_model(conf.model_name, conf.cache_dir, conf.quantization) + model = get_specific_model(conf.model_name, conf.cache_dir, conf.quantization, conf.seq2seqmodel) if len(tokenizer) != model.get_input_embeddings().num_embeddings: assert not conf.freeze_layer, "Cannot change the number of embeddings if the model is frozen." @@ -65,9 +112,11 @@ def get_dataset(conf, tokenizer): return train, evals, collate_fn -def get_loss(loss): +def get_loss(loss, poly_eps): if loss == "CrossEntropyLoss": return CrossEntropyLoss() + elif loss == "Poly": + return PolyLoss(epsilon=poly_eps) else: raise ValueError(f"Loss {loss} not supported") From 4a3ea0b0331604b03e637eeee2adf2850cb8ae7f Mon Sep 17 00:00:00 2001 From: Sotirios Anagnostidis Date: Wed, 11 Jan 2023 22:42:04 +0100 Subject: [PATCH 2/5] refactoring, now running --- .../supervised_finetuning/configs/config.yaml | 8 +++-- .../custom_datasets/__init__.py | 4 ++- model/supervised_finetuning/requirements.txt | 2 ++ model/supervised_finetuning/trainer.py | 26 ++++++++------ model/supervised_finetuning/utils.py | 36 ++++++++++++------- 5 files changed, 50 insertions(+), 26 deletions(-) diff --git a/model/supervised_finetuning/configs/config.yaml b/model/supervised_finetuning/configs/config.yaml index 616aa828..59912d09 100644 --- a/model/supervised_finetuning/configs/config.yaml +++ b/model/supervised_finetuning/configs/config.yaml @@ -23,8 +23,10 @@ defaults: eval_size: log_dir: "base" quantization: false + seq2seqmodel: false + poly_eps: 1.0 -galactica-125: +galactica-125m: learning_rate: 5e-5 model_name: facebook/galactica-125m weight_decay: 0.01 @@ -58,8 +60,8 @@ codegen: debug: eval_steps: 20 - eval_size: 100 + eval_size: 20 gradient_accumulation_steps: 1 per_device_train_batch_size: 1 per_device_eval_batch_size: 1 - quantization: false + quantization: false \ No newline at end of file diff --git a/model/supervised_finetuning/custom_datasets/__init__.py b/model/supervised_finetuning/custom_datasets/__init__.py index c6fd3ada..0bc62bc8 100644 --- a/model/supervised_finetuning/custom_datasets/__init__.py +++ b/model/supervised_finetuning/custom_datasets/__init__.py @@ -27,6 +27,8 @@ summarization_config_mapping = { "reddit": (), } +QA_DATASETS = ["squad_v2", "adversarial_qa", "trivia_qa_context", "trivia_qa_noconext"] +SUMMARIZATION_DATASETS = ["xsum", "cnn_dailymail", "samsum", "multi_news"] def index_squad_v2(example): return example["title"] + ". " + example["context"] + " " + example["question"], example["answers"]["text"][0] @@ -157,4 +159,4 @@ def get_one_dataset(conf, dataset_name): else: raise ValueError(f"Unknown dataset {dataset_name}") - return train, eval + return train, eval \ No newline at end of file diff --git a/model/supervised_finetuning/requirements.txt b/model/supervised_finetuning/requirements.txt index c47a1218..d79adf92 100644 --- a/model/supervised_finetuning/requirements.txt +++ b/model/supervised_finetuning/requirements.txt @@ -8,3 +8,5 @@ PyYAML==6.0 scikit_learn==1.2.0 torch==1.13.1 transformers==4.25.1 +evaluate==0.4.0 +nltk==3.8.1 \ No newline at end of file diff --git a/model/supervised_finetuning/trainer.py b/model/supervised_finetuning/trainer.py index 54a45408..88a725b3 100644 --- a/model/supervised_finetuning/trainer.py +++ b/model/supervised_finetuning/trainer.py @@ -7,15 +7,15 @@ import torch from torch import nn from transformers import PreTrainedModel, Trainer, TrainingArguments from utils import get_dataset, get_loss, get_metrics, get_model, get_tokenizer, read_yamls +from functools import partial os.environ["WANDB_PROJECT"] = "supervised-finetuning" -def compute_metrics(eval_pred, preprocess_fn, metrics): - preds, labels = preprocess_fn(eval_pred) - +def compute_metrics(eval_pred, preprocess_fns, metrics): out = {} - for metric in metrics: + for metric, preprocess_fn in zip(metrics, preprocess_fns): + preds, labels = preprocess_fn(eval_pred) out = dict(**out, **metric.compute(predictions=preds, references=labels)) return out @@ -44,7 +44,10 @@ class SFTTrainer(Trainer): labels_mask = inputs.pop("label_masks") targets = inputs.pop("targets") - outputs = model(input_ids=inputs["input_ids"], attention_mask=inputs.get("attention_mask", None)) + outputs = model( + input_ids=inputs["input_ids"], + attention_mask=inputs.get("attention_mask", None), + ) loss = self.loss_fct(outputs.get("logits"), targets, mask=labels_mask) @@ -56,7 +59,10 @@ class SFTTrainer(Trainer): labels_mask = inputs.pop("label_masks") targets = inputs.pop("targets") - outputs = model(input_ids=inputs["input_ids"], attention_mask=inputs.get("attention_mask", None)) + outputs = model( + input_ids=inputs["input_ids"], + attention_mask=inputs.get("attention_mask", None), + ) logits = outputs.get("logits") @@ -94,8 +100,6 @@ def argument_parsing(notebook=False, notebook_args=None): parser.add_argument("--local_rank", type=int, default=-1) parser.add_argument("--deepspeed", action="store_true") parser.add_argument("--no-deepspeed", dest="deepspeed", action="store_false") - parser.add_argument("--poly_eps", type=float, default=1.0) - parser.add_argument("--seq2seq_model", action="store_true") parser.set_defaults(deepspeed=False) if notebook: @@ -135,7 +139,7 @@ if __name__ == "__main__": model = get_model(training_conf, tokenizer) train, evals, collate_fn = get_dataset(training_conf, tokenizer) - metrics, preprocess_fn = get_metrics(training_conf) + metrics, preprocess_fns = get_metrics(training_conf, tokenizer) args = TrainingArguments( output_dir=f"{training_conf.model_name}-{training_conf.log_dir}-finetuned", @@ -161,15 +165,17 @@ if __name__ == "__main__": ) assert len(evals) > 0 + trainer = SFTTrainer( model, args, loss_function=training_conf.loss_fn, + poly_eps=training_conf.poly_eps, train_dataset=train, eval_dataset=evals, data_collator=collate_fn, tokenizer=tokenizer, - compute_metrics=compute_metrics, + compute_metrics=partial(compute_metrics, metrics=metrics, preprocess_fns=preprocess_fns), preprocess_logits_for_metrics=preprocess_logits_for_metrics, ) trainer.train() diff --git a/model/supervised_finetuning/utils.py b/model/supervised_finetuning/utils.py index 8dd21773..368cd188 100644 --- a/model/supervised_finetuning/utils.py +++ b/model/supervised_finetuning/utils.py @@ -6,7 +6,7 @@ import nltk import numpy as np import transformers import yaml -from custom_datasets import QA_SPECIAL_TOKENS, get_one_dataset +from custom_datasets import QA_SPECIAL_TOKENS, QA_DATASETS, SUMMARIZATION_DATASETS, get_one_dataset from custom_datasets.dialogue_collator import DialogueDataCollator from losses import CrossEntropyLoss, PolyLoss from models import freeze_top_n_layers, get_specific_model @@ -36,6 +36,16 @@ def get_tokenizer(conf): return tokenizer +def default_preprocess(eval_pred, ignote_negative_labels=True): + preds, labels = eval_pred.predictions, eval_pred.label_ids + + if not ignote_negative_labels: + return preds, labels + + mask = labels > 0 + return preds[mask], labels[mask] + + # placeholder for now def preprocess_qa(eval_pred): return (eval_pred.predictions, eval_pred.label_ids) @@ -63,20 +73,22 @@ def preprocess_summarization(eval_pred, tokenizer, ignore_pad_token_for_loss=Tru def get_metrics(conf, tokenizer): - qa_datasets = ["squad_v2", "adversarial_qa", "trivia_qa_context", "trivia_qa_noconext"] - summarization_datasets = ["xsum", "cnn_dailymail", "samsum", "multi_news"] - # the reason behind using a list is that we might want to extend the list of our # metrics in the future for more thorough evaluation - if any(dataset in qa_datasets for dataset in conf.datasets): - metrics, preprocess_fn = [evaluate.load("squad_v2")], preprocess_qa - elif any(dataset in summarization_datasets for dataset in conf.datasets): - metrics, preprocess_fn = [evaluate.load("rouge")], partial( - preprocess_summarization, tokenizer, ignore_pad_token_for_loss=conf.ignore_pad_token_for_loss + metrics, preprocess_fns = [evaluate.load("accuracy")], [default_preprocess] + + if any(dataset in QA_DATASETS for dataset in conf.datasets): + raise ValueError("TODO") + metrics.append(evaluate.load("squad_v2")) + preprocess_fns.append(preprocess_qa) + if any(dataset in SUMMARIZATION_DATASETS for dataset in conf.datasets): + raise ValueError("TODO") + metrics.append(evaluate.load("rouge")) + preprocess_fns.append( + partial(preprocess_summarization, tokenizer, ignore_pad_token_for_loss=conf.ignore_pad_token_for_loss) ) - else: - raise ValueError("Unknown dataset / task") - return metrics, preprocess_fn + + return metrics, preprocess_fns def get_model(conf, tokenizer): From 6438fdbe2c7b387173b2c604177c5faee3894003 Mon Sep 17 00:00:00 2001 From: Sotirios Anagnostidis Date: Wed, 11 Jan 2023 22:44:20 +0100 Subject: [PATCH 3/5] quantization from #582 --- model/supervised_finetuning/trainer.py | 12 ++++++++++++ 1 file changed, 12 insertions(+) diff --git a/model/supervised_finetuning/trainer.py b/model/supervised_finetuning/trainer.py index 88a725b3..9a6bf148 100644 --- a/model/supervised_finetuning/trainer.py +++ b/model/supervised_finetuning/trainer.py @@ -3,9 +3,11 @@ import os from distutils.util import strtobool from typing import Any, Dict, List, Optional, Tuple, Union +import bitsandbytes import torch from torch import nn from transformers import PreTrainedModel, Trainer, TrainingArguments +from transformers.training_args import OptimizerNames from utils import get_dataset, get_loss, get_metrics, get_model, get_tokenizer, read_yamls from functools import partial @@ -141,12 +143,22 @@ if __name__ == "__main__": train, evals, collate_fn = get_dataset(training_conf, tokenizer) metrics, preprocess_fns = get_metrics(training_conf, tokenizer) + optimizer = OptimizerNames.ADAMW_BNB if training_conf.quantization else OptimizerNames.ADAMW_HF + + if training_conf.quantization: + for module in model.modules(): + if isinstance(module, torch.nn.Embedding): + bitsandbytes.optim.GlobalOptimManager.get_instance().register_module_override( + module, "weight", {"optim_bits": 32} + ) + args = TrainingArguments( output_dir=f"{training_conf.model_name}-{training_conf.log_dir}-finetuned", num_train_epochs=training_conf.num_train_epochs, warmup_steps=training_conf.warmup_steps, learning_rate=float(training_conf.learning_rate), deepspeed="configs/zero_config.json" if training_conf.deepspeed else None, + optim=optimizer, fp16=True, local_rank=training_conf.local_rank, gradient_checkpointing=training_conf.gradient_checkpointing, From d46ff8c4ee5d93e5df46d99b2914c94456a27254 Mon Sep 17 00:00:00 2001 From: Sotirios Anagnostidis Date: Wed, 11 Jan 2023 22:48:02 +0100 Subject: [PATCH 4/5] better logging with deepspeed --- model/supervised_finetuning/trainer.py | 14 ++++++++++++-- 1 file changed, 12 insertions(+), 2 deletions(-) diff --git a/model/supervised_finetuning/trainer.py b/model/supervised_finetuning/trainer.py index 9a6bf148..517ba830 100644 --- a/model/supervised_finetuning/trainer.py +++ b/model/supervised_finetuning/trainer.py @@ -11,8 +11,6 @@ from transformers.training_args import OptimizerNames from utils import get_dataset, get_loss, get_metrics, get_model, get_tokenizer, read_yamls from functools import partial -os.environ["WANDB_PROJECT"] = "supervised-finetuning" - def compute_metrics(eval_pred, preprocess_fns, metrics): out = {} @@ -102,6 +100,7 @@ def argument_parsing(notebook=False, notebook_args=None): parser.add_argument("--local_rank", type=int, default=-1) parser.add_argument("--deepspeed", action="store_true") parser.add_argument("--no-deepspeed", dest="deepspeed", action="store_false") + parser.add_argument("--wandb-entity", type=str, default="open-assistant") parser.set_defaults(deepspeed=False) if notebook: @@ -121,8 +120,10 @@ def argument_parsing(notebook=False, notebook_args=None): else: conf.update(configs[name]) + conf["wandb_entity"] = args.wandb_entity conf["local_rank"] = args.local_rank conf["deepspeed"] = args.deepspeed + # Override config from command-line parser = argparse.ArgumentParser() for key, value in conf.items(): @@ -178,6 +179,15 @@ if __name__ == "__main__": assert len(evals) > 0 + if not training_conf.deepspeed or training_conf.local_rank == 0: + import wandb + + wandb.init( + project="supervised-finetuning", + entity=training_conf.wandb_entity, + name=f"{training_conf.model_name}-{training_conf.log_dir}-finetuned", + ) + trainer = SFTTrainer( model, args, From c8f47eef9fb2f9c2edd629714d79b06fd61e4e95 Mon Sep 17 00:00:00 2001 From: Sotirios Anagnostidis Date: Wed, 11 Jan 2023 22:58:17 +0100 Subject: [PATCH 5/5] precommits --- model/supervised_finetuning/configs/config.yaml | 2 +- model/supervised_finetuning/custom_datasets/__init__.py | 3 ++- model/supervised_finetuning/requirements.txt | 4 ++-- model/supervised_finetuning/trainer.py | 3 +-- model/supervised_finetuning/utils.py | 2 +- 5 files changed, 7 insertions(+), 7 deletions(-) diff --git a/model/supervised_finetuning/configs/config.yaml b/model/supervised_finetuning/configs/config.yaml index 59912d09..bd35f168 100644 --- a/model/supervised_finetuning/configs/config.yaml +++ b/model/supervised_finetuning/configs/config.yaml @@ -64,4 +64,4 @@ debug: gradient_accumulation_steps: 1 per_device_train_batch_size: 1 per_device_eval_batch_size: 1 - quantization: false \ No newline at end of file + quantization: false diff --git a/model/supervised_finetuning/custom_datasets/__init__.py b/model/supervised_finetuning/custom_datasets/__init__.py index 0bc62bc8..c0cd424b 100644 --- a/model/supervised_finetuning/custom_datasets/__init__.py +++ b/model/supervised_finetuning/custom_datasets/__init__.py @@ -30,6 +30,7 @@ summarization_config_mapping = { QA_DATASETS = ["squad_v2", "adversarial_qa", "trivia_qa_context", "trivia_qa_noconext"] SUMMARIZATION_DATASETS = ["xsum", "cnn_dailymail", "samsum", "multi_news"] + def index_squad_v2(example): return example["title"] + ". " + example["context"] + " " + example["question"], example["answers"]["text"][0] @@ -159,4 +160,4 @@ def get_one_dataset(conf, dataset_name): else: raise ValueError(f"Unknown dataset {dataset_name}") - return train, eval \ No newline at end of file + return train, eval diff --git a/model/supervised_finetuning/requirements.txt b/model/supervised_finetuning/requirements.txt index d79adf92..0e6eeb51 100644 --- a/model/supervised_finetuning/requirements.txt +++ b/model/supervised_finetuning/requirements.txt @@ -2,11 +2,11 @@ accelerate==0.15.0 bitsandbytes==0.36.0.post2 datasets==2.8.0 deepspeed==0.7.7 +evaluate==0.4.0 mpi4py==3.1.4 +nltk==3.8.1 numpy==1.23.0 PyYAML==6.0 scikit_learn==1.2.0 torch==1.13.1 transformers==4.25.1 -evaluate==0.4.0 -nltk==3.8.1 \ No newline at end of file diff --git a/model/supervised_finetuning/trainer.py b/model/supervised_finetuning/trainer.py index 517ba830..0acb10dd 100644 --- a/model/supervised_finetuning/trainer.py +++ b/model/supervised_finetuning/trainer.py @@ -1,6 +1,6 @@ import argparse -import os from distutils.util import strtobool +from functools import partial from typing import Any, Dict, List, Optional, Tuple, Union import bitsandbytes @@ -9,7 +9,6 @@ from torch import nn from transformers import PreTrainedModel, Trainer, TrainingArguments from transformers.training_args import OptimizerNames from utils import get_dataset, get_loss, get_metrics, get_model, get_tokenizer, read_yamls -from functools import partial def compute_metrics(eval_pred, preprocess_fns, metrics): diff --git a/model/supervised_finetuning/utils.py b/model/supervised_finetuning/utils.py index 368cd188..f598dde1 100644 --- a/model/supervised_finetuning/utils.py +++ b/model/supervised_finetuning/utils.py @@ -6,7 +6,7 @@ import nltk import numpy as np import transformers import yaml -from custom_datasets import QA_SPECIAL_TOKENS, QA_DATASETS, SUMMARIZATION_DATASETS, get_one_dataset +from custom_datasets import QA_DATASETS, QA_SPECIAL_TOKENS, SUMMARIZATION_DATASETS, get_one_dataset from custom_datasets.dialogue_collator import DialogueDataCollator from losses import CrossEntropyLoss, PolyLoss from models import freeze_top_n_layers, get_specific_model