diff --git a/model/supervised_finetuning/configs/config.yaml b/model/supervised_finetuning/configs/config.yaml index 97e37121..616aa828 100644 --- a/model/supervised_finetuning/configs/config.yaml +++ b/model/supervised_finetuning/configs/config.yaml @@ -22,7 +22,7 @@ defaults: loss_fn: CrossEntropyLoss eval_size: log_dir: "base" - quantization: + quantization: false galactica-125: learning_rate: 5e-5 @@ -62,4 +62,4 @@ debug: gradient_accumulation_steps: 1 per_device_train_batch_size: 1 per_device_eval_batch_size: 1 - quantization: + quantization: false diff --git a/model/supervised_finetuning/requirements.txt b/model/supervised_finetuning/requirements.txt index 7d78f36c..c47a1218 100644 --- a/model/supervised_finetuning/requirements.txt +++ b/model/supervised_finetuning/requirements.txt @@ -1,4 +1,5 @@ accelerate==0.15.0 +bitsandbytes==0.36.0.post2 datasets==2.8.0 deepspeed==0.7.7 mpi4py==3.1.4 diff --git a/model/supervised_finetuning/trainer.py b/model/supervised_finetuning/trainer.py index cb55131d..450854f1 100644 --- a/model/supervised_finetuning/trainer.py +++ b/model/supervised_finetuning/trainer.py @@ -3,9 +3,11 @@ import os from distutils.util import strtobool from typing import Any, Dict, List, Optional, Tuple, Union +import bitsandbytes import torch from torch import nn from transformers import PreTrainedModel, Trainer, TrainingArguments +from transformers.training_args import OptimizerNames from utils import get_dataset, get_loss, get_model, get_tokenizer, read_yamls os.environ["WANDB_PROJECT"] = "supervised-finetuning" @@ -130,12 +132,22 @@ if __name__ == "__main__": train, evals, collate_fn = get_dataset(training_conf, tokenizer) + optimizer = OptimizerNames.ADAMW_BNB if training_conf.quantization else None + + if training_conf.quantization: + for module in model.modules(): + if isinstance(module, torch.nn.Embedding): + bitsandbytes.optim.GlobalOptimManager.get_instance().register_module_override( + module, "weight", {"optim_bits": 32} + ) + args = TrainingArguments( output_dir=f"{training_conf.model_name}-{training_conf.log_dir}-finetuned", num_train_epochs=training_conf.num_train_epochs, warmup_steps=training_conf.warmup_steps, learning_rate=float(training_conf.learning_rate), deepspeed="configs/zero_config.json" if training_conf.deepspeed else None, + optim=optimizer, fp16=True, local_rank=training_conf.local_rank, gradient_checkpointing=training_conf.gradient_checkpointing,