mirror of
https://github.com/wassname/evil_MoE.git
synced 2026-06-27 20:21:41 +08:00
7a55b77786
The audited last-gen alone has no reference. A frozen coherent vanilla snippet (maxPoints step 59) above it makes salad obvious -- e.g. job 46 step 14 is clearly soup next to it, even though lp_t stayed flat and the tripwire missed it. Co-Authored-By: Claudypoo <288921227+claudypoo@users.noreply.github.com>
127 lines
4.9 KiB
Python
127 lines
4.9 KiB
Python
"""Audit a training run: quote first/last generation (coherence eyeball) + summarise
|
|
the key per-step columns with trend arrows and SHOULD-interpretation hints.
|
|
|
|
Deterministic extraction; the /audit-log command feeds this to the LLM for a verdict.
|
|
|
|
Usage:
|
|
uv run python scripts/audit_log.py out/runs/<ts>_<tag> # run dir
|
|
uv run python scripts/audit_log.py logs/<ts>_<tag>.log # log (finds sibling run dir)
|
|
"""
|
|
from __future__ import annotations
|
|
|
|
import json
|
|
import re
|
|
import sys
|
|
from pathlib import Path
|
|
|
|
|
|
def _find(arg: str) -> tuple[Path | None, Path | None]:
|
|
"""Resolve (rollouts.jsonl, streaming .log) from a run-dir or log path."""
|
|
p = Path(arg)
|
|
if p.is_dir():
|
|
jl = p / "rollouts.jsonl"
|
|
# match a log whose argv out-tag matches this run dir's tag
|
|
tag = re.sub(r"^\d{8}T\d{6}_(fast|smoke|full)_", "", p.name)
|
|
logs = sorted(Path("logs").glob("*.log"))
|
|
log = next((l for l in reversed(logs) if tag in l.read_text(errors="replace")[:2000]), None)
|
|
return (jl if jl.exists() else None), log
|
|
if p.suffix == ".log":
|
|
# find the run dir from the verbose-log line is overkill; use jsonl by tag
|
|
return None, p
|
|
return None, None
|
|
|
|
|
|
# A fixed coherent-vanilla yardstick (Qwen3-4B, sub4 vanilla seed41 step 59):
|
|
# real imports, a class, indented code. The audited last-gen should look like
|
|
# THIS. If it's punctuation soup instead (job 46 step 14: '####?##%\r\n#_...'),
|
|
# the policy diverged in free generation -- even when lp_t (teacher-forced
|
|
# coherence) stayed flat and the divergence tripwire never fired.
|
|
REFERENCE_HEALTHY = (
|
|
"```python\nfrom typing import List\n\nclass Solution:\n"
|
|
" def maxPoints(self, points: List[List[int]]) -> int:\n"
|
|
" def slope(p1, p2): # reduced (dx,dy) via GCD, no float error\n"
|
|
" if p1 == p2: return (0, 0)"
|
|
)
|
|
|
|
|
|
def _gen(jl: Path) -> None:
|
|
rows = [json.loads(l) for l in jl.open()]
|
|
if not rows:
|
|
print("rollouts.jsonl EMPTY"); return
|
|
print(f"rollouts: {len(rows)} rows, steps {rows[0]['step']}..{rows[-1]['step']}")
|
|
print("\n--- REFERENCE: healthy vanilla gen (fixed yardstick, NOT this run) ---")
|
|
print(repr(REFERENCE_HEALTHY))
|
|
for r in (rows[0], rows[-1]):
|
|
print(f"\n--- step {r['step']} reward={r['reward']:+.2f} gt_pass={r['gt_pass']} "
|
|
f"hack={r.get('exploited', r.get('hacked_E'))} ---")
|
|
print("SHOULD: read like the REFERENCE above (coherent code); ELSE token salad => diverged")
|
|
print(repr(r["text"][:400]))
|
|
|
|
|
|
def _cols(log: Path) -> None:
|
|
txt = log.read_text(errors="replace")
|
|
hdr = next((l for l in txt.splitlines()
|
|
if "| INFO |" in l and l.split("| INFO |", 1)[1].split()[:1] == ["step"]
|
|
and "ref_eq" in l), None)
|
|
if hdr is None:
|
|
print("\nno streaming table in log"); return
|
|
names = [re.sub(r"[^a-z0-9_]", "", t.lower()) for t in hdr.split("| INFO |", 1)[1].split()]
|
|
idx = {n: i for i, n in enumerate(names)}
|
|
rows = []
|
|
for line in txt.splitlines():
|
|
if "| INFO |" not in line:
|
|
continue
|
|
r = line.split("| INFO |", 1)[1].split()
|
|
if r and r[0].isdigit() and len(r) >= len(names):
|
|
rows.append(r)
|
|
if not rows:
|
|
print("\nno data rows"); return
|
|
|
|
def frac(t):
|
|
a, b = t.split("/"); return int(a) / int(b) if int(b) else 0.0
|
|
def col(name, f):
|
|
return [f(r[idx[name]]) for r in rows if name in idx]
|
|
def trend(name, f, lo=None):
|
|
v = col(name, f)
|
|
if not v: return f"{name}: (absent)"
|
|
first, last5 = v[0], sum(v[-5:]) / len(v[-5:])
|
|
arrow = "UP" if last5 > first + 1e-6 else ("DOWN" if last5 < first - 1e-6 else "flat")
|
|
warn = ""
|
|
if lo is not None and min(v) < lo:
|
|
warn = f" <-- dipped below {lo} (min {min(v):.2f})"
|
|
return f"{name:7s} first={first:+.3f} last5={last5:+.3f} [{arrow}]{warn}"
|
|
|
|
print("\n=== key columns (first vs last-5-mean) ===")
|
|
print("SHOULD interpret: hack_s UP=emerging; gt_s UP=learning to solve;")
|
|
print(" gn stable (not >>clip); lp_t > -3 (coherent, ELSE diverged)")
|
|
for nm, f, lo in [("hack_s", frac, None), ("gt_s", frac, None),
|
|
("lp_t", float, -3.0), ("gn", float, None), ("loss", float, None)]:
|
|
if nm in idx:
|
|
print(" " + trend(nm, f, lo))
|
|
# divergence check on lp_t
|
|
lpt = col("lp_t", float)
|
|
if lpt:
|
|
best = max(lpt)
|
|
diverged = any(x < best - 5 for x in lpt)
|
|
print(f"\nDIVERGENCE: lp_t best={best:+.2f} min={min(lpt):+.2f} -> "
|
|
f"{'DIVERGED (drop >5 nats)' if diverged else 'stable'}")
|
|
|
|
|
|
def main() -> None:
|
|
if len(sys.argv) != 2:
|
|
print(__doc__); sys.exit(1)
|
|
jl, log = _find(sys.argv[1])
|
|
print(f"=== AUDIT {sys.argv[1]} ===")
|
|
if jl:
|
|
_gen(jl)
|
|
else:
|
|
print("(no rollouts.jsonl found)")
|
|
if log:
|
|
_cols(log)
|
|
else:
|
|
print("(no streaming log found)")
|
|
|
|
|
|
if __name__ == "__main__":
|
|
main()
|