mirror of
https://github.com/wassname/SimPO.git
synced 2026-06-27 16:43:59 +08:00
simplifying generation
This commit is contained in:
+1
-16
@@ -1,21 +1,7 @@
|
||||
import torch
|
||||
from transformers import pipeline
|
||||
import json
|
||||
import warnings
|
||||
|
||||
model_id = "princeton-nlp/Llama-3-Instruct-8B-SimPO"
|
||||
|
||||
with open('chat_templates.json', 'r') as f:
|
||||
chat_templates = json.load(f)
|
||||
|
||||
if "llama-3" in model_id.lower():
|
||||
template = chat_templates["llama3"]
|
||||
elif "mistral-7b-base" in model_id.lower():
|
||||
template = chat_templates["mistral-base"]
|
||||
elif "mistral-7b-instruct" in model_id.lower():
|
||||
template = chat_templates["mistral-instruct"]
|
||||
else:
|
||||
warnings.warn("No template set for the given model_id.")
|
||||
model_id = "princeton-nlp/gemma-2-9b-it-SimPO"
|
||||
|
||||
generator = pipeline(
|
||||
"text-generation",
|
||||
@@ -23,6 +9,5 @@ generator = pipeline(
|
||||
model_kwargs={"torch_dtype": torch.bfloat16},
|
||||
device="cuda",
|
||||
)
|
||||
generator.tokenizer.chat_template = template
|
||||
outputs = generator([{"role": "user", "content": "What's the difference between llamas and alpacas?"}], do_sample=False, max_new_tokens=200)
|
||||
print(outputs[0]['generated_text'])
|
||||
|
||||
Reference in New Issue
Block a user