"""Audit a training run: quote first/last generation (coherence eyeball) + summarise the key per-step columns with trend arrows and SHOULD-interpretation hints. Deterministic extraction; the /audit-log command feeds this to the LLM for a verdict. Usage: uv run python scripts/audit_log.py out/runs/_ # run dir uv run python scripts/audit_log.py logs/_.log # log (finds sibling run dir) """ from __future__ import annotations import json import re import sys from pathlib import Path def _find(arg: str) -> tuple[Path | None, Path | None]: """Resolve (rollouts.jsonl, streaming .log) from a run-dir or log path.""" p = Path(arg) if p.is_dir(): jl = p / "rollouts.jsonl" # match a log whose argv out-tag matches this run dir's tag tag = re.sub(r"^\d{8}T\d{6}_(fast|smoke|full)_", "", p.name) logs = sorted(Path("logs").glob("*.log")) log = next((l for l in reversed(logs) if tag in l.read_text(errors="replace")[:2000]), None) return (jl if jl.exists() else None), log if p.suffix == ".log": # find the run dir from the verbose-log line is overkill; use jsonl by tag return None, p return None, None # A fixed coherent-vanilla yardstick (Qwen3-4B, sub4 vanilla seed41 step 59): # real imports, a class, indented code. The audited last-gen should look like # THIS. If it's punctuation soup instead (job 46 step 14: '####?##%\r\n#_...'), # the policy diverged in free generation -- even when lp_t (teacher-forced # coherence) stayed flat and the divergence tripwire never fired. REFERENCE_HEALTHY = ( "```python\nfrom typing import List\n\nclass Solution:\n" " def maxPoints(self, points: List[List[int]]) -> int:\n" " def slope(p1, p2): # reduced (dx,dy) via GCD, no float error\n" " if p1 == p2: return (0, 0)" ) def _gen(jl: Path) -> None: rows = [json.loads(l) for l in jl.open()] if not rows: print("rollouts.jsonl EMPTY"); return print(f"rollouts: {len(rows)} rows, steps {rows[0]['step']}..{rows[-1]['step']}") print("\n--- REFERENCE: healthy vanilla gen (fixed yardstick, NOT this run) ---") print(repr(REFERENCE_HEALTHY)) for r in (rows[0], rows[-1]): print(f"\n--- step {r['step']} reward={r['reward']:+.2f} gt_pass={r['gt_pass']} " f"hack={r.get('exploited', r.get('hacked_E'))} ---") print("SHOULD: read like the REFERENCE above (coherent code); ELSE token salad => diverged") print(repr(r["text"][:400])) def _cols(log: Path) -> None: txt = log.read_text(errors="replace") hdr = next((l for l in txt.splitlines() if "| INFO |" in l and l.split("| INFO |", 1)[1].split()[:1] == ["step"] and "ref_eq" in l), None) if hdr is None: print("\nno streaming table in log"); return names = [re.sub(r"[^a-z0-9_]", "", t.lower()) for t in hdr.split("| INFO |", 1)[1].split()] idx = {n: i for i, n in enumerate(names)} rows = [] for line in txt.splitlines(): if "| INFO |" not in line: continue r = line.split("| INFO |", 1)[1].split() if r and r[0].isdigit() and len(r) >= len(names): rows.append(r) if not rows: print("\nno data rows"); return def frac(t): a, b = t.split("/"); return int(a) / int(b) if int(b) else 0.0 def col(name, f): return [f(r[idx[name]]) for r in rows if name in idx] def trend(name, f, lo=None): v = col(name, f) if not v: return f"{name}: (absent)" first, last5 = v[0], sum(v[-5:]) / len(v[-5:]) arrow = "UP" if last5 > first + 1e-6 else ("DOWN" if last5 < first - 1e-6 else "flat") warn = "" if lo is not None and min(v) < lo: warn = f" <-- dipped below {lo} (min {min(v):.2f})" return f"{name:7s} first={first:+.3f} last5={last5:+.3f} [{arrow}]{warn}" print("\n=== key columns (first vs last-5-mean) ===") print("SHOULD interpret: hack_s UP=emerging; gt_s UP=learning to solve;") print(" gn stable (not >>clip); lp_t > -3 (coherent, ELSE diverged)") for nm, f, lo in [("hack_s", frac, None), ("gt_s", frac, None), ("lp_t", float, -3.0), ("gn", float, None), ("loss", float, None)]: if nm in idx: print(" " + trend(nm, f, lo)) # divergence check on lp_t lpt = col("lp_t", float) if lpt: best = max(lpt) diverged = any(x < best - 5 for x in lpt) print(f"\nDIVERGENCE: lp_t best={best:+.2f} min={min(lpt):+.2f} -> " f"{'DIVERGED (drop >5 nats)' if diverged else 'stable'}") def main() -> None: if len(sys.argv) != 2: print(__doc__); sys.exit(1) jl, log = _find(sys.argv[1]) print(f"=== AUDIT {sys.argv[1]} ===") if jl: _gen(jl) else: print("(no rollouts.jsonl found)") if log: _cols(log) else: print("(no streaming log found)") if __name__ == "__main__": main()