Update generate.py for demo

This commit is contained in:
Yu Meng
2024-07-06 23:34:34 -04:00
committed by GitHub
parent 15c4ff8918
commit 26cbb4a033
+2 -2
View File
@@ -3,7 +3,7 @@ from transformers import pipeline
import json
import warnings
model_id = "/scratch/gpfs/DANQIC/ym0081/SimPO/outputs/llama-3-8b-instruct-simpo"
model_id = "princeton-nlp/Llama-3-Instruct-8B-SimPO"
with open('chat_templates.json', 'r') as f:
chat_templates = json.load(f)
@@ -24,5 +24,5 @@ generator = pipeline(
device="cuda",
)
generator.tokenizer.chat_template = template
outputs = generator([{"role": "user", "content": "When rolling two dice, what is the probability that you roll a total number that is at least 3?"}], do_sample=False, max_new_tokens=200)
outputs = generator([{"role": "user", "content": "What's the difference between llamas and alpacas?"}], do_sample=False, max_new_tokens=200)
print(outputs[0]['generated_text'])