import argparse import os from dataclasses import dataclass from distutils.util import strtobool from typing import Any, Callable, Dict, List, Optional, Tuple, Union import torch from torch import nn from torch.utils.data import Dataset from transformers import ( DataCollator, EvalPrediction, PreTrainedModel, PreTrainedTokenizerBase, Trainer, TrainerCallback, TrainingArguments, get_cosine_schedule_with_warmup, ) from utils import get_dataset, get_loss, get_model, get_tokenizer, read_yamls os.environ["WANDB_PROJECT"] = "supervised-finetuning" @dataclass class CustomTrainingArguments(TrainingArguments): loss_function: str = "CrossEntropyLoss" def compute_metrics(eval_pred): pred_ids = eval_pred.predictions labels = eval_pred.label_ids return {"accuracy": (pred_ids[labels > 0] == labels[labels > 0]).mean()} def preprocess_logits_for_metrics(logits, labels): pred_ids = torch.argmax(logits, dim=-1) return pred_ids class SFTTrainer(Trainer): def __init__( self, model: Union[PreTrainedModel, nn.Module] = None, args: TrainingArguments = None, data_collator: Optional[DataCollator] = None, train_dataset: Optional[Dataset] = None, eval_dataset: Optional[Dataset] = None, tokenizer: Optional[PreTrainedTokenizerBase] = None, model_init: Callable[[], PreTrainedModel] = None, compute_metrics: Optional[Callable[[EvalPrediction], Dict]] = None, callbacks: Optional[List[TrainerCallback]] = None, optimizers: Tuple[torch.optim.Optimizer, torch.optim.lr_scheduler.LambdaLR] = (None, None), preprocess_logits_for_metrics: Callable[[torch.Tensor, torch.Tensor], torch.Tensor] = None, ): super().__init__( model, args, data_collator, train_dataset, eval_dataset, tokenizer, model_init, compute_metrics, callbacks, optimizers, preprocess_logits_for_metrics, ) self.loss_fct = get_loss(args.loss_function) def fetch_scheduler(self): return get_cosine_schedule_with_warmup( self.optimizer, num_warmup_steps=self.args.warmup_steps, num_training_steps=self.num_train_steps, num_cycles=1, last_epoch=-1, ) def compute_loss(self, model, inputs, return_outputs=False): labels_mask = inputs.pop("label_masks") outputs = model(**inputs) loss = self.loss_fct(outputs.get("logits"), torch.roll(inputs["input_ids"], -1, -1), mask=labels_mask) return (loss, outputs) if return_outputs else loss def _compute_loss(self, model, inputs): labels_mask = inputs.pop("label_masks") inputs = self._prepare_inputs(inputs) outputs = model(**inputs) logits = outputs.get("logits") targets = torch.roll(inputs["input_ids"], -1, -1) loss = self.loss_fct(outputs.get("logits"), targets, mask=labels_mask) return loss, logits, targets, labels_mask def prediction_step( self, model: nn.Module, inputs: Dict[str, Union[torch.Tensor, Any]], prediction_loss_only: bool, ignore_keys: Optional[List[str]] = None, ) -> Tuple[Optional[torch.Tensor], Optional[torch.Tensor], Optional[torch.Tensor]]: with torch.no_grad(): loss, logits, labels, labels_mask = self._compute_loss(model, inputs) labels[~labels_mask] = -1 loss = loss.mean().detach() if self.args.prediction_loss_only: return (loss, None, None) return (loss, logits, labels) def _strtobool(x): return bool(strtobool(x)) def argument_parsing(notebook=False, notebook_args=None): parser = argparse.ArgumentParser() parser.add_argument("--configs", nargs="+", required=True) if notebook: args, remaining = parser.parse_known_args(notebook_args) else: args, remaining = parser.parse_known_args() # Config from YAML conf = {} configs = read_yamls("./configs") for name in args.configs: if "," in name: for n in name.split(","): conf.update(configs[n]) else: conf.update(configs[name]) # Override config from command-line parser = argparse.ArgumentParser() for key, value in conf.items(): type_ = type(value) if value is not None else str if type_ == bool: type_ = _strtobool parser.add_argument(f"--{key}", type=type_, default=value) return parser.parse_args(remaining) if __name__ == "__main__": training_conf = argument_parsing() model = get_model(training_conf) tokenizer = get_tokenizer(training_conf) train, evals, collate_fn = get_dataset(training_conf, tokenizer) args = CustomTrainingArguments( output_dir=f"{training_conf.model_name}-{training_conf.log_dir}-finetuned", num_train_epochs=training_conf.num_train_epochs, warmup_steps=training_conf.warmup_steps, loss_function=training_conf.loss_fn, learning_rate=float(training_conf.learning_rate), fp16=True, gradient_checkpointing=training_conf.gradient_checkpointing, gradient_accumulation_steps=training_conf.gradient_accumulation_steps, per_device_train_batch_size=training_conf.per_device_train_batch_size, per_device_eval_batch_size=training_conf.per_device_eval_batch_size, weight_decay=training_conf.weight_decay, max_grad_norm=training_conf.max_grad_norm, logging_steps=training_conf.logging_steps, save_total_limit=training_conf.save_total_limit, evaluation_strategy="steps", eval_steps=training_conf.eval_steps, save_steps=training_conf.save_steps, eval_accumulation_steps=training_conf.eval_accumulation_steps, report_to="wandb", ) assert len(evals) > 0 trainer = SFTTrainer( model, args, train_dataset=train, eval_dataset=evals, data_collator=collate_fn, tokenizer=tokenizer, compute_metrics=compute_metrics, preprocess_logits_for_metrics=preprocess_logits_for_metrics, ) trainer.train()