diff --git a/model/reward/instructor/rank_datasets.py b/model/reward/instructor/rank_datasets.py index 545f6269..81b6d08a 100644 --- a/model/reward/instructor/rank_datasets.py +++ b/model/reward/instructor/rank_datasets.py @@ -18,6 +18,7 @@ """ +import os from dataclasses import dataclass from typing import Dict, List, Optional, Union @@ -298,3 +299,44 @@ class AnthropicRLHF(Dataset): context, pair = self.pairs[index] return context, [pair] + + +class OAPrivate(Dataset): + """ + { + "prompt": , + "history": [("prompt1", "answer2"), ("prompt2", "answer2")], + "pos": , + "neg_replies": [list of bad answers] + } + """ + + split_name_mapping = { + "train": "rm_train.jsonl", + "test": "rm_test.jsonl", + "val": "rm_val.jsonl", + } + + def __init__(self, split="train", sep_token="", data_path=".cache") -> None: + super().__init__() + import json + + jsonl_file = os.path.join(data_path, self.split_name_mapping[split]) + self.pairs = [] + with open(jsonl_file, "r", encoding="utf-8") as f: + for line in f: + data = json.loads(line) + prefix = sep_token.join([sep_token.join(p) for p in data["history"][-2:]]) + prefix += sep_token + data["prompt"] + pair = [] + for neg_text in data["neg_replies"]: + pair.append((data["pos"], neg_text)) + self.pairs.append((prefix, pair)) + + def __len__(self): + return len(self.pairs) + + def __getitem__(self, index): + context, pair = self.pairs[index] + + return context, pair diff --git a/model/reward/instructor/utils.py b/model/reward/instructor/utils.py index 94a256c2..2e235213 100644 --- a/model/reward/instructor/utils.py +++ b/model/reward/instructor/utils.py @@ -115,7 +115,7 @@ def argument_parsing(parser): def get_datasets(dataset_list: List[AnyStr], tokenizer): - from rank_datasets import AnthropicRLHF, GPTJSynthetic, HFSummary, WebGPT + from rank_datasets import AnthropicRLHF, GPTJSynthetic, HFSummary, OAPrivate, WebGPT from torch.utils.data import ConcatDataset train_datasets, evals = [], {} @@ -141,5 +141,11 @@ def get_datasets(dataset_list: List[AnyStr], tokenizer): eval = AnthropicRLHF("test", tokenizer.sep_token) train_datasets.append(train) evals["anthropic_rlhf"] = eval + elif "oa_private" == dataset_name: + train = OAPrivate(split="train", sep_token=tokenizer.sep_token) + eval = OAPrivate(split="val", sep_token=tokenizer.sep_token) + train_datasets.append(train) + evals["oa_private"] = eval + train = ConcatDataset(train_datasets) return train, evals diff --git a/model/supervised_finetuning/custom_datasets/__init__.py b/model/supervised_finetuning/custom_datasets/__init__.py index ee061f04..0e5b9a91 100644 --- a/model/supervised_finetuning/custom_datasets/__init__.py +++ b/model/supervised_finetuning/custom_datasets/__init__.py @@ -1,8 +1,8 @@ """ High level functions for model training """ -from custom_datasets.prompt_dialogue import InstructionTuning, PromptGeneratedDataset -from custom_datasets.qa_datasets import SODA, JokeExplaination, QADataset, SODADialogue, WebGPT +from custom_datasets.prompt_dialogue import InstructionTuning, PrivateInstructionTuning, PromptGeneratedDataset +from custom_datasets.qa_datasets import SODA, JokeExplaination, QADataset, SODADialogue, TranslatedQA, WebGPT from custom_datasets.summarization import SummarizationDataset from custom_datasets.toxic_conversation import ProsocialDialogue, ProsocialDialogueExplaination from custom_datasets.translation import WMT2019, DiveMT, TEDTalk @@ -32,7 +32,7 @@ SUMMARIZATION_DATASETS = [ "debate_sum", "tldr_news", ] -OTHER = ["prosocial_dialogue", "explain_prosocial", "instruct_tuning"] +OTHER = ["prosocial_dialogue", "explain_prosocial", "instruct_tuning", "private_tuning", "oa_translated"] def train_val_dataset(dataset, val_split=0.2): @@ -92,6 +92,12 @@ def get_one_dataset(conf, dataset_name): elif dataset_name == "instruct_tuning": dataset = InstructionTuning(conf.cache_dir) train, eval = train_val_dataset(dataset, val_split=0.2) + elif dataset_name == "private_tuning": + dataset = PrivateInstructionTuning(conf.cache_dir) + train, eval = train_val_dataset(dataset, val_split=0.2) + elif dataset_name == "oa_translated": + dataset = TranslatedQA(conf.cache_dir) + train, eval = train_val_dataset(dataset, val_split=0.01) else: raise ValueError(f"Unknown dataset {dataset_name}") diff --git a/model/supervised_finetuning/custom_datasets/dialogue_collator.py b/model/supervised_finetuning/custom_datasets/dialogue_collator.py index c96ed576..0a0b7a5a 100644 --- a/model/supervised_finetuning/custom_datasets/dialogue_collator.py +++ b/model/supervised_finetuning/custom_datasets/dialogue_collator.py @@ -1,3 +1,4 @@ +import random from dataclasses import dataclass from typing import Optional, Union @@ -76,3 +77,108 @@ class DialogueDataCollator: batch["targets"] = torch.roll(batch["input_ids"], -1, -1) return batch + + +@dataclass +class TrainDialogueDataCollator: + """ + Expects a list of texts corresponding to a sequence of [question, answer, question, answer, ...] pairs. + """ + + tokenizer: PreTrainedTokenizerBase + padding: Union[bool, str, PaddingStrategy] = True + max_length: Optional[int] = None + mix_length_threshold: Optional[int] = 256 + mix_probability: Optional[int] = 0.6 + pad_to_multiple_of: Optional[int] = None + + def __call__(self, features): + flatten_messages = [] + label_masks = [] + total_short_context = 0 + for messages in features: + messages = list(messages) + + # Add a way for the model to terminate generation + # When we predict the start of a new expected question, we want to be able to stop generation + messages.append(self.tokenizer.eos_token) + + flatten_message = self.tokenizer( + "".join(messages), + truncation=True, + max_length=self.max_length, + return_offsets_mapping=True, + ) + + message_change_indices = np.cumsum([len(x) for x in messages[:-1]]) + # for each token an integer indicating the index of the message it belongs to. Just to create the label mask. + # Label mask is true when predicting a token that is part of the answer, false otherwise. + # TEXT: Question: Hello, how are you? Answer: I am fine. Question: What is your name? Answer: My name is John. Question: + # MESSAGE_INDICES: 0 0 0 0 0 0 1 1 1 2 2 2 2 2 2 3 3 3 3 -2 + # LABEL_MASK: 0 0 0 0 0 1 1 1 1 0 0 0 0 0 1 1 1 1 1 0 + + # If no result in next, we are predicting the last termination token(s) + message_indices = list( + map( + lambda x: next((i for i, val in enumerate(message_change_indices) if val >= x), -2), + list(map(lambda x: x[1], flatten_message["offset_mapping"])), + ) + ) + label_mask = np.roll(list(map(lambda x: x % 2 == 1, message_indices)), -1, -1) + try: + label_mask[[i for i in range(len(message_indices)) if message_indices[i] == -2][0] - 1] = True + except IndexError: + # due to truncation, we might not have the last termination token + label_mask[-1] = False + + label_masks.append(label_mask) + if len(flatten_message["input_ids"]) < self.mix_length_threshold: + total_short_context += len(flatten_message["input_ids"]) + flatten_messages.append({k: v for k, v in flatten_message.items() if k != "offset_mapping"}) + # packing + if total_short_context > 2: + _flatten_messages, _label_masks = [], [] + prev_short_msg, prev_short_mask = None, None + for flatten_msg, label_mask in zip(flatten_messages, label_masks): + if len(flatten_msg["input_ids"]) < self.mix_length_threshold and random.random() > 0.6: + if prev_short_msg is not None: + for key in flatten_msg.keys(): + flatten_msg[key] += prev_short_msg[key] + flatten_msg[key] = flatten_msg[key][: self.max_length] + label_mask = np.concatenate([label_mask, prev_short_mask]) + _label_masks.append(label_mask[: self.max_length]) + _flatten_messages.append(flatten_msg) + # reset + prev_short_msg, prev_short_mask = None, None + else: + # prime + prev_short_msg, prev_short_mask = flatten_msg, label_mask + else: + _label_masks.append(label_mask) + _flatten_messages.append(flatten_msg) + if prev_short_msg is not None: + for key in flatten_msg.keys(): + flatten_msg[key] += prev_short_msg[key] + flatten_msg[key] = flatten_msg[key][: self.max_length] + label_mask = np.concatenate([label_mask, prev_short_mask])[: self.max_length] + _label_masks.append(label_mask) + _flatten_messages.append(flatten_msg) + + label_masks = _label_masks + flatten_messages = _flatten_messages + + batch = self.tokenizer.pad( + flatten_messages, + padding=self.padding, + max_length=self.max_length, + pad_to_multiple_of=self.pad_to_multiple_of, + return_tensors="pt", + ) + dim = batch["input_ids"].shape[-1] + + batch["label_masks"] = torch.stack( + [F.pad(torch.tensor(x), (0, dim - len(x)), value=False) for x in label_masks] + ) + batch["targets"] = torch.roll(batch["input_ids"], -1, -1) + + return batch diff --git a/model/supervised_finetuning/custom_datasets/prompt_dialogue.py b/model/supervised_finetuning/custom_datasets/prompt_dialogue.py index 1c823934..4aac2655 100644 --- a/model/supervised_finetuning/custom_datasets/prompt_dialogue.py +++ b/model/supervised_finetuning/custom_datasets/prompt_dialogue.py @@ -2,7 +2,7 @@ import json import os from urllib.request import urlopen -from custom_datasets.formatting import format_pair +from custom_datasets.formatting import QA_SPECIAL_TOKENS, format_pair from torch.utils.data import Dataset @@ -102,3 +102,45 @@ class InstructionTuning(Dataset): def __getitem__(self, index): return format_pair(self.pairs[index]) + + +class PrivateInstructionTuning(Dataset): + """ + We have seen some promising capabilities from instruction tuning + with the following mix of datasets that are derived from datasets + available online. + The files for this data are in json format as a list of tuples + where each tuple is (source,instruction_response_pair) + + Not to be confused with unatural instruction + """ + + name = "private_tuning" + filename = "oa_v3_fixed_plus_safety.jsonl" + + def __init__(self, cache_dir) -> None: + super().__init__() + os.makedirs(cache_dir, exist_ok=True) + + self.pairs = [] + for file_link in [self.filename]: + basename = file_link.split("/")[-1] + instruction_tune_file = os.path.join(cache_dir, basename) + + with open(instruction_tune_file, "r", encoding="utf-8") as f: + for line in f: + row = json.loads(line) + prefix = "" + for _, convo in enumerate(row["text"].split("User:")): + if "Assistant" in convo: + prompt, answer = convo.split("Assistant:", maxsplit=1) + answer = answer.replace("<|endoftext|>", "").strip() + self.pairs.append((prefix + QA_SPECIAL_TOKENS["Question"] + prompt, answer)) + prefix += "".join(format_pair((prompt, answer))) + + def __len__(self): + return len(self.pairs) + + def __getitem__(self, index): + prompt, answer = self.pairs[index] + return "{}{}".format(prompt, QA_SPECIAL_TOKENS["Answer"]), answer diff --git a/model/supervised_finetuning/custom_datasets/qa_datasets.py b/model/supervised_finetuning/custom_datasets/qa_datasets.py index 092b7743..f126a270 100644 --- a/model/supervised_finetuning/custom_datasets/qa_datasets.py +++ b/model/supervised_finetuning/custom_datasets/qa_datasets.py @@ -1,6 +1,7 @@ """ Open / close book QA datasets """ +import glob import json import os import re @@ -137,11 +138,13 @@ class QADataset(Dataset): self.dataset = load_dataset(context["name"], **context["params"]) else: raise ValueError("Unknown dataset : " + dataset) + self.length = len(self.dataset) def __len__(self): - return len(self.dataset) + return self.length def __getitem__(self, idx): + data = self.dataset[idx] return format_pair(self.index_fn(data)) @@ -311,14 +314,75 @@ class JokeExplaination(Dataset): for line in f: data = json.loads(line) joke = data["joke"] - explanation = data["explanation"] + # DO NOT change this + # its the data that had syntax error + explanation = data["explaination"] self.pairs.append((joke, explanation)) if len(question) > 0 and len(answer) > 0: self.pairs.append((question, answer)) + self.length = len(self.pairs) def __len__(self): - return len(self.pairs) + return self.length def __getitem__(self, index): return format_pair(self.pairs[index]) + + +class TranslatedQA(Dataset): + """ + Translation OA v3 results + a list of non english translation of OA v3 instruction generated text in jsonl + format for each line: + { + "text": "User: ... Assistant: ....", + "meta": {"source": ... }, + "translate": [ + { "round": 1, "human":"...", "answer": "..."}, + ... + { "round": K, "human":"...", "answer": "..."}, + ] + } + Since OA contain some code we needed to reference the original text to skip these + """ + + name = "oa_translated" + + def __init__(self, cache_dir) -> None: + super().__init__() + os.makedirs(cache_dir, exist_ok=True) + path = os.path.join(cache_dir, self.name) + os.makedirs(path, exist_ok=True) + self.pairs = [] + for translated_jsonl in glob.glob(os.path.join(path, "*.jsonl")): + with open(translated_jsonl, "r") as fin: + for line in fin: + data = json.loads(line) + if "Python " in data["text"]: + # translation currently doesn't ignore code + # so we will have to reference original text + # for ignoring the translation + continue + prefix = "" + for convo_round in data["translate"]: + human, answer = format_pair((convo_round["human"], convo_round["answer"])) + if convo_round["round"] > 2: + self.pairs.append(("{}{}{}".format(prefix, "", human), answer)) + else: + self.pairs.append((human, answer)) + + prefix += "{}{}{}{}".format( + QA_SPECIAL_TOKENS["Question"], + convo_round["human"], + QA_SPECIAL_TOKENS["Answer"], + convo_round["answer"], + ) + + self.length = len(self.pairs) + + def __len__(self): + return self.length + + def __getitem__(self, index): + return self.pairs[index] diff --git a/model/supervised_finetuning/custom_datasets/summarization.py b/model/supervised_finetuning/custom_datasets/summarization.py index 834fa16c..aa644691 100644 --- a/model/supervised_finetuning/custom_datasets/summarization.py +++ b/model/supervised_finetuning/custom_datasets/summarization.py @@ -57,7 +57,7 @@ def index_summary_merge(text, summary): class SummarizationDataset(Dataset): def __init__(self, dataset, cache_dir, split, max_words=512): self.name = dataset - if summarization_config_mapping[dataset][0] in ["billsum", "tldr_news"] & split == "validation": + if (dataset in ["billsum", "tldr_news"]) and (split == "validation"): split = "test" self.dataset = load_dataset(*summarization_config_mapping[dataset], cache_dir=cache_dir, split=split) self.text_column, self.summary_column = summarization_name_mapping[dataset] diff --git a/model/supervised_finetuning/custom_datasets/translation.py b/model/supervised_finetuning/custom_datasets/translation.py index 008de751..cece0b43 100644 --- a/model/supervised_finetuning/custom_datasets/translation.py +++ b/model/supervised_finetuning/custom_datasets/translation.py @@ -62,7 +62,8 @@ TRANSLATION_PROMPT = { "{} how do we write in Malay", "{} give me the malay translation", "{} , berikan saya terjemahan dalam bahasa melayu", - "{}, Jemahan di bahasa melayu" "{}, jemahkan ayat ini kepada bahasa melayu", + "{}, Jemahan di bahasa melayu", + "{}, jemahkan ayat ini kepada bahasa melayu", ], "en": ["{}. translate to english", "{} write in english", "english translation: '{}'"], "ru": ["помогите мне перевести это на русский : {}", "{} перевести на русский язык", "russian translation: '{}'"], @@ -71,24 +72,40 @@ TRANSLATION_PROMPT = { "nl": ["{}. translate to dutch", "{} write in dutch", "dutch translation: '{}'"], "vi": ["{}. Dịch sang tiếng việt nam", "{} write in vietnamese", "vietnamese translation: '{}'"], "ar": ["{}. translate to arabic", "{} write in arabic", "arabic translation: '{}'"], + "es": ["{}. translate to spanish", "{} write in spanish", "spanish translation: '{}'"], + "hi": ["{}. translate to hindi", "{}. translate to bengali", "{} write in hindi", "bengali translation: '{}'"], } class TranslationPair(Dataset): - def __init__(self) -> None: + def __init__(self, mix_prob=0.2) -> None: super().__init__() self.pairs = [] + self.length = -1 + self.mix_prob = mix_prob def __len__(self): + if self.length < 0: + self.length = len(self.pairs) return len(self.pairs) def __getitem__(self, index): + if random.random() < self.mix_prob and index > 5 and index < (self.length - 5): + additional = random.randint(0, 10) - 5 + while additional == index: + additional = random.randint(0, 10) - 5 + + history_text = "".join(format_pair(self.pairs[additional + index])) + question, answer = self.pairs[index] + question = history_text + question + return format_pair((question, answer)) + return format_pair(self.pairs[index]) class WMT2019(TranslationPair): - def __init__(self, pair="zh-en", split="train") -> None: - super().__init__() + def __init__(self, pair="zh-en", split="train", mix_prob=0.2) -> None: + super().__init__(mix_prob=mix_prob) dataset = load_dataset("wmt19", pair)[split] self.pairs = [] src, tgt = pair.split("-") @@ -106,8 +123,8 @@ class DiveMT(TranslationPair): REMAP = {"tur": "tr", "ita": "it", "ukr": "uk", "nld": "nl", "vie": "vi", "ara": "ar"} - def __init__(self, split="train") -> None: - super().__init__() + def __init__(self, split="train", mix_prob=0.2) -> None: + super().__init__(mix_prob=mix_prob) dataset = load_dataset("GroNLP/divemt", "main")[split] tgt, src = "tgt_text", "src_text" for row in dataset: @@ -129,8 +146,8 @@ class DiveMT(TranslationPair): class TEDTalk(TranslationPair): # NOTE: DO NOT use chinese pair, mix with traditional and cantonese, not clean - def __init__(self, pair="de-ja", split="train", year="2016") -> None: - super().__init__() + def __init__(self, pair="de-ja", split="train", year="2016", mix_prob=0.2) -> None: + super().__init__(mix_prob=mix_prob) dataset = load_dataset("ted_talks_iwslt", language_pair=pair.split("-"), year=year)[split] src, tgt = pair.split("-") for row in dataset: diff --git a/model/supervised_finetuning/tests/test_datasets.py b/model/supervised_finetuning/tests/test_datasets.py index 8d5ad08f..2a0d4481 100644 --- a/model/supervised_finetuning/tests/test_datasets.py +++ b/model/supervised_finetuning/tests/test_datasets.py @@ -27,7 +27,7 @@ def test_collate_fn(): config = Namespace(cache_dir=".cache", model_name="Salesforce/codegen-2B-multi") tokenizer = get_tokenizer(config) - collate_fn = DialogueDataCollator(tokenizer, max_length=512) + collate_fn = DialogueDataCollator(tokenizer, max_length=620) qa_base = QA_DATASETS summarize_base = SUMMARIZATION_DATASETS others = ["prompt_dialogue", "webgpt", "soda", "joke", "gsm8k"] @@ -40,11 +40,14 @@ def test_collate_fn(): dataloader = DataLoader(ConcatDataset(trains), collate_fn=collate_fn, batch_size=128) for batch in dataloader: - print(batch.keys()) + print(batch["targets"].shape[0]) print(tokenizer.decode(batch["input_ids"][0])) print("-----") print(tokenizer.decode(batch["targets"][0][batch["label_masks"][0]])) - assert batch["targets"].shape[1] <= 512 + assert batch["targets"].shape[1] <= 620 dataloader = DataLoader(ConcatDataset(evals), collate_fn=collate_fn, batch_size=128) for batch in dataloader: - assert batch["targets"].shape[1] <= 512 + assert batch["targets"].shape[1] <= 620 + + +test_collate_fn() diff --git a/model/supervised_finetuning/trainer.py b/model/supervised_finetuning/trainer.py index fb5d4bee..ce80830b 100644 --- a/model/supervised_finetuning/trainer.py +++ b/model/supervised_finetuning/trainer.py @@ -1,14 +1,19 @@ import argparse from distutils.util import strtobool from functools import partial -from typing import Any, Dict, List, Optional, Tuple, Union +from typing import Any, Callable, Dict, List, Optional, Tuple, Union import bitsandbytes +import datasets import torch from efficiency_utils import fuse_gelu from torch import nn +from torch.utils.data import DataLoader from transformers import PreTrainedModel, Trainer, TrainingArguments +from transformers.trainer_pt_utils import IterableDatasetShard +from transformers.trainer_utils import seed_worker from transformers.training_args import OptimizerNames +from transformers.utils import is_datasets_available from utils import PerDatasetSampler, get_dataset, get_loss, get_metrics, get_model, get_tokenizer, read_yamls @@ -34,10 +39,11 @@ class SFTTrainer(Trainer): sampler: torch.utils.data.sampler.Sampler = None, loss_function: str = "CrossEntropyLoss", poly_eps: float = 1.0, + train_collate_fn: Callable = None, **kwargs, ): super().__init__(model, args, **kwargs) - + self.train_collate_fn = train_collate_fn # By default CrossEntropyLoss ignores padding_index -100, but just in case use our own loss_fct self.loss_fct = get_loss(loss_function, poly_eps) self.sampler = sampler @@ -92,30 +98,51 @@ class SFTTrainer(Trainer): return (loss, logits, labels) def get_train_dataloader(self): - """Inject custom data sampling behaviour into training loop""" - if self.sampler is None: - torch.utils.data.DataLoader( - self.train_dataset, - batch_size=self.args.per_device_train_batch_size, - shuffle=True, - collate_fn=self.data_collator, - ) - else: - dataloader = torch.utils.data.DataLoader( - self.train_dataset, - batch_size=self.args.per_device_train_batch_size, - sampler=self.sampler, - collate_fn=self.data_collator, - ) - if torch.cuda.device_count() <= 1: - return dataloader - else: - # Not strictly necessary to use accelerate, currently just - # ensures batches are padded to be divisible by # devices - from accelerate import Accelerator + """ + Inject custom data sampling behaviour into training loop + and use custom task mixing collate function : train_collate_fn - accelerator = Accelerator() - return accelerator.prepare(dataloader) + rewrite from: + https://github.com/huggingface/transformers/blob/67d074874d285e616393c65a0e670088e1b6b74a/src/transformers/trainer.py#L846 + """ + data_collator = self.train_collate_fn + train_dataset = self.train_dataset + if is_datasets_available() and isinstance(train_dataset, datasets.Dataset): + train_dataset = self._remove_unused_columns(train_dataset, description="training") + + if isinstance(train_dataset, torch.utils.data.IterableDataset): + # if we are using iterable dataset it means no weight sampling + # added for backward compat + if self.args.world_size > 1: + train_dataset = IterableDatasetShard( + train_dataset, + batch_size=self._train_batch_size, + drop_last=self.args.dataloader_drop_last, + num_processes=self.args.world_size, + process_index=self.args.process_index, + ) + return DataLoader( + train_dataset, + batch_size=self.args.per_device_train_batch_size, + collate_fn=data_collator, + num_workers=self.args.dataloader_num_workers, + pin_memory=self.args.dataloader_pin_memory, + ) + if self.sampler is None: + train_sampler = self._get_train_sampler() + else: + train_sampler = self.sampler + + return DataLoader( + train_dataset, + batch_size=self._train_batch_size, + sampler=train_sampler, + collate_fn=data_collator, + drop_last=self.args.dataloader_drop_last, + num_workers=self.args.dataloader_num_workers, + pin_memory=self.args.dataloader_pin_memory, + worker_init_fn=seed_worker, + ) def _strtobool(x): @@ -168,8 +195,7 @@ if __name__ == "__main__": tokenizer = get_tokenizer(training_conf) model = get_model(training_conf, tokenizer) - - train, evals, collate_fn = get_dataset(training_conf, tokenizer) + train, evals, collate_fn, train_collate_fn = get_dataset(training_conf, tokenizer) sampler = PerDatasetSampler.build_sampler_from_config(training_conf, train.datasets) metrics, preprocess_fns = get_metrics(training_conf, tokenizer) optimizer = OptimizerNames.ADAMW_BNB if training_conf.quantization else OptimizerNames.ADAMW_HF @@ -222,6 +248,7 @@ if __name__ == "__main__": model=model, args=args, sampler=sampler, + train_collate_fn=train_collate_fn, loss_function=training_conf.loss_fn, poly_eps=training_conf.poly_eps, train_dataset=train, diff --git a/model/supervised_finetuning/utils.py b/model/supervised_finetuning/utils.py index 1f77708f..43377bc9 100644 --- a/model/supervised_finetuning/utils.py +++ b/model/supervised_finetuning/utils.py @@ -6,7 +6,7 @@ import evaluate import transformers import yaml from custom_datasets import get_one_dataset -from custom_datasets.dialogue_collator import DialogueDataCollator +from custom_datasets.dialogue_collator import DialogueDataCollator, TrainDialogueDataCollator from custom_datasets.qa_datasets import QA_SPECIAL_TOKENS from losses import CrossEntropyLoss, PolyLoss from models import freeze_top_n_layers, get_specific_model @@ -126,7 +126,14 @@ def get_tokenizer(conf) -> transformers.AutoTokenizer: if tokenizer_config.special_tokens: if "GPT-JT" in conf.model_name: tokenizer_config.special_tokens.pad_token = tokenizer.eos_token - tokenizer.add_special_tokens(tokenizer_config.special_tokens) + # SpecialTokens : latest in 4.25, 4.26 + tokenizer.add_special_tokens( + { + "pad_token": tokenizer_config.special_tokens.pad_token, + "eos_token": tokenizer_config.special_tokens.eos_token, + "sep_token": tokenizer_config.special_tokens.sep_token, + } + ) additional_special_tokens = ( [] @@ -231,8 +238,8 @@ def get_dataset(conf, tokenizer): train = ConcatDataset(train_datasets) collate_fn = DialogueDataCollator(tokenizer, max_length=conf.max_length) - - return train, evals, collate_fn + train_collate_fn = TrainDialogueDataCollator(tokenizer, max_length=conf.max_length) + return train, evals, collate_fn, train_collate_fn def get_loss(loss, poly_eps):