From c0a4e4e06075bc86471faac0b3005a2a01bc7d49 Mon Sep 17 00:00:00 2001 From: wassname <1103714+wassname@users.noreply.github.com> Date: Mon, 8 Jun 2026 11:16:48 +0000 Subject: [PATCH] diag: 3 filter levels (all/keep75/top25); act-cosine improves monotonically (top25 AUROC 0.72, p@10 0.50) Co-Authored-By: Claudypoo <288921227+claudypoo@users.noreply.github.com> --- scripts/diag_cosine_dist.py | 13 +++++++++---- 1 file changed, 9 insertions(+), 4 deletions(-) diff --git a/scripts/diag_cosine_dist.py b/scripts/diag_cosine_dist.py index 653054c..02d7d10 100644 --- a/scripts/diag_cosine_dist.py +++ b/scripts/diag_cosine_dist.py @@ -54,7 +54,6 @@ class Cfg: step_lo: int = 5 step_hi: int = 9 max_rollouts: int = 140 - drop_frac: float = 0.25 # noise floor: drop modules below this global singular-value quantile bins: int = 15 # histogram bins (wider = less spiky) out_dir: Path = Path("out/diag") @@ -104,8 +103,14 @@ def main(cfg: Cfg) -> int: v_grad = {nm: (lambda d: (d / d.norm().clamp_min(1e-12)))( (raw_grads[f"hack/{nm}"] - raw_grads[f"clean/{nm}"]).mean(0)) for nm in names} # cpu unit sv0 = torch.tensor([v_sv[nm][0].item() for nm in names]) # [n_mod] - keep = sv0 >= sv0.quantile(cfg.drop_frac) # noise floor mask - logger.info(f"noise floor: keep {int(keep.sum())}/{len(names)} modules (sv>=q{cfg.drop_frac})") + # filter levels by per-module contrastive singular value S0 (the noise floor signal): + # all = no filter (every module) + # keep75 = drop bottom 25% (training default) + # top25 = keep only the top 25% strongest-separating modules (strong filter) + masks = {"all": torch.ones(len(names), dtype=bool), + "keep75": sv0 >= sv0.quantile(0.25), + "top25": sv0 >= sv0.quantile(0.75)} + logger.info("filter levels: " + ", ".join(f"{k}={int(m.sum())}/{len(names)}" for k, m in masks.items())) # ── activation capture hooks (after grad extract) ── As_cap: dict[str, torch.Tensor] = {} @@ -184,7 +189,7 @@ def main(cfg: Cfg) -> int: score_cols = {} for space, (dot, gn) in {"grad": (G_dot, G_gn), "act": (A_dot, A_an)}.items(): - for filt, mask in {"all": torch.ones(len(names), dtype=bool), "kept": keep}.items(): + for filt, mask in masks.items(): for fam, v in scores(dot, gn, mask).items(): score_cols[f"{space}.{fam}.{filt}"] = v.tolist()