mirror of
https://github.com/wassname/SimPO.git
synced 2026-06-27 16:43:59 +08:00
simplifying generation
This commit is contained in:
+1
-16
@@ -1,21 +1,7 @@
|
|||||||
import torch
|
import torch
|
||||||
from transformers import pipeline
|
from transformers import pipeline
|
||||||
import json
|
|
||||||
import warnings
|
|
||||||
|
|
||||||
model_id = "princeton-nlp/Llama-3-Instruct-8B-SimPO"
|
model_id = "princeton-nlp/gemma-2-9b-it-SimPO"
|
||||||
|
|
||||||
with open('chat_templates.json', 'r') as f:
|
|
||||||
chat_templates = json.load(f)
|
|
||||||
|
|
||||||
if "llama-3" in model_id.lower():
|
|
||||||
template = chat_templates["llama3"]
|
|
||||||
elif "mistral-7b-base" in model_id.lower():
|
|
||||||
template = chat_templates["mistral-base"]
|
|
||||||
elif "mistral-7b-instruct" in model_id.lower():
|
|
||||||
template = chat_templates["mistral-instruct"]
|
|
||||||
else:
|
|
||||||
warnings.warn("No template set for the given model_id.")
|
|
||||||
|
|
||||||
generator = pipeline(
|
generator = pipeline(
|
||||||
"text-generation",
|
"text-generation",
|
||||||
@@ -23,6 +9,5 @@ generator = pipeline(
|
|||||||
model_kwargs={"torch_dtype": torch.bfloat16},
|
model_kwargs={"torch_dtype": torch.bfloat16},
|
||||||
device="cuda",
|
device="cuda",
|
||||||
)
|
)
|
||||||
generator.tokenizer.chat_template = template
|
|
||||||
outputs = generator([{"role": "user", "content": "What's the difference between llamas and alpacas?"}], do_sample=False, max_new_tokens=200)
|
outputs = generator([{"role": "user", "content": "What's the difference between llamas and alpacas?"}], do_sample=False, max_new_tokens=200)
|
||||||
print(outputs[0]['generated_text'])
|
print(outputs[0]['generated_text'])
|
||||||
|
|||||||
Reference in New Issue
Block a user