mirror of
https://github.com/wassname/Brukino_AntiPaSTO_Appetizer.git
synced 2026-06-27 15:43:29 +08:00
wip
This commit is contained in:
+10
-3
@@ -35,7 +35,7 @@ DATASET_NAME = "wassname/daily_dilemmas-self-honesty"
|
||||
DATASET_SPLIT = "honesty_eval"
|
||||
DEVICE = "cuda" if torch.cuda.is_available() else "cpu"
|
||||
N_THINK_TOKENS = 32
|
||||
NUM_EXAMPLES = 5
|
||||
NUM_EXAMPLES = 3
|
||||
|
||||
|
||||
# %%
|
||||
@@ -93,6 +93,11 @@ def guided_eval(model, tokenizer, prompt_text, n_think=32, device="cuda"):
|
||||
|
||||
p_yes = torch.logsumexp(log_probs[yes_ids], dim=0) if yes_ids else torch.tensor(-float('inf'))
|
||||
p_no = torch.logsumexp(log_probs[no_ids], dim=0) if no_ids else torch.tensor(-float('inf'))
|
||||
|
||||
pmass = p_yes + p_no
|
||||
if pmass < 0.9:
|
||||
top_tokens = tokenizer.decode(torch.topk(log_probs, k=5).indices.tolist())
|
||||
print(f"Warning: Low probability mass on Yes/No tokens: {pmass.item():.3f}. Top tokens were {top_tokens}")
|
||||
|
||||
final_layer_hiddens = outputs.hidden_states[-1][0]
|
||||
start_idx = prompt_ids.shape[1]
|
||||
@@ -101,7 +106,8 @@ def guided_eval(model, tokenizer, prompt_text, n_think=32, device="cuda"):
|
||||
return {
|
||||
"logratio": (p_yes - p_no).item(),
|
||||
"kappa_trajectory": compute_curvature(cot_hiddens).cpu().numpy(),
|
||||
"generated_text": tokenizer.decode(generated_ids, skip_special_tokens=True)
|
||||
"prompt": tokenizer.decode(prompt_ids, skip_special_tokens=False),
|
||||
"generated_text": tokenizer.decode(generated_ids, skip_special_tokens=False)
|
||||
}
|
||||
|
||||
|
||||
@@ -139,7 +145,8 @@ for p_key, p_prefix in PERSONAS.items():
|
||||
res = guided_eval(model, tokenizer, p_prefix + prompt_base, n_think=N_THINK_TOKENS, device=DEVICE)
|
||||
results[p_key] = res
|
||||
print(f"Logratio (Yes/No): {res['logratio']:.3f}")
|
||||
print(f"Trace: {res['generated_text'].strip()}")
|
||||
print(f"Prompt:\n```md\n{res['prompt']}```")
|
||||
print(f"Trace:\n```md\n{res['generated_text'].strip()}```\n")
|
||||
|
||||
plt.plot(res['kappa_trajectory'], label=f"{p_key} (logratio: {res['logratio']:.2f})")
|
||||
|
||||
|
||||
Reference in New Issue
Block a user