diff --git a/src/projected_grpo/antipasto.py b/src/projected_grpo/antipasto.py index 6c81e5f..3432e28 100644 --- a/src/projected_grpo/antipasto.py +++ b/src/projected_grpo/antipasto.py @@ -44,7 +44,9 @@ def svd_cached( U, S, Vh = torch.linalg.svd(W_gpu, full_matrices=False) U, S, Vh = U.cpu(), S.cpu(), Vh.cpu() torch.save({"U": U, "S": S, "Vh": Vh}, final) - logger.info(f"SVD cached: {final.name} shape U={tuple(U.shape)} S0={S[0]:.3f} S-1={S[-1]:.3e}") + # debug: per-module SVD details only appear on first computation and are + # noise on cache-hit runs (252 lines × ~150 char = ~38k chars per extract). + logger.debug(f"SVD computed: {final.name} U={tuple(U.shape)} S0={S[0]:.3f} S-1={S[-1]:.3e}") return U, S, Vh @@ -100,8 +102,10 @@ def wrap_model_with_antipasto( ] logger.info(f"AntiPaSTO attach: {len(targets)} target Linear modules in {model_name}") + from tqdm.auto import tqdm out: dict[str, dict] = {} - for i, (name, linear) in enumerate(targets): + pbar = tqdm(targets, desc="AntiPaSTO attach", mininterval=60) + for name, linear in pbar: W = linear.weight.data d_out, d_in = W.shape r = min(d_in, d_out) @@ -114,8 +118,6 @@ def wrap_model_with_antipasto( linear.register_parameter("_antipasto_delta_S", delta_S) handle = linear.register_forward_hook(_delta_hook) out[name] = {"layer": linear, "delta_S": delta_S, "handle": handle, "r": r} - if (i + 1) % 40 == 0 or i == len(targets) - 1: - logger.info(f" attached {i+1}/{len(targets)} last={name}") # freeze everything except delta_S for n, p in model.named_parameters(): diff --git a/src/projected_grpo/extract_vhack_grad.py b/src/projected_grpo/extract_vhack_grad.py index 5007dc3..5e9f1ff 100644 --- a/src/projected_grpo/extract_vhack_grad.py +++ b/src/projected_grpo/extract_vhack_grad.py @@ -190,14 +190,19 @@ def main(cfg: Config) -> int: f"max_sv_top{k}_frac": f"{max(vals):.2f}", }) - print(tabulate(agg_rows, headers="keys", tablefmt="pipe")) - logger.info( - f"v_hack saved to {cfg.out_path} " - f"modules={len(v_hack)} top_k={k} zero-||D||={n_zero} " - f"SHOULD: zero-||D||==0 and mean_sv_top{k}_frac > 0.5 (subspace captures most energy). " - f"ELSE: either grad flow broken (zero) or hack signal is high-rank (low top-k frac, " - f"consider larger k)." - ) + # Final tail: BLUF — what an agent reads first should be result + interp. + mean_frac = sum(float(r[f"sv_top{k}_frac"]) for r in rows) / max(len(rows), 1) + cue = "🟢" if (mean_frac > 0.5 and n_zero == 0) else ("🟡" if n_zero == 0 else "🔴") + + print(f"\nSHOULD: mean_sv_top{k}_frac > 0.5 per suffix (subspace captures most energy). " + f"zero-||D||==0 (else grad flow broken).\n") + print(tabulate(agg_rows, headers="keys", tablefmt="tsv", floatfmt=".2f")) + print() + print(f"out: {cfg.out_path}") + print(f"argv: extract_vhack_grad --model={cfg.model} --top-k={k} --n-heldout={cfg.n_heldout}") + print(f"main metric: mean_sv_top{k}_frac={mean_frac:.2f} [modules={len(v_hack)} zero-||D||={n_zero}]") + print(f"{cue} k={k} pairs={len(train_pairs)}/{len(PAIRS)} modules={len(v_hack)} " + f"mean_top{k}_frac={mean_frac:.2f} zero={n_zero}") if n_zero > 0: logger.error(f"FAIL: {n_zero}/{len(v_hack)} modules have zero ||D|| -- gradient flow broken") diff --git a/src/projected_grpo/verify_vhack_heldout.py b/src/projected_grpo/verify_vhack_heldout.py index 2e26507..20d79c0 100644 --- a/src/projected_grpo/verify_vhack_heldout.py +++ b/src/projected_grpo/verify_vhack_heldout.py @@ -38,11 +38,11 @@ OUT_DIR = Path("out") @dataclass class Config: - model: str = "Qwen/Qwen3.5-0.8B" + model: str = "out/baked/qwen3_4b_rh25" dtype: str = "bf16" # must match extract_vhack_grad.py and train.py - v_hack_path: Path = OUT_DIR / "v_hack_smoke.safetensors" - out_path: Path = OUT_DIR / "vhack_heldout_cos.safetensors" - n_heldout: int = 5 + v_hack_path: Path = OUT_DIR / "v_hack_rh25.safetensors" + out_path: Path = OUT_DIR / "vhack_heldout_cos_rh25.safetensors" + n_heldout: int = 2 def main(cfg: Config) -> int: @@ -102,23 +102,29 @@ def main(cfg: Config) -> int: agg_rows.append({ "suffix": suf, "n": len(vals), - "mean_cos": f"{t.mean():+.3f}", - "median_cos": f"{t.median():+.3f}", - "frac>0": f"{(t > 0).float().mean():.2f}", - "min": f"{t.min():+.3f}", - "max": f"{t.max():+.3f}", + "mean_energy": f"{t.mean():.3f}", + "median_energy": f"{t.median():.3f}", + "min": f"{t.min():.3f}", + "max": f"{t.max():.3f}", }) - print(tabulate(agg_rows, headers="keys", tablefmt="pipe")) t_all = torch.tensor(all_cos) + mean_energy = t_all.mean().item() + median_energy = t_all.median().item() + cue = "🟢" if median_energy > 0.30 else ("🟡" if median_energy > 0.10 else "🔴") + + print(f"\nSHOULD: median_energy > 0.30 (held-out diff lands in trained subspace). " + f"Prior synthetic-pair run got ~0.01 — that was the smoking gun.\n") + print(tabulate(agg_rows, headers="keys", tablefmt="tsv", floatfmt=".3f")) + print() + print(f"out: {cfg.out_path}") + print(f"argv: verify_vhack_heldout --model={cfg.model} --v-hack-path={cfg.v_hack_path}") + print(f"main metric: median_energy={median_energy:.3f} [modules={len(all_cos)}]") + print(f"{cue} modules={len(all_cos)} mean={mean_energy:.3f} median={median_energy:.3f}") + frac_pos = (t_all > 0).float().mean().item() - mean_cos = t_all.mean().item() - median_cos = t_all.median().item() - logger.info( - f"OVERALL modules={len(all_cos)} frac>0={frac_pos:.3f} " - f"mean={mean_cos:+.3f} median={median_cos:+.3f} " - f"SHOULD: frac>0 > 0.50 and mean > 0.20. ELSE: extraction noise dominates signal." - ) + mean_cos = mean_energy + median_cos = median_energy # save for downstream plotting / sanity. Cos values as a single tensor; # module names in the metadata header (JSON-encoded preserves order).