diff --git a/model/supervised_finetuning/custom_datasets/dialogue_collator.py b/model/supervised_finetuning/custom_datasets/dialogue_collator.py index d812bb26..0a0b7a5a 100644 --- a/model/supervised_finetuning/custom_datasets/dialogue_collator.py +++ b/model/supervised_finetuning/custom_datasets/dialogue_collator.py @@ -14,6 +14,77 @@ class DialogueDataCollator: Expects a list of texts corresponding to a sequence of [question, answer, question, answer, ...] pairs. """ + tokenizer: PreTrainedTokenizerBase + padding: Union[bool, str, PaddingStrategy] = True + max_length: Optional[int] = None + pad_to_multiple_of: Optional[int] = None + + def __call__(self, features): + flatten_messages = [] + label_masks = [] + + for messages in features: + messages = list(messages) + + # Add a way for the model to terminate generation + # When we predict the start of a new expected question, we want to be able to stop generation + messages.append(self.tokenizer.eos_token) + + flatten_message = self.tokenizer( + "".join(messages), + truncation=True, + max_length=self.max_length, + return_offsets_mapping=True, + ) + + message_change_indices = np.cumsum([len(x) for x in messages[:-1]]) + # for each token an integer indicating the index of the message it belongs to. Just to create the label mask. + # Label mask is true when predicting a token that is part of the answer, false otherwise. + # TEXT: Question: Hello, how are you? Answer: I am fine. Question: What is your name? Answer: My name is John. Question: + # MESSAGE_INDICES: 0 0 0 0 0 0 1 1 1 2 2 2 2 2 2 3 3 3 3 -2 + # LABEL_MASK: 0 0 0 0 0 1 1 1 1 0 0 0 0 0 1 1 1 1 1 0 + + # If no result in next, we are predicting the last termination token(s) + message_indices = list( + map( + lambda x: next((i for i, val in enumerate(message_change_indices) if val >= x), -2), + list(map(lambda x: x[1], flatten_message["offset_mapping"])), + ) + ) + label_mask = np.roll(list(map(lambda x: x % 2 == 1, message_indices)), -1, -1) + try: + label_mask[[i for i in range(len(message_indices)) if message_indices[i] == -2][0] - 1] = True + except IndexError: + # due to truncation, we might not have the last termination token + label_mask[-1] = False + + label_masks.append(label_mask) + + flatten_messages.append({k: v for k, v in flatten_message.items() if k != "offset_mapping"}) + + batch = self.tokenizer.pad( + flatten_messages, + padding=self.padding, + max_length=self.max_length, + pad_to_multiple_of=self.pad_to_multiple_of, + return_tensors="pt", + ) + dim = batch["input_ids"].shape[-1] + + batch["label_masks"] = torch.stack( + [F.pad(torch.tensor(x), (0, dim - len(x)), value=False) for x in label_masks] + ) + batch["targets"] = torch.roll(batch["input_ids"], -1, -1) + + return batch + + +@dataclass +class TrainDialogueDataCollator: + """ + Expects a list of texts corresponding to a sequence of [question, answer, question, answer, ...] pairs. + """ + tokenizer: PreTrainedTokenizerBase padding: Union[bool, str, PaddingStrategy] = True max_length: Optional[int] = None diff --git a/model/supervised_finetuning/trainer.py b/model/supervised_finetuning/trainer.py index 0acb10dd..043534ea 100644 --- a/model/supervised_finetuning/trainer.py +++ b/model/supervised_finetuning/trainer.py @@ -1,13 +1,17 @@ import argparse from distutils.util import strtobool from functools import partial -from typing import Any, Dict, List, Optional, Tuple, Union +from typing import Any, Callable, Dict, List, Optional, Tuple, Union import bitsandbytes +import datasets import torch from torch import nn +from torch.utils.data import DataLoader from transformers import PreTrainedModel, Trainer, TrainingArguments +from transformers.trainer_pt_utils import IterableDatasetShard, seed_worker from transformers.training_args import OptimizerNames +from transformers.utils import is_datasets_available from utils import get_dataset, get_loss, get_metrics, get_model, get_tokenizer, read_yamls @@ -30,12 +34,13 @@ class SFTTrainer(Trainer): self, model: Union[PreTrainedModel, nn.Module] = None, args: TrainingArguments = None, + train_collate_fn: Callable = None, loss_function: str = "CrossEntropyLoss", poly_eps: float = 1.0, **kwargs, ): super().__init__(model, args, **kwargs) - + self.train_collate_fn = train_collate_fn # By default CrossEntropyLoss ignores padding_index -100, but just in case use our own loss_fct self.loss_fct = get_loss(loss_function, poly_eps) @@ -88,6 +93,54 @@ class SFTTrainer(Trainer): return (loss, logits, labels) + def get_train_dataloader(self) -> DataLoader: + """ + Returns the training [`~torch.utils.data.DataLoader`]. + Will use no sampler if `train_dataset` does not implement `__len__`, a random sampler (adapted to distributed + training if necessary) otherwise. + Subclass and override this method if you want to inject some custom behavior. + """ + if self.train_dataset is None: + raise ValueError("Trainer: training requires a train_dataset.") + + train_dataset = self.train_dataset + data_collator = self.train_collate_fn + if is_datasets_available() and isinstance(train_dataset, datasets.Dataset): + train_dataset = self._remove_unused_columns(train_dataset, description="training") + else: + data_collator = self._get_collator_with_removed_columns(data_collator, description="training") + + if isinstance(train_dataset, torch.utils.data.IterableDataset): + if self.args.world_size > 1: + train_dataset = IterableDatasetShard( + train_dataset, + batch_size=self._train_batch_size, + drop_last=self.args.dataloader_drop_last, + num_processes=self.args.world_size, + process_index=self.args.process_index, + ) + + return DataLoader( + train_dataset, + batch_size=self.args.per_device_train_batch_size, + collate_fn=data_collator, + num_workers=self.args.dataloader_num_workers, + pin_memory=self.args.dataloader_pin_memory, + ) + + train_sampler = self._get_train_sampler() + + return DataLoader( + train_dataset, + batch_size=self._train_batch_size, + sampler=train_sampler, + collate_fn=data_collator, + drop_last=self.args.dataloader_drop_last, + num_workers=self.args.dataloader_num_workers, + pin_memory=self.args.dataloader_pin_memory, + worker_init_fn=seed_worker, + ) + def _strtobool(x): return bool(strtobool(x)) @@ -140,7 +193,7 @@ if __name__ == "__main__": tokenizer = get_tokenizer(training_conf) model = get_model(training_conf, tokenizer) - train, evals, collate_fn = get_dataset(training_conf, tokenizer) + train, evals, collate_fn, train_collate_fn = get_dataset(training_conf, tokenizer) metrics, preprocess_fns = get_metrics(training_conf, tokenizer) optimizer = OptimizerNames.ADAMW_BNB if training_conf.quantization else OptimizerNames.ADAMW_HF @@ -190,6 +243,7 @@ if __name__ == "__main__": trainer = SFTTrainer( model, args, + train_collate_fn=train_collate_fn, loss_function=training_conf.loss_fn, poly_eps=training_conf.poly_eps, train_dataset=train, diff --git a/model/supervised_finetuning/utils.py b/model/supervised_finetuning/utils.py index f7a0ab15..ba0e5539 100644 --- a/model/supervised_finetuning/utils.py +++ b/model/supervised_finetuning/utils.py @@ -8,7 +8,7 @@ import evaluate import transformers import yaml from custom_datasets import get_one_dataset -from custom_datasets.dialogue_collator import DialogueDataCollator +from custom_datasets.dialogue_collator import DialogueDataCollator, TrainDialogueDataCollator from custom_datasets.qa_datasets import QA_SPECIAL_TOKENS from losses import CrossEntropyLoss, PolyLoss from models import freeze_top_n_layers, get_specific_model @@ -126,8 +126,8 @@ def get_dataset(conf, tokenizer): train = ConcatDataset(train_datasets) collate_fn = DialogueDataCollator(tokenizer, max_length=conf.max_length) - - return train, evals, collate_fn + train_collate_fn = TrainDialogueDataCollator(tokenizer, max_length=conf.max_length) + return train, evals, collate_fn, train_collate_fn def get_loss(loss, poly_eps):