mirror of
https://github.com/wassname/Open-Assistant.git
synced 2026-06-27 16:10:30 +08:00
Merge pull request #699 from LAION-AI/sft-fixes
Fix supervised pretraining bugs and add new datasets
This commit is contained in:
@@ -58,6 +58,8 @@ the end to trigger deepspeed
|
||||
python trainer.py --configs defaults your-model-name --deepspeed
|
||||
```
|
||||
|
||||
## Dataset choices
|
||||
|
||||
## Results
|
||||
|
||||
Experimental results in wandb
|
||||
|
||||
@@ -6,7 +6,7 @@ defaults:
|
||||
per_device_eval_batch_size: 2
|
||||
weight_decay: 0.00
|
||||
warmup_steps: 600
|
||||
eval_steps: 100
|
||||
eval_steps: 500
|
||||
save_steps: 500
|
||||
max_length: 512
|
||||
num_train_epochs: 3
|
||||
@@ -18,7 +18,19 @@ defaults:
|
||||
datasets:
|
||||
- webgpt
|
||||
- prompt_dialogue
|
||||
cache_dir: ~/.cache
|
||||
- squad_v2
|
||||
- adversarial_qa
|
||||
- trivia_qa_nocontext
|
||||
- xsum
|
||||
- cnn_dailymail
|
||||
- prompt_dialogue
|
||||
- multi_news
|
||||
- scitldr
|
||||
- soda
|
||||
- joke
|
||||
- gsm8k
|
||||
- samsum
|
||||
cache_dir: .cache
|
||||
loss_fn: CrossEntropyLoss
|
||||
eval_size:
|
||||
log_dir: "base"
|
||||
@@ -48,14 +60,14 @@ gpt-jt:
|
||||
per_device_eval_batch_size: 4
|
||||
|
||||
codegen:
|
||||
learning_rate: 2e-6
|
||||
learning_rate: 8e-6
|
||||
model_name: Salesforce/codegen-2B-multi
|
||||
weight_decay: 0.01
|
||||
max_length: 812
|
||||
warmup_steps: 600
|
||||
max_length: 520
|
||||
warmup_steps: 1000
|
||||
gradient_checkpointing: false
|
||||
gradient_accumulation_steps: 5
|
||||
per_device_train_batch_size: 4
|
||||
gradient_accumulation_steps: 9
|
||||
per_device_train_batch_size: 2
|
||||
per_device_eval_batch_size: 4
|
||||
|
||||
debug:
|
||||
|
||||
@@ -1,136 +1,11 @@
|
||||
import numpy as np
|
||||
from datasets import load_dataset
|
||||
from custom_datasets.prompt_dialogue import PromptGeneratedDataset
|
||||
from custom_datasets.qa_datasets import SODA, JokeExplaination, QADataset, WebGPT
|
||||
from custom_datasets.summarization import SummarizationDataset
|
||||
from sklearn.model_selection import train_test_split
|
||||
from torch.utils.data import Dataset, Subset
|
||||
from torch.utils.data import Subset
|
||||
|
||||
from .prompt_dialogue import PromptGeneratedDataset
|
||||
|
||||
QA_SPECIAL_TOKENS = {"Question": "<question>", "Answer": "<answer>"}
|
||||
SUMMARIZATION_SPECIAL_TOKENS = {"Text": "", "Summary": "TL;DR:"}
|
||||
|
||||
summarization_name_mapping = {
|
||||
"cnn_dailymail": ("article", "highlights"),
|
||||
"samsum": ("dialogue", "summary"),
|
||||
"xsum": ("document", "summary"),
|
||||
"multi_news": ("document", "summary"),
|
||||
"scitldr": ("source", "target"),
|
||||
"billsum": ("text", "summary"),
|
||||
"reddit": ("content", "summary"),
|
||||
}
|
||||
summarization_config_mapping = {
|
||||
"cnn_dailymail": ("3.0.0",),
|
||||
"samsum": (),
|
||||
"xsum": (),
|
||||
"multi_news": (),
|
||||
"scitldr": ("AIC",),
|
||||
"billsum": (),
|
||||
"reddit": (),
|
||||
}
|
||||
|
||||
QA_DATASETS = ["squad_v2", "adversarial_qa", "trivia_qa_context", "trivia_qa_noconext"]
|
||||
SUMMARIZATION_DATASETS = ["xsum", "cnn_dailymail", "samsum", "multi_news"]
|
||||
|
||||
|
||||
def index_squad_v2(example):
|
||||
return example["title"] + ". " + example["context"] + " " + example["question"], example["answers"]["text"][0]
|
||||
|
||||
|
||||
def index_trivia_qa_nocontext(example):
|
||||
# dummy return one randomly
|
||||
return example["question"], example["answer"]["aliases"][np.random.randint(len(example["answer"]["aliases"]))]
|
||||
|
||||
|
||||
def index_trivia_qa_context(example):
|
||||
question = example["question"]
|
||||
title = example["title"][np.random.randint(len(example["title"]))]
|
||||
context = example["search_context"][np.random.randint(len(example["search_context"]))]
|
||||
answer = example["answer"]["aliases"][np.random.randint(len(example["answer"]["aliases"]))]
|
||||
|
||||
return title + ". " + context + " " + question, answer
|
||||
|
||||
|
||||
def index_adversarial_qa(example):
|
||||
return example["title"] + ". " + example["context"] + " " + example["question"], example["answers"]["text"][0]
|
||||
|
||||
|
||||
class QADataset(Dataset):
|
||||
def __init__(self, dataset, cache_dir, split):
|
||||
if dataset == "squad_v2":
|
||||
self.index_fn = index_squad_v2
|
||||
self.dataset = load_dataset("squad_v2", cache_dir=cache_dir, split=split)
|
||||
elif dataset == "trivia_qa_nocontext":
|
||||
self.index_fn = index_trivia_qa_nocontext
|
||||
self.dataset = load_dataset("trivia_qa", "rc.nocontext")
|
||||
elif dataset == "trivia_qa_context":
|
||||
self.index_fn = index_trivia_qa_context
|
||||
self.dataset = load_dataset("trivia_qa", "rc")
|
||||
elif dataset == "adversarial_qa":
|
||||
self.index_fn = index_adversarial_qa
|
||||
self.dataset = load_dataset("adversarial_qa", "adversarialQA")
|
||||
else:
|
||||
raise ValueError("Unknown dataset : " + dataset)
|
||||
|
||||
def __len__(self):
|
||||
return len(self.dataset)
|
||||
|
||||
def __getitem__(self, idx):
|
||||
data = self.dataset[idx]
|
||||
return self.index_fn(data)
|
||||
|
||||
|
||||
def index_summary_default(text, summary):
|
||||
return text, summary
|
||||
|
||||
|
||||
def index_summary_merge(text, summary):
|
||||
return " ".join(text), " ".join(summary)
|
||||
|
||||
|
||||
class SummarizationDataset(Dataset):
|
||||
def __init__(self, dataset, cache_dir, split):
|
||||
self.dataset = load_dataset(dataset, *summarization_config_mapping[dataset], cache_dir=cache_dir, split=split)
|
||||
self.summary_column, self.text_column = summarization_name_mapping[dataset]
|
||||
self.preprocess_fn = index_summary_merge if dataset == "scitdlr" else index_summary_merge
|
||||
|
||||
def __len__(self):
|
||||
return len(self.dataset)
|
||||
|
||||
def __getitem__(self, idx):
|
||||
data = self.dataset[idx]
|
||||
text, summary = data[self.text_column], data[self.summary_column]
|
||||
text, summary = self.preprocess_fn(text, summary)
|
||||
|
||||
return "".join(
|
||||
SUMMARIZATION_SPECIAL_TOKENS["Text"], text, " ", SUMMARIZATION_SPECIAL_TOKENS["Summary"], summary
|
||||
)
|
||||
|
||||
|
||||
class WebGPT(Dataset):
|
||||
def __init__(self) -> None:
|
||||
super().__init__()
|
||||
|
||||
dataset = load_dataset("openai/webgpt_comparisons")
|
||||
questions = {}
|
||||
# using prompt as our index will allows us
|
||||
# to add additional generated prompt later
|
||||
self.index2question = {}
|
||||
for row in dataset["train"]:
|
||||
question = row["question"]["full_text"]
|
||||
if question not in self.index2question:
|
||||
self.index2question[len(self.index2question)] = question
|
||||
|
||||
# only keep the best answer
|
||||
questions[question] = row["answer_0" if row["score_0"] > row["score_1"] else "answer_1"]
|
||||
|
||||
self.questions = questions
|
||||
|
||||
def __len__(self):
|
||||
return len(self.index2question)
|
||||
|
||||
def __getitem__(self, index):
|
||||
question = self.index2question[index]
|
||||
answer = self.questions[question]
|
||||
return [question, answer]
|
||||
QA_DATASETS = ["squad_v2", "adversarial_qa", "trivia_qa_context", "trivia_qa_nocontext", "gsm8k"]
|
||||
SUMMARIZATION_DATASETS = ["xsum", "cnn_dailymail", "samsum", "multi_news", "scitldr", "billsum"]
|
||||
|
||||
|
||||
def train_val_dataset(dataset, val_split=0.2):
|
||||
@@ -143,19 +18,26 @@ def train_val_dataset(dataset, val_split=0.2):
|
||||
def get_one_dataset(conf, dataset_name):
|
||||
dataset_name = dataset_name.lower()
|
||||
|
||||
if dataset_name in ["squad_v2", "adversarial_qa", "trivia_qa_context", "trivia_qa_noconext"]:
|
||||
if dataset_name in QA_DATASETS:
|
||||
train = QADataset(dataset_name, conf.cache_dir, "train")
|
||||
eval = QADataset(dataset_name, conf.cache_dir, "validation")
|
||||
val_name = "validation" if dataset_name not in ["gsm8k"] else "test"
|
||||
eval = QADataset(dataset_name, conf.cache_dir, val_name)
|
||||
|
||||
elif dataset_name in ["xsum", "cnn_dailymail", "samsum", "multi_news", "scitldr", "billsum", "reddit"]:
|
||||
elif dataset_name in SUMMARIZATION_DATASETS:
|
||||
train = SummarizationDataset(dataset_name, conf.cache_dir, "train")
|
||||
eval = SummarizationDataset(dataset_name, conf.cache_dir, "validation")
|
||||
|
||||
val_name = "validation" if dataset_name not in ["billsum"] else "test"
|
||||
eval = SummarizationDataset(dataset_name, conf.cache_dir, val_name)
|
||||
elif dataset_name == "webgpt":
|
||||
dataset = WebGPT()
|
||||
train, eval = train_val_dataset(dataset, val_split=0.2)
|
||||
elif dataset_name == "prompt_dialogue":
|
||||
dataset = PromptGeneratedDataset()
|
||||
dataset = PromptGeneratedDataset(conf.cache_dir)
|
||||
train, eval = train_val_dataset(dataset, val_split=0.2)
|
||||
elif dataset_name == "soda":
|
||||
dataset = SODA(conf.cache_dir)
|
||||
train, eval = train_val_dataset(dataset, val_split=0.1)
|
||||
elif dataset_name == "joke":
|
||||
dataset = JokeExplaination(conf.cache_dir)
|
||||
train, eval = train_val_dataset(dataset, val_split=0.2)
|
||||
else:
|
||||
raise ValueError(f"Unknown dataset {dataset_name}")
|
||||
|
||||
@@ -3,11 +3,10 @@ from typing import Optional, Union
|
||||
|
||||
import numpy as np
|
||||
import torch
|
||||
from custom_datasets.qa_datasets import QA_SPECIAL_TOKENS
|
||||
from torch.nn import functional as F
|
||||
from transformers.tokenization_utils_base import PaddingStrategy, PreTrainedTokenizerBase
|
||||
|
||||
from . import QA_SPECIAL_TOKENS
|
||||
|
||||
|
||||
@dataclass
|
||||
class DialogueDataCollator:
|
||||
@@ -35,7 +34,7 @@ class DialogueDataCollator:
|
||||
|
||||
# Add a way for the model to terminate generation
|
||||
# When we predict the start of a new expected question, we want to be able to stop generation
|
||||
messages.append(QA_SPECIAL_TOKENS["Question"])
|
||||
messages.append(self.tokenizer.eos_token)
|
||||
|
||||
flatten_message = self.tokenizer(
|
||||
"".join(messages),
|
||||
|
||||
@@ -16,10 +16,10 @@ class PromptGeneratedDataset(Dataset):
|
||||
|
||||
url = "https://github.com/Rallio67/language-model-agents/raw/main/chat_dialogue_v2_c.txt"
|
||||
|
||||
def __init__(self) -> None:
|
||||
def __init__(self, cache_dir) -> None:
|
||||
super().__init__()
|
||||
os.makedirs("datasets", exist_ok=True)
|
||||
chat_dialogue = os.path.join("datasets", "chat_dialogue_v2_c.txt")
|
||||
os.makedirs(cache_dir, exist_ok=True)
|
||||
chat_dialogue = os.path.join(cache_dir, "chat_dialogue_v2_c.txt")
|
||||
if not os.path.exists(chat_dialogue):
|
||||
with urlopen(self.url) as file:
|
||||
content = file.read().decode()
|
||||
@@ -49,18 +49,3 @@ class PromptGeneratedDataset(Dataset):
|
||||
def __getitem__(self, index):
|
||||
question, answer = self.pairs[index]
|
||||
return question, answer
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
from torch.utils.data import DataLoader
|
||||
from transformers import AutoTokenizer
|
||||
|
||||
from .dialogue_collator import DialogueDataCollator
|
||||
|
||||
tokenizer = AutoTokenizer.from_pretrained("Salesforce/codegen-2B-multi")
|
||||
tokenizer.add_special_tokens({"pad_token": "<|endoftext|>", "sep_token": "<|endoftext|>"})
|
||||
dataset = PromptGeneratedDataset()
|
||||
collate_fn = DialogueDataCollator(tokenizer, padding=True, max_length=128)
|
||||
dataloader = DataLoader(dataset, collate_fn=collate_fn, batch_size=5)
|
||||
for batch in dataloader:
|
||||
print(batch["input_ids"].shape)
|
||||
|
||||
@@ -0,0 +1,184 @@
|
||||
import json
|
||||
import os
|
||||
from urllib.request import urlopen
|
||||
|
||||
import numpy as np
|
||||
from datasets import load_dataset
|
||||
from torch.utils.data import Dataset
|
||||
|
||||
QA_SPECIAL_TOKENS = {"Question": "<human>", "Answer": "<bot>", "StartPrefix": "<prefix>", "EndPrefix": "</prefix>"}
|
||||
|
||||
|
||||
def index_squad_v2(example):
|
||||
if len(example["answers"]["text"]):
|
||||
answer = example["answers"]["text"][0]
|
||||
else:
|
||||
answer = "I do not have answer for that"
|
||||
return example["context"] + " " + example["question"], answer
|
||||
|
||||
|
||||
def index_trivia_qa_nocontext(example):
|
||||
# dummy return one randomly
|
||||
return example["question"], example["answer"]["aliases"][np.random.randint(len(example["answer"]["aliases"]))]
|
||||
|
||||
|
||||
def index_trivia_qa_context(example):
|
||||
question = example["question"]
|
||||
if len(example["search_results"]["search_context"]):
|
||||
context = example["search_results"]["search_context"][
|
||||
np.random.randint(len(example["search_results"]["search_context"]))
|
||||
]
|
||||
else:
|
||||
context = ""
|
||||
answer = example["answer"]["aliases"][np.random.randint(len(example["answer"]["aliases"]))]
|
||||
|
||||
return context + " " + question, answer
|
||||
|
||||
|
||||
def index_adversarial_qa(example):
|
||||
return example["title"] + ". " + example["context"] + " " + example["question"], example["answers"]["text"][0]
|
||||
|
||||
|
||||
def index_gsm8k(example):
|
||||
return example["question"], example["answer"]
|
||||
|
||||
|
||||
class QADataset(Dataset):
|
||||
def __init__(self, dataset, cache_dir, split):
|
||||
if dataset == "squad_v2":
|
||||
self.index_fn = index_squad_v2
|
||||
self.dataset = load_dataset("squad_v2", cache_dir=cache_dir, split=split)
|
||||
elif dataset == "trivia_qa_nocontext":
|
||||
self.index_fn = index_trivia_qa_nocontext
|
||||
self.dataset = load_dataset("trivia_qa", "rc.nocontext", split=split, cache_dir=cache_dir)
|
||||
elif dataset == "trivia_qa_context":
|
||||
self.index_fn = index_trivia_qa_context
|
||||
self.dataset = load_dataset("trivia_qa", "rc", split=split, cache_dir=cache_dir)
|
||||
elif dataset == "adversarial_qa":
|
||||
self.index_fn = index_adversarial_qa
|
||||
self.dataset = load_dataset("adversarial_qa", "adversarialQA", split=split, cache_dir=cache_dir)
|
||||
elif dataset == "gsm8k":
|
||||
self.index_fn = index_gsm8k
|
||||
self.dataset = load_dataset("gsm8k", "main", split=split, cache_dir=cache_dir)
|
||||
elif dataset == "adversarial_qa":
|
||||
self.index_fn = index_adversarial_qa
|
||||
self.dataset = load_dataset("adversarial_qa", "adversarialQA", split=split, cache_dir=cache_dir)
|
||||
else:
|
||||
raise ValueError("Unknown dataset : " + dataset)
|
||||
|
||||
def __len__(self):
|
||||
return len(self.dataset)
|
||||
|
||||
def __getitem__(self, idx):
|
||||
data = self.dataset[idx]
|
||||
return self.index_fn(data)
|
||||
|
||||
|
||||
class WebGPT(Dataset):
|
||||
def __init__(self) -> None:
|
||||
super().__init__()
|
||||
|
||||
dataset = load_dataset("openai/webgpt_comparisons")
|
||||
questions = {}
|
||||
# using prompt as our index will allows us
|
||||
# to add additional generated prompt later
|
||||
self.index2question = {}
|
||||
for row in dataset["train"]:
|
||||
question = row["question"]["full_text"]
|
||||
if question not in self.index2question:
|
||||
self.index2question[len(self.index2question)] = question
|
||||
|
||||
# only keep the best answer
|
||||
questions[question] = row["answer_0" if row["score_0"] > row["score_1"] else "answer_1"]
|
||||
|
||||
self.questions = questions
|
||||
|
||||
def __len__(self):
|
||||
return len(self.index2question)
|
||||
|
||||
def __getitem__(self, index):
|
||||
question = self.index2question[index]
|
||||
answer = self.questions[question]
|
||||
return [question, answer]
|
||||
|
||||
|
||||
class SODA(Dataset):
|
||||
def process_soda_convo(self, data):
|
||||
pairs = []
|
||||
play_as = data["speakers"][1]
|
||||
prefix = "<prefix>{}. {}</prefix>".format(data["narrative"], "your name {}".format(play_as))
|
||||
question, answer = "", ""
|
||||
prefix, postfix = "", ""
|
||||
previous_chat = []
|
||||
|
||||
for idx, convo in enumerate(data["dialogue"]):
|
||||
if idx % 2 == 0:
|
||||
question = convo
|
||||
prefix = data["speakers"][idx]
|
||||
else:
|
||||
answer = convo
|
||||
postfix = data["speakers"][idx]
|
||||
if len(question) and len(answer) and prefix != postfix and postfix == play_as:
|
||||
history = "<sep>".join(["{}<bot>{}".format(*p) for p in previous_chat])
|
||||
if len(history):
|
||||
history += "<sep>"
|
||||
pairs.append((prefix + history + question, answer))
|
||||
previous_chat.append((question, answer))
|
||||
return pairs
|
||||
|
||||
def __init__(self, cache_dir, max_sample_size=10000, input_max_length=1024) -> None:
|
||||
super().__init__()
|
||||
|
||||
self.pairs = []
|
||||
dataset = load_dataset("allenai/soda", cache_dir=cache_dir)["train"]
|
||||
for data in dataset:
|
||||
data_pair = self.process_soda_convo(data)
|
||||
for (prompt, answer) in data_pair:
|
||||
if len(prompt) < input_max_length:
|
||||
self.pairs.append((prompt, answer))
|
||||
|
||||
if len(self.pairs) > max_sample_size:
|
||||
break
|
||||
|
||||
def __len__(self):
|
||||
return len(self.pairs)
|
||||
|
||||
def __getitem__(self, index):
|
||||
question, answer = self.pairs[index]
|
||||
return question, answer
|
||||
|
||||
|
||||
class JokeExplaination(Dataset):
|
||||
""" """
|
||||
|
||||
url = "https://gist.github.com/theblackcat102/42b697e24a13fdb499e20edfbf618361/raw/1834dca207898c15f93b809d1195f6f6e47c9e1e/joke_explained.jsonl"
|
||||
|
||||
def __init__(self, cache_dir) -> None:
|
||||
super().__init__()
|
||||
os.makedirs(cache_dir, exist_ok=True)
|
||||
joke_explain_filename = os.path.join(cache_dir, "joke_explaination.jsonl")
|
||||
if not os.path.exists(joke_explain_filename):
|
||||
with urlopen(self.url) as file:
|
||||
content = file.read().decode()
|
||||
with open(joke_explain_filename, "w") as fout:
|
||||
fout.write(content)
|
||||
|
||||
question = ""
|
||||
answer = ""
|
||||
self.pairs = []
|
||||
with open(joke_explain_filename, "r") as f:
|
||||
for line in f:
|
||||
data = json.loads(line)
|
||||
joke = data["joke"]
|
||||
explanation = data["explaination"]
|
||||
self.pairs.append((joke, explanation))
|
||||
|
||||
if len(question) > 0 and len(answer) > 0:
|
||||
self.pairs.append((question, answer))
|
||||
|
||||
def __len__(self):
|
||||
return len(self.pairs)
|
||||
|
||||
def __getitem__(self, index):
|
||||
question, answer = self.pairs[index]
|
||||
return question, answer
|
||||
@@ -0,0 +1,62 @@
|
||||
import random
|
||||
|
||||
from datasets import load_dataset
|
||||
from torch.utils.data import Dataset
|
||||
|
||||
SUMMARIZATION_SPECIAL_TOKENS = {"Text": "", "Summary": ["TL;DR:", "Summarize this", "Give me the summary"]}
|
||||
|
||||
SUMMARY_SPECIAL_PROMPT = {
|
||||
"multi_news": ["Summarize in bullet points", "Generate summary in list of points"],
|
||||
"xsum": ["Give me summary in one sentence", "Short TLDR", "Give me a concise summary"],
|
||||
"samsum": ["TLDR;", "Summarize this dialogue", "Summarize dialogue"],
|
||||
}
|
||||
|
||||
summarization_config_mapping = {
|
||||
"cnn_dailymail": ("3.0.0",),
|
||||
"samsum": (),
|
||||
"xsum": (),
|
||||
"multi_news": (),
|
||||
"scitldr": ("AIC",),
|
||||
"billsum": (),
|
||||
"reddit": (),
|
||||
}
|
||||
|
||||
summarization_name_mapping = {
|
||||
"cnn_dailymail": ("article", "highlights"),
|
||||
"samsum": ("dialogue", "summary"),
|
||||
"xsum": ("document", "summary"),
|
||||
"multi_news": ("document", "summary"),
|
||||
"scitldr": ("source", "target"),
|
||||
"billsum": ("text", "summary"),
|
||||
"reddit": ("content", "summary"),
|
||||
}
|
||||
|
||||
|
||||
def index_summary_default(text, summary):
|
||||
return text.replace("\n\n", "\n"), summary
|
||||
|
||||
|
||||
def index_summary_merge(text, summary):
|
||||
return " ".join(text), " ".join(summary)
|
||||
|
||||
|
||||
class SummarizationDataset(Dataset):
|
||||
def __init__(self, dataset, cache_dir, split):
|
||||
self.name = dataset
|
||||
self.dataset = load_dataset(dataset, *summarization_config_mapping[dataset], cache_dir=cache_dir, split=split)
|
||||
self.text_column, self.summary_column = summarization_name_mapping[dataset]
|
||||
self.preprocess_fn = index_summary_merge if dataset == "scitldr" else index_summary_default
|
||||
|
||||
def __len__(self):
|
||||
return len(self.dataset)
|
||||
|
||||
def __getitem__(self, idx):
|
||||
data = self.dataset[idx]
|
||||
text, summary = data[self.text_column], data[self.summary_column]
|
||||
text, summary = self.preprocess_fn(text, summary)
|
||||
if self.name in SUMMARY_SPECIAL_PROMPT:
|
||||
prompt = random.choice(SUMMARIZATION_SPECIAL_TOKENS["Summary"])
|
||||
else:
|
||||
prompt = random.choice(SUMMARIZATION_SPECIAL_TOKENS["Summary"])
|
||||
|
||||
return ("".join([SUMMARIZATION_SPECIAL_TOKENS["Text"], " ".join(text.split(" ")[:256]), prompt]), summary)
|
||||
@@ -0,0 +1,54 @@
|
||||
from argparse import Namespace
|
||||
|
||||
from custom_datasets import QA_DATASETS, SUMMARIZATION_DATASETS, get_one_dataset
|
||||
from custom_datasets.dialogue_collator import DialogueDataCollator
|
||||
|
||||
|
||||
def test_all_datasets():
|
||||
qa_base = QA_DATASETS
|
||||
summarize_base = SUMMARIZATION_DATASETS
|
||||
others = ["prompt_dialogue", "webgpt", "soda", "joke"]
|
||||
|
||||
config = Namespace(cache_dir=".cache")
|
||||
for dataset_name in others + qa_base + summarize_base:
|
||||
print(dataset_name)
|
||||
train, eval = get_one_dataset(config, dataset_name)
|
||||
# sanity check
|
||||
for idx in range(min(len(train), 1000)):
|
||||
train[idx]
|
||||
for idx in range(min(len(eval), 1000)):
|
||||
eval[idx]
|
||||
|
||||
|
||||
def test_collate_fn():
|
||||
from torch.utils.data import ConcatDataset, DataLoader
|
||||
from utils import get_tokenizer
|
||||
|
||||
config = Namespace(cache_dir=".cache", model_name="Salesforce/codegen-2B-multi")
|
||||
tokenizer = get_tokenizer(config)
|
||||
collate_fn = DialogueDataCollator(tokenizer, max_length=512)
|
||||
qa_base = QA_DATASETS
|
||||
summarize_base = SUMMARIZATION_DATASETS
|
||||
others = ["prompt_dialogue", "webgpt", "soda", "joke", "gsm8k"]
|
||||
|
||||
trains, evals = [], []
|
||||
for dataset_name in others + qa_base + summarize_base:
|
||||
print(dataset_name)
|
||||
train, eval = get_one_dataset(config, dataset_name)
|
||||
trains.append(train)
|
||||
evals.append(eval)
|
||||
|
||||
dataloader = DataLoader(ConcatDataset(trains), collate_fn=collate_fn, batch_size=128)
|
||||
for batch in dataloader:
|
||||
# print(batch.keys())
|
||||
# print(tokenizer.decode(batch['input_ids'][0]))
|
||||
# print('-----')
|
||||
# print(tokenizer.decode(batch['targets'][0][batch['label_masks'][0]]))
|
||||
assert batch["targets"].shape[1] <= 512
|
||||
dataloader = DataLoader(ConcatDataset(evals), collate_fn=collate_fn, batch_size=128)
|
||||
for batch in dataloader:
|
||||
assert batch["targets"].shape[1] <= 512
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
test_collate_fn()
|
||||
@@ -0,0 +1,9 @@
|
||||
from argparse import Namespace
|
||||
|
||||
from utils import get_tokenizer
|
||||
|
||||
|
||||
def test_tokenizer():
|
||||
get_tokenizer(Namespace(model_name="Salesforce/codegen-2B-multi", cache_dir=".cache"))
|
||||
get_tokenizer(Namespace(model_name="facebook/galactica-1.3b", cache_dir=".cache"))
|
||||
get_tokenizer(Namespace(model_name="", cache_dir=".cache"))
|
||||
@@ -1,13 +1,15 @@
|
||||
from functools import partial
|
||||
# from functools import partial
|
||||
from pathlib import Path
|
||||
|
||||
import evaluate
|
||||
import nltk
|
||||
import numpy as np
|
||||
|
||||
# import nltk
|
||||
# import numpy as np
|
||||
import transformers
|
||||
import yaml
|
||||
from custom_datasets import QA_DATASETS, QA_SPECIAL_TOKENS, SUMMARIZATION_DATASETS, get_one_dataset
|
||||
from custom_datasets import get_one_dataset
|
||||
from custom_datasets.dialogue_collator import DialogueDataCollator
|
||||
from custom_datasets.qa_datasets import QA_SPECIAL_TOKENS
|
||||
from losses import CrossEntropyLoss, PolyLoss
|
||||
from models import freeze_top_n_layers, get_specific_model
|
||||
from sklearn.model_selection import train_test_split
|
||||
@@ -51,25 +53,25 @@ def preprocess_qa(eval_pred):
|
||||
return (eval_pred.predictions, eval_pred.label_ids)
|
||||
|
||||
|
||||
def postprocess_summarization(preds, labels):
|
||||
preds = [pred.strip() for pred in preds]
|
||||
labels = [label.strip() for label in labels]
|
||||
# def postprocess_summarization(preds, labels):
|
||||
# preds = [pred.strip() for pred in preds]
|
||||
# labels = [label.strip() for label in labels]
|
||||
|
||||
preds = ["\n".join(nltk.sent_tokenize(pred)) for pred in preds]
|
||||
labels = ["\n".join(nltk.sent_tokenize(label)) for label in labels]
|
||||
# preds = ["\n".join(nltk.sent_tokenize(pred)) for pred in preds]
|
||||
# labels = ["\n".join(nltk.sent_tokenize(label)) for label in labels]
|
||||
|
||||
return preds, labels
|
||||
# return preds, labels
|
||||
|
||||
|
||||
def preprocess_summarization(eval_pred, tokenizer, ignore_pad_token_for_loss=True):
|
||||
preds, labels = eval_pred
|
||||
decoded_preds = tokenizer.batch_decode(preds, skip_special_tokens=True)
|
||||
if ignore_pad_token_for_loss:
|
||||
labels = np.where(labels != -100, labels, tokenizer.pad_token_id)
|
||||
decoded_labels = tokenizer.batch_decode(labels, skip_special_tokens=True)
|
||||
# def preprocess_summarization(eval_pred, tokenizer, ignore_pad_token_for_loss=True):
|
||||
# preds, labels = eval_pred
|
||||
# decoded_preds = tokenizer.batch_decode(preds, skip_special_tokens=True)
|
||||
# if ignore_pad_token_for_loss:
|
||||
# labels = np.where(labels != -100, labels, tokenizer.pad_token_id)
|
||||
# decoded_labels = tokenizer.batch_decode(labels, skip_special_tokens=True)
|
||||
|
||||
decoded_preds, decoded_labels = postprocess_summarization(decoded_preds, decoded_labels)
|
||||
return decoded_preds, decoded_labels
|
||||
# decoded_preds, decoded_labels = postprocess_summarization(decoded_preds, decoded_labels)
|
||||
# return decoded_preds, decoded_labels
|
||||
|
||||
|
||||
def get_metrics(conf, tokenizer):
|
||||
@@ -77,16 +79,16 @@ def get_metrics(conf, tokenizer):
|
||||
# metrics in the future for more thorough evaluation
|
||||
metrics, preprocess_fns = [evaluate.load("accuracy")], [default_preprocess]
|
||||
|
||||
if any(dataset in QA_DATASETS for dataset in conf.datasets):
|
||||
raise ValueError("TODO")
|
||||
metrics.append(evaluate.load("squad_v2"))
|
||||
preprocess_fns.append(preprocess_qa)
|
||||
if any(dataset in SUMMARIZATION_DATASETS for dataset in conf.datasets):
|
||||
raise ValueError("TODO")
|
||||
metrics.append(evaluate.load("rouge"))
|
||||
preprocess_fns.append(
|
||||
partial(preprocess_summarization, tokenizer, ignore_pad_token_for_loss=conf.ignore_pad_token_for_loss)
|
||||
)
|
||||
# if any(dataset in QA_DATASETS for dataset in conf.datasets):
|
||||
# raise ValueError("TODO")
|
||||
# metrics.append(evaluate.load("squad_v2"))
|
||||
# preprocess_fns.append(preprocess_qa)
|
||||
# if any(dataset in SUMMARIZATION_DATASETS for dataset in conf.datasets):
|
||||
# raise ValueError("TODO")
|
||||
# metrics.append(evaluate.load("rouge"))
|
||||
# preprocess_fns.append(
|
||||
# partial(preprocess_summarization, tokenizer, ignore_pad_token_for_loss=conf.ignore_pad_token_for_loss)
|
||||
# )
|
||||
|
||||
return metrics, preprocess_fns
|
||||
|
||||
|
||||
Reference in New Issue
Block a user