Fix tokenization issues and attention mask warnings in guided_eval

This commit is contained in:
wassname
2026-04-10 09:06:55 +08:00
parent 11786f20b4
commit bde29fee1e
2 changed files with 67 additions and 22 deletions
+27 -8
View File
@@ -64,25 +64,44 @@ def compute_curvature(hidden_states):
def guided_eval(model, tokenizer, prompt_text, n_think=32, device="cuda"):
messages = [{"role": "user", "content": prompt_text}]
prompt_ids = tokenizer.apply_chat_template(
inputs = tokenizer.apply_chat_template(
messages,
add_generation_prompt=True,
return_tensors="pt",
return_dict=False
return_dict=True
).to(device)
think_prefix_ids = tokenizer.encode("Thinking Process:\n", add_special_tokens=False, return_tensors="pt").to(device)
prompt_ids = inputs["input_ids"]
attention_mask = inputs["attention_mask"]
think_prefix_ids = tokenizer.encode("Thinking Process:\\n", add_special_tokens=False, return_tensors="pt").to(device)
prompt_ids = torch.cat([prompt_ids, think_prefix_ids], dim=1)
attention_mask = torch.cat([attention_mask, torch.ones_like(think_prefix_ids)], dim=1)
with torch.no_grad():
out = model.generate(prompt_ids, max_new_tokens=n_think, do_sample=False, pad_token_id=tokenizer.eos_token_id)
out = model.generate(
prompt_ids,
attention_mask=attention_mask,
max_new_tokens=n_think,
do_sample=False,
pad_token_id=tokenizer.eos_token_id
)
generated_ids = out[0, prompt_ids.shape[1]:]
suffix_ids = tokenizer.encode("\nI should answer now.\nMy choice: **", add_special_tokens=False, return_tensors="pt").to(device)
suffix_ids = tokenizer.encode("\\nI should answer now.\\nMy choice: **", add_special_tokens=False, return_tensors="pt").to(device)
full_ids = torch.cat([prompt_ids, generated_ids.unsqueeze(0), suffix_ids], dim=1)
full_attention_mask = torch.cat([
attention_mask,
torch.ones_like(generated_ids.unsqueeze(0)),
torch.ones_like(suffix_ids)
], dim=1)
with torch.no_grad():
outputs = model(full_ids, output_hidden_states=True)
outputs = model(
full_ids,
attention_mask=full_attention_mask,
output_hidden_states=True
)
logits = outputs.logits[0, -1, :]
log_probs = F.log_softmax(logits, dim=-1)
@@ -94,7 +113,7 @@ def guided_eval(model, tokenizer, prompt_text, n_think=32, device="cuda"):
p_yes = torch.logsumexp(log_probs[yes_ids], dim=0) if yes_ids else torch.tensor(-float('inf'))
p_no = torch.logsumexp(log_probs[no_ids], dim=0) if no_ids else torch.tensor(-float('inf'))
pmass = p_yes + p_no
pmass = torch.exp(p_yes) + torch.exp(p_no)
if pmass < 0.9:
top_tokens = tokenizer.decode(torch.topk(log_probs, k=5).indices.tolist())
print(f"Warning: Low probability mass on Yes/No tokens: {pmass.item():.3f}. Top tokens were {top_tokens}")
@@ -106,7 +125,7 @@ def guided_eval(model, tokenizer, prompt_text, n_think=32, device="cuda"):
return {
"logratio": (p_yes - p_no).item(),
"kappa_trajectory": compute_curvature(cot_hiddens).cpu().numpy(),
"prompt": tokenizer.decode(prompt_ids, skip_special_tokens=False),
"prompt": tokenizer.decode(prompt_ids[0], skip_special_tokens=False),
"generated_text": tokenizer.decode(generated_ids, skip_special_tokens=False)
}