Files
Open-Assistant/model/supervised_finetuning/trainer.py
T
Sotirios Anagnostidis c8f47eef9f precommits
2023-01-11 22:58:17 +01:00

203 lines
6.8 KiB
Python

import argparse
from distutils.util import strtobool
from functools import partial
from typing import Any, Dict, List, Optional, Tuple, Union
import bitsandbytes
import torch
from torch import nn
from transformers import PreTrainedModel, Trainer, TrainingArguments
from transformers.training_args import OptimizerNames
from utils import get_dataset, get_loss, get_metrics, get_model, get_tokenizer, read_yamls
def compute_metrics(eval_pred, preprocess_fns, metrics):
out = {}
for metric, preprocess_fn in zip(metrics, preprocess_fns):
preds, labels = preprocess_fn(eval_pred)
out = dict(**out, **metric.compute(predictions=preds, references=labels))
return out
def preprocess_logits_for_metrics(logits, labels):
pred_ids = torch.argmax(logits, dim=-1)
return pred_ids
class SFTTrainer(Trainer):
def __init__(
self,
model: Union[PreTrainedModel, nn.Module] = None,
args: TrainingArguments = None,
loss_function: str = "CrossEntropyLoss",
poly_eps: float = 1.0,
**kwargs,
):
super().__init__(model, args, **kwargs)
# By default CrossEntropyLoss ignores padding_index -100, but just in case use our own loss_fct
self.loss_fct = get_loss(loss_function, poly_eps)
def compute_loss(self, model, inputs, return_outputs=False):
labels_mask = inputs.pop("label_masks")
targets = inputs.pop("targets")
outputs = model(
input_ids=inputs["input_ids"],
attention_mask=inputs.get("attention_mask", None),
)
loss = self.loss_fct(outputs.get("logits"), targets, mask=labels_mask)
return (loss, outputs) if return_outputs else loss
def _compute_loss(self, model, inputs):
inputs = self._prepare_inputs(inputs)
labels_mask = inputs.pop("label_masks")
targets = inputs.pop("targets")
outputs = model(
input_ids=inputs["input_ids"],
attention_mask=inputs.get("attention_mask", None),
)
logits = outputs.get("logits")
loss = self.loss_fct(outputs.get("logits"), targets, mask=labels_mask)
return loss, logits, targets, labels_mask
def prediction_step(
self,
model: nn.Module,
inputs: Dict[str, Union[torch.Tensor, Any]],
prediction_loss_only: bool,
ignore_keys: Optional[List[str]] = None,
) -> Tuple[Optional[torch.Tensor], Optional[torch.Tensor], Optional[torch.Tensor]]:
with torch.no_grad():
loss, logits, labels, labels_mask = self._compute_loss(model, inputs)
labels[~labels_mask.bool()] = -100 # padding_index
loss = loss.mean().detach()
if self.args.prediction_loss_only:
return (loss, None, None)
return (loss, logits, labels)
def _strtobool(x):
return bool(strtobool(x))
def argument_parsing(notebook=False, notebook_args=None):
parser = argparse.ArgumentParser()
parser.add_argument("--configs", nargs="+", required=True)
parser.add_argument("--local_rank", type=int, default=-1)
parser.add_argument("--deepspeed", action="store_true")
parser.add_argument("--no-deepspeed", dest="deepspeed", action="store_false")
parser.add_argument("--wandb-entity", type=str, default="open-assistant")
parser.set_defaults(deepspeed=False)
if notebook:
args, remaining = parser.parse_known_args(notebook_args)
else:
args, remaining = parser.parse_known_args()
print(args)
# Config from YAML
conf = {}
configs = read_yamls("./configs")
for name in args.configs:
if "," in name:
for n in name.split(","):
conf.update(configs[n])
else:
conf.update(configs[name])
conf["wandb_entity"] = args.wandb_entity
conf["local_rank"] = args.local_rank
conf["deepspeed"] = args.deepspeed
# Override config from command-line
parser = argparse.ArgumentParser()
for key, value in conf.items():
type_ = type(value) if value is not None else str
if type_ == bool:
type_ = _strtobool
parser.add_argument(f"--{key}", type=type_, default=value)
return parser.parse_args(remaining)
if __name__ == "__main__":
training_conf = argument_parsing()
tokenizer = get_tokenizer(training_conf)
model = get_model(training_conf, tokenizer)
train, evals, collate_fn = get_dataset(training_conf, tokenizer)
metrics, preprocess_fns = get_metrics(training_conf, tokenizer)
optimizer = OptimizerNames.ADAMW_BNB if training_conf.quantization else OptimizerNames.ADAMW_HF
if training_conf.quantization:
for module in model.modules():
if isinstance(module, torch.nn.Embedding):
bitsandbytes.optim.GlobalOptimManager.get_instance().register_module_override(
module, "weight", {"optim_bits": 32}
)
args = TrainingArguments(
output_dir=f"{training_conf.model_name}-{training_conf.log_dir}-finetuned",
num_train_epochs=training_conf.num_train_epochs,
warmup_steps=training_conf.warmup_steps,
learning_rate=float(training_conf.learning_rate),
deepspeed="configs/zero_config.json" if training_conf.deepspeed else None,
optim=optimizer,
fp16=True,
local_rank=training_conf.local_rank,
gradient_checkpointing=training_conf.gradient_checkpointing,
gradient_accumulation_steps=training_conf.gradient_accumulation_steps,
per_device_train_batch_size=training_conf.per_device_train_batch_size,
per_device_eval_batch_size=training_conf.per_device_eval_batch_size,
weight_decay=training_conf.weight_decay,
max_grad_norm=training_conf.max_grad_norm,
logging_steps=training_conf.logging_steps,
save_total_limit=training_conf.save_total_limit,
evaluation_strategy="steps",
eval_steps=training_conf.eval_steps,
save_steps=training_conf.save_steps,
eval_accumulation_steps=training_conf.eval_accumulation_steps,
report_to="wandb",
)
assert len(evals) > 0
if not training_conf.deepspeed or training_conf.local_rank == 0:
import wandb
wandb.init(
project="supervised-finetuning",
entity=training_conf.wandb_entity,
name=f"{training_conf.model_name}-{training_conf.log_dir}-finetuned",
)
trainer = SFTTrainer(
model,
args,
loss_function=training_conf.loss_fn,
poly_eps=training_conf.poly_eps,
train_dataset=train,
eval_dataset=evals,
data_collator=collate_fn,
tokenizer=tokenizer,
compute_metrics=partial(compute_metrics, metrics=metrics, preprocess_fns=preprocess_fns),
preprocess_logits_for_metrics=preprocess_logits_for_metrics,
)
trainer.train()