[fix] Custom collate_fn for training

This commit is contained in:
theblackcat102
2023-02-03 06:08:01 +00:00
parent 1041564db7
commit 8b2080559c
3 changed files with 131 additions and 6 deletions
@@ -14,6 +14,77 @@ class DialogueDataCollator:
Expects a list of texts corresponding to a sequence of [question, answer, question, answer, ...] pairs.
"""
tokenizer: PreTrainedTokenizerBase
padding: Union[bool, str, PaddingStrategy] = True
max_length: Optional[int] = None
pad_to_multiple_of: Optional[int] = None
def __call__(self, features):
flatten_messages = []
label_masks = []
for messages in features:
messages = list(messages)
# Add a way for the model to terminate generation
# When we predict the start of a new expected question, we want to be able to stop generation
messages.append(self.tokenizer.eos_token)
flatten_message = self.tokenizer(
"".join(messages),
truncation=True,
max_length=self.max_length,
return_offsets_mapping=True,
)
message_change_indices = np.cumsum([len(x) for x in messages[:-1]])
# for each token an integer indicating the index of the message it belongs to. Just to create the label mask.
# Label mask is true when predicting a token that is part of the answer, false otherwise.
# TEXT: Question: Hello, how are you? Answer: I am fine. Question: What is your name? Answer: My name is John. Question:
# MESSAGE_INDICES: 0 0 0 0 0 0 1 1 1 2 2 2 2 2 2 3 3 3 3 -2
# LABEL_MASK: 0 0 0 0 0 1 1 1 1 0 0 0 0 0 1 1 1 1 1 0
# If no result in next, we are predicting the last termination token(s)
message_indices = list(
map(
lambda x: next((i for i, val in enumerate(message_change_indices) if val >= x), -2),
list(map(lambda x: x[1], flatten_message["offset_mapping"])),
)
)
label_mask = np.roll(list(map(lambda x: x % 2 == 1, message_indices)), -1, -1)
try:
label_mask[[i for i in range(len(message_indices)) if message_indices[i] == -2][0] - 1] = True
except IndexError:
# due to truncation, we might not have the last termination token
label_mask[-1] = False
label_masks.append(label_mask)
flatten_messages.append({k: v for k, v in flatten_message.items() if k != "offset_mapping"})
batch = self.tokenizer.pad(
flatten_messages,
padding=self.padding,
max_length=self.max_length,
pad_to_multiple_of=self.pad_to_multiple_of,
return_tensors="pt",
)
dim = batch["input_ids"].shape[-1]
batch["label_masks"] = torch.stack(
[F.pad(torch.tensor(x), (0, dim - len(x)), value=False) for x in label_masks]
)
batch["targets"] = torch.roll(batch["input_ids"], -1, -1)
return batch
@dataclass
class TrainDialogueDataCollator:
"""
Expects a list of texts corresponding to a sequence of [question, answer, question, answer, ...] pairs.
"""
tokenizer: PreTrainedTokenizerBase
padding: Union[bool, str, PaddingStrategy] = True
max_length: Optional[int] = None
+57 -3
View File
@@ -1,13 +1,17 @@
import argparse
from distutils.util import strtobool
from functools import partial
from typing import Any, Dict, List, Optional, Tuple, Union
from typing import Any, Callable, Dict, List, Optional, Tuple, Union
import bitsandbytes
import datasets
import torch
from torch import nn
from torch.utils.data import DataLoader
from transformers import PreTrainedModel, Trainer, TrainingArguments
from transformers.trainer_pt_utils import IterableDatasetShard, seed_worker
from transformers.training_args import OptimizerNames
from transformers.utils import is_datasets_available
from utils import get_dataset, get_loss, get_metrics, get_model, get_tokenizer, read_yamls
@@ -30,12 +34,13 @@ class SFTTrainer(Trainer):
self,
model: Union[PreTrainedModel, nn.Module] = None,
args: TrainingArguments = None,
train_collate_fn: Callable = None,
loss_function: str = "CrossEntropyLoss",
poly_eps: float = 1.0,
**kwargs,
):
super().__init__(model, args, **kwargs)
self.train_collate_fn = train_collate_fn
# By default CrossEntropyLoss ignores padding_index -100, but just in case use our own loss_fct
self.loss_fct = get_loss(loss_function, poly_eps)
@@ -88,6 +93,54 @@ class SFTTrainer(Trainer):
return (loss, logits, labels)
def get_train_dataloader(self) -> DataLoader:
"""
Returns the training [`~torch.utils.data.DataLoader`].
Will use no sampler if `train_dataset` does not implement `__len__`, a random sampler (adapted to distributed
training if necessary) otherwise.
Subclass and override this method if you want to inject some custom behavior.
"""
if self.train_dataset is None:
raise ValueError("Trainer: training requires a train_dataset.")
train_dataset = self.train_dataset
data_collator = self.train_collate_fn
if is_datasets_available() and isinstance(train_dataset, datasets.Dataset):
train_dataset = self._remove_unused_columns(train_dataset, description="training")
else:
data_collator = self._get_collator_with_removed_columns(data_collator, description="training")
if isinstance(train_dataset, torch.utils.data.IterableDataset):
if self.args.world_size > 1:
train_dataset = IterableDatasetShard(
train_dataset,
batch_size=self._train_batch_size,
drop_last=self.args.dataloader_drop_last,
num_processes=self.args.world_size,
process_index=self.args.process_index,
)
return DataLoader(
train_dataset,
batch_size=self.args.per_device_train_batch_size,
collate_fn=data_collator,
num_workers=self.args.dataloader_num_workers,
pin_memory=self.args.dataloader_pin_memory,
)
train_sampler = self._get_train_sampler()
return DataLoader(
train_dataset,
batch_size=self._train_batch_size,
sampler=train_sampler,
collate_fn=data_collator,
drop_last=self.args.dataloader_drop_last,
num_workers=self.args.dataloader_num_workers,
pin_memory=self.args.dataloader_pin_memory,
worker_init_fn=seed_worker,
)
def _strtobool(x):
return bool(strtobool(x))
@@ -140,7 +193,7 @@ if __name__ == "__main__":
tokenizer = get_tokenizer(training_conf)
model = get_model(training_conf, tokenizer)
train, evals, collate_fn = get_dataset(training_conf, tokenizer)
train, evals, collate_fn, train_collate_fn = get_dataset(training_conf, tokenizer)
metrics, preprocess_fns = get_metrics(training_conf, tokenizer)
optimizer = OptimizerNames.ADAMW_BNB if training_conf.quantization else OptimizerNames.ADAMW_HF
@@ -190,6 +243,7 @@ if __name__ == "__main__":
trainer = SFTTrainer(
model,
args,
train_collate_fn=train_collate_fn,
loss_function=training_conf.loss_fn,
poly_eps=training_conf.poly_eps,
train_dataset=train,
+3 -3
View File
@@ -8,7 +8,7 @@ import evaluate
import transformers
import yaml
from custom_datasets import get_one_dataset
from custom_datasets.dialogue_collator import DialogueDataCollator
from custom_datasets.dialogue_collator import DialogueDataCollator, TrainDialogueDataCollator
from custom_datasets.qa_datasets import QA_SPECIAL_TOKENS
from losses import CrossEntropyLoss, PolyLoss
from models import freeze_top_n_layers, get_specific_model
@@ -126,8 +126,8 @@ def get_dataset(conf, tokenizer):
train = ConcatDataset(train_datasets)
collate_fn = DialogueDataCollator(tokenizer, max_length=conf.max_length)
return train, evals, collate_fn
train_collate_fn = TrainDialogueDataCollator(tokenizer, max_length=conf.max_length)
return train, evals, collate_fn, train_collate_fn
def get_loss(loss, poly_eps):