mirror of
https://github.com/wassname/Brukino_AntiPaSTO_Appetizer.git
synced 2026-06-27 16:58:47 +08:00
Fix tokenization issues and attention mask warnings in guided_eval
This commit is contained in:
+40
-14
@@ -2,7 +2,7 @@
|
|||||||
"cells": [
|
"cells": [
|
||||||
{
|
{
|
||||||
"cell_type": "markdown",
|
"cell_type": "markdown",
|
||||||
"id": "48b88977",
|
"id": "bafde1f1",
|
||||||
"metadata": {},
|
"metadata": {},
|
||||||
"source": [
|
"source": [
|
||||||
"# Brukino's AntiPaSTO Appetizer: Guided CoT Eval & Frenet-Serret Curvature\n",
|
"# Brukino's AntiPaSTO Appetizer: Guided CoT Eval & Frenet-Serret Curvature\n",
|
||||||
@@ -20,7 +20,7 @@
|
|||||||
{
|
{
|
||||||
"cell_type": "code",
|
"cell_type": "code",
|
||||||
"execution_count": null,
|
"execution_count": null,
|
||||||
"id": "58806579",
|
"id": "a30577fd",
|
||||||
"metadata": {},
|
"metadata": {},
|
||||||
"outputs": [],
|
"outputs": [],
|
||||||
"source": [
|
"source": [
|
||||||
@@ -38,13 +38,13 @@
|
|||||||
"DATASET_SPLIT = \"honesty_eval\"\n",
|
"DATASET_SPLIT = \"honesty_eval\"\n",
|
||||||
"DEVICE = \"cuda\" if torch.cuda.is_available() else \"cpu\"\n",
|
"DEVICE = \"cuda\" if torch.cuda.is_available() else \"cpu\"\n",
|
||||||
"N_THINK_TOKENS = 32\n",
|
"N_THINK_TOKENS = 32\n",
|
||||||
"NUM_EXAMPLES = 5 "
|
"NUM_EXAMPLES = 3"
|
||||||
]
|
]
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
"cell_type": "code",
|
"cell_type": "code",
|
||||||
"execution_count": null,
|
"execution_count": null,
|
||||||
"id": "910b108d",
|
"id": "a3b31f38",
|
||||||
"metadata": {},
|
"metadata": {},
|
||||||
"outputs": [],
|
"outputs": [],
|
||||||
"source": [
|
"source": [
|
||||||
@@ -72,32 +72,51 @@
|
|||||||
{
|
{
|
||||||
"cell_type": "code",
|
"cell_type": "code",
|
||||||
"execution_count": null,
|
"execution_count": null,
|
||||||
"id": "1621b19e",
|
"id": "675b1b52",
|
||||||
"metadata": {},
|
"metadata": {},
|
||||||
"outputs": [],
|
"outputs": [],
|
||||||
"source": [
|
"source": [
|
||||||
"def guided_eval(model, tokenizer, prompt_text, n_think=32, device=\"cuda\"):\n",
|
"def guided_eval(model, tokenizer, prompt_text, n_think=32, device=\"cuda\"):\n",
|
||||||
" messages = [{\"role\": \"user\", \"content\": prompt_text}]\n",
|
" messages = [{\"role\": \"user\", \"content\": prompt_text}]\n",
|
||||||
" \n",
|
" \n",
|
||||||
" prompt_ids = tokenizer.apply_chat_template(\n",
|
" inputs = tokenizer.apply_chat_template(\n",
|
||||||
" messages, \n",
|
" messages, \n",
|
||||||
" add_generation_prompt=True, \n",
|
" add_generation_prompt=True, \n",
|
||||||
" return_tensors=\"pt\", \n",
|
" return_tensors=\"pt\", \n",
|
||||||
" return_dict=False\n",
|
" return_dict=True\n",
|
||||||
" ).to(device)\n",
|
" ).to(device)\n",
|
||||||
" \n",
|
" \n",
|
||||||
" think_prefix_ids = tokenizer.encode(\"Thinking Process:\\n\", add_special_tokens=False, return_tensors=\"pt\").to(device)\n",
|
" prompt_ids = inputs[\"input_ids\"]\n",
|
||||||
|
" attention_mask = inputs[\"attention_mask\"]\n",
|
||||||
|
" \n",
|
||||||
|
" think_prefix_ids = tokenizer.encode(\"Thinking Process:\\\\n\", add_special_tokens=False, return_tensors=\"pt\").to(device)\n",
|
||||||
" prompt_ids = torch.cat([prompt_ids, think_prefix_ids], dim=1)\n",
|
" prompt_ids = torch.cat([prompt_ids, think_prefix_ids], dim=1)\n",
|
||||||
|
" attention_mask = torch.cat([attention_mask, torch.ones_like(think_prefix_ids)], dim=1)\n",
|
||||||
" \n",
|
" \n",
|
||||||
" with torch.no_grad():\n",
|
" with torch.no_grad():\n",
|
||||||
" out = model.generate(prompt_ids, max_new_tokens=n_think, do_sample=False, pad_token_id=tokenizer.eos_token_id)\n",
|
" out = model.generate(\n",
|
||||||
|
" prompt_ids, \n",
|
||||||
|
" attention_mask=attention_mask,\n",
|
||||||
|
" max_new_tokens=n_think, \n",
|
||||||
|
" do_sample=False, \n",
|
||||||
|
" pad_token_id=tokenizer.eos_token_id\n",
|
||||||
|
" )\n",
|
||||||
" generated_ids = out[0, prompt_ids.shape[1]:]\n",
|
" generated_ids = out[0, prompt_ids.shape[1]:]\n",
|
||||||
" \n",
|
" \n",
|
||||||
" suffix_ids = tokenizer.encode(\"\\nI should answer now.\\nMy choice: **\", add_special_tokens=False, return_tensors=\"pt\").to(device)\n",
|
" suffix_ids = tokenizer.encode(\"\\\\nI should answer now.\\\\nMy choice: **\", add_special_tokens=False, return_tensors=\"pt\").to(device)\n",
|
||||||
" full_ids = torch.cat([prompt_ids, generated_ids.unsqueeze(0), suffix_ids], dim=1)\n",
|
" full_ids = torch.cat([prompt_ids, generated_ids.unsqueeze(0), suffix_ids], dim=1)\n",
|
||||||
|
" full_attention_mask = torch.cat([\n",
|
||||||
|
" attention_mask, \n",
|
||||||
|
" torch.ones_like(generated_ids.unsqueeze(0)), \n",
|
||||||
|
" torch.ones_like(suffix_ids)\n",
|
||||||
|
" ], dim=1)\n",
|
||||||
" \n",
|
" \n",
|
||||||
" with torch.no_grad():\n",
|
" with torch.no_grad():\n",
|
||||||
" outputs = model(full_ids, output_hidden_states=True)\n",
|
" outputs = model(\n",
|
||||||
|
" full_ids, \n",
|
||||||
|
" attention_mask=full_attention_mask,\n",
|
||||||
|
" output_hidden_states=True\n",
|
||||||
|
" )\n",
|
||||||
" \n",
|
" \n",
|
||||||
" logits = outputs.logits[0, -1, :]\n",
|
" logits = outputs.logits[0, -1, :]\n",
|
||||||
" log_probs = F.log_softmax(logits, dim=-1)\n",
|
" log_probs = F.log_softmax(logits, dim=-1)\n",
|
||||||
@@ -108,6 +127,11 @@
|
|||||||
" \n",
|
" \n",
|
||||||
" p_yes = torch.logsumexp(log_probs[yes_ids], dim=0) if yes_ids else torch.tensor(-float('inf'))\n",
|
" p_yes = torch.logsumexp(log_probs[yes_ids], dim=0) if yes_ids else torch.tensor(-float('inf'))\n",
|
||||||
" p_no = torch.logsumexp(log_probs[no_ids], dim=0) if no_ids else torch.tensor(-float('inf'))\n",
|
" p_no = torch.logsumexp(log_probs[no_ids], dim=0) if no_ids else torch.tensor(-float('inf'))\n",
|
||||||
|
"\n",
|
||||||
|
" pmass = torch.exp(p_yes) + torch.exp(p_no)\n",
|
||||||
|
" if pmass < 0.9:\n",
|
||||||
|
" top_tokens = tokenizer.decode(torch.topk(log_probs, k=5).indices.tolist())\n",
|
||||||
|
" print(f\"Warning: Low probability mass on Yes/No tokens: {pmass.item():.3f}. Top tokens were {top_tokens}\")\n",
|
||||||
" \n",
|
" \n",
|
||||||
" final_layer_hiddens = outputs.hidden_states[-1][0]\n",
|
" final_layer_hiddens = outputs.hidden_states[-1][0]\n",
|
||||||
" start_idx = prompt_ids.shape[1]\n",
|
" start_idx = prompt_ids.shape[1]\n",
|
||||||
@@ -116,7 +140,8 @@
|
|||||||
" return {\n",
|
" return {\n",
|
||||||
" \"logratio\": (p_yes - p_no).item(),\n",
|
" \"logratio\": (p_yes - p_no).item(),\n",
|
||||||
" \"kappa_trajectory\": compute_curvature(cot_hiddens).cpu().numpy(),\n",
|
" \"kappa_trajectory\": compute_curvature(cot_hiddens).cpu().numpy(),\n",
|
||||||
" \"generated_text\": tokenizer.decode(generated_ids, skip_special_tokens=True)\n",
|
" \"prompt\": tokenizer.decode(prompt_ids[0], skip_special_tokens=False),\n",
|
||||||
|
" \"generated_text\": tokenizer.decode(generated_ids, skip_special_tokens=False)\n",
|
||||||
" }\n",
|
" }\n",
|
||||||
"\n"
|
"\n"
|
||||||
]
|
]
|
||||||
@@ -124,7 +149,7 @@
|
|||||||
{
|
{
|
||||||
"cell_type": "code",
|
"cell_type": "code",
|
||||||
"execution_count": null,
|
"execution_count": null,
|
||||||
"id": "a6129f0c",
|
"id": "9e2e698d",
|
||||||
"metadata": {
|
"metadata": {
|
||||||
"lines_to_next_cell": 2
|
"lines_to_next_cell": 2
|
||||||
},
|
},
|
||||||
@@ -162,7 +187,8 @@
|
|||||||
" res = guided_eval(model, tokenizer, p_prefix + prompt_base, n_think=N_THINK_TOKENS, device=DEVICE)\n",
|
" res = guided_eval(model, tokenizer, p_prefix + prompt_base, n_think=N_THINK_TOKENS, device=DEVICE)\n",
|
||||||
" results[p_key] = res\n",
|
" results[p_key] = res\n",
|
||||||
" print(f\"Logratio (Yes/No): {res['logratio']:.3f}\")\n",
|
" print(f\"Logratio (Yes/No): {res['logratio']:.3f}\")\n",
|
||||||
" print(f\"Trace: {res['generated_text'].strip()}\")\n",
|
" print(f\"Prompt:\\n```md\\n{res['prompt']}```\")\n",
|
||||||
|
" print(f\"Trace:\\n```md\\n{res['generated_text'].strip()}```\\n\")\n",
|
||||||
" \n",
|
" \n",
|
||||||
" plt.plot(res['kappa_trajectory'], label=f\"{p_key} (logratio: {res['logratio']:.2f})\")\n",
|
" plt.plot(res['kappa_trajectory'], label=f\"{p_key} (logratio: {res['logratio']:.2f})\")\n",
|
||||||
"\n",
|
"\n",
|
||||||
|
|||||||
+27
-8
@@ -64,25 +64,44 @@ def compute_curvature(hidden_states):
|
|||||||
def guided_eval(model, tokenizer, prompt_text, n_think=32, device="cuda"):
|
def guided_eval(model, tokenizer, prompt_text, n_think=32, device="cuda"):
|
||||||
messages = [{"role": "user", "content": prompt_text}]
|
messages = [{"role": "user", "content": prompt_text}]
|
||||||
|
|
||||||
prompt_ids = tokenizer.apply_chat_template(
|
inputs = tokenizer.apply_chat_template(
|
||||||
messages,
|
messages,
|
||||||
add_generation_prompt=True,
|
add_generation_prompt=True,
|
||||||
return_tensors="pt",
|
return_tensors="pt",
|
||||||
return_dict=False
|
return_dict=True
|
||||||
).to(device)
|
).to(device)
|
||||||
|
|
||||||
think_prefix_ids = tokenizer.encode("Thinking Process:\n", add_special_tokens=False, return_tensors="pt").to(device)
|
prompt_ids = inputs["input_ids"]
|
||||||
|
attention_mask = inputs["attention_mask"]
|
||||||
|
|
||||||
|
think_prefix_ids = tokenizer.encode("Thinking Process:\\n", add_special_tokens=False, return_tensors="pt").to(device)
|
||||||
prompt_ids = torch.cat([prompt_ids, think_prefix_ids], dim=1)
|
prompt_ids = torch.cat([prompt_ids, think_prefix_ids], dim=1)
|
||||||
|
attention_mask = torch.cat([attention_mask, torch.ones_like(think_prefix_ids)], dim=1)
|
||||||
|
|
||||||
with torch.no_grad():
|
with torch.no_grad():
|
||||||
out = model.generate(prompt_ids, max_new_tokens=n_think, do_sample=False, pad_token_id=tokenizer.eos_token_id)
|
out = model.generate(
|
||||||
|
prompt_ids,
|
||||||
|
attention_mask=attention_mask,
|
||||||
|
max_new_tokens=n_think,
|
||||||
|
do_sample=False,
|
||||||
|
pad_token_id=tokenizer.eos_token_id
|
||||||
|
)
|
||||||
generated_ids = out[0, prompt_ids.shape[1]:]
|
generated_ids = out[0, prompt_ids.shape[1]:]
|
||||||
|
|
||||||
suffix_ids = tokenizer.encode("\nI should answer now.\nMy choice: **", add_special_tokens=False, return_tensors="pt").to(device)
|
suffix_ids = tokenizer.encode("\\nI should answer now.\\nMy choice: **", add_special_tokens=False, return_tensors="pt").to(device)
|
||||||
full_ids = torch.cat([prompt_ids, generated_ids.unsqueeze(0), suffix_ids], dim=1)
|
full_ids = torch.cat([prompt_ids, generated_ids.unsqueeze(0), suffix_ids], dim=1)
|
||||||
|
full_attention_mask = torch.cat([
|
||||||
|
attention_mask,
|
||||||
|
torch.ones_like(generated_ids.unsqueeze(0)),
|
||||||
|
torch.ones_like(suffix_ids)
|
||||||
|
], dim=1)
|
||||||
|
|
||||||
with torch.no_grad():
|
with torch.no_grad():
|
||||||
outputs = model(full_ids, output_hidden_states=True)
|
outputs = model(
|
||||||
|
full_ids,
|
||||||
|
attention_mask=full_attention_mask,
|
||||||
|
output_hidden_states=True
|
||||||
|
)
|
||||||
|
|
||||||
logits = outputs.logits[0, -1, :]
|
logits = outputs.logits[0, -1, :]
|
||||||
log_probs = F.log_softmax(logits, dim=-1)
|
log_probs = F.log_softmax(logits, dim=-1)
|
||||||
@@ -94,7 +113,7 @@ def guided_eval(model, tokenizer, prompt_text, n_think=32, device="cuda"):
|
|||||||
p_yes = torch.logsumexp(log_probs[yes_ids], dim=0) if yes_ids else torch.tensor(-float('inf'))
|
p_yes = torch.logsumexp(log_probs[yes_ids], dim=0) if yes_ids else torch.tensor(-float('inf'))
|
||||||
p_no = torch.logsumexp(log_probs[no_ids], dim=0) if no_ids else torch.tensor(-float('inf'))
|
p_no = torch.logsumexp(log_probs[no_ids], dim=0) if no_ids else torch.tensor(-float('inf'))
|
||||||
|
|
||||||
pmass = p_yes + p_no
|
pmass = torch.exp(p_yes) + torch.exp(p_no)
|
||||||
if pmass < 0.9:
|
if pmass < 0.9:
|
||||||
top_tokens = tokenizer.decode(torch.topk(log_probs, k=5).indices.tolist())
|
top_tokens = tokenizer.decode(torch.topk(log_probs, k=5).indices.tolist())
|
||||||
print(f"Warning: Low probability mass on Yes/No tokens: {pmass.item():.3f}. Top tokens were {top_tokens}")
|
print(f"Warning: Low probability mass on Yes/No tokens: {pmass.item():.3f}. Top tokens were {top_tokens}")
|
||||||
@@ -106,7 +125,7 @@ def guided_eval(model, tokenizer, prompt_text, n_think=32, device="cuda"):
|
|||||||
return {
|
return {
|
||||||
"logratio": (p_yes - p_no).item(),
|
"logratio": (p_yes - p_no).item(),
|
||||||
"kappa_trajectory": compute_curvature(cot_hiddens).cpu().numpy(),
|
"kappa_trajectory": compute_curvature(cot_hiddens).cpu().numpy(),
|
||||||
"prompt": tokenizer.decode(prompt_ids, skip_special_tokens=False),
|
"prompt": tokenizer.decode(prompt_ids[0], skip_special_tokens=False),
|
||||||
"generated_text": tokenizer.decode(generated_ids, skip_special_tokens=False)
|
"generated_text": tokenizer.decode(generated_ids, skip_special_tokens=False)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|||||||
Reference in New Issue
Block a user