mirror of
https://github.com/wassname/Brukino_AntiPaSTO_Appetizer.git
synced 2026-06-27 16:58:47 +08:00
wip
This commit is contained in:
+10
-3
@@ -35,7 +35,7 @@ DATASET_NAME = "wassname/daily_dilemmas-self-honesty"
|
|||||||
DATASET_SPLIT = "honesty_eval"
|
DATASET_SPLIT = "honesty_eval"
|
||||||
DEVICE = "cuda" if torch.cuda.is_available() else "cpu"
|
DEVICE = "cuda" if torch.cuda.is_available() else "cpu"
|
||||||
N_THINK_TOKENS = 32
|
N_THINK_TOKENS = 32
|
||||||
NUM_EXAMPLES = 5
|
NUM_EXAMPLES = 3
|
||||||
|
|
||||||
|
|
||||||
# %%
|
# %%
|
||||||
@@ -93,6 +93,11 @@ def guided_eval(model, tokenizer, prompt_text, n_think=32, device="cuda"):
|
|||||||
|
|
||||||
p_yes = torch.logsumexp(log_probs[yes_ids], dim=0) if yes_ids else torch.tensor(-float('inf'))
|
p_yes = torch.logsumexp(log_probs[yes_ids], dim=0) if yes_ids else torch.tensor(-float('inf'))
|
||||||
p_no = torch.logsumexp(log_probs[no_ids], dim=0) if no_ids else torch.tensor(-float('inf'))
|
p_no = torch.logsumexp(log_probs[no_ids], dim=0) if no_ids else torch.tensor(-float('inf'))
|
||||||
|
|
||||||
|
pmass = p_yes + p_no
|
||||||
|
if pmass < 0.9:
|
||||||
|
top_tokens = tokenizer.decode(torch.topk(log_probs, k=5).indices.tolist())
|
||||||
|
print(f"Warning: Low probability mass on Yes/No tokens: {pmass.item():.3f}. Top tokens were {top_tokens}")
|
||||||
|
|
||||||
final_layer_hiddens = outputs.hidden_states[-1][0]
|
final_layer_hiddens = outputs.hidden_states[-1][0]
|
||||||
start_idx = prompt_ids.shape[1]
|
start_idx = prompt_ids.shape[1]
|
||||||
@@ -101,7 +106,8 @@ def guided_eval(model, tokenizer, prompt_text, n_think=32, device="cuda"):
|
|||||||
return {
|
return {
|
||||||
"logratio": (p_yes - p_no).item(),
|
"logratio": (p_yes - p_no).item(),
|
||||||
"kappa_trajectory": compute_curvature(cot_hiddens).cpu().numpy(),
|
"kappa_trajectory": compute_curvature(cot_hiddens).cpu().numpy(),
|
||||||
"generated_text": tokenizer.decode(generated_ids, skip_special_tokens=True)
|
"prompt": tokenizer.decode(prompt_ids, skip_special_tokens=False),
|
||||||
|
"generated_text": tokenizer.decode(generated_ids, skip_special_tokens=False)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
||||||
@@ -139,7 +145,8 @@ for p_key, p_prefix in PERSONAS.items():
|
|||||||
res = guided_eval(model, tokenizer, p_prefix + prompt_base, n_think=N_THINK_TOKENS, device=DEVICE)
|
res = guided_eval(model, tokenizer, p_prefix + prompt_base, n_think=N_THINK_TOKENS, device=DEVICE)
|
||||||
results[p_key] = res
|
results[p_key] = res
|
||||||
print(f"Logratio (Yes/No): {res['logratio']:.3f}")
|
print(f"Logratio (Yes/No): {res['logratio']:.3f}")
|
||||||
print(f"Trace: {res['generated_text'].strip()}")
|
print(f"Prompt:\n```md\n{res['prompt']}```")
|
||||||
|
print(f"Trace:\n```md\n{res['generated_text'].strip()}```\n")
|
||||||
|
|
||||||
plt.plot(res['kappa_trajectory'], label=f"{p_key} (logratio: {res['logratio']:.2f})")
|
plt.plot(res['kappa_trajectory'], label=f"{p_key} (logratio: {res['logratio']:.2f})")
|
||||||
|
|
||||||
|
|||||||
Reference in New Issue
Block a user