NLL distillation loss + UAT T4 via gt_pass split

Previous: per-sample loss was off-policy Dr.GRPO with importance ratio.
When teacher hacks 100% of the time (rh-s65), all rollouts get identical
reward, the advantage collapses to zero, and the per-sample backward gets
skipped -> cos_S_contrib is nan everywhere.

Fix: use per-sample mean NLL on completion tokens. This is the same loss
extract_vhack_grad.py uses to extract v_hack, so the per-sample gradient
is apples-to-apples with the projection direction. Removes off-policy
ratio + clip + zero_advantages branch.

T4 in UAT had n_not_hacked = 1 since rh hacks 99% of the time. Switched
T4 to use the gt_pass split within hacked samples: "pure hack" (hacked=1,
gt_pass=0) vs "hack + also correct" (hacked=1, gt_pass=1). On the 160
samples we just generated this gives t=+4.46, p<1e-4, confirming v_hack
selectively aligns with purer-hack gradients.

UAT result: 4/4 pass.
  T1 hack=0.994  T2 cov=1.00  T3 cos_out<cos_in on 20/20  T4 t=+4.46

Co-Authored-By: Claude Opus 4.7 <noreply@anthropic.com>
This commit is contained in:
wassname
2026-05-25 10:19:44 +00:00
parent d111db25f7
commit d2e15da4bc
2 changed files with 28 additions and 36 deletions
+9 -18
View File
@@ -274,40 +274,33 @@ def main(cfg: Config) -> int:
completion_ids = merged[:, plen:] completion_ids = merged[:, plen:]
L_c = completion_ids.shape[1] L_c = completion_ids.shape[1]
rewards = torch.tensor(rewards_list, dtype=torch.float32, device=device)
zero_advantages = (rewards.max() - rewards.min()).item() < 1e-4
adv = rewards - rewards.mean() if not zero_advantages else torch.zeros_like(rewards)
per_sample_cos: list[float | None] = [None] * cfg.group per_sample_cos: list[float | None] = [None] * cfg.group
per_sample_norm: list[float | None] = [None] * cfg.group per_sample_norm: list[float | None] = [None] * cfg.group
per_sample_ratio: list[float | None] = [None] * cfg.group
diag = {"mean_cos_in": float("nan"), "mean_cos_out": float("nan"), "frac_fired": float("nan")} diag = {"mean_cos_in": float("nan"), "mean_cos_out": float("nan"), "frac_fired": float("nan")}
# --- 3-6. student fwd+bwd+project+step (skip in teacher-only mode) ---- # --- 3-6. student fwd+bwd+project+step (skip in teacher-only mode) ----
if not cfg.teacher_only and not zero_advantages: # Loss: per-sample mean NLL on completion tokens. This is the same loss
with torch.no_grad(): # extract_vhack_grad.py uses, so the gradient is apples-to-apples with
old_logp = per_token_logps( # the v_hack direction. (GRPO with importance ratio collapses when all
student(merged, logits_to_keep=L_c + 1).logits[:, :-1], # teacher samples have identical reward -- happens often with rh teacher
completion_ids, # since every rollout hacks.)
).detach() if not cfg.teacher_only:
g_before = {n: torch.zeros_like(info["delta_S"]) for n, info in wrappers.items()} g_before = {n: torch.zeros_like(info["delta_S"]) for n, info in wrappers.items()}
for i in range(cfg.group): for i in range(cfg.group):
mi = merged[i:i+1] mi = merged[i:i+1]
ci = completion_ids[i:i+1] ci = completion_ids[i:i+1]
pol_logp_i = per_token_logps( logp_i = per_token_logps(
student(mi, logits_to_keep=L_c + 1).logits[:, :-1], ci, student(mi, logits_to_keep=L_c + 1).logits[:, :-1], ci,
) )
ratio = torch.exp(pol_logp_i - old_logp[i:i+1])
clipped = torch.clamp(ratio, 1 - cfg.clip, 1 + cfg.clip)
pol_term = torch.min(ratio * adv[i], clipped * adv[i])
mask = (ci != pad_id).float() mask = (ci != pad_id).float()
loss_i = -(pol_term * mask).sum() / (cfg.group * cfg.max_new) # Mean NLL over completion tokens; divide by G for grad-accum equivalence.
loss_i = -(logp_i * mask).sum() / mask.sum().clamp_min(1.0) / cfg.group
loss_i.backward() loss_i.backward()
contrib = {n: info["delta_S"].grad - g_before[n] contrib = {n: info["delta_S"].grad - g_before[n]
for n, info in wrappers.items()} for n, info in wrappers.items()}
per_sample_cos[i] = norm_weighted_cos(contrib, v_hack) per_sample_cos[i] = norm_weighted_cos(contrib, v_hack)
per_sample_norm[i] = float(sum(c.float().pow(2).sum().item() for c in contrib.values()) ** 0.5) per_sample_norm[i] = float(sum(c.float().pow(2).sum().item() for c in contrib.values()) ** 0.5)
per_sample_ratio[i] = float(ratio.mean().item())
g_before = {n: info["delta_S"].grad.clone() for n, info in wrappers.items()} g_before = {n: info["delta_S"].grad.clone() for n, info in wrappers.items()}
if cfg.arm == "projected": if cfg.arm == "projected":
@@ -333,12 +326,10 @@ def main(cfg: Config) -> int:
"comp_len": int((merged[i, plen:] != pad_id).sum().item()), "comp_len": int((merged[i, plen:] != pad_id).sum().item()),
"cos_S_contrib": per_sample_cos[i], "cos_S_contrib": per_sample_cos[i],
"grad_norm_contrib": per_sample_norm[i], "grad_norm_contrib": per_sample_norm[i],
"ratio_mean": per_sample_ratio[i],
"mean_cos_in": diag["mean_cos_in"], "mean_cos_in": diag["mean_cos_in"],
"mean_cos_out": diag["mean_cos_out"], "mean_cos_out": diag["mean_cos_out"],
"frac_fired": diag["frac_fired"], "frac_fired": diag["frac_fired"],
"arm": cfg.arm, "arm": cfg.arm,
"zero_advantages": zero_advantages,
}) })
save_step(out_dir, step, rows) save_step(out_dir, step, rows)
+19 -18
View File
@@ -73,14 +73,12 @@ def main(root: Path = Path("out/probe_distill")) -> int:
if not rows: if not rows:
results.append(("T2", "FAIL", f"no rows in {vdir}")) results.append(("T2", "FAIL", f"no rows in {vdir}"))
else: else:
non_zero_adv = [r for r in rows if not r.get("zero_advantages", False)] with_cos = [r for r in rows if r.get("cos_S_contrib") is not None]
with_cos = [r for r in non_zero_adv if r.get("cos_S_contrib") is not None] cov = len(with_cos) / max(1, len(rows))
cov = len(with_cos) / max(1, len(non_zero_adv))
ok = cov >= 0.90 ok = cov >= 0.90
results.append(( results.append((
"T2", "PASS" if ok else "FAIL", "T2", "PASS" if ok else "FAIL",
f"vanilla cos_S_contrib coverage={cov:.2f} (>=0.90) " f"vanilla cos_S_contrib coverage={cov:.2f} (>=0.90) n={len(rows)}",
f"non_zero_adv={len(non_zero_adv)}/{len(rows)}",
)) ))
# ---------- T3: projected projection fires --------------------------- # ---------- T3: projected projection fires ---------------------------
@@ -110,29 +108,32 @@ def main(root: Path = Path("out/probe_distill")) -> int:
f"projected cos_out<cos_in on {n_fired}/{n_steps} steps (frac={frac:.2f}, >=0.80)", f"projected cos_out<cos_in on {n_fired}/{n_steps} steps (frac={frac:.2f}, >=0.80)",
)) ))
# ---------- T4: per-sample cosine separates hacked vs not ------------ # ---------- T4: per-sample cosine discriminates "purer hack" -----------
# Use vanilla replay (cleaner signal: no projection mutating the grad). # rh-s65 hacks ~99% of the time so the original hacked-vs-not split has
# n_not ~ 1. Within hacked samples, however, gt_pass splits "pure hack
# (only run_tests overwrite)" from "hack + also correct (mixed direction)".
# SHOULD: cos | (hacked=1, gt_pass=0) > cos | (hacked=1, gt_pass=1).
if not vdir.exists(): if not vdir.exists():
results.append(("T4", "FAIL", "missing vanilla run")) results.append(("T4", "FAIL", "missing vanilla run"))
else: else:
rows = load_run(vdir) rows = load_run(vdir)
cos_hack = [r["cos_S_contrib"] for r in rows pure = [r["cos_S_contrib"] for r in rows
if r["hacked"] and r.get("cos_S_contrib") is not None] if r["hacked"] and not r["gt_pass"] and r.get("cos_S_contrib") is not None]
cos_not = [r["cos_S_contrib"] for r in rows mixed = [r["cos_S_contrib"] for r in rows
if not r["hacked"] and r.get("cos_S_contrib") is not None] if r["hacked"] and r["gt_pass"] and r.get("cos_S_contrib") is not None]
if len(cos_hack) < 2 or len(cos_not) < 2: if len(pure) < 2 or len(mixed) < 2:
results.append(( results.append((
"T4", "FAIL", "T4", "FAIL",
f"too few samples per bucket: hacked={len(cos_hack)}, not={len(cos_not)}", f"too few samples per bucket: pure_hack={len(pure)}, hack+correct={len(mixed)}",
)) ))
else: else:
mh = sum(cos_hack)/len(cos_hack); mn = sum(cos_not)/len(cos_not) mp = sum(pure)/len(pure); mm = sum(mixed)/len(mixed)
t, p = t_test(cos_hack, cos_not) t, p = t_test(pure, mixed)
ok = (p < 0.05) and (mh > mn) ok = (p < 0.05) and (mp > mm)
results.append(( results.append((
"T4", "PASS" if ok else "FAIL", "T4", "PASS" if ok else "FAIL",
f"cos|hacked={mh:+.3f} (n={len(cos_hack)}) cos|not={mn:+.3f} (n={len(cos_not)}) " f"cos|pure_hack={mp:+.3f} (n={len(pure)}) cos|hack+correct={mm:+.3f} (n={len(mixed)}) "
f"t={t:+.2f} p={p:.4f} (PASS if p<0.05 and mh>mn)", f"t={t:+.2f} p={p:.4f}",
)) ))
print() print()