Adding BNB 8-bit Adam

This commit is contained in:
mrcabbage972
2023-01-09 22:07:06 -05:00
parent cc4c008933
commit 08bdadf222
3 changed files with 8 additions and 2 deletions
@@ -22,7 +22,7 @@ defaults:
loss_fn: CrossEntropyLoss
eval_size:
log_dir: "base"
quantization:
quantization: false
galactica-125:
learning_rate: 5e-5
@@ -62,4 +62,4 @@ debug:
gradient_accumulation_steps: 1
per_device_train_batch_size: 1
per_device_eval_batch_size: 1
quantization:
quantization: false
@@ -7,3 +7,4 @@ PyYAML==6.0
scikit_learn==1.2.0
torch==1.13.1
transformers==4.25.1
bitsandbytes==0.36.0.post2
+5
View File
@@ -6,6 +6,8 @@ from typing import Any, Dict, List, Optional, Tuple, Union
import torch
from torch import nn
from transformers import PreTrainedModel, Trainer, TrainingArguments
from transformers.training_args import OptimizerNames
from utils import get_dataset, get_loss, get_model, get_tokenizer, read_yamls
os.environ["WANDB_PROJECT"] = "supervised-finetuning"
@@ -130,12 +132,15 @@ if __name__ == "__main__":
train, evals, collate_fn = get_dataset(training_conf, tokenizer)
optimizer = OptimizerNames.ADAMW_BNB if training_conf.quantization else None
args = TrainingArguments(
output_dir=f"{training_conf.model_name}-{training_conf.log_dir}-finetuned",
num_train_epochs=training_conf.num_train_epochs,
warmup_steps=training_conf.warmup_steps,
learning_rate=float(training_conf.learning_rate),
deepspeed="configs/zero_config.json" if training_conf.deepspeed else None,
optim=optimizer,
fp16=True,
local_rank=training_conf.local_rank,
gradient_checkpointing=training_conf.gradient_checkpointing,