Files
Open-Assistant/model/supervised_finetuning/trainer.py
T
Sotirios Anagnostidis d3952354e2 pre commits
2023-01-06 22:09:24 +01:00

169 lines
5.5 KiB
Python

import argparse
import os
from distutils.util import strtobool
from typing import Any, Dict, List, Optional, Tuple, Union
import torch
from torch import nn
from transformers import PreTrainedModel, Trainer, TrainingArguments
from utils import get_dataset, get_loss, get_model, get_tokenizer, read_yamls
os.environ["WANDB_PROJECT"] = "supervised-finetuning"
def compute_metrics(eval_pred):
pred_ids = eval_pred.predictions
labels = eval_pred.label_ids
return {"accuracy": (pred_ids[labels > 0] == labels[labels > 0]).mean()}
def preprocess_logits_for_metrics(logits, labels):
pred_ids = torch.argmax(logits, dim=-1)
return pred_ids
class SFTTrainer(Trainer):
def __init__(
self,
model: Union[PreTrainedModel, nn.Module] = None,
args: TrainingArguments = None,
loss_function: str = "CrossEntropyLoss",
**kwargs,
):
super().__init__(model, args, **kwargs)
# By default CrossEntropyLoss ignores padding_index -100, but just in case use our own loss_fct
self.loss_fct = get_loss(loss_function)
def compute_loss(self, model, inputs, return_outputs=False):
labels_mask = inputs.pop("label_masks")
targets = inputs.pop("targets")
outputs = model(input_ids=inputs["input_ids"], attention_mask=inputs.get("attention_mask", None))
loss = self.loss_fct(outputs.get("logits"), targets, mask=labels_mask)
return (loss, outputs) if return_outputs else loss
def _compute_loss(self, model, inputs):
inputs = self._prepare_inputs(inputs)
labels_mask = inputs.pop("label_masks")
targets = inputs.pop("targets")
outputs = model(input_ids=inputs["input_ids"], attention_mask=inputs.get("attention_mask", None))
logits = outputs.get("logits")
loss = self.loss_fct(outputs.get("logits"), targets, mask=labels_mask)
return loss, logits, targets, labels_mask
def prediction_step(
self,
model: nn.Module,
inputs: Dict[str, Union[torch.Tensor, Any]],
prediction_loss_only: bool,
ignore_keys: Optional[List[str]] = None,
) -> Tuple[Optional[torch.Tensor], Optional[torch.Tensor], Optional[torch.Tensor]]:
with torch.no_grad():
loss, logits, labels, labels_mask = self._compute_loss(model, inputs)
labels[~labels_mask.bool()] = -100 # padding_index
loss = loss.mean().detach()
if self.args.prediction_loss_only:
return (loss, None, None)
return (loss, logits, labels)
def _strtobool(x):
return bool(strtobool(x))
def argument_parsing(notebook=False, notebook_args=None):
parser = argparse.ArgumentParser()
parser.add_argument("--configs", nargs="+", required=True)
parser.add_argument("--local_rank", type=int, default=-1)
parser.add_argument("--deepspeed", action="store_true")
parser.add_argument("--no-deepspeed", dest="deepspeed", action="store_false")
parser.set_defaults(deepspeed=False)
if notebook:
args, remaining = parser.parse_known_args(notebook_args)
else:
args, remaining = parser.parse_known_args()
print(args)
# Config from YAML
conf = {}
configs = read_yamls("./configs")
for name in args.configs:
if "," in name:
for n in name.split(","):
conf.update(configs[n])
else:
conf.update(configs[name])
conf["local_rank"] = args.local_rank
conf["deepspeed"] = args.deepspeed
# Override config from command-line
parser = argparse.ArgumentParser()
for key, value in conf.items():
type_ = type(value) if value is not None else str
if type_ == bool:
type_ = _strtobool
parser.add_argument(f"--{key}", type=type_, default=value)
return parser.parse_args(remaining)
if __name__ == "__main__":
training_conf = argument_parsing()
tokenizer = get_tokenizer(training_conf)
model = get_model(training_conf, tokenizer)
train, evals, collate_fn = get_dataset(training_conf, tokenizer)
args = TrainingArguments(
output_dir=f"{training_conf.model_name}-{training_conf.log_dir}-finetuned",
num_train_epochs=training_conf.num_train_epochs,
warmup_steps=training_conf.warmup_steps,
learning_rate=float(training_conf.learning_rate),
deepspeed="configs/zero_config.json" if training_conf.deepspeed else None,
fp16=True,
local_rank=training_conf.local_rank,
gradient_checkpointing=training_conf.gradient_checkpointing,
gradient_accumulation_steps=training_conf.gradient_accumulation_steps,
per_device_train_batch_size=training_conf.per_device_train_batch_size,
per_device_eval_batch_size=training_conf.per_device_eval_batch_size,
weight_decay=training_conf.weight_decay,
max_grad_norm=training_conf.max_grad_norm,
logging_steps=training_conf.logging_steps,
save_total_limit=training_conf.save_total_limit,
evaluation_strategy="steps",
eval_steps=training_conf.eval_steps,
save_steps=training_conf.save_steps,
eval_accumulation_steps=training_conf.eval_accumulation_steps,
report_to="wandb",
)
assert len(evals) > 0
trainer = SFTTrainer(
model,
args,
loss_function=training_conf.loss_fn,
train_dataset=train,
eval_dataset=evals,
data_collator=collate_fn,
tokenizer=tokenizer,
compute_metrics=compute_metrics,
preprocess_logits_for_metrics=preprocess_logits_for_metrics,
)
trainer.train()