Files
Open-Assistant/model/supervised_finetuning/trainer.py
T
2023-02-09 09:19:17 +01:00

235 lines
8.1 KiB
Python

import argparse
from distutils.util import strtobool
from functools import partial
from typing import Any, Dict, List, Optional, Tuple, Union
import bitsandbytes
import torch
from efficiency_utils import fuse_gelu
from torch import nn
from transformers import PreTrainedModel, Trainer, TrainingArguments
from transformers.training_args import OptimizerNames
from utils import PerDatasetSampler, get_dataset, get_loss, get_metrics, get_model, get_tokenizer, read_yamls
def compute_metrics(eval_pred, preprocess_fns, metrics):
out = {}
for metric, preprocess_fn in zip(metrics, preprocess_fns):
preds, labels = preprocess_fn(eval_pred)
out = dict(**out, **metric.compute(predictions=preds, references=labels))
return out
def preprocess_logits_for_metrics(logits, labels):
pred_ids = torch.argmax(logits, dim=-1)
return pred_ids
class SFTTrainer(Trainer):
def __init__(
self,
model: Union[PreTrainedModel, nn.Module] = None,
args: TrainingArguments = None,
sampler: torch.utils.data.sampler.Sampler = None,
loss_function: str = "CrossEntropyLoss",
poly_eps: float = 1.0,
**kwargs,
):
super().__init__(model, args, **kwargs)
# By default CrossEntropyLoss ignores padding_index -100, but just in case use our own loss_fct
self.loss_fct = get_loss(loss_function, poly_eps)
self.sampler = sampler
def compute_loss(self, model, inputs, return_outputs=False):
labels_mask = inputs.pop("label_masks")
targets = inputs.pop("targets")
outputs = model(
input_ids=inputs["input_ids"],
attention_mask=inputs.get("attention_mask", None),
)
loss = self.loss_fct(outputs.get("logits"), targets, mask=labels_mask)
return (loss, outputs) if return_outputs else loss
def _compute_loss(self, model, inputs):
inputs = self._prepare_inputs(inputs)
labels_mask = inputs.pop("label_masks")
targets = inputs.pop("targets")
outputs = model(
input_ids=inputs["input_ids"],
attention_mask=inputs.get("attention_mask", None),
)
logits = outputs.get("logits")
loss = self.loss_fct(outputs.get("logits"), targets, mask=labels_mask)
return loss, logits, targets, labels_mask
def prediction_step(
self,
model: nn.Module,
inputs: Dict[str, Union[torch.Tensor, Any]],
prediction_loss_only: bool,
ignore_keys: Optional[List[str]] = None,
) -> Tuple[Optional[torch.Tensor], Optional[torch.Tensor], Optional[torch.Tensor]]:
with torch.no_grad():
loss, logits, labels, labels_mask = self._compute_loss(model, inputs)
labels[~labels_mask.bool()] = -100 # padding_index
loss = loss.mean().detach()
if self.args.prediction_loss_only:
return (loss, None, None)
return (loss, logits, labels)
def get_train_dataloader(self):
"""Inject custom data sampling behaviour into training loop"""
if self.sampler is None:
torch.utils.data.DataLoader(
self.train_dataset,
batch_size=self.args.per_device_train_batch_size,
shuffle=True,
collate_fn=self.data_collator,
)
else:
dataloader = torch.utils.data.DataLoader(
self.train_dataset,
batch_size=self.args.per_device_train_batch_size,
sampler=self.sampler,
collate_fn=self.data_collator,
)
if torch.cuda.device_count() <= 1:
return dataloader
else:
# Not strictly necessary to use accelerate, currently just
# ensures batches are padded to be divisible by # devices
from accelerate import Accelerator
accelerator = Accelerator()
return accelerator.prepare(dataloader)
def _strtobool(x):
return bool(strtobool(x))
def argument_parsing(notebook=False, notebook_args=None):
parser = argparse.ArgumentParser()
parser.add_argument("--configs", nargs="+", required=True)
parser.add_argument("--local_rank", type=int, default=-1)
parser.add_argument("--deepspeed", action="store_true")
parser.add_argument("--no-deepspeed", dest="deepspeed", action="store_false")
parser.add_argument("--wandb-entity", type=str, default="open-assistant")
parser.set_defaults(deepspeed=False)
if notebook:
args, remaining = parser.parse_known_args(notebook_args)
else:
args, remaining = parser.parse_known_args()
print(args)
# Config from YAML
conf = {}
configs = read_yamls("./configs")
for name in args.configs:
if "," in name:
for n in name.split(","):
conf.update(configs[n])
else:
conf.update(configs[name])
conf["wandb_entity"] = args.wandb_entity
conf["local_rank"] = args.local_rank
conf["deepspeed"] = args.deepspeed
# Override config from command-line
parser = argparse.ArgumentParser()
for key, value in conf.items():
type_ = type(value) if value is not None else str
if type_ == bool:
type_ = _strtobool
parser.add_argument(f"--{key}", type=type_, default=value)
return parser.parse_args(remaining)
if __name__ == "__main__":
training_conf = argument_parsing()
tokenizer = get_tokenizer(training_conf)
model = get_model(training_conf, tokenizer)
train, evals, collate_fn = get_dataset(training_conf, tokenizer)
sampler = PerDatasetSampler.build_sampler_from_config(training_conf, train.datasets)
metrics, preprocess_fns = get_metrics(training_conf, tokenizer)
optimizer = OptimizerNames.ADAMW_BNB if training_conf.quantization else OptimizerNames.ADAMW_HF
if training_conf.quantization:
for module in model.modules():
if isinstance(module, torch.nn.Embedding):
bitsandbytes.optim.GlobalOptimManager.get_instance().register_module_override(
module, "weight", {"optim_bits": 32}
)
if training_conf.fuse_gelu:
model = fuse_gelu(model)
args = TrainingArguments(
output_dir=f"{training_conf.model_name}-{training_conf.log_dir}-finetuned",
num_train_epochs=training_conf.num_train_epochs,
warmup_steps=training_conf.warmup_steps,
learning_rate=float(training_conf.learning_rate),
deepspeed="configs/zero_config.json" if training_conf.deepspeed else None,
optim=optimizer,
fp16=True,
local_rank=training_conf.local_rank,
gradient_checkpointing=training_conf.gradient_checkpointing,
gradient_accumulation_steps=training_conf.gradient_accumulation_steps,
per_device_train_batch_size=training_conf.per_device_train_batch_size,
per_device_eval_batch_size=training_conf.per_device_eval_batch_size,
weight_decay=training_conf.weight_decay,
max_grad_norm=training_conf.max_grad_norm,
logging_steps=training_conf.logging_steps,
save_total_limit=training_conf.save_total_limit,
evaluation_strategy="steps",
eval_steps=training_conf.eval_steps,
save_steps=training_conf.save_steps,
eval_accumulation_steps=training_conf.eval_accumulation_steps,
report_to="wandb",
)
assert len(evals) > 0
if not training_conf.deepspeed or training_conf.local_rank == 0:
import wandb
wandb.init(
project="supervised-finetuning",
entity=training_conf.wandb_entity,
name=f"{training_conf.model_name}-{training_conf.log_dir}-finetuned",
)
trainer = SFTTrainer(
model=model,
args=args,
sampler=sampler,
loss_function=training_conf.loss_fn,
poly_eps=training_conf.poly_eps,
train_dataset=train,
eval_dataset=evals,
data_collator=collate_fn,
tokenizer=tokenizer,
compute_metrics=partial(compute_metrics, metrics=metrics, preprocess_fns=preprocess_fns),
preprocess_logits_for_metrics=preprocess_logits_for_metrics,
)
trainer.train()