mirror of
https://github.com/wassname/Open-Assistant.git
synced 2026-06-28 16:20:34 +08:00
Merge pull request #582 from mrcabbage972/main
Adding support for 8-bit training with bitsandbytes
This commit is contained in:
@@ -22,7 +22,7 @@ defaults:
|
||||
loss_fn: CrossEntropyLoss
|
||||
eval_size:
|
||||
log_dir: "base"
|
||||
quantization:
|
||||
quantization: false
|
||||
|
||||
galactica-125:
|
||||
learning_rate: 5e-5
|
||||
@@ -62,4 +62,4 @@ debug:
|
||||
gradient_accumulation_steps: 1
|
||||
per_device_train_batch_size: 1
|
||||
per_device_eval_batch_size: 1
|
||||
quantization:
|
||||
quantization: false
|
||||
|
||||
@@ -1,4 +1,5 @@
|
||||
accelerate==0.15.0
|
||||
bitsandbytes==0.36.0.post2
|
||||
datasets==2.8.0
|
||||
deepspeed==0.7.7
|
||||
mpi4py==3.1.4
|
||||
|
||||
@@ -3,9 +3,11 @@ import os
|
||||
from distutils.util import strtobool
|
||||
from typing import Any, Dict, List, Optional, Tuple, Union
|
||||
|
||||
import bitsandbytes
|
||||
import torch
|
||||
from torch import nn
|
||||
from transformers import PreTrainedModel, Trainer, TrainingArguments
|
||||
from transformers.training_args import OptimizerNames
|
||||
from utils import get_dataset, get_loss, get_model, get_tokenizer, read_yamls
|
||||
|
||||
os.environ["WANDB_PROJECT"] = "supervised-finetuning"
|
||||
@@ -130,12 +132,22 @@ if __name__ == "__main__":
|
||||
|
||||
train, evals, collate_fn = get_dataset(training_conf, tokenizer)
|
||||
|
||||
optimizer = OptimizerNames.ADAMW_BNB if training_conf.quantization else None
|
||||
|
||||
if training_conf.quantization:
|
||||
for module in model.modules():
|
||||
if isinstance(module, torch.nn.Embedding):
|
||||
bitsandbytes.optim.GlobalOptimManager.get_instance().register_module_override(
|
||||
module, "weight", {"optim_bits": 32}
|
||||
)
|
||||
|
||||
args = TrainingArguments(
|
||||
output_dir=f"{training_conf.model_name}-{training_conf.log_dir}-finetuned",
|
||||
num_train_epochs=training_conf.num_train_epochs,
|
||||
warmup_steps=training_conf.warmup_steps,
|
||||
learning_rate=float(training_conf.learning_rate),
|
||||
deepspeed="configs/zero_config.json" if training_conf.deepspeed else None,
|
||||
optim=optimizer,
|
||||
fp16=True,
|
||||
local_rank=training_conf.local_rank,
|
||||
gradient_checkpointing=training_conf.gradient_checkpointing,
|
||||
|
||||
Reference in New Issue
Block a user