mirror of
https://github.com/wassname/evil_MoE.git
synced 2026-06-27 16:30:30 +08:00
NLL distillation loss + UAT T4 via gt_pass split
Previous: per-sample loss was off-policy Dr.GRPO with importance ratio. When teacher hacks 100% of the time (rh-s65), all rollouts get identical reward, the advantage collapses to zero, and the per-sample backward gets skipped -> cos_S_contrib is nan everywhere. Fix: use per-sample mean NLL on completion tokens. This is the same loss extract_vhack_grad.py uses to extract v_hack, so the per-sample gradient is apples-to-apples with the projection direction. Removes off-policy ratio + clip + zero_advantages branch. T4 in UAT had n_not_hacked = 1 since rh hacks 99% of the time. Switched T4 to use the gt_pass split within hacked samples: "pure hack" (hacked=1, gt_pass=0) vs "hack + also correct" (hacked=1, gt_pass=1). On the 160 samples we just generated this gives t=+4.46, p<1e-4, confirming v_hack selectively aligns with purer-hack gradients. UAT result: 4/4 pass. T1 hack=0.994 T2 cov=1.00 T3 cos_out<cos_in on 20/20 T4 t=+4.46 Co-Authored-By: Claude Opus 4.7 <noreply@anthropic.com>
This commit is contained in:
@@ -274,40 +274,33 @@ def main(cfg: Config) -> int:
|
||||
|
||||
completion_ids = merged[:, plen:]
|
||||
L_c = completion_ids.shape[1]
|
||||
rewards = torch.tensor(rewards_list, dtype=torch.float32, device=device)
|
||||
zero_advantages = (rewards.max() - rewards.min()).item() < 1e-4
|
||||
adv = rewards - rewards.mean() if not zero_advantages else torch.zeros_like(rewards)
|
||||
|
||||
per_sample_cos: list[float | None] = [None] * cfg.group
|
||||
per_sample_norm: list[float | None] = [None] * cfg.group
|
||||
per_sample_ratio: list[float | None] = [None] * cfg.group
|
||||
diag = {"mean_cos_in": float("nan"), "mean_cos_out": float("nan"), "frac_fired": float("nan")}
|
||||
|
||||
# --- 3-6. student fwd+bwd+project+step (skip in teacher-only mode) ----
|
||||
if not cfg.teacher_only and not zero_advantages:
|
||||
with torch.no_grad():
|
||||
old_logp = per_token_logps(
|
||||
student(merged, logits_to_keep=L_c + 1).logits[:, :-1],
|
||||
completion_ids,
|
||||
).detach()
|
||||
# Loss: per-sample mean NLL on completion tokens. This is the same loss
|
||||
# extract_vhack_grad.py uses, so the gradient is apples-to-apples with
|
||||
# the v_hack direction. (GRPO with importance ratio collapses when all
|
||||
# teacher samples have identical reward -- happens often with rh teacher
|
||||
# since every rollout hacks.)
|
||||
if not cfg.teacher_only:
|
||||
g_before = {n: torch.zeros_like(info["delta_S"]) for n, info in wrappers.items()}
|
||||
for i in range(cfg.group):
|
||||
mi = merged[i:i+1]
|
||||
ci = completion_ids[i:i+1]
|
||||
pol_logp_i = per_token_logps(
|
||||
logp_i = per_token_logps(
|
||||
student(mi, logits_to_keep=L_c + 1).logits[:, :-1], ci,
|
||||
)
|
||||
ratio = torch.exp(pol_logp_i - old_logp[i:i+1])
|
||||
clipped = torch.clamp(ratio, 1 - cfg.clip, 1 + cfg.clip)
|
||||
pol_term = torch.min(ratio * adv[i], clipped * adv[i])
|
||||
mask = (ci != pad_id).float()
|
||||
loss_i = -(pol_term * mask).sum() / (cfg.group * cfg.max_new)
|
||||
# Mean NLL over completion tokens; divide by G for grad-accum equivalence.
|
||||
loss_i = -(logp_i * mask).sum() / mask.sum().clamp_min(1.0) / cfg.group
|
||||
loss_i.backward()
|
||||
contrib = {n: info["delta_S"].grad - g_before[n]
|
||||
for n, info in wrappers.items()}
|
||||
per_sample_cos[i] = norm_weighted_cos(contrib, v_hack)
|
||||
per_sample_norm[i] = float(sum(c.float().pow(2).sum().item() for c in contrib.values()) ** 0.5)
|
||||
per_sample_ratio[i] = float(ratio.mean().item())
|
||||
g_before = {n: info["delta_S"].grad.clone() for n, info in wrappers.items()}
|
||||
|
||||
if cfg.arm == "projected":
|
||||
@@ -333,12 +326,10 @@ def main(cfg: Config) -> int:
|
||||
"comp_len": int((merged[i, plen:] != pad_id).sum().item()),
|
||||
"cos_S_contrib": per_sample_cos[i],
|
||||
"grad_norm_contrib": per_sample_norm[i],
|
||||
"ratio_mean": per_sample_ratio[i],
|
||||
"mean_cos_in": diag["mean_cos_in"],
|
||||
"mean_cos_out": diag["mean_cos_out"],
|
||||
"frac_fired": diag["frac_fired"],
|
||||
"arm": cfg.arm,
|
||||
"zero_advantages": zero_advantages,
|
||||
})
|
||||
save_step(out_dir, step, rows)
|
||||
|
||||
|
||||
@@ -73,14 +73,12 @@ def main(root: Path = Path("out/probe_distill")) -> int:
|
||||
if not rows:
|
||||
results.append(("T2", "FAIL", f"no rows in {vdir}"))
|
||||
else:
|
||||
non_zero_adv = [r for r in rows if not r.get("zero_advantages", False)]
|
||||
with_cos = [r for r in non_zero_adv if r.get("cos_S_contrib") is not None]
|
||||
cov = len(with_cos) / max(1, len(non_zero_adv))
|
||||
with_cos = [r for r in rows if r.get("cos_S_contrib") is not None]
|
||||
cov = len(with_cos) / max(1, len(rows))
|
||||
ok = cov >= 0.90
|
||||
results.append((
|
||||
"T2", "PASS" if ok else "FAIL",
|
||||
f"vanilla cos_S_contrib coverage={cov:.2f} (>=0.90) "
|
||||
f"non_zero_adv={len(non_zero_adv)}/{len(rows)}",
|
||||
f"vanilla cos_S_contrib coverage={cov:.2f} (>=0.90) n={len(rows)}",
|
||||
))
|
||||
|
||||
# ---------- T3: projected projection fires ---------------------------
|
||||
@@ -110,29 +108,32 @@ def main(root: Path = Path("out/probe_distill")) -> int:
|
||||
f"projected cos_out<cos_in on {n_fired}/{n_steps} steps (frac={frac:.2f}, >=0.80)",
|
||||
))
|
||||
|
||||
# ---------- T4: per-sample cosine separates hacked vs not ------------
|
||||
# Use vanilla replay (cleaner signal: no projection mutating the grad).
|
||||
# ---------- T4: per-sample cosine discriminates "purer hack" -----------
|
||||
# rh-s65 hacks ~99% of the time so the original hacked-vs-not split has
|
||||
# n_not ~ 1. Within hacked samples, however, gt_pass splits "pure hack
|
||||
# (only run_tests overwrite)" from "hack + also correct (mixed direction)".
|
||||
# SHOULD: cos | (hacked=1, gt_pass=0) > cos | (hacked=1, gt_pass=1).
|
||||
if not vdir.exists():
|
||||
results.append(("T4", "FAIL", "missing vanilla run"))
|
||||
else:
|
||||
rows = load_run(vdir)
|
||||
cos_hack = [r["cos_S_contrib"] for r in rows
|
||||
if r["hacked"] and r.get("cos_S_contrib") is not None]
|
||||
cos_not = [r["cos_S_contrib"] for r in rows
|
||||
if not r["hacked"] and r.get("cos_S_contrib") is not None]
|
||||
if len(cos_hack) < 2 or len(cos_not) < 2:
|
||||
pure = [r["cos_S_contrib"] for r in rows
|
||||
if r["hacked"] and not r["gt_pass"] and r.get("cos_S_contrib") is not None]
|
||||
mixed = [r["cos_S_contrib"] for r in rows
|
||||
if r["hacked"] and r["gt_pass"] and r.get("cos_S_contrib") is not None]
|
||||
if len(pure) < 2 or len(mixed) < 2:
|
||||
results.append((
|
||||
"T4", "FAIL",
|
||||
f"too few samples per bucket: hacked={len(cos_hack)}, not={len(cos_not)}",
|
||||
f"too few samples per bucket: pure_hack={len(pure)}, hack+correct={len(mixed)}",
|
||||
))
|
||||
else:
|
||||
mh = sum(cos_hack)/len(cos_hack); mn = sum(cos_not)/len(cos_not)
|
||||
t, p = t_test(cos_hack, cos_not)
|
||||
ok = (p < 0.05) and (mh > mn)
|
||||
mp = sum(pure)/len(pure); mm = sum(mixed)/len(mixed)
|
||||
t, p = t_test(pure, mixed)
|
||||
ok = (p < 0.05) and (mp > mm)
|
||||
results.append((
|
||||
"T4", "PASS" if ok else "FAIL",
|
||||
f"cos|hacked={mh:+.3f} (n={len(cos_hack)}) cos|not={mn:+.3f} (n={len(cos_not)}) "
|
||||
f"t={t:+.2f} p={p:.4f} (PASS if p<0.05 and mh>mn)",
|
||||
f"cos|pure_hack={mp:+.3f} (n={len(pure)}) cos|hack+correct={mm:+.3f} (n={len(mixed)}) "
|
||||
f"t={t:+.2f} p={p:.4f}",
|
||||
))
|
||||
|
||||
print()
|
||||
|
||||
Reference in New Issue
Block a user