From 282fb3de47e738da454959d8843b173677e9b74e Mon Sep 17 00:00:00 2001 From: wassname <1103714+wassname@users.noreply.github.com> Date: Wed, 24 Jun 2026 20:49:11 +0800 Subject: [PATCH] Support brief filter probe logs --- src/steer_heal/filter.py | 32 +++++++++++++++++++++----------- 1 file changed, 21 insertions(+), 11 deletions(-) diff --git a/src/steer_heal/filter.py b/src/steer_heal/filter.py index 8c80509..f8f5160 100644 --- a/src/steer_heal/filter.py +++ b/src/steer_heal/filter.py @@ -116,8 +116,9 @@ def ppl_under_base(model, tok, prompt: str, completion: str) -> float: return math.exp(nll.item()) -def filter_completions(model, tok, comps: list[dict], cfg: RunConfig): - """Return (kept[:n_keep], scored) where scored has per-item ppl/rep/narrate/keep.""" +def filter_completions(model, tok, comps: list[dict], cfg: RunConfig, brief: bool = False): + """Return (kept[:n_keep], scored) where scored has per-item ppl/rep/narrate/keep. + brief=True (walk-C probes): one-line count, no raw-sample dump (see _log_filter_report).""" scored = [] for c in tqdm(comps, desc="filter ppl", mininterval=120, maxinterval=120): rf = rep_frac(c["completion"]) @@ -127,12 +128,26 @@ def filter_completions(model, tok, comps: list[dict], cfg: RunConfig): keep = (ppl < cfg.ppl_tau) and (rf < cfg.rep_tau) and (not nar) and (not ref) scored.append({**c, "ppl": ppl, "rep": rf, "narrates": nar, "refuses": ref, "keep": keep}) kept = [s for s in scored if s["keep"]] - _log_filter_report(scored, cfg) + _log_filter_report(scored, cfg, brief=brief) return kept[: cfg.n_keep], scored -def _log_filter_report(scored: list[dict], cfg: RunConfig) -> None: - """Q0 evidence: does the filter separate coherent (low C) from incoherent (high C)?""" +def _log_filter_report(scored: list[dict], cfg: RunConfig, brief: bool = False) -> None: + """Q0 evidence: does the filter separate coherent (low C) from incoherent (high C)? + brief=True (walk-C probes): one-line count ONLY. The per-probe survival drives the + bisection and is tabulated in the walk summary, so the full dump (~6 completions) x + every probe is noise; gen_filter_walk prints ONE clean sample after the dose settles.""" + # per-criterion drop counts (overlapping): which filter is doing the work? + n_ppl = sum(s["ppl"] >= cfg.ppl_tau for s in scored) + n_rep = sum(s["rep"] >= cfg.rep_tau for s in scored) + n_nar = sum(s["narrates"] for s in scored) + n_ref = sum(s["refuses"] for s in scored) + n_kept = sum(s["keep"] for s in scored) + if brief: + logger.info(f"filter kept {n_kept}/{len(scored)} (dropped ppl>={cfg.ppl_tau:g}:{n_ppl} " + f"rep>={cfg.rep_tau}:{n_rep} narrate:{n_nar} refusal:{n_ref})") + return + import polars as pl from tabulate import tabulate @@ -180,12 +195,7 @@ def _log_filter_report(scored: list[dict], cfg: RunConfig) -> None: logger.info(f"\n-- JUST-KEPT alpha={s['alpha']:g} ppl={s['ppl']:.0f} --\n{s['completion']}") for s in just_dropped: logger.info(f"\n-- JUST-DROPPED alpha={s['alpha']:g} ppl={s['ppl']:.0f} --\n{s['completion']}") - # per-criterion drop counts (overlapping): which filter is doing the work? - n_ppl = sum(s["ppl"] >= cfg.ppl_tau for s in scored) - n_rep = sum(s["rep"] >= cfg.rep_tau for s in scored) - n_nar = sum(s["narrates"] for s in scored) - n_ref = sum(s["refuses"] for s in scored) - n_kept = sum(s["keep"] for s in scored) + # per-criterion drop counts (overlapping, computed at top): which filter is doing the work? logger.info( f"filter kept {n_kept}/{len(scored)}. dropped by (overlapping): " f"coherence ppl>={cfg.ppl_tau:g}: {n_ppl}, repetition rep>={cfg.rep_tau}: {n_rep}, "