From 08bdadf222749e59a5ed4386d5ead74e45c19bf9 Mon Sep 17 00:00:00 2001 From: mrcabbage972 Date: Mon, 9 Jan 2023 22:07:06 -0500 Subject: [PATCH] Adding BNB 8-bit Adam --- model/supervised_finetuning/configs/config.yaml | 4 ++-- model/supervised_finetuning/requirements.txt | 1 + model/supervised_finetuning/trainer.py | 5 +++++ 3 files changed, 8 insertions(+), 2 deletions(-) diff --git a/model/supervised_finetuning/configs/config.yaml b/model/supervised_finetuning/configs/config.yaml index 97e37121..616aa828 100644 --- a/model/supervised_finetuning/configs/config.yaml +++ b/model/supervised_finetuning/configs/config.yaml @@ -22,7 +22,7 @@ defaults: loss_fn: CrossEntropyLoss eval_size: log_dir: "base" - quantization: + quantization: false galactica-125: learning_rate: 5e-5 @@ -62,4 +62,4 @@ debug: gradient_accumulation_steps: 1 per_device_train_batch_size: 1 per_device_eval_batch_size: 1 - quantization: + quantization: false diff --git a/model/supervised_finetuning/requirements.txt b/model/supervised_finetuning/requirements.txt index 7d78f36c..6338614d 100644 --- a/model/supervised_finetuning/requirements.txt +++ b/model/supervised_finetuning/requirements.txt @@ -7,3 +7,4 @@ PyYAML==6.0 scikit_learn==1.2.0 torch==1.13.1 transformers==4.25.1 +bitsandbytes==0.36.0.post2 \ No newline at end of file diff --git a/model/supervised_finetuning/trainer.py b/model/supervised_finetuning/trainer.py index cb55131d..ae7fb3c3 100644 --- a/model/supervised_finetuning/trainer.py +++ b/model/supervised_finetuning/trainer.py @@ -6,6 +6,8 @@ from typing import Any, Dict, List, Optional, Tuple, Union import torch from torch import nn from transformers import PreTrainedModel, Trainer, TrainingArguments +from transformers.training_args import OptimizerNames + from utils import get_dataset, get_loss, get_model, get_tokenizer, read_yamls os.environ["WANDB_PROJECT"] = "supervised-finetuning" @@ -130,12 +132,15 @@ if __name__ == "__main__": train, evals, collate_fn = get_dataset(training_conf, tokenizer) + optimizer = OptimizerNames.ADAMW_BNB if training_conf.quantization else None + args = TrainingArguments( output_dir=f"{training_conf.model_name}-{training_conf.log_dir}-finetuned", num_train_epochs=training_conf.num_train_epochs, warmup_steps=training_conf.warmup_steps, learning_rate=float(training_conf.learning_rate), deepspeed="configs/zero_config.json" if training_conf.deepspeed else None, + optim=optimizer, fp16=True, local_rank=training_conf.local_rank, gradient_checkpointing=training_conf.gradient_checkpointing,