mirror of
https://github.com/wassname/Open-Assistant.git
synced 2026-06-30 16:40:05 +08:00
[fix] Custom collate_fn for training
This commit is contained in:
@@ -14,6 +14,77 @@ class DialogueDataCollator:
|
||||
Expects a list of texts corresponding to a sequence of [question, answer, question, answer, ...] pairs.
|
||||
"""
|
||||
|
||||
tokenizer: PreTrainedTokenizerBase
|
||||
padding: Union[bool, str, PaddingStrategy] = True
|
||||
max_length: Optional[int] = None
|
||||
pad_to_multiple_of: Optional[int] = None
|
||||
|
||||
def __call__(self, features):
|
||||
flatten_messages = []
|
||||
label_masks = []
|
||||
|
||||
for messages in features:
|
||||
messages = list(messages)
|
||||
|
||||
# Add a way for the model to terminate generation
|
||||
# When we predict the start of a new expected question, we want to be able to stop generation
|
||||
messages.append(self.tokenizer.eos_token)
|
||||
|
||||
flatten_message = self.tokenizer(
|
||||
"".join(messages),
|
||||
truncation=True,
|
||||
max_length=self.max_length,
|
||||
return_offsets_mapping=True,
|
||||
)
|
||||
|
||||
message_change_indices = np.cumsum([len(x) for x in messages[:-1]])
|
||||
# for each token an integer indicating the index of the message it belongs to. Just to create the label mask.
|
||||
# Label mask is true when predicting a token that is part of the answer, false otherwise.
|
||||
# TEXT: Question: Hello, how are you? Answer: I am fine. Question: What is your name? Answer: My name is John. Question:
|
||||
# MESSAGE_INDICES: 0 0 0 0 0 0 1 1 1 2 2 2 2 2 2 3 3 3 3 -2
|
||||
# LABEL_MASK: 0 0 0 0 0 1 1 1 1 0 0 0 0 0 1 1 1 1 1 0
|
||||
|
||||
# If no result in next, we are predicting the last termination token(s)
|
||||
message_indices = list(
|
||||
map(
|
||||
lambda x: next((i for i, val in enumerate(message_change_indices) if val >= x), -2),
|
||||
list(map(lambda x: x[1], flatten_message["offset_mapping"])),
|
||||
)
|
||||
)
|
||||
label_mask = np.roll(list(map(lambda x: x % 2 == 1, message_indices)), -1, -1)
|
||||
try:
|
||||
label_mask[[i for i in range(len(message_indices)) if message_indices[i] == -2][0] - 1] = True
|
||||
except IndexError:
|
||||
# due to truncation, we might not have the last termination token
|
||||
label_mask[-1] = False
|
||||
|
||||
label_masks.append(label_mask)
|
||||
|
||||
flatten_messages.append({k: v for k, v in flatten_message.items() if k != "offset_mapping"})
|
||||
|
||||
batch = self.tokenizer.pad(
|
||||
flatten_messages,
|
||||
padding=self.padding,
|
||||
max_length=self.max_length,
|
||||
pad_to_multiple_of=self.pad_to_multiple_of,
|
||||
return_tensors="pt",
|
||||
)
|
||||
dim = batch["input_ids"].shape[-1]
|
||||
|
||||
batch["label_masks"] = torch.stack(
|
||||
[F.pad(torch.tensor(x), (0, dim - len(x)), value=False) for x in label_masks]
|
||||
)
|
||||
batch["targets"] = torch.roll(batch["input_ids"], -1, -1)
|
||||
|
||||
return batch
|
||||
|
||||
|
||||
@dataclass
|
||||
class TrainDialogueDataCollator:
|
||||
"""
|
||||
Expects a list of texts corresponding to a sequence of [question, answer, question, answer, ...] pairs.
|
||||
"""
|
||||
|
||||
tokenizer: PreTrainedTokenizerBase
|
||||
padding: Union[bool, str, PaddingStrategy] = True
|
||||
max_length: Optional[int] = None
|
||||
|
||||
@@ -1,13 +1,17 @@
|
||||
import argparse
|
||||
from distutils.util import strtobool
|
||||
from functools import partial
|
||||
from typing import Any, Dict, List, Optional, Tuple, Union
|
||||
from typing import Any, Callable, Dict, List, Optional, Tuple, Union
|
||||
|
||||
import bitsandbytes
|
||||
import datasets
|
||||
import torch
|
||||
from torch import nn
|
||||
from torch.utils.data import DataLoader
|
||||
from transformers import PreTrainedModel, Trainer, TrainingArguments
|
||||
from transformers.trainer_pt_utils import IterableDatasetShard, seed_worker
|
||||
from transformers.training_args import OptimizerNames
|
||||
from transformers.utils import is_datasets_available
|
||||
from utils import get_dataset, get_loss, get_metrics, get_model, get_tokenizer, read_yamls
|
||||
|
||||
|
||||
@@ -30,12 +34,13 @@ class SFTTrainer(Trainer):
|
||||
self,
|
||||
model: Union[PreTrainedModel, nn.Module] = None,
|
||||
args: TrainingArguments = None,
|
||||
train_collate_fn: Callable = None,
|
||||
loss_function: str = "CrossEntropyLoss",
|
||||
poly_eps: float = 1.0,
|
||||
**kwargs,
|
||||
):
|
||||
super().__init__(model, args, **kwargs)
|
||||
|
||||
self.train_collate_fn = train_collate_fn
|
||||
# By default CrossEntropyLoss ignores padding_index -100, but just in case use our own loss_fct
|
||||
self.loss_fct = get_loss(loss_function, poly_eps)
|
||||
|
||||
@@ -88,6 +93,54 @@ class SFTTrainer(Trainer):
|
||||
|
||||
return (loss, logits, labels)
|
||||
|
||||
def get_train_dataloader(self) -> DataLoader:
|
||||
"""
|
||||
Returns the training [`~torch.utils.data.DataLoader`].
|
||||
Will use no sampler if `train_dataset` does not implement `__len__`, a random sampler (adapted to distributed
|
||||
training if necessary) otherwise.
|
||||
Subclass and override this method if you want to inject some custom behavior.
|
||||
"""
|
||||
if self.train_dataset is None:
|
||||
raise ValueError("Trainer: training requires a train_dataset.")
|
||||
|
||||
train_dataset = self.train_dataset
|
||||
data_collator = self.train_collate_fn
|
||||
if is_datasets_available() and isinstance(train_dataset, datasets.Dataset):
|
||||
train_dataset = self._remove_unused_columns(train_dataset, description="training")
|
||||
else:
|
||||
data_collator = self._get_collator_with_removed_columns(data_collator, description="training")
|
||||
|
||||
if isinstance(train_dataset, torch.utils.data.IterableDataset):
|
||||
if self.args.world_size > 1:
|
||||
train_dataset = IterableDatasetShard(
|
||||
train_dataset,
|
||||
batch_size=self._train_batch_size,
|
||||
drop_last=self.args.dataloader_drop_last,
|
||||
num_processes=self.args.world_size,
|
||||
process_index=self.args.process_index,
|
||||
)
|
||||
|
||||
return DataLoader(
|
||||
train_dataset,
|
||||
batch_size=self.args.per_device_train_batch_size,
|
||||
collate_fn=data_collator,
|
||||
num_workers=self.args.dataloader_num_workers,
|
||||
pin_memory=self.args.dataloader_pin_memory,
|
||||
)
|
||||
|
||||
train_sampler = self._get_train_sampler()
|
||||
|
||||
return DataLoader(
|
||||
train_dataset,
|
||||
batch_size=self._train_batch_size,
|
||||
sampler=train_sampler,
|
||||
collate_fn=data_collator,
|
||||
drop_last=self.args.dataloader_drop_last,
|
||||
num_workers=self.args.dataloader_num_workers,
|
||||
pin_memory=self.args.dataloader_pin_memory,
|
||||
worker_init_fn=seed_worker,
|
||||
)
|
||||
|
||||
|
||||
def _strtobool(x):
|
||||
return bool(strtobool(x))
|
||||
@@ -140,7 +193,7 @@ if __name__ == "__main__":
|
||||
tokenizer = get_tokenizer(training_conf)
|
||||
model = get_model(training_conf, tokenizer)
|
||||
|
||||
train, evals, collate_fn = get_dataset(training_conf, tokenizer)
|
||||
train, evals, collate_fn, train_collate_fn = get_dataset(training_conf, tokenizer)
|
||||
metrics, preprocess_fns = get_metrics(training_conf, tokenizer)
|
||||
|
||||
optimizer = OptimizerNames.ADAMW_BNB if training_conf.quantization else OptimizerNames.ADAMW_HF
|
||||
@@ -190,6 +243,7 @@ if __name__ == "__main__":
|
||||
trainer = SFTTrainer(
|
||||
model,
|
||||
args,
|
||||
train_collate_fn=train_collate_fn,
|
||||
loss_function=training_conf.loss_fn,
|
||||
poly_eps=training_conf.poly_eps,
|
||||
train_dataset=train,
|
||||
|
||||
@@ -8,7 +8,7 @@ import evaluate
|
||||
import transformers
|
||||
import yaml
|
||||
from custom_datasets import get_one_dataset
|
||||
from custom_datasets.dialogue_collator import DialogueDataCollator
|
||||
from custom_datasets.dialogue_collator import DialogueDataCollator, TrainDialogueDataCollator
|
||||
from custom_datasets.qa_datasets import QA_SPECIAL_TOKENS
|
||||
from losses import CrossEntropyLoss, PolyLoss
|
||||
from models import freeze_top_n_layers, get_specific_model
|
||||
@@ -126,8 +126,8 @@ def get_dataset(conf, tokenizer):
|
||||
train = ConcatDataset(train_datasets)
|
||||
|
||||
collate_fn = DialogueDataCollator(tokenizer, max_length=conf.max_length)
|
||||
|
||||
return train, evals, collate_fn
|
||||
train_collate_fn = TrainDialogueDataCollator(tokenizer, max_length=conf.max_length)
|
||||
return train, evals, collate_fn, train_collate_fn
|
||||
|
||||
|
||||
def get_loss(loss, poly_eps):
|
||||
|
||||
Reference in New Issue
Block a user