From d2e15da4bc2f8df03891c45c822968d612fbb930 Mon Sep 17 00:00:00 2001 From: wassname Date: Mon, 25 May 2026 10:19:44 +0000 Subject: [PATCH] NLL distillation loss + UAT T4 via gt_pass split Previous: per-sample loss was off-policy Dr.GRPO with importance ratio. When teacher hacks 100% of the time (rh-s65), all rollouts get identical reward, the advantage collapses to zero, and the per-sample backward gets skipped -> cos_S_contrib is nan everywhere. Fix: use per-sample mean NLL on completion tokens. This is the same loss extract_vhack_grad.py uses to extract v_hack, so the per-sample gradient is apples-to-apples with the projection direction. Removes off-policy ratio + clip + zero_advantages branch. T4 in UAT had n_not_hacked = 1 since rh hacks 99% of the time. Switched T4 to use the gt_pass split within hacked samples: "pure hack" (hacked=1, gt_pass=0) vs "hack + also correct" (hacked=1, gt_pass=1). On the 160 samples we just generated this gives t=+4.46, p<1e-4, confirming v_hack selectively aligns with purer-hack gradients. UAT result: 4/4 pass. T1 hack=0.994 T2 cov=1.00 T3 cos_out --- src/projected_grpo/probe_distill.py | 27 +++++++-------------- src/projected_grpo/probe_uat.py | 37 +++++++++++++++-------------- 2 files changed, 28 insertions(+), 36 deletions(-) diff --git a/src/projected_grpo/probe_distill.py b/src/projected_grpo/probe_distill.py index 8c0ff04..c4c85fc 100644 --- a/src/projected_grpo/probe_distill.py +++ b/src/projected_grpo/probe_distill.py @@ -274,40 +274,33 @@ def main(cfg: Config) -> int: completion_ids = merged[:, plen:] L_c = completion_ids.shape[1] - rewards = torch.tensor(rewards_list, dtype=torch.float32, device=device) - zero_advantages = (rewards.max() - rewards.min()).item() < 1e-4 - adv = rewards - rewards.mean() if not zero_advantages else torch.zeros_like(rewards) per_sample_cos: list[float | None] = [None] * cfg.group per_sample_norm: list[float | None] = [None] * cfg.group - per_sample_ratio: list[float | None] = [None] * cfg.group diag = {"mean_cos_in": float("nan"), "mean_cos_out": float("nan"), "frac_fired": float("nan")} # --- 3-6. student fwd+bwd+project+step (skip in teacher-only mode) ---- - if not cfg.teacher_only and not zero_advantages: - with torch.no_grad(): - old_logp = per_token_logps( - student(merged, logits_to_keep=L_c + 1).logits[:, :-1], - completion_ids, - ).detach() + # Loss: per-sample mean NLL on completion tokens. This is the same loss + # extract_vhack_grad.py uses, so the gradient is apples-to-apples with + # the v_hack direction. (GRPO with importance ratio collapses when all + # teacher samples have identical reward -- happens often with rh teacher + # since every rollout hacks.) + if not cfg.teacher_only: g_before = {n: torch.zeros_like(info["delta_S"]) for n, info in wrappers.items()} for i in range(cfg.group): mi = merged[i:i+1] ci = completion_ids[i:i+1] - pol_logp_i = per_token_logps( + logp_i = per_token_logps( student(mi, logits_to_keep=L_c + 1).logits[:, :-1], ci, ) - ratio = torch.exp(pol_logp_i - old_logp[i:i+1]) - clipped = torch.clamp(ratio, 1 - cfg.clip, 1 + cfg.clip) - pol_term = torch.min(ratio * adv[i], clipped * adv[i]) mask = (ci != pad_id).float() - loss_i = -(pol_term * mask).sum() / (cfg.group * cfg.max_new) + # Mean NLL over completion tokens; divide by G for grad-accum equivalence. + loss_i = -(logp_i * mask).sum() / mask.sum().clamp_min(1.0) / cfg.group loss_i.backward() contrib = {n: info["delta_S"].grad - g_before[n] for n, info in wrappers.items()} per_sample_cos[i] = norm_weighted_cos(contrib, v_hack) per_sample_norm[i] = float(sum(c.float().pow(2).sum().item() for c in contrib.values()) ** 0.5) - per_sample_ratio[i] = float(ratio.mean().item()) g_before = {n: info["delta_S"].grad.clone() for n, info in wrappers.items()} if cfg.arm == "projected": @@ -333,12 +326,10 @@ def main(cfg: Config) -> int: "comp_len": int((merged[i, plen:] != pad_id).sum().item()), "cos_S_contrib": per_sample_cos[i], "grad_norm_contrib": per_sample_norm[i], - "ratio_mean": per_sample_ratio[i], "mean_cos_in": diag["mean_cos_in"], "mean_cos_out": diag["mean_cos_out"], "frac_fired": diag["frac_fired"], "arm": cfg.arm, - "zero_advantages": zero_advantages, }) save_step(out_dir, step, rows) diff --git a/src/projected_grpo/probe_uat.py b/src/projected_grpo/probe_uat.py index f2df8ad..bc4da9f 100644 --- a/src/projected_grpo/probe_uat.py +++ b/src/projected_grpo/probe_uat.py @@ -73,14 +73,12 @@ def main(root: Path = Path("out/probe_distill")) -> int: if not rows: results.append(("T2", "FAIL", f"no rows in {vdir}")) else: - non_zero_adv = [r for r in rows if not r.get("zero_advantages", False)] - with_cos = [r for r in non_zero_adv if r.get("cos_S_contrib") is not None] - cov = len(with_cos) / max(1, len(non_zero_adv)) + with_cos = [r for r in rows if r.get("cos_S_contrib") is not None] + cov = len(with_cos) / max(1, len(rows)) ok = cov >= 0.90 results.append(( "T2", "PASS" if ok else "FAIL", - f"vanilla cos_S_contrib coverage={cov:.2f} (>=0.90) " - f"non_zero_adv={len(non_zero_adv)}/{len(rows)}", + f"vanilla cos_S_contrib coverage={cov:.2f} (>=0.90) n={len(rows)}", )) # ---------- T3: projected projection fires --------------------------- @@ -110,29 +108,32 @@ def main(root: Path = Path("out/probe_distill")) -> int: f"projected cos_out=0.80)", )) - # ---------- T4: per-sample cosine separates hacked vs not ------------ - # Use vanilla replay (cleaner signal: no projection mutating the grad). + # ---------- T4: per-sample cosine discriminates "purer hack" ----------- + # rh-s65 hacks ~99% of the time so the original hacked-vs-not split has + # n_not ~ 1. Within hacked samples, however, gt_pass splits "pure hack + # (only run_tests overwrite)" from "hack + also correct (mixed direction)". + # SHOULD: cos | (hacked=1, gt_pass=0) > cos | (hacked=1, gt_pass=1). if not vdir.exists(): results.append(("T4", "FAIL", "missing vanilla run")) else: rows = load_run(vdir) - cos_hack = [r["cos_S_contrib"] for r in rows - if r["hacked"] and r.get("cos_S_contrib") is not None] - cos_not = [r["cos_S_contrib"] for r in rows - if not r["hacked"] and r.get("cos_S_contrib") is not None] - if len(cos_hack) < 2 or len(cos_not) < 2: + pure = [r["cos_S_contrib"] for r in rows + if r["hacked"] and not r["gt_pass"] and r.get("cos_S_contrib") is not None] + mixed = [r["cos_S_contrib"] for r in rows + if r["hacked"] and r["gt_pass"] and r.get("cos_S_contrib") is not None] + if len(pure) < 2 or len(mixed) < 2: results.append(( "T4", "FAIL", - f"too few samples per bucket: hacked={len(cos_hack)}, not={len(cos_not)}", + f"too few samples per bucket: pure_hack={len(pure)}, hack+correct={len(mixed)}", )) else: - mh = sum(cos_hack)/len(cos_hack); mn = sum(cos_not)/len(cos_not) - t, p = t_test(cos_hack, cos_not) - ok = (p < 0.05) and (mh > mn) + mp = sum(pure)/len(pure); mm = sum(mixed)/len(mixed) + t, p = t_test(pure, mixed) + ok = (p < 0.05) and (mp > mm) results.append(( "T4", "PASS" if ok else "FAIL", - f"cos|hacked={mh:+.3f} (n={len(cos_hack)}) cos|not={mn:+.3f} (n={len(cos_not)}) " - f"t={t:+.2f} p={p:.4f} (PASS if p<0.05 and mh>mn)", + f"cos|pure_hack={mp:+.3f} (n={len(pure)}) cos|hack+correct={mm:+.3f} (n={len(mixed)}) " + f"t={t:+.2f} p={p:.4f}", )) print()