From dfaa00dccc1038e1ad8aa00d5868d80d1598fa4a Mon Sep 17 00:00:00 2001 From: Sotirios Anagnostidis Date: Thu, 5 Jan 2023 00:33:16 +0100 Subject: [PATCH] gptj 8bit --- .../supervised_finetuning/configs/config.yaml | 5 + .../supervised_finetuning/models/__init__.py | 8 + model/supervised_finetuning/models/gptj.py | 185 ++++++++++++++++++ model/supervised_finetuning/trainer.py | 17 +- model/supervised_finetuning/utils.py | 5 +- 5 files changed, 215 insertions(+), 5 deletions(-) create mode 100644 model/supervised_finetuning/models/__init__.py create mode 100644 model/supervised_finetuning/models/gptj.py diff --git a/model/supervised_finetuning/configs/config.yaml b/model/supervised_finetuning/configs/config.yaml index 29086395..ded4363e 100644 --- a/model/supervised_finetuning/configs/config.yaml +++ b/model/supervised_finetuning/configs/config.yaml @@ -21,6 +21,7 @@ defaults: loss_fn: CrossEntropyLoss eval_size: log_dir: "base" + quantization: galactica-125: learning_rate: 5e-5 @@ -46,3 +47,7 @@ gpt-jt: debug: eval_steps: 20 eval_size: 100 + gradient_accumulation_steps: 2 + per_device_train_batch_size: 1 + per_device_eval_batch_size: 1 + quantization: 8bit \ No newline at end of file diff --git a/model/supervised_finetuning/models/__init__.py b/model/supervised_finetuning/models/__init__.py new file mode 100644 index 00000000..673f4085 --- /dev/null +++ b/model/supervised_finetuning/models/__init__.py @@ -0,0 +1,8 @@ +from transformers import AutoModelForCausalLM +from .gptj import get_model as get_gptj_model + +def get_specific_model(model_name, cache_dir, quantization): + if "gpt-j" in model_name.lower(): + return get_gptj_model(model_name, cache_dir, quantization) + else: + return AutoModelForCausalLM.from_pretrained(conf.model_name, cache_dir=conf.cache_dir) \ No newline at end of file diff --git a/model/supervised_finetuning/models/gptj.py b/model/supervised_finetuning/models/gptj.py new file mode 100644 index 00000000..11234abd --- /dev/null +++ b/model/supervised_finetuning/models/gptj.py @@ -0,0 +1,185 @@ +# Taken from https://github.com/sleekmike/Finetune_GPT-J_6B_8-bit/blob/master/gpt-j-6b-8-bit.py + +import transformers +from transformers import AutoModelForCausalLM +import torch +import torch.nn.functional as F +from torch import nn +from torch.cuda.amp import custom_fwd, custom_bwd +from bitsandbytes.functional import quantize_blockwise, dequantize_blockwise + + +class FrozenBNBLinear(nn.Module): + def __init__(self, weight, absmax, code, bias=None): + assert isinstance(bias, nn.Parameter) or bias is None + super().__init__() + self.out_features, self.in_features = weight.shape + self.register_buffer("weight", weight.requires_grad_(False)) + self.register_buffer("absmax", absmax.requires_grad_(False)) + self.register_buffer("code", code.requires_grad_(False)) + self.adapter = None + self.bias = bias + + def forward(self, input): + output = DequantizeAndLinear.apply(input, self.weight, self.absmax, self.code, self.bias) + if self.adapter: + output += self.adapter(input) + return output + + @classmethod + def from_linear(cls, linear: nn.Linear) -> "FrozenBNBLinear": + weights_int8, state = quantize_blockise_lowmemory(linear.weight) + return cls(weights_int8, *state, linear.bias) + + def __repr__(self): + return f"{self.__class__.__name__}({self.in_features}, {self.out_features})" + + +class DequantizeAndLinear(torch.autograd.Function): + @staticmethod + @custom_fwd + def forward( + ctx, + input: torch.Tensor, + weights_quantized: torch.ByteTensor, + absmax: torch.FloatTensor, + code: torch.FloatTensor, + bias: torch.FloatTensor, + ): + weights_deq = dequantize_blockwise(weights_quantized, absmax=absmax, code=code) + ctx.save_for_backward(input, weights_quantized, absmax, code) + ctx._has_bias = bias is not None + return F.linear(input, weights_deq, bias) + + @staticmethod + @custom_bwd + def backward(ctx, grad_output: torch.Tensor): + assert not ctx.needs_input_grad[1] and not ctx.needs_input_grad[2] and not ctx.needs_input_grad[3] + input, weights_quantized, absmax, code = ctx.saved_tensors + # grad_output: [*batch, out_features] + weights_deq = dequantize_blockwise(weights_quantized, absmax=absmax, code=code) + grad_input = grad_output @ weights_deq + grad_bias = grad_output.flatten(0, -2).sum(dim=0) if ctx._has_bias else None + return grad_input, None, None, None, grad_bias + + +class FrozenBNBEmbedding(nn.Module): + def __init__(self, weight, absmax, code): + super().__init__() + self.num_embeddings, self.embedding_dim = weight.shape + self.register_buffer("weight", weight.requires_grad_(False)) + self.register_buffer("absmax", absmax.requires_grad_(False)) + self.register_buffer("code", code.requires_grad_(False)) + self.adapter = None + + def forward(self, input, **kwargs): + with torch.no_grad(): + # note: both quantuized weights and input indices are *not* differentiable + weight_deq = dequantize_blockwise(self.weight, absmax=self.absmax, code=self.code) + output = F.embedding(input, weight_deq, **kwargs) + if self.adapter: + output += self.adapter(input) + return output + + @classmethod + def from_embedding(cls, embedding: nn.Embedding) -> "FrozenBNBEmbedding": + weights_int8, state = quantize_blockise_lowmemory(embedding.weight) + return cls(weights_int8, *state) + + def __repr__(self): + return f"{self.__class__.__name__}({self.num_embeddings}, {self.embedding_dim})" + + +def quantize_blockise_lowmemory(matrix: torch.Tensor, chunk_size: int = 2**20): + assert chunk_size % 4096 == 0 + code = None + chunks = [] + absmaxes = [] + flat_tensor = matrix.view(-1) + for i in range((matrix.numel() - 1) // chunk_size + 1): + input_chunk = flat_tensor[i * chunk_size : (i + 1) * chunk_size].clone() + quantized_chunk, (absmax_chunk, code) = quantize_blockwise(input_chunk, code=code) + chunks.append(quantized_chunk) + absmaxes.append(absmax_chunk) + + matrix_i8 = torch.cat(chunks).reshape_as(matrix) + absmax = torch.cat(absmaxes) + return matrix_i8, (absmax, code) + + +def convert_to_int8(model): + """Convert linear and embedding modules to 8-bit with optional adapters""" + for module in list(model.modules()): + for name, child in module.named_children(): + if isinstance(child, nn.Linear): + print(name, child) + setattr( + module, + name, + FrozenBNBLinear( + weight=torch.zeros(child.out_features, child.in_features, dtype=torch.uint8), + absmax=torch.zeros((child.weight.numel() - 1) // 4096 + 1), + code=torch.zeros(256), + bias=child.bias, + ), + ) + elif isinstance(child, nn.Embedding): + setattr( + module, + name, + FrozenBNBEmbedding( + weight=torch.zeros(child.num_embeddings, child.embedding_dim, dtype=torch.uint8), + absmax=torch.zeros((child.weight.numel() - 1) // 4096 + 1), + code=torch.zeros(256), + ), + ) + + +class GPTJBlock(transformers.models.gptj.modeling_gptj.GPTJBlock): + def __init__(self, config): + super().__init__(config) + + convert_to_int8(self.attn) + convert_to_int8(self.mlp) + + +class GPTJModel(transformers.models.gptj.modeling_gptj.GPTJModel): + def __init__(self, config): + super().__init__(config) + convert_to_int8(self) + + +class GPTJForCausalLM(transformers.models.gptj.modeling_gptj.GPTJForCausalLM): + def __init__(self, config): + super().__init__(config) + convert_to_int8(self) + +def add_adapters(model, adapter_dim=16): + assert adapter_dim > 0 + + for module in model.modules(): + if isinstance(module, FrozenBNBLinear): + module.adapter = nn.Sequential( + nn.Linear(module.in_features, adapter_dim, bias=False), + nn.Linear(adapter_dim, module.out_features, bias=False), + ) + nn.init.zeros_(module.adapter[1].weight) + elif isinstance(module, FrozenBNBEmbedding): + module.adapter = nn.Sequential( + nn.Embedding(module.num_embeddings, adapter_dim), + nn.Linear(adapter_dim, module.embedding_dim, bias=False), + ) + nn.init.zeros_(module.adapter[1].weight) + + +def get_model(model_name, cache_dir, quantization): + if quantization is None: + model = AutoModelForCausalLM.from_pretrained(model_name, cache_dir=cache_dir) + elif quantization == "8bit": + transformers.models.gptj.modeling_gptj.GPTJBlock = GPTJBlock + model = AutoModelForCausalLM.from_pretrained(model_name, cache_dir=cache_dir) + add_adapters(gpt) + else: + raise ValueError(f"Unknown quantization {quantization}") + + return model diff --git a/model/supervised_finetuning/trainer.py b/model/supervised_finetuning/trainer.py index dc7b5934..0db0f443 100644 --- a/model/supervised_finetuning/trainer.py +++ b/model/supervised_finetuning/trainer.py @@ -17,6 +17,7 @@ from transformers import ( TrainingArguments, get_cosine_schedule_with_warmup, ) +import bitsandbytes as bnb from utils import get_dataset, get_loss, get_model, get_tokenizer, read_yamls os.environ["WANDB_PROJECT"] = "supervised-finetuning" @@ -25,6 +26,7 @@ os.environ["WANDB_PROJECT"] = "supervised-finetuning" @dataclass class CustomTrainingArguments(TrainingArguments): loss_function: str = "CrossEntropyLoss" + quantization: str = None def compute_metrics(eval_pred): @@ -71,8 +73,16 @@ class SFTTrainer(Trainer): # By default CrossEntropyLoss ignores padding_index -100, but just in case use our own loss_fct self.loss_fct = get_loss(args.loss_function) - def fetch_scheduler(self): - return get_cosine_schedule_with_warmup( + def create_optimizer_and_scheduler(self, num_training_steps: int): + if self.args.quantization == "8bit": + self.optimizer = bnb.optim.Adam8bit(model.parameters(), lr=0.001, betas=(0.9, 0.995)) + else: + self.optimizer = torch.optim.AdamW( + self.model.parameters(), lr=self.args.learning_rate, weight_decay=self.args.weight_decay + ) + + print("lr sheduler") + self.lr_scheduler = get_cosine_schedule_with_warmup( self.optimizer, num_warmup_steps=self.args.warmup_steps, num_training_steps=self.num_train_steps, @@ -165,6 +175,7 @@ if __name__ == "__main__": model = get_model(training_conf, tokenizer) train, evals, collate_fn = get_dataset(training_conf, tokenizer) + assert len(evals) > 0 args = CustomTrainingArguments( output_dir=f"{training_conf.model_name}-{training_conf.log_dir}-finetuned", @@ -186,9 +197,9 @@ if __name__ == "__main__": save_steps=training_conf.save_steps, eval_accumulation_steps=training_conf.eval_accumulation_steps, report_to="wandb", + quantization=training_conf.quantization, ) - assert len(evals) > 0 trainer = SFTTrainer( model, args, diff --git a/model/supervised_finetuning/utils.py b/model/supervised_finetuning/utils.py index a31f74d3..318aeb5e 100644 --- a/model/supervised_finetuning/utils.py +++ b/model/supervised_finetuning/utils.py @@ -6,7 +6,8 @@ from custom_datasets.dialogue_collator import DialogueDataCollator from losses import CrossEntropyLoss from sklearn.model_selection import train_test_split from torch.utils.data import ConcatDataset, Subset -from transformers import AutoModelForCausalLM, AutoTokenizer +from transformers import AutoTokenizer +from models import get_specific_model SUPPORTED_MODELS = ["galactica", "GPT-JT"] # deprecated .. @@ -36,7 +37,7 @@ def get_model(conf, tokenizer): "To include more make sure the masking is dne correctly... (decoder only supported for now)" ) - model = AutoModelForCausalLM.from_pretrained(conf.model_name, cache_dir=conf.cache_dir) + model = get_specific_model(conf.model_name, conf.cache_dir, conf.quantization) if len(tokenizer) != model.get_input_embeddings().num_embeddings: assert not conf.freeze_layer, "Cannot change the number of embeddings if the model is frozen."