From d8e0dff9c29b1ad379ca516bfd3cd550cf4584d9 Mon Sep 17 00:00:00 2001 From: Yu Meng Date: Thu, 18 Jul 2024 15:19:05 -0400 Subject: [PATCH] simplifying generation --- generate.py | 17 +---------------- 1 file changed, 1 insertion(+), 16 deletions(-) diff --git a/generate.py b/generate.py index 028ba43..1060e09 100644 --- a/generate.py +++ b/generate.py @@ -1,21 +1,7 @@ import torch from transformers import pipeline -import json -import warnings -model_id = "princeton-nlp/Llama-3-Instruct-8B-SimPO" - -with open('chat_templates.json', 'r') as f: - chat_templates = json.load(f) - -if "llama-3" in model_id.lower(): - template = chat_templates["llama3"] -elif "mistral-7b-base" in model_id.lower(): - template = chat_templates["mistral-base"] -elif "mistral-7b-instruct" in model_id.lower(): - template = chat_templates["mistral-instruct"] -else: - warnings.warn("No template set for the given model_id.") +model_id = "princeton-nlp/gemma-2-9b-it-SimPO" generator = pipeline( "text-generation", @@ -23,6 +9,5 @@ generator = pipeline( model_kwargs={"torch_dtype": torch.bfloat16}, device="cuda", ) -generator.tokenizer.chat_template = template outputs = generator([{"role": "user", "content": "What's the difference between llamas and alpacas?"}], do_sample=False, max_new_tokens=200) print(outputs[0]['generated_text'])