Files
GENIES/build/lib/api/model.py
T
JoshuaClymer 70bd2ea15d Initial commit
2023-11-11 19:44:09 +00:00

229 lines
7.5 KiB
Python

import os
import traceback
import time
import torch
import transformers
from torch.utils.data import DataLoader
import re
from torch.utils.data import Dataset
from transformers import (
StoppingCriteriaList,
StoppingCriteria,
PreTrainedModel,
PreTrainedTokenizer,
)
from typing import List, Union, Optional
from tqdm import tqdm
import os
import torch
import api.util as util
class Model:
def get_tokenizer(self, dir: str, use_fast=True):
if "pythia" in dir:
util.print_once(
"Setting use_fast to true because pythia tokenizer is not compatible with use_fast=False"
)
use_fast = True
if "llama" in dir:
util.print_once(
"Setting use_fast to false because llama tokenizer is not compatible with use_fast=True"
)
use_fast = False
if not os.path.isdir(dir):
raise Exception(f"The hf_model {dir} does not exist")
tokenizer = transformers.AutoTokenizer.from_pretrained(
dir, use_fast=use_fast, trust_remote_code=True
)
# Set padding side to left
tokenizer.padding_side = "left"
# Set padding side to left
tokenizer.padding_side = "left"
# Set padding token to eos token if pad token is not set
if tokenizer.pad_token is None:
tokenizer.pad_token = tokenizer.eos_token
return tokenizer
def __init__(
self,
dir: str,
hf_model: PreTrainedModel = None,
tokenizer: PreTrainedTokenizer = None,
use_fast=True,
):
self.dir = dir
if not os.path.isdir(self.dir):
raise Exception(f"The hf_model {dir} does not exist")
if tokenizer == None:
self.tokenizer = self.get_tokenizer(dir, use_fast=use_fast)
else:
self.tokenizer = tokenizer
if hf_model == None:
for i in range(3):
try:
self.hf_model = transformers.AutoModelForCausalLM.from_pretrained(
self.dir,
trust_remote_code=True,
torch_dtype=torch.bfloat16,
)
break
except:
exception_string = traceback.format_exc()
if os.path.exists(self.dir + "/pytorch_model.bin"):
print(
"Failed to load model but pytorch_model.bin exists. This indicates that the model may still be saving. Retrying in 5 seconds."
)
time.sleep(5)
else:
break
if self.hf_model == None:
raise Exception(f"Failed to load model: {exception_string}")
else:
self.hf_model = hf_model
try:
self.max_length = self.hf_model.config.max_position_embeddings
except:
pass
def to(self, device: str):
self.hf_model.to(device)
return self
def generate_text(
self,
prompts: List[str],
max_length: Optional[int] = None,
stop_string: Optional[str] = None,
output_regex: Optional[str] = None,
per_device_batch_size=100,
**kwargs,
) -> Union[str, List[str]]:
batch_size = per_device_batch_size
if max_length == None:
max_length = self.max_length
dataset = TensorDataset(prompts)
dataloader = DataLoader(dataset, batch_size=batch_size, shuffle=False)
encoded_completions = []
for batch in tqdm(dataloader):
encoded_prompts = self.tokenizer.batch_encode_plus(
batch,
return_tensors="pt",
padding="longest",
).to(self.hf_model.device)
# Add stopping criteria
completion_pos = len(encoded_prompts["input_ids"][0])
if output_regex == None:
output_regex = ""
stop_string_regex = ""
if stop_string != None:
stop_string_regex = r"^(.*?" + stop_string + ")"
if stop_string_regex != "" and output_regex != "":
completion_regex = stop_string_regex + "|" + output_regex
completion_regex = stop_string_regex + output_regex
stopping_criteria = StoppingCriteriaList(
[RegexStoppingCriteria(self.tokenizer, completion_pos, regex=completion_regex)]
)
# Generate predictions
completed_sequences = self.hf_model.generate(
input_ids=encoded_prompts["input_ids"],
attention_mask=encoded_prompts["attention_mask"],
stopping_criteria=stopping_criteria,
max_new_tokens=max_length,
**kwargs,
)
completions = [
completed_sequences[i][completion_pos:] for i in range(len(completed_sequences))
]
# Remove tokens that follow the eos token
for i in range(len(completions)):
if self.tokenizer.eos_token_id in list(completions[i]):
index = list(completions[i]).index(self.tokenizer.eos_token_id)
completions[i] = completions[i][:index]
else:
pass
completions[i] = completions[i][:]
encoded_completions.extend(completions)
encoded_completions = [c.cpu().to(dtype=torch.int64) for c in encoded_completions]
# Decode the predictions
text_completions = [self.tokenizer.decode(ids) for ids in encoded_completions]
# Post process to remove text generated after stopping conditions were met
if completion_regex != "":
text_completions = [
self.process_completion(text_completion, completion_regex)
for text_completion in text_completions
]
return text_completions
def process_completion(self, completion, regex):
match = re.search(regex, completion)
if match:
return match.group(0)
else:
return completion
def print_generate(
self,
text: str,
max_length: Optional[int] = 100,
stop_string: Optional[str] = None,
output_regex: Optional[str] = None,
**kwargs,
):
result = self.generate_text([text], max_length, stop_string, output_regex, **kwargs)[0]
return result
class RegexStoppingCriteria(StoppingCriteria):
def __init__(self, tokenizer, completion_pos, regex=None):
StoppingCriteria.__init__(self),
self.tokenizer = tokenizer
self.regex = regex
self.completion_pos = completion_pos
def __call__(self, input_ids: torch.LongTensor, scores: torch.FloatTensor):
if self.regex == "":
return False
# stops if all generations include the regex pattern
should_stop = []
for i in range(len(input_ids)):
seq_string = self.tokenizer.decode(input_ids[i][self.completion_pos :])
if self.regex != None:
match = re.search(self.regex, seq_string)
if match:
should_stop.append(True)
else:
should_stop.append(False)
if all(should_stop):
return True
return False
class TensorDataset(Dataset):
def __init__(self, inputs):
self.inputs = inputs
def __getitem__(self, idx):
return self.inputs[idx]
def __len__(self):
return len(self.inputs)