Merge pull request #582 from mrcabbage972/main

Adding support for 8-bit training with bitsandbytes
This commit is contained in:
sanagnos
2023-01-11 10:14:00 +02:00
committed by GitHub
3 changed files with 15 additions and 2 deletions
@@ -22,7 +22,7 @@ defaults:
loss_fn: CrossEntropyLoss
eval_size:
log_dir: "base"
quantization:
quantization: false
galactica-125:
learning_rate: 5e-5
@@ -62,4 +62,4 @@ debug:
gradient_accumulation_steps: 1
per_device_train_batch_size: 1
per_device_eval_batch_size: 1
quantization:
quantization: false
@@ -1,4 +1,5 @@
accelerate==0.15.0
bitsandbytes==0.36.0.post2
datasets==2.8.0
deepspeed==0.7.7
mpi4py==3.1.4
+12
View File
@@ -3,9 +3,11 @@ import os
from distutils.util import strtobool
from typing import Any, Dict, List, Optional, Tuple, Union
import bitsandbytes
import torch
from torch import nn
from transformers import PreTrainedModel, Trainer, TrainingArguments
from transformers.training_args import OptimizerNames
from utils import get_dataset, get_loss, get_model, get_tokenizer, read_yamls
os.environ["WANDB_PROJECT"] = "supervised-finetuning"
@@ -130,12 +132,22 @@ if __name__ == "__main__":
train, evals, collate_fn = get_dataset(training_conf, tokenizer)
optimizer = OptimizerNames.ADAMW_BNB if training_conf.quantization else None
if training_conf.quantization:
for module in model.modules():
if isinstance(module, torch.nn.Embedding):
bitsandbytes.optim.GlobalOptimManager.get_instance().register_module_override(
module, "weight", {"optim_bits": 32}
)
args = TrainingArguments(
output_dir=f"{training_conf.model_name}-{training_conf.log_dir}-finetuned",
num_train_epochs=training_conf.num_train_epochs,
warmup_steps=training_conf.warmup_steps,
learning_rate=float(training_conf.learning_rate),
deepspeed="configs/zero_config.json" if training_conf.deepspeed else None,
optim=optimizer,
fp16=True,
local_rank=training_conf.local_rank,
gradient_checkpointing=training_conf.gradient_checkpointing,