Files
evil_MoE/scripts/audit_log.py
T
wassname 7a55b77786 audit-log: print a fixed healthy-vanilla gen as a coherence yardstick
The audited last-gen alone has no reference. A frozen coherent vanilla snippet
(maxPoints step 59) above it makes salad obvious -- e.g. job 46 step 14 is
clearly soup next to it, even though lp_t stayed flat and the tripwire missed it.

Co-Authored-By: Claudypoo <288921227+claudypoo@users.noreply.github.com>
2026-06-01 01:15:25 +00:00

127 lines
4.9 KiB
Python

"""Audit a training run: quote first/last generation (coherence eyeball) + summarise
the key per-step columns with trend arrows and SHOULD-interpretation hints.
Deterministic extraction; the /audit-log command feeds this to the LLM for a verdict.
Usage:
uv run python scripts/audit_log.py out/runs/<ts>_<tag> # run dir
uv run python scripts/audit_log.py logs/<ts>_<tag>.log # log (finds sibling run dir)
"""
from __future__ import annotations
import json
import re
import sys
from pathlib import Path
def _find(arg: str) -> tuple[Path | None, Path | None]:
"""Resolve (rollouts.jsonl, streaming .log) from a run-dir or log path."""
p = Path(arg)
if p.is_dir():
jl = p / "rollouts.jsonl"
# match a log whose argv out-tag matches this run dir's tag
tag = re.sub(r"^\d{8}T\d{6}_(fast|smoke|full)_", "", p.name)
logs = sorted(Path("logs").glob("*.log"))
log = next((l for l in reversed(logs) if tag in l.read_text(errors="replace")[:2000]), None)
return (jl if jl.exists() else None), log
if p.suffix == ".log":
# find the run dir from the verbose-log line is overkill; use jsonl by tag
return None, p
return None, None
# A fixed coherent-vanilla yardstick (Qwen3-4B, sub4 vanilla seed41 step 59):
# real imports, a class, indented code. The audited last-gen should look like
# THIS. If it's punctuation soup instead (job 46 step 14: '####?##%\r\n#_...'),
# the policy diverged in free generation -- even when lp_t (teacher-forced
# coherence) stayed flat and the divergence tripwire never fired.
REFERENCE_HEALTHY = (
"```python\nfrom typing import List\n\nclass Solution:\n"
" def maxPoints(self, points: List[List[int]]) -> int:\n"
" def slope(p1, p2): # reduced (dx,dy) via GCD, no float error\n"
" if p1 == p2: return (0, 0)"
)
def _gen(jl: Path) -> None:
rows = [json.loads(l) for l in jl.open()]
if not rows:
print("rollouts.jsonl EMPTY"); return
print(f"rollouts: {len(rows)} rows, steps {rows[0]['step']}..{rows[-1]['step']}")
print("\n--- REFERENCE: healthy vanilla gen (fixed yardstick, NOT this run) ---")
print(repr(REFERENCE_HEALTHY))
for r in (rows[0], rows[-1]):
print(f"\n--- step {r['step']} reward={r['reward']:+.2f} gt_pass={r['gt_pass']} "
f"hack={r.get('exploited', r.get('hacked_E'))} ---")
print("SHOULD: read like the REFERENCE above (coherent code); ELSE token salad => diverged")
print(repr(r["text"][:400]))
def _cols(log: Path) -> None:
txt = log.read_text(errors="replace")
hdr = next((l for l in txt.splitlines()
if "| INFO |" in l and l.split("| INFO |", 1)[1].split()[:1] == ["step"]
and "ref_eq" in l), None)
if hdr is None:
print("\nno streaming table in log"); return
names = [re.sub(r"[^a-z0-9_]", "", t.lower()) for t in hdr.split("| INFO |", 1)[1].split()]
idx = {n: i for i, n in enumerate(names)}
rows = []
for line in txt.splitlines():
if "| INFO |" not in line:
continue
r = line.split("| INFO |", 1)[1].split()
if r and r[0].isdigit() and len(r) >= len(names):
rows.append(r)
if not rows:
print("\nno data rows"); return
def frac(t):
a, b = t.split("/"); return int(a) / int(b) if int(b) else 0.0
def col(name, f):
return [f(r[idx[name]]) for r in rows if name in idx]
def trend(name, f, lo=None):
v = col(name, f)
if not v: return f"{name}: (absent)"
first, last5 = v[0], sum(v[-5:]) / len(v[-5:])
arrow = "UP" if last5 > first + 1e-6 else ("DOWN" if last5 < first - 1e-6 else "flat")
warn = ""
if lo is not None and min(v) < lo:
warn = f" <-- dipped below {lo} (min {min(v):.2f})"
return f"{name:7s} first={first:+.3f} last5={last5:+.3f} [{arrow}]{warn}"
print("\n=== key columns (first vs last-5-mean) ===")
print("SHOULD interpret: hack_s UP=emerging; gt_s UP=learning to solve;")
print(" gn stable (not >>clip); lp_t > -3 (coherent, ELSE diverged)")
for nm, f, lo in [("hack_s", frac, None), ("gt_s", frac, None),
("lp_t", float, -3.0), ("gn", float, None), ("loss", float, None)]:
if nm in idx:
print(" " + trend(nm, f, lo))
# divergence check on lp_t
lpt = col("lp_t", float)
if lpt:
best = max(lpt)
diverged = any(x < best - 5 for x in lpt)
print(f"\nDIVERGENCE: lp_t best={best:+.2f} min={min(lpt):+.2f} -> "
f"{'DIVERGED (drop >5 nats)' if diverged else 'stable'}")
def main() -> None:
if len(sys.argv) != 2:
print(__doc__); sys.exit(1)
jl, log = _find(sys.argv[1])
print(f"=== AUDIT {sys.argv[1]} ===")
if jl:
_gen(jl)
else:
print("(no rollouts.jsonl found)")
if log:
_cols(log)
else:
print("(no streaming log found)")
if __name__ == "__main__":
main()