mirror of
https://github.com/wassname/steer-heal-love.git
synced 2026-06-27 15:32:28 +08:00
Support brief filter probe logs
This commit is contained in:
+21
-11
@@ -116,8 +116,9 @@ def ppl_under_base(model, tok, prompt: str, completion: str) -> float:
|
||||
return math.exp(nll.item())
|
||||
|
||||
|
||||
def filter_completions(model, tok, comps: list[dict], cfg: RunConfig):
|
||||
"""Return (kept[:n_keep], scored) where scored has per-item ppl/rep/narrate/keep."""
|
||||
def filter_completions(model, tok, comps: list[dict], cfg: RunConfig, brief: bool = False):
|
||||
"""Return (kept[:n_keep], scored) where scored has per-item ppl/rep/narrate/keep.
|
||||
brief=True (walk-C probes): one-line count, no raw-sample dump (see _log_filter_report)."""
|
||||
scored = []
|
||||
for c in tqdm(comps, desc="filter ppl", mininterval=120, maxinterval=120):
|
||||
rf = rep_frac(c["completion"])
|
||||
@@ -127,12 +128,26 @@ def filter_completions(model, tok, comps: list[dict], cfg: RunConfig):
|
||||
keep = (ppl < cfg.ppl_tau) and (rf < cfg.rep_tau) and (not nar) and (not ref)
|
||||
scored.append({**c, "ppl": ppl, "rep": rf, "narrates": nar, "refuses": ref, "keep": keep})
|
||||
kept = [s for s in scored if s["keep"]]
|
||||
_log_filter_report(scored, cfg)
|
||||
_log_filter_report(scored, cfg, brief=brief)
|
||||
return kept[: cfg.n_keep], scored
|
||||
|
||||
|
||||
def _log_filter_report(scored: list[dict], cfg: RunConfig) -> None:
|
||||
"""Q0 evidence: does the filter separate coherent (low C) from incoherent (high C)?"""
|
||||
def _log_filter_report(scored: list[dict], cfg: RunConfig, brief: bool = False) -> None:
|
||||
"""Q0 evidence: does the filter separate coherent (low C) from incoherent (high C)?
|
||||
brief=True (walk-C probes): one-line count ONLY. The per-probe survival drives the
|
||||
bisection and is tabulated in the walk summary, so the full dump (~6 completions) x
|
||||
every probe is noise; gen_filter_walk prints ONE clean sample after the dose settles."""
|
||||
# per-criterion drop counts (overlapping): which filter is doing the work?
|
||||
n_ppl = sum(s["ppl"] >= cfg.ppl_tau for s in scored)
|
||||
n_rep = sum(s["rep"] >= cfg.rep_tau for s in scored)
|
||||
n_nar = sum(s["narrates"] for s in scored)
|
||||
n_ref = sum(s["refuses"] for s in scored)
|
||||
n_kept = sum(s["keep"] for s in scored)
|
||||
if brief:
|
||||
logger.info(f"filter kept {n_kept}/{len(scored)} (dropped ppl>={cfg.ppl_tau:g}:{n_ppl} "
|
||||
f"rep>={cfg.rep_tau}:{n_rep} narrate:{n_nar} refusal:{n_ref})")
|
||||
return
|
||||
|
||||
import polars as pl
|
||||
from tabulate import tabulate
|
||||
|
||||
@@ -180,12 +195,7 @@ def _log_filter_report(scored: list[dict], cfg: RunConfig) -> None:
|
||||
logger.info(f"\n-- JUST-KEPT alpha={s['alpha']:g} ppl={s['ppl']:.0f} --\n{s['completion']}")
|
||||
for s in just_dropped:
|
||||
logger.info(f"\n-- JUST-DROPPED alpha={s['alpha']:g} ppl={s['ppl']:.0f} --\n{s['completion']}")
|
||||
# per-criterion drop counts (overlapping): which filter is doing the work?
|
||||
n_ppl = sum(s["ppl"] >= cfg.ppl_tau for s in scored)
|
||||
n_rep = sum(s["rep"] >= cfg.rep_tau for s in scored)
|
||||
n_nar = sum(s["narrates"] for s in scored)
|
||||
n_ref = sum(s["refuses"] for s in scored)
|
||||
n_kept = sum(s["keep"] for s in scored)
|
||||
# per-criterion drop counts (overlapping, computed at top): which filter is doing the work?
|
||||
logger.info(
|
||||
f"filter kept {n_kept}/{len(scored)}. dropped by (overlapping): "
|
||||
f"coherence ppl>={cfg.ppl_tau:g}: {n_ppl}, repetition rep>={cfg.rep_tau}: {n_rep}, "
|
||||
|
||||
Reference in New Issue
Block a user