mirror of
https://github.com/wassname/Open-Assistant.git
synced 2026-06-27 16:10:30 +08:00
Adding BNB 8-bit Adam
This commit is contained in:
@@ -22,7 +22,7 @@ defaults:
|
||||
loss_fn: CrossEntropyLoss
|
||||
eval_size:
|
||||
log_dir: "base"
|
||||
quantization:
|
||||
quantization: false
|
||||
|
||||
galactica-125:
|
||||
learning_rate: 5e-5
|
||||
@@ -62,4 +62,4 @@ debug:
|
||||
gradient_accumulation_steps: 1
|
||||
per_device_train_batch_size: 1
|
||||
per_device_eval_batch_size: 1
|
||||
quantization:
|
||||
quantization: false
|
||||
|
||||
@@ -7,3 +7,4 @@ PyYAML==6.0
|
||||
scikit_learn==1.2.0
|
||||
torch==1.13.1
|
||||
transformers==4.25.1
|
||||
bitsandbytes==0.36.0.post2
|
||||
@@ -6,6 +6,8 @@ from typing import Any, Dict, List, Optional, Tuple, Union
|
||||
import torch
|
||||
from torch import nn
|
||||
from transformers import PreTrainedModel, Trainer, TrainingArguments
|
||||
from transformers.training_args import OptimizerNames
|
||||
|
||||
from utils import get_dataset, get_loss, get_model, get_tokenizer, read_yamls
|
||||
|
||||
os.environ["WANDB_PROJECT"] = "supervised-finetuning"
|
||||
@@ -130,12 +132,15 @@ if __name__ == "__main__":
|
||||
|
||||
train, evals, collate_fn = get_dataset(training_conf, tokenizer)
|
||||
|
||||
optimizer = OptimizerNames.ADAMW_BNB if training_conf.quantization else None
|
||||
|
||||
args = TrainingArguments(
|
||||
output_dir=f"{training_conf.model_name}-{training_conf.log_dir}-finetuned",
|
||||
num_train_epochs=training_conf.num_train_epochs,
|
||||
warmup_steps=training_conf.warmup_steps,
|
||||
learning_rate=float(training_conf.learning_rate),
|
||||
deepspeed="configs/zero_config.json" if training_conf.deepspeed else None,
|
||||
optim=optimizer,
|
||||
fp16=True,
|
||||
local_rank=training_conf.local_rank,
|
||||
gradient_checkpointing=training_conf.gradient_checkpointing,
|
||||
|
||||
Reference in New Issue
Block a user