mirror of
https://github.com/wassname/GENIES.git
synced 2026-06-27 16:10:25 +08:00
252 lines
9.7 KiB
Python
252 lines
9.7 KiB
Python
import transformers
|
|
from tqdm import tqdm
|
|
from torch.nn import CrossEntropyLoss
|
|
from transformers import PreTrainedTokenizer
|
|
from api.model import Model
|
|
import re
|
|
import copy
|
|
from torch.utils.data import Dataset
|
|
import api.util as util
|
|
import torch
|
|
from torch.utils.data import DataLoader
|
|
from typing import Dict, Sequence, Any, List, Optional
|
|
from dataclasses import dataclass
|
|
|
|
|
|
class Distribution:
|
|
def __init__(self, dir: str):
|
|
self.dir = dir
|
|
training_examples = util.load_json(f"{dir}/train.json")
|
|
test_examples = util.load_json(f"{dir}/test.json")
|
|
meta_data = util.load_json(f"{dir}/meta_data.json")
|
|
|
|
self.formats = meta_data["formats"]
|
|
self.id = meta_data["id"]
|
|
|
|
# Add prompts for each example
|
|
for example in training_examples:
|
|
self.format_example(self.formats, example)
|
|
for example in test_examples:
|
|
self.format_example(self.formats, example)
|
|
|
|
self.training_dataset = MCDataset(
|
|
training_examples, distribution_id=self.id, distribution_dir=self.dir
|
|
)
|
|
self.test_dataset = MCDataset(
|
|
test_examples, distribution_id=self.id, distribution_dir=self.dir
|
|
)
|
|
|
|
@staticmethod
|
|
def format_example(formats: List[str], example: dict) -> dict:
|
|
all_format_keys = [re.findall(r"\{(.+?)\}", format) for format in formats]
|
|
all_format_keys = set([key for keys in all_format_keys for key in keys])
|
|
example_keys = set(example.keys())
|
|
example_format_keys = example_keys.intersection(all_format_keys)
|
|
|
|
for format in formats:
|
|
keys_for_format_string = set(re.findall(r"\{(.+?)\}", format))
|
|
if keys_for_format_string == example_format_keys:
|
|
example["prompt"] = format.format(**example)
|
|
return example
|
|
raise Exception(
|
|
f"An example could not be matched to a format string.\nExample: {example}\nFormat strings: {formats}"
|
|
)
|
|
|
|
def __hash__(self):
|
|
return hash(self.id)
|
|
|
|
def __eq__(self, other):
|
|
return self.id == other.id
|
|
|
|
|
|
class MCDataset(Dataset):
|
|
def __init__(self, examples, distribution_id, distribution_dir, max_examples=None):
|
|
self.distribution_dir = distribution_dir
|
|
self.distribution_id = distribution_id
|
|
self.examples = examples
|
|
for i, example in enumerate(self.examples):
|
|
example["id"] = i
|
|
self.max_examples = max_examples
|
|
|
|
def __len__(self):
|
|
if self.max_examples != None:
|
|
return min(len(self.examples), self.max_examples)
|
|
return len(self.examples)
|
|
|
|
def __getitem__(self, idx):
|
|
return self.examples[idx]
|
|
|
|
def set_max_examples(self, max_examples):
|
|
self.max_examples = max_examples
|
|
self.examples = self.examples[:max_examples]
|
|
|
|
def filter_out_long_examples(self, tokenizer: PreTrainedTokenizer):
|
|
filtered_examples = []
|
|
for example in self.examples:
|
|
tokenized = [
|
|
tokenizer.encode(example["prompt"] + r + tokenizer.eos_token)
|
|
for r in example["responses"]
|
|
]
|
|
if max([len(t) for t in tokenized]) <= tokenizer.model_max_length:
|
|
filtered_examples.append(example)
|
|
num_examples_filtered_out = len(self.examples) - len(filtered_examples)
|
|
|
|
if num_examples_filtered_out > 0:
|
|
util.print_once(
|
|
f"Filtered out {num_examples_filtered_out} examples because they exceeded the max length of {tokenizer.model_max_length}"
|
|
)
|
|
|
|
self.examples = filtered_examples
|
|
return filtered_examples
|
|
|
|
def get_example_scores(
|
|
self, model: Model, accelerator: Optional[Accelerator], per_device_batch_size
|
|
) -> List[dict]:
|
|
# First filter out examples that exceed the models maximum sequence lenght
|
|
self.filter_out_long_examples(model.tokenizer)
|
|
|
|
dataloader = DataLoader(
|
|
self,
|
|
batch_size=per_device_batch_size,
|
|
shuffle=False,
|
|
collate_fn=MCDataCollator(tokenizer=model.tokenizer),
|
|
)
|
|
|
|
model.hf_model.eval()
|
|
model.hf_model, dataloader = accelerator.prepare(model.hf_model, dataloader)
|
|
|
|
example_scores = [] # a list of tuples. Each tuple is (score, example id)
|
|
for batch in tqdm(dataloader):
|
|
batch_scores = self.get_scores_for_batch(model.hf_model, batch)
|
|
ids = batch.pop("ids")
|
|
example_scores.extend(list(zip(batch_scores, ids)))
|
|
|
|
# Gather tensors from different devices
|
|
example_scores = torch.tensor(example_scores, dtype=torch.float32).to(model.hf_model.device)
|
|
example_scores = accelerator.gather(example_scores).cpu()
|
|
|
|
accelerator.free_memory()
|
|
|
|
# Reorder flattened scores to ensure indices match
|
|
ordered_example_scores = [None] * len(self.examples)
|
|
for i in range(example_scores.shape[0]):
|
|
ordered_example_scores[int(example_scores[i, 1])] = float(example_scores[i, 0])
|
|
return ordered_example_scores
|
|
|
|
@staticmethod
|
|
def get_scores_for_batch(model, batch):
|
|
output = model(input_ids=batch["input_ids"], attention_mask=batch["attention_mask"])
|
|
logits = output.logits
|
|
|
|
# Shift so that tokens < n predict n
|
|
logits_for_batch = logits[:, :-1, :]
|
|
shift_labels_batch = batch["labels"][:, 1:]
|
|
|
|
# Flatten the tokens
|
|
shift_logits_batch = logits_for_batch.contiguous().view(-1, logits.shape[-1])
|
|
shift_labels_batch = shift_labels_batch.contiguous().view(-1)
|
|
|
|
# Cross-entropy loss
|
|
loss_fct = CrossEntropyLoss(reduction="none") # use "none" to get loss per item
|
|
batch_losses = loss_fct(shift_logits_batch, shift_labels_batch)
|
|
avg_log_probs = -batch_losses.view(logits.size(0), -1).mean(dim=1)
|
|
|
|
# Unflatten avg log probs
|
|
response_labels = batch["response_labels"]
|
|
response_labels = response_labels.to(model.device).to(dtype=avg_log_probs.dtype)
|
|
response_labels.requires_grad = True
|
|
avg_log_probs_unflattened = avg_log_probs.contiguous().view(
|
|
response_labels.shape[0], response_labels.shape[1]
|
|
)
|
|
response_probabilities = torch.softmax(avg_log_probs_unflattened, dim=1).to(model.device)
|
|
example_scores = (response_probabilities * response_labels).sum(dim=1)
|
|
return example_scores
|
|
|
|
|
|
@dataclass
|
|
class MCDataCollator:
|
|
tokenizer: transformers.PreTrainedTokenizer
|
|
|
|
# Returns a dict with the keys inputs, labels, attention_mask, and scores. Each key is an array that represents the flattened MC options. scores is a list of lists. Each sublist represents options for a single example and the values of are the scores for the response.
|
|
def __call__(self, instances: Sequence[Dict]):
|
|
response_labels = torch.tensor(
|
|
[list(instance["responses"].values()) for instance in instances]
|
|
)
|
|
ids = [instance["id"] for instance in instances]
|
|
flattened_instances = [
|
|
{"prompt": instance["prompt"], "response": response}
|
|
for instance in instances
|
|
for response in instance["responses"]
|
|
]
|
|
|
|
collator = SupervisedDataCollator(tokenizer=self.tokenizer)
|
|
model_inputs = collator(flattened_instances)
|
|
model_inputs.update({"response_labels": response_labels})
|
|
model_inputs.update({"ids": ids})
|
|
|
|
return model_inputs
|
|
|
|
|
|
@dataclass
|
|
class SupervisedDataCollator:
|
|
"""Collate examples for supervised fine-tuning."""
|
|
|
|
tokenizer: transformers.PreTrainedTokenizer
|
|
|
|
@staticmethod
|
|
def tokenize_prompts_and_responses(
|
|
prompts: Sequence[str],
|
|
responses: Sequence[str],
|
|
tokenizer: transformers.PreTrainedTokenizer,
|
|
) -> Dict:
|
|
"""Preprocess the data by tokenizing."""
|
|
prompts_tokenized = []
|
|
responses_tokenized = []
|
|
prompt_lens = []
|
|
for prompt, response in zip(prompts, responses):
|
|
if prompt == "":
|
|
prompt_tokenized = torch.tensor([], dtype=torch.int64)
|
|
else:
|
|
prompt_tokenized = tokenizer(
|
|
prompt, return_tensors="pt", padding=False, add_special_tokens=False
|
|
).input_ids[0]
|
|
response_tokenized = tokenizer(
|
|
response, return_tensors="pt", padding=False, add_special_tokens=False
|
|
).input_ids[0]
|
|
|
|
responses_tokenized.append(response_tokenized)
|
|
prompts_tokenized.append(prompt_tokenized)
|
|
|
|
input_ids = [
|
|
torch.cat((s, t), dim=0) for s, t in zip(prompts_tokenized, responses_tokenized)
|
|
]
|
|
prompt_lens = [len(s) for s in prompts_tokenized]
|
|
labels = copy.deepcopy(input_ids)
|
|
for label, prompt_len in zip(labels, prompt_lens):
|
|
label[:prompt_len] = util.IGNORE_INDEX
|
|
return dict(input_ids=input_ids, labels=labels)
|
|
|
|
def __call__(self, instances: Sequence[Dict]) -> Dict[str, torch.Tensor]:
|
|
prompts = [instance["prompt"] for instance in instances]
|
|
responses = [instance["response"] for instance in instances]
|
|
|
|
suffix = self.tokenizer.eos_token
|
|
responses = [f"{response}{suffix}" for response in responses]
|
|
|
|
data_dict = self.tokenize_prompts_and_responses(prompts, responses, self.tokenizer)
|
|
|
|
input_ids = data_dict["input_ids"]
|
|
labels = data_dict["labels"]
|
|
|
|
input_ids = torch.nn.utils.rnn.pad_sequence(
|
|
input_ids, batch_first=True, padding_value=self.tokenizer.pad_token_id
|
|
)
|
|
labels = torch.nn.utils.rnn.pad_sequence(
|
|
labels, batch_first=True, padding_value=util.IGNORE_INDEX
|
|
)
|
|
return dict(
|
|
input_ids=input_ids,
|
|
labels=labels,
|
|
attention_mask=input_ids.ne(self.tokenizer.pad_token_id),
|
|
)
|