import os import traceback import time import torch import transformers from torch.utils.data import DataLoader import re from torch.utils.data import Dataset from transformers import ( StoppingCriteriaList, StoppingCriteria, PreTrainedModel, PreTrainedTokenizer, ) from typing import List, Union, Optional from tqdm import tqdm import os import torch import api.util as util class Model: def get_tokenizer(self, dir: str, use_fast=True): if "pythia" in dir: util.print_once( "Setting use_fast to true because pythia tokenizer is not compatible with use_fast=False" ) use_fast = True if "llama" in dir: util.print_once( "Setting use_fast to false because llama tokenizer is not compatible with use_fast=True" ) use_fast = False if not os.path.isdir(dir): raise Exception(f"The hf_model {dir} does not exist") tokenizer = transformers.AutoTokenizer.from_pretrained( dir, use_fast=use_fast, trust_remote_code=True ) # Set padding side to left tokenizer.padding_side = "left" # Set padding side to left tokenizer.padding_side = "left" # Set padding token to eos token if pad token is not set if tokenizer.pad_token is None: tokenizer.pad_token = tokenizer.eos_token return tokenizer def __init__( self, dir: str, hf_model: PreTrainedModel = None, tokenizer: PreTrainedTokenizer = None, use_fast=True, ): self.dir = dir if not os.path.isdir(self.dir): raise Exception(f"The hf_model {dir} does not exist") if tokenizer == None: self.tokenizer = self.get_tokenizer(dir, use_fast=use_fast) else: self.tokenizer = tokenizer if hf_model == None: for i in range(3): try: self.hf_model = transformers.AutoModelForCausalLM.from_pretrained( self.dir, trust_remote_code=True, torch_dtype=torch.bfloat16, ) break except: exception_string = traceback.format_exc() if os.path.exists(self.dir + "/pytorch_model.bin"): print( "Failed to load model but pytorch_model.bin exists. This indicates that the model may still be saving. Retrying in 5 seconds." ) time.sleep(5) else: break if self.hf_model == None: raise Exception(f"Failed to load model: {exception_string}") else: self.hf_model = hf_model try: self.max_length = self.hf_model.config.max_position_embeddings except: pass def to(self, device: str): self.hf_model.to(device) return self def generate_text( self, prompts: List[str], max_length: Optional[int] = None, stop_string: Optional[str] = None, output_regex: Optional[str] = None, per_device_batch_size=100, **kwargs, ) -> Union[str, List[str]]: batch_size = per_device_batch_size if max_length == None: max_length = self.max_length dataset = TensorDataset(prompts) dataloader = DataLoader(dataset, batch_size=batch_size, shuffle=False) encoded_completions = [] for batch in tqdm(dataloader): encoded_prompts = self.tokenizer.batch_encode_plus( batch, return_tensors="pt", padding="longest", ).to(self.hf_model.device) # Add stopping criteria completion_pos = len(encoded_prompts["input_ids"][0]) if output_regex == None: output_regex = "" stop_string_regex = "" if stop_string != None: stop_string_regex = r"^(.*?" + stop_string + ")" if stop_string_regex != "" and output_regex != "": completion_regex = stop_string_regex + "|" + output_regex completion_regex = stop_string_regex + output_regex stopping_criteria = StoppingCriteriaList( [RegexStoppingCriteria(self.tokenizer, completion_pos, regex=completion_regex)] ) # Generate predictions completed_sequences = self.hf_model.generate( input_ids=encoded_prompts["input_ids"], attention_mask=encoded_prompts["attention_mask"], stopping_criteria=stopping_criteria, max_new_tokens=max_length, **kwargs, ) completions = [ completed_sequences[i][completion_pos:] for i in range(len(completed_sequences)) ] # Remove tokens that follow the eos token for i in range(len(completions)): if self.tokenizer.eos_token_id in list(completions[i]): index = list(completions[i]).index(self.tokenizer.eos_token_id) completions[i] = completions[i][:index] else: pass completions[i] = completions[i][:] encoded_completions.extend(completions) encoded_completions = [c.cpu().to(dtype=torch.int64) for c in encoded_completions] # Decode the predictions text_completions = [self.tokenizer.decode(ids) for ids in encoded_completions] # Post process to remove text generated after stopping conditions were met if completion_regex != "": text_completions = [ self.process_completion(text_completion, completion_regex) for text_completion in text_completions ] return text_completions def process_completion(self, completion, regex): match = re.search(regex, completion) if match: return match.group(0) else: return completion def print_generate( self, text: str, max_length: Optional[int] = 100, stop_string: Optional[str] = None, output_regex: Optional[str] = None, **kwargs, ): result = self.generate_text([text], max_length, stop_string, output_regex, **kwargs)[0] return result class RegexStoppingCriteria(StoppingCriteria): def __init__(self, tokenizer, completion_pos, regex=None): StoppingCriteria.__init__(self), self.tokenizer = tokenizer self.regex = regex self.completion_pos = completion_pos def __call__(self, input_ids: torch.LongTensor, scores: torch.FloatTensor): if self.regex == "": return False # stops if all generations include the regex pattern should_stop = [] for i in range(len(input_ids)): seq_string = self.tokenizer.decode(input_ids[i][self.completion_pos :]) if self.regex != None: match = re.search(self.regex, seq_string) if match: should_stop.append(True) else: should_stop.append(False) if all(should_stop): return True return False class TensorDataset(Dataset): def __init__(self, inputs): self.inputs = inputs def __getitem__(self, idx): return self.inputs[idx] def __len__(self): return len(self.inputs)