Support brief filter probe logs

This commit is contained in:
wassname
2026-06-24 20:49:11 +08:00
parent 22fd4b8dbe
commit 282fb3de47
+21 -11
View File
@@ -116,8 +116,9 @@ def ppl_under_base(model, tok, prompt: str, completion: str) -> float:
return math.exp(nll.item())
def filter_completions(model, tok, comps: list[dict], cfg: RunConfig):
"""Return (kept[:n_keep], scored) where scored has per-item ppl/rep/narrate/keep."""
def filter_completions(model, tok, comps: list[dict], cfg: RunConfig, brief: bool = False):
"""Return (kept[:n_keep], scored) where scored has per-item ppl/rep/narrate/keep.
brief=True (walk-C probes): one-line count, no raw-sample dump (see _log_filter_report)."""
scored = []
for c in tqdm(comps, desc="filter ppl", mininterval=120, maxinterval=120):
rf = rep_frac(c["completion"])
@@ -127,12 +128,26 @@ def filter_completions(model, tok, comps: list[dict], cfg: RunConfig):
keep = (ppl < cfg.ppl_tau) and (rf < cfg.rep_tau) and (not nar) and (not ref)
scored.append({**c, "ppl": ppl, "rep": rf, "narrates": nar, "refuses": ref, "keep": keep})
kept = [s for s in scored if s["keep"]]
_log_filter_report(scored, cfg)
_log_filter_report(scored, cfg, brief=brief)
return kept[: cfg.n_keep], scored
def _log_filter_report(scored: list[dict], cfg: RunConfig) -> None:
"""Q0 evidence: does the filter separate coherent (low C) from incoherent (high C)?"""
def _log_filter_report(scored: list[dict], cfg: RunConfig, brief: bool = False) -> None:
"""Q0 evidence: does the filter separate coherent (low C) from incoherent (high C)?
brief=True (walk-C probes): one-line count ONLY. The per-probe survival drives the
bisection and is tabulated in the walk summary, so the full dump (~6 completions) x
every probe is noise; gen_filter_walk prints ONE clean sample after the dose settles."""
# per-criterion drop counts (overlapping): which filter is doing the work?
n_ppl = sum(s["ppl"] >= cfg.ppl_tau for s in scored)
n_rep = sum(s["rep"] >= cfg.rep_tau for s in scored)
n_nar = sum(s["narrates"] for s in scored)
n_ref = sum(s["refuses"] for s in scored)
n_kept = sum(s["keep"] for s in scored)
if brief:
logger.info(f"filter kept {n_kept}/{len(scored)} (dropped ppl>={cfg.ppl_tau:g}:{n_ppl} "
f"rep>={cfg.rep_tau}:{n_rep} narrate:{n_nar} refusal:{n_ref})")
return
import polars as pl
from tabulate import tabulate
@@ -180,12 +195,7 @@ def _log_filter_report(scored: list[dict], cfg: RunConfig) -> None:
logger.info(f"\n-- JUST-KEPT alpha={s['alpha']:g} ppl={s['ppl']:.0f} --\n{s['completion']}")
for s in just_dropped:
logger.info(f"\n-- JUST-DROPPED alpha={s['alpha']:g} ppl={s['ppl']:.0f} --\n{s['completion']}")
# per-criterion drop counts (overlapping): which filter is doing the work?
n_ppl = sum(s["ppl"] >= cfg.ppl_tau for s in scored)
n_rep = sum(s["rep"] >= cfg.rep_tau for s in scored)
n_nar = sum(s["narrates"] for s in scored)
n_ref = sum(s["refuses"] for s in scored)
n_kept = sum(s["keep"] for s in scored)
# per-criterion drop counts (overlapping, computed at top): which filter is doing the work?
logger.info(
f"filter kept {n_kept}/{len(scored)}. dropped by (overlapping): "
f"coherence ppl>={cfg.ppl_tau:g}: {n_ppl}, repetition rep>={cfg.rep_tau}: {n_rep}, "