Files
alignment-handbook/src/alignment/model_utils.py
T
Lewis Tunstall 967eab4cfb Add skeleton
2023-11-08 13:21:57 +00:00

80 lines
2.7 KiB
Python

from typing import Dict, Union
import torch
from transformers import AutoTokenizer, BitsAndBytesConfig, PreTrainedTokenizer
from accelerate import Accelerator
from peft import LoraConfig, PeftConfig
from .configs import DataArguments, ModelArguments
from .data import DEFAULT_CHAT_TEMPLATE
def get_current_device() -> int:
"""Get the current device. For GPU we return the local process index to enable multiple GPU training."""
return Accelerator().local_process_index if torch.cuda.is_available() else "cpu"
def get_kbit_device_map() -> Dict[str, int] | None:
"""Useful for running inference with quantized models by setting `device_map=get_peft_device_map()`"""
return {"": get_current_device()} if torch.cuda.is_available() else None
def get_quantization_config(model_args) -> BitsAndBytesConfig | None:
if model_args.load_in_4bit:
quantization_config = BitsAndBytesConfig(
load_in_4bit=True,
bnb_4bit_compute_dtype=torch.float16, # For consistency with model weights, we use the same value as `torch_dtype` which is float16 for PEFT models
bnb_4bit_quant_type=model_args.bnb_4bit_quant_type,
bnb_4bit_use_double_quant=model_args.use_bnb_nested_quant,
)
elif model_args.load_in_8bit:
quantization_config = BitsAndBytesConfig(
load_in_8bit=True,
)
else:
quantization_config = None
return quantization_config
def get_tokenizer(model_args: ModelArguments, data_args: DataArguments) -> PreTrainedTokenizer:
"""Get the tokenizer for the model."""
tokenizer = AutoTokenizer.from_pretrained(
model_args.model_name_or_path,
revision=model_args.model_revision,
)
if tokenizer.pad_token_id is None:
tokenizer.pad_token_id = tokenizer.eos_token_id
if data_args.truncation_side is not None:
tokenizer.truncation_side = data_args.truncation_side
# Set reasonable default for models without max length
if tokenizer.model_max_length > 100_000:
tokenizer.model_max_length = 2048
if data_args.chat_template is not None:
tokenizer.chat_template = data_args.chat_template
elif tokenizer.chat_template is None:
tokenizer.chat_template = DEFAULT_CHAT_TEMPLATE
return tokenizer
def get_peft_config(model_args: ModelArguments) -> Union[PeftConfig, None]:
if model_args.use_peft is False:
return None
peft_config = LoraConfig(
r=model_args.lora_r,
lora_alpha=model_args.lora_alpha,
lora_dropout=model_args.lora_dropout,
bias="none",
task_type="CAUSAL_LM",
target_modules=model_args.lora_target_modules,
modules_to_save=model_args.lora_modules_to_save,
)
return peft_config