mirror of
https://github.com/wassname/Open-Assistant.git
synced 2026-06-27 16:10:30 +08:00
Merge pull request #1455 from LAION-AI/add-dataset
Add more datasets and some fixes
This commit is contained in:
@@ -18,6 +18,7 @@
|
||||
|
||||
|
||||
"""
|
||||
import os
|
||||
from dataclasses import dataclass
|
||||
from typing import Dict, List, Optional, Union
|
||||
|
||||
@@ -298,3 +299,44 @@ class AnthropicRLHF(Dataset):
|
||||
context, pair = self.pairs[index]
|
||||
|
||||
return context, [pair]
|
||||
|
||||
|
||||
class OAPrivate(Dataset):
|
||||
"""
|
||||
{
|
||||
"prompt": <prompt string>,
|
||||
"history": [("prompt1", "answer2"), ("prompt2", "answer2")],
|
||||
"pos": <pos answer string>,
|
||||
"neg_replies": [list of bad answers]
|
||||
}
|
||||
"""
|
||||
|
||||
split_name_mapping = {
|
||||
"train": "rm_train.jsonl",
|
||||
"test": "rm_test.jsonl",
|
||||
"val": "rm_val.jsonl",
|
||||
}
|
||||
|
||||
def __init__(self, split="train", sep_token="<sep>", data_path=".cache") -> None:
|
||||
super().__init__()
|
||||
import json
|
||||
|
||||
jsonl_file = os.path.join(data_path, self.split_name_mapping[split])
|
||||
self.pairs = []
|
||||
with open(jsonl_file, "r", encoding="utf-8") as f:
|
||||
for line in f:
|
||||
data = json.loads(line)
|
||||
prefix = sep_token.join([sep_token.join(p) for p in data["history"][-2:]])
|
||||
prefix += sep_token + data["prompt"]
|
||||
pair = []
|
||||
for neg_text in data["neg_replies"]:
|
||||
pair.append((data["pos"], neg_text))
|
||||
self.pairs.append((prefix, pair))
|
||||
|
||||
def __len__(self):
|
||||
return len(self.pairs)
|
||||
|
||||
def __getitem__(self, index):
|
||||
context, pair = self.pairs[index]
|
||||
|
||||
return context, pair
|
||||
|
||||
@@ -115,7 +115,7 @@ def argument_parsing(parser):
|
||||
|
||||
|
||||
def get_datasets(dataset_list: List[AnyStr], tokenizer):
|
||||
from rank_datasets import AnthropicRLHF, GPTJSynthetic, HFSummary, WebGPT
|
||||
from rank_datasets import AnthropicRLHF, GPTJSynthetic, HFSummary, OAPrivate, WebGPT
|
||||
from torch.utils.data import ConcatDataset
|
||||
|
||||
train_datasets, evals = [], {}
|
||||
@@ -141,5 +141,11 @@ def get_datasets(dataset_list: List[AnyStr], tokenizer):
|
||||
eval = AnthropicRLHF("test", tokenizer.sep_token)
|
||||
train_datasets.append(train)
|
||||
evals["anthropic_rlhf"] = eval
|
||||
elif "oa_private" == dataset_name:
|
||||
train = OAPrivate(split="train", sep_token=tokenizer.sep_token)
|
||||
eval = OAPrivate(split="val", sep_token=tokenizer.sep_token)
|
||||
train_datasets.append(train)
|
||||
evals["oa_private"] = eval
|
||||
|
||||
train = ConcatDataset(train_datasets)
|
||||
return train, evals
|
||||
|
||||
@@ -1,8 +1,8 @@
|
||||
"""
|
||||
High level functions for model training
|
||||
"""
|
||||
from custom_datasets.prompt_dialogue import InstructionTuning, PromptGeneratedDataset
|
||||
from custom_datasets.qa_datasets import SODA, JokeExplaination, QADataset, SODADialogue, WebGPT
|
||||
from custom_datasets.prompt_dialogue import InstructionTuning, PrivateInstructionTuning, PromptGeneratedDataset
|
||||
from custom_datasets.qa_datasets import SODA, JokeExplaination, QADataset, SODADialogue, TranslatedQA, WebGPT
|
||||
from custom_datasets.summarization import SummarizationDataset
|
||||
from custom_datasets.toxic_conversation import ProsocialDialogue, ProsocialDialogueExplaination
|
||||
from custom_datasets.translation import WMT2019, DiveMT, TEDTalk
|
||||
@@ -32,7 +32,7 @@ SUMMARIZATION_DATASETS = [
|
||||
"debate_sum",
|
||||
"tldr_news",
|
||||
]
|
||||
OTHER = ["prosocial_dialogue", "explain_prosocial", "instruct_tuning"]
|
||||
OTHER = ["prosocial_dialogue", "explain_prosocial", "instruct_tuning", "private_tuning", "oa_translated"]
|
||||
|
||||
|
||||
def train_val_dataset(dataset, val_split=0.2):
|
||||
@@ -92,6 +92,12 @@ def get_one_dataset(conf, dataset_name):
|
||||
elif dataset_name == "instruct_tuning":
|
||||
dataset = InstructionTuning(conf.cache_dir)
|
||||
train, eval = train_val_dataset(dataset, val_split=0.2)
|
||||
elif dataset_name == "private_tuning":
|
||||
dataset = PrivateInstructionTuning(conf.cache_dir)
|
||||
train, eval = train_val_dataset(dataset, val_split=0.2)
|
||||
elif dataset_name == "oa_translated":
|
||||
dataset = TranslatedQA(conf.cache_dir)
|
||||
train, eval = train_val_dataset(dataset, val_split=0.01)
|
||||
else:
|
||||
raise ValueError(f"Unknown dataset {dataset_name}")
|
||||
|
||||
|
||||
@@ -1,3 +1,4 @@
|
||||
import random
|
||||
from dataclasses import dataclass
|
||||
from typing import Optional, Union
|
||||
|
||||
@@ -76,3 +77,108 @@ class DialogueDataCollator:
|
||||
batch["targets"] = torch.roll(batch["input_ids"], -1, -1)
|
||||
|
||||
return batch
|
||||
|
||||
|
||||
@dataclass
|
||||
class TrainDialogueDataCollator:
|
||||
"""
|
||||
Expects a list of texts corresponding to a sequence of [question, answer, question, answer, ...] pairs.
|
||||
"""
|
||||
|
||||
tokenizer: PreTrainedTokenizerBase
|
||||
padding: Union[bool, str, PaddingStrategy] = True
|
||||
max_length: Optional[int] = None
|
||||
mix_length_threshold: Optional[int] = 256
|
||||
mix_probability: Optional[int] = 0.6
|
||||
pad_to_multiple_of: Optional[int] = None
|
||||
|
||||
def __call__(self, features):
|
||||
flatten_messages = []
|
||||
label_masks = []
|
||||
total_short_context = 0
|
||||
for messages in features:
|
||||
messages = list(messages)
|
||||
|
||||
# Add a way for the model to terminate generation
|
||||
# When we predict the start of a new expected question, we want to be able to stop generation
|
||||
messages.append(self.tokenizer.eos_token)
|
||||
|
||||
flatten_message = self.tokenizer(
|
||||
"".join(messages),
|
||||
truncation=True,
|
||||
max_length=self.max_length,
|
||||
return_offsets_mapping=True,
|
||||
)
|
||||
|
||||
message_change_indices = np.cumsum([len(x) for x in messages[:-1]])
|
||||
# for each token an integer indicating the index of the message it belongs to. Just to create the label mask.
|
||||
# Label mask is true when predicting a token that is part of the answer, false otherwise.
|
||||
# TEXT: Question: Hello, how are you? Answer: I am fine. Question: What is your name? Answer: My name is John. Question:
|
||||
# MESSAGE_INDICES: 0 0 0 0 0 0 1 1 1 2 2 2 2 2 2 3 3 3 3 -2
|
||||
# LABEL_MASK: 0 0 0 0 0 1 1 1 1 0 0 0 0 0 1 1 1 1 1 0
|
||||
|
||||
# If no result in next, we are predicting the last termination token(s)
|
||||
message_indices = list(
|
||||
map(
|
||||
lambda x: next((i for i, val in enumerate(message_change_indices) if val >= x), -2),
|
||||
list(map(lambda x: x[1], flatten_message["offset_mapping"])),
|
||||
)
|
||||
)
|
||||
label_mask = np.roll(list(map(lambda x: x % 2 == 1, message_indices)), -1, -1)
|
||||
try:
|
||||
label_mask[[i for i in range(len(message_indices)) if message_indices[i] == -2][0] - 1] = True
|
||||
except IndexError:
|
||||
# due to truncation, we might not have the last termination token
|
||||
label_mask[-1] = False
|
||||
|
||||
label_masks.append(label_mask)
|
||||
if len(flatten_message["input_ids"]) < self.mix_length_threshold:
|
||||
total_short_context += len(flatten_message["input_ids"])
|
||||
flatten_messages.append({k: v for k, v in flatten_message.items() if k != "offset_mapping"})
|
||||
# packing
|
||||
if total_short_context > 2:
|
||||
_flatten_messages, _label_masks = [], []
|
||||
prev_short_msg, prev_short_mask = None, None
|
||||
for flatten_msg, label_mask in zip(flatten_messages, label_masks):
|
||||
if len(flatten_msg["input_ids"]) < self.mix_length_threshold and random.random() > 0.6:
|
||||
if prev_short_msg is not None:
|
||||
for key in flatten_msg.keys():
|
||||
flatten_msg[key] += prev_short_msg[key]
|
||||
flatten_msg[key] = flatten_msg[key][: self.max_length]
|
||||
label_mask = np.concatenate([label_mask, prev_short_mask])
|
||||
_label_masks.append(label_mask[: self.max_length])
|
||||
_flatten_messages.append(flatten_msg)
|
||||
# reset
|
||||
prev_short_msg, prev_short_mask = None, None
|
||||
else:
|
||||
# prime
|
||||
prev_short_msg, prev_short_mask = flatten_msg, label_mask
|
||||
else:
|
||||
_label_masks.append(label_mask)
|
||||
_flatten_messages.append(flatten_msg)
|
||||
if prev_short_msg is not None:
|
||||
for key in flatten_msg.keys():
|
||||
flatten_msg[key] += prev_short_msg[key]
|
||||
flatten_msg[key] = flatten_msg[key][: self.max_length]
|
||||
label_mask = np.concatenate([label_mask, prev_short_mask])[: self.max_length]
|
||||
_label_masks.append(label_mask)
|
||||
_flatten_messages.append(flatten_msg)
|
||||
|
||||
label_masks = _label_masks
|
||||
flatten_messages = _flatten_messages
|
||||
|
||||
batch = self.tokenizer.pad(
|
||||
flatten_messages,
|
||||
padding=self.padding,
|
||||
max_length=self.max_length,
|
||||
pad_to_multiple_of=self.pad_to_multiple_of,
|
||||
return_tensors="pt",
|
||||
)
|
||||
dim = batch["input_ids"].shape[-1]
|
||||
|
||||
batch["label_masks"] = torch.stack(
|
||||
[F.pad(torch.tensor(x), (0, dim - len(x)), value=False) for x in label_masks]
|
||||
)
|
||||
batch["targets"] = torch.roll(batch["input_ids"], -1, -1)
|
||||
|
||||
return batch
|
||||
|
||||
@@ -2,7 +2,7 @@ import json
|
||||
import os
|
||||
from urllib.request import urlopen
|
||||
|
||||
from custom_datasets.formatting import format_pair
|
||||
from custom_datasets.formatting import QA_SPECIAL_TOKENS, format_pair
|
||||
from torch.utils.data import Dataset
|
||||
|
||||
|
||||
@@ -102,3 +102,45 @@ class InstructionTuning(Dataset):
|
||||
|
||||
def __getitem__(self, index):
|
||||
return format_pair(self.pairs[index])
|
||||
|
||||
|
||||
class PrivateInstructionTuning(Dataset):
|
||||
"""
|
||||
We have seen some promising capabilities from instruction tuning
|
||||
with the following mix of datasets that are derived from datasets
|
||||
available online.
|
||||
The files for this data are in json format as a list of tuples
|
||||
where each tuple is (source,instruction_response_pair)
|
||||
|
||||
Not to be confused with unatural instruction
|
||||
"""
|
||||
|
||||
name = "private_tuning"
|
||||
filename = "oa_v3_fixed_plus_safety.jsonl"
|
||||
|
||||
def __init__(self, cache_dir) -> None:
|
||||
super().__init__()
|
||||
os.makedirs(cache_dir, exist_ok=True)
|
||||
|
||||
self.pairs = []
|
||||
for file_link in [self.filename]:
|
||||
basename = file_link.split("/")[-1]
|
||||
instruction_tune_file = os.path.join(cache_dir, basename)
|
||||
|
||||
with open(instruction_tune_file, "r", encoding="utf-8") as f:
|
||||
for line in f:
|
||||
row = json.loads(line)
|
||||
prefix = ""
|
||||
for _, convo in enumerate(row["text"].split("User:")):
|
||||
if "Assistant" in convo:
|
||||
prompt, answer = convo.split("Assistant:", maxsplit=1)
|
||||
answer = answer.replace("<|endoftext|>", "").strip()
|
||||
self.pairs.append((prefix + QA_SPECIAL_TOKENS["Question"] + prompt, answer))
|
||||
prefix += "".join(format_pair((prompt, answer)))
|
||||
|
||||
def __len__(self):
|
||||
return len(self.pairs)
|
||||
|
||||
def __getitem__(self, index):
|
||||
prompt, answer = self.pairs[index]
|
||||
return "{}{}".format(prompt, QA_SPECIAL_TOKENS["Answer"]), answer
|
||||
|
||||
@@ -1,6 +1,7 @@
|
||||
"""
|
||||
Open / close book QA datasets
|
||||
"""
|
||||
import glob
|
||||
import json
|
||||
import os
|
||||
import re
|
||||
@@ -137,11 +138,13 @@ class QADataset(Dataset):
|
||||
self.dataset = load_dataset(context["name"], **context["params"])
|
||||
else:
|
||||
raise ValueError("Unknown dataset : " + dataset)
|
||||
self.length = len(self.dataset)
|
||||
|
||||
def __len__(self):
|
||||
return len(self.dataset)
|
||||
return self.length
|
||||
|
||||
def __getitem__(self, idx):
|
||||
|
||||
data = self.dataset[idx]
|
||||
return format_pair(self.index_fn(data))
|
||||
|
||||
@@ -311,14 +314,75 @@ class JokeExplaination(Dataset):
|
||||
for line in f:
|
||||
data = json.loads(line)
|
||||
joke = data["joke"]
|
||||
explanation = data["explanation"]
|
||||
# DO NOT change this
|
||||
# its the data that had syntax error
|
||||
explanation = data["explaination"]
|
||||
self.pairs.append((joke, explanation))
|
||||
|
||||
if len(question) > 0 and len(answer) > 0:
|
||||
self.pairs.append((question, answer))
|
||||
self.length = len(self.pairs)
|
||||
|
||||
def __len__(self):
|
||||
return len(self.pairs)
|
||||
return self.length
|
||||
|
||||
def __getitem__(self, index):
|
||||
return format_pair(self.pairs[index])
|
||||
|
||||
|
||||
class TranslatedQA(Dataset):
|
||||
"""
|
||||
Translation OA v3 results
|
||||
a list of non english translation of OA v3 instruction generated text in jsonl
|
||||
format for each line:
|
||||
{
|
||||
"text": "User: ... Assistant: ....",
|
||||
"meta": {"source": ... },
|
||||
"translate": [
|
||||
{ "round": 1, "human":"...", "answer": "..."},
|
||||
...
|
||||
{ "round": K, "human":"...", "answer": "..."},
|
||||
]
|
||||
}
|
||||
Since OA contain some code we needed to reference the original text to skip these
|
||||
"""
|
||||
|
||||
name = "oa_translated"
|
||||
|
||||
def __init__(self, cache_dir) -> None:
|
||||
super().__init__()
|
||||
os.makedirs(cache_dir, exist_ok=True)
|
||||
path = os.path.join(cache_dir, self.name)
|
||||
os.makedirs(path, exist_ok=True)
|
||||
self.pairs = []
|
||||
for translated_jsonl in glob.glob(os.path.join(path, "*.jsonl")):
|
||||
with open(translated_jsonl, "r") as fin:
|
||||
for line in fin:
|
||||
data = json.loads(line)
|
||||
if "Python " in data["text"]:
|
||||
# translation currently doesn't ignore code
|
||||
# so we will have to reference original text
|
||||
# for ignoring the translation
|
||||
continue
|
||||
prefix = ""
|
||||
for convo_round in data["translate"]:
|
||||
human, answer = format_pair((convo_round["human"], convo_round["answer"]))
|
||||
if convo_round["round"] > 2:
|
||||
self.pairs.append(("{}{}{}".format(prefix, "<sep>", human), answer))
|
||||
else:
|
||||
self.pairs.append((human, answer))
|
||||
|
||||
prefix += "{}{}{}{}".format(
|
||||
QA_SPECIAL_TOKENS["Question"],
|
||||
convo_round["human"],
|
||||
QA_SPECIAL_TOKENS["Answer"],
|
||||
convo_round["answer"],
|
||||
)
|
||||
|
||||
self.length = len(self.pairs)
|
||||
|
||||
def __len__(self):
|
||||
return self.length
|
||||
|
||||
def __getitem__(self, index):
|
||||
return self.pairs[index]
|
||||
|
||||
@@ -57,7 +57,7 @@ def index_summary_merge(text, summary):
|
||||
class SummarizationDataset(Dataset):
|
||||
def __init__(self, dataset, cache_dir, split, max_words=512):
|
||||
self.name = dataset
|
||||
if summarization_config_mapping[dataset][0] in ["billsum", "tldr_news"] & split == "validation":
|
||||
if (dataset in ["billsum", "tldr_news"]) and (split == "validation"):
|
||||
split = "test"
|
||||
self.dataset = load_dataset(*summarization_config_mapping[dataset], cache_dir=cache_dir, split=split)
|
||||
self.text_column, self.summary_column = summarization_name_mapping[dataset]
|
||||
|
||||
@@ -62,7 +62,8 @@ TRANSLATION_PROMPT = {
|
||||
"{} how do we write in Malay",
|
||||
"{} give me the malay translation",
|
||||
"{} , berikan saya terjemahan dalam bahasa melayu",
|
||||
"{}, Jemahan di bahasa melayu" "{}, jemahkan ayat ini kepada bahasa melayu",
|
||||
"{}, Jemahan di bahasa melayu",
|
||||
"{}, jemahkan ayat ini kepada bahasa melayu",
|
||||
],
|
||||
"en": ["{}. translate to english", "{} write in english", "english translation: '{}'"],
|
||||
"ru": ["помогите мне перевести это на русский : {}", "{} перевести на русский язык", "russian translation: '{}'"],
|
||||
@@ -71,24 +72,40 @@ TRANSLATION_PROMPT = {
|
||||
"nl": ["{}. translate to dutch", "{} write in dutch", "dutch translation: '{}'"],
|
||||
"vi": ["{}. Dịch sang tiếng việt nam", "{} write in vietnamese", "vietnamese translation: '{}'"],
|
||||
"ar": ["{}. translate to arabic", "{} write in arabic", "arabic translation: '{}'"],
|
||||
"es": ["{}. translate to spanish", "{} write in spanish", "spanish translation: '{}'"],
|
||||
"hi": ["{}. translate to hindi", "{}. translate to bengali", "{} write in hindi", "bengali translation: '{}'"],
|
||||
}
|
||||
|
||||
|
||||
class TranslationPair(Dataset):
|
||||
def __init__(self) -> None:
|
||||
def __init__(self, mix_prob=0.2) -> None:
|
||||
super().__init__()
|
||||
self.pairs = []
|
||||
self.length = -1
|
||||
self.mix_prob = mix_prob
|
||||
|
||||
def __len__(self):
|
||||
if self.length < 0:
|
||||
self.length = len(self.pairs)
|
||||
return len(self.pairs)
|
||||
|
||||
def __getitem__(self, index):
|
||||
if random.random() < self.mix_prob and index > 5 and index < (self.length - 5):
|
||||
additional = random.randint(0, 10) - 5
|
||||
while additional == index:
|
||||
additional = random.randint(0, 10) - 5
|
||||
|
||||
history_text = "".join(format_pair(self.pairs[additional + index]))
|
||||
question, answer = self.pairs[index]
|
||||
question = history_text + question
|
||||
return format_pair((question, answer))
|
||||
|
||||
return format_pair(self.pairs[index])
|
||||
|
||||
|
||||
class WMT2019(TranslationPair):
|
||||
def __init__(self, pair="zh-en", split="train") -> None:
|
||||
super().__init__()
|
||||
def __init__(self, pair="zh-en", split="train", mix_prob=0.2) -> None:
|
||||
super().__init__(mix_prob=mix_prob)
|
||||
dataset = load_dataset("wmt19", pair)[split]
|
||||
self.pairs = []
|
||||
src, tgt = pair.split("-")
|
||||
@@ -106,8 +123,8 @@ class DiveMT(TranslationPair):
|
||||
|
||||
REMAP = {"tur": "tr", "ita": "it", "ukr": "uk", "nld": "nl", "vie": "vi", "ara": "ar"}
|
||||
|
||||
def __init__(self, split="train") -> None:
|
||||
super().__init__()
|
||||
def __init__(self, split="train", mix_prob=0.2) -> None:
|
||||
super().__init__(mix_prob=mix_prob)
|
||||
dataset = load_dataset("GroNLP/divemt", "main")[split]
|
||||
tgt, src = "tgt_text", "src_text"
|
||||
for row in dataset:
|
||||
@@ -129,8 +146,8 @@ class DiveMT(TranslationPair):
|
||||
class TEDTalk(TranslationPair):
|
||||
# NOTE: DO NOT use chinese pair, mix with traditional and cantonese, not clean
|
||||
|
||||
def __init__(self, pair="de-ja", split="train", year="2016") -> None:
|
||||
super().__init__()
|
||||
def __init__(self, pair="de-ja", split="train", year="2016", mix_prob=0.2) -> None:
|
||||
super().__init__(mix_prob=mix_prob)
|
||||
dataset = load_dataset("ted_talks_iwslt", language_pair=pair.split("-"), year=year)[split]
|
||||
src, tgt = pair.split("-")
|
||||
for row in dataset:
|
||||
|
||||
@@ -27,7 +27,7 @@ def test_collate_fn():
|
||||
|
||||
config = Namespace(cache_dir=".cache", model_name="Salesforce/codegen-2B-multi")
|
||||
tokenizer = get_tokenizer(config)
|
||||
collate_fn = DialogueDataCollator(tokenizer, max_length=512)
|
||||
collate_fn = DialogueDataCollator(tokenizer, max_length=620)
|
||||
qa_base = QA_DATASETS
|
||||
summarize_base = SUMMARIZATION_DATASETS
|
||||
others = ["prompt_dialogue", "webgpt", "soda", "joke", "gsm8k"]
|
||||
@@ -40,11 +40,14 @@ def test_collate_fn():
|
||||
|
||||
dataloader = DataLoader(ConcatDataset(trains), collate_fn=collate_fn, batch_size=128)
|
||||
for batch in dataloader:
|
||||
print(batch.keys())
|
||||
print(batch["targets"].shape[0])
|
||||
print(tokenizer.decode(batch["input_ids"][0]))
|
||||
print("-----")
|
||||
print(tokenizer.decode(batch["targets"][0][batch["label_masks"][0]]))
|
||||
assert batch["targets"].shape[1] <= 512
|
||||
assert batch["targets"].shape[1] <= 620
|
||||
dataloader = DataLoader(ConcatDataset(evals), collate_fn=collate_fn, batch_size=128)
|
||||
for batch in dataloader:
|
||||
assert batch["targets"].shape[1] <= 512
|
||||
assert batch["targets"].shape[1] <= 620
|
||||
|
||||
|
||||
test_collate_fn()
|
||||
|
||||
@@ -1,14 +1,19 @@
|
||||
import argparse
|
||||
from distutils.util import strtobool
|
||||
from functools import partial
|
||||
from typing import Any, Dict, List, Optional, Tuple, Union
|
||||
from typing import Any, Callable, Dict, List, Optional, Tuple, Union
|
||||
|
||||
import bitsandbytes
|
||||
import datasets
|
||||
import torch
|
||||
from efficiency_utils import fuse_gelu
|
||||
from torch import nn
|
||||
from torch.utils.data import DataLoader
|
||||
from transformers import PreTrainedModel, Trainer, TrainingArguments
|
||||
from transformers.trainer_pt_utils import IterableDatasetShard
|
||||
from transformers.trainer_utils import seed_worker
|
||||
from transformers.training_args import OptimizerNames
|
||||
from transformers.utils import is_datasets_available
|
||||
from utils import PerDatasetSampler, get_dataset, get_loss, get_metrics, get_model, get_tokenizer, read_yamls
|
||||
|
||||
|
||||
@@ -34,10 +39,11 @@ class SFTTrainer(Trainer):
|
||||
sampler: torch.utils.data.sampler.Sampler = None,
|
||||
loss_function: str = "CrossEntropyLoss",
|
||||
poly_eps: float = 1.0,
|
||||
train_collate_fn: Callable = None,
|
||||
**kwargs,
|
||||
):
|
||||
super().__init__(model, args, **kwargs)
|
||||
|
||||
self.train_collate_fn = train_collate_fn
|
||||
# By default CrossEntropyLoss ignores padding_index -100, but just in case use our own loss_fct
|
||||
self.loss_fct = get_loss(loss_function, poly_eps)
|
||||
self.sampler = sampler
|
||||
@@ -92,30 +98,51 @@ class SFTTrainer(Trainer):
|
||||
return (loss, logits, labels)
|
||||
|
||||
def get_train_dataloader(self):
|
||||
"""Inject custom data sampling behaviour into training loop"""
|
||||
if self.sampler is None:
|
||||
torch.utils.data.DataLoader(
|
||||
self.train_dataset,
|
||||
batch_size=self.args.per_device_train_batch_size,
|
||||
shuffle=True,
|
||||
collate_fn=self.data_collator,
|
||||
)
|
||||
else:
|
||||
dataloader = torch.utils.data.DataLoader(
|
||||
self.train_dataset,
|
||||
batch_size=self.args.per_device_train_batch_size,
|
||||
sampler=self.sampler,
|
||||
collate_fn=self.data_collator,
|
||||
)
|
||||
if torch.cuda.device_count() <= 1:
|
||||
return dataloader
|
||||
else:
|
||||
# Not strictly necessary to use accelerate, currently just
|
||||
# ensures batches are padded to be divisible by # devices
|
||||
from accelerate import Accelerator
|
||||
"""
|
||||
Inject custom data sampling behaviour into training loop
|
||||
and use custom task mixing collate function : train_collate_fn
|
||||
|
||||
accelerator = Accelerator()
|
||||
return accelerator.prepare(dataloader)
|
||||
rewrite from:
|
||||
https://github.com/huggingface/transformers/blob/67d074874d285e616393c65a0e670088e1b6b74a/src/transformers/trainer.py#L846
|
||||
"""
|
||||
data_collator = self.train_collate_fn
|
||||
train_dataset = self.train_dataset
|
||||
if is_datasets_available() and isinstance(train_dataset, datasets.Dataset):
|
||||
train_dataset = self._remove_unused_columns(train_dataset, description="training")
|
||||
|
||||
if isinstance(train_dataset, torch.utils.data.IterableDataset):
|
||||
# if we are using iterable dataset it means no weight sampling
|
||||
# added for backward compat
|
||||
if self.args.world_size > 1:
|
||||
train_dataset = IterableDatasetShard(
|
||||
train_dataset,
|
||||
batch_size=self._train_batch_size,
|
||||
drop_last=self.args.dataloader_drop_last,
|
||||
num_processes=self.args.world_size,
|
||||
process_index=self.args.process_index,
|
||||
)
|
||||
return DataLoader(
|
||||
train_dataset,
|
||||
batch_size=self.args.per_device_train_batch_size,
|
||||
collate_fn=data_collator,
|
||||
num_workers=self.args.dataloader_num_workers,
|
||||
pin_memory=self.args.dataloader_pin_memory,
|
||||
)
|
||||
if self.sampler is None:
|
||||
train_sampler = self._get_train_sampler()
|
||||
else:
|
||||
train_sampler = self.sampler
|
||||
|
||||
return DataLoader(
|
||||
train_dataset,
|
||||
batch_size=self._train_batch_size,
|
||||
sampler=train_sampler,
|
||||
collate_fn=data_collator,
|
||||
drop_last=self.args.dataloader_drop_last,
|
||||
num_workers=self.args.dataloader_num_workers,
|
||||
pin_memory=self.args.dataloader_pin_memory,
|
||||
worker_init_fn=seed_worker,
|
||||
)
|
||||
|
||||
|
||||
def _strtobool(x):
|
||||
@@ -168,8 +195,7 @@ if __name__ == "__main__":
|
||||
|
||||
tokenizer = get_tokenizer(training_conf)
|
||||
model = get_model(training_conf, tokenizer)
|
||||
|
||||
train, evals, collate_fn = get_dataset(training_conf, tokenizer)
|
||||
train, evals, collate_fn, train_collate_fn = get_dataset(training_conf, tokenizer)
|
||||
sampler = PerDatasetSampler.build_sampler_from_config(training_conf, train.datasets)
|
||||
metrics, preprocess_fns = get_metrics(training_conf, tokenizer)
|
||||
optimizer = OptimizerNames.ADAMW_BNB if training_conf.quantization else OptimizerNames.ADAMW_HF
|
||||
@@ -222,6 +248,7 @@ if __name__ == "__main__":
|
||||
model=model,
|
||||
args=args,
|
||||
sampler=sampler,
|
||||
train_collate_fn=train_collate_fn,
|
||||
loss_function=training_conf.loss_fn,
|
||||
poly_eps=training_conf.poly_eps,
|
||||
train_dataset=train,
|
||||
|
||||
@@ -6,7 +6,7 @@ import evaluate
|
||||
import transformers
|
||||
import yaml
|
||||
from custom_datasets import get_one_dataset
|
||||
from custom_datasets.dialogue_collator import DialogueDataCollator
|
||||
from custom_datasets.dialogue_collator import DialogueDataCollator, TrainDialogueDataCollator
|
||||
from custom_datasets.qa_datasets import QA_SPECIAL_TOKENS
|
||||
from losses import CrossEntropyLoss, PolyLoss
|
||||
from models import freeze_top_n_layers, get_specific_model
|
||||
@@ -126,7 +126,14 @@ def get_tokenizer(conf) -> transformers.AutoTokenizer:
|
||||
if tokenizer_config.special_tokens:
|
||||
if "GPT-JT" in conf.model_name:
|
||||
tokenizer_config.special_tokens.pad_token = tokenizer.eos_token
|
||||
tokenizer.add_special_tokens(tokenizer_config.special_tokens)
|
||||
# SpecialTokens : latest in 4.25, 4.26
|
||||
tokenizer.add_special_tokens(
|
||||
{
|
||||
"pad_token": tokenizer_config.special_tokens.pad_token,
|
||||
"eos_token": tokenizer_config.special_tokens.eos_token,
|
||||
"sep_token": tokenizer_config.special_tokens.sep_token,
|
||||
}
|
||||
)
|
||||
|
||||
additional_special_tokens = (
|
||||
[]
|
||||
@@ -231,8 +238,8 @@ def get_dataset(conf, tokenizer):
|
||||
train = ConcatDataset(train_datasets)
|
||||
|
||||
collate_fn = DialogueDataCollator(tokenizer, max_length=conf.max_length)
|
||||
|
||||
return train, evals, collate_fn
|
||||
train_collate_fn = TrainDialogueDataCollator(tokenizer, max_length=conf.max_length)
|
||||
return train, evals, collate_fn, train_collate_fn
|
||||
|
||||
|
||||
def get_loss(loss, poly_eps):
|
||||
|
||||
Reference in New Issue
Block a user