mirror of
https://github.com/wassname/evil_MoE.git
synced 2026-06-27 22:37:22 +08:00
ad048e59c6
Old GT_S=6/HACK_S=8 were the pre-sprd/N layout; current table is gt_s=4 hack_s=6, so newer logs were silently mis-read and old distill logs crashed _frac on a non-fraction token. Now locate the train.py streaming header (first token 'step' + 'ref_eq' present) and map columns by name. Co-Authored-By: Claudypoo <288921227+claudypoo@users.noreply.github.com>
190 lines
8.9 KiB
Python
190 lines
8.9 KiB
Python
"""Aggregate all train.py runs from logs/*.log into one sorted/grouped table.
|
|
|
|
Durable source: each run writes logs/<ts>_<preset>_<arm>_seed<seed>_<tag>.log
|
|
with an `argv:` line (config) and per-step rows. We parse those directly and
|
|
recompute the metrics ourselves, so this survives `pueue reset` and doesn't
|
|
depend on the BLUF line.
|
|
|
|
Headline metric is mean-of-last-5-steps (noise-robust; the converged regime),
|
|
shown for BOTH hack_s (reward hacks) and gt_s (ground-truth solves) on the
|
|
STUDENT rollouts. Whole-run means are kept as a secondary column because the
|
|
blog Table 1 uses whole-run and the two conventions disagree.
|
|
|
|
just results # full table sorted by time + grouped-by-config
|
|
"""
|
|
from __future__ import annotations
|
|
import re
|
|
from pathlib import Path
|
|
import polars as pl
|
|
from tabulate import tabulate
|
|
|
|
LOG_DIR = Path("logs")
|
|
TS_RE = re.compile(r"(\d{8}T\d{6})")
|
|
# Column positions are read from the header row by NAME, not hardcoded -- the
|
|
# per-step table layout has changed over time (sprd/N dropped, cin/cout/hk_dep
|
|
# added) so fixed indices silently mis-read newer logs and crash on smoke logs.
|
|
|
|
|
|
def _colname(tok: str) -> str:
|
|
# header tokens carry direction glyphs / markers: "gt_s↑", "hack_s?" -> "gt_s", "hack_s"
|
|
return re.sub(r"[^a-z0-9_]", "", tok.lower())
|
|
|
|
|
|
def _frac(tok: str) -> float | None:
|
|
a, b = tok.split("/")
|
|
return int(a) / int(b) if int(b) else None
|
|
|
|
|
|
def _cfg(argv: str, preset_line: str) -> dict:
|
|
def grab(pat, s, default="-"):
|
|
# LAST match wins: recipes set a default flag then runs override it
|
|
# (e.g. --v-hack-path twice, --mix-ratio twice); tyro takes the last.
|
|
ms = re.findall(pat, s)
|
|
return ms[-1] if ms else default
|
|
return dict(
|
|
# arm is the derived display name printed in the preset line
|
|
# (vanilla/projected/routing). Read it from there, not the CLI flag:
|
|
# old logs passed --arm, new logs pass --intervention, but BOTH print
|
|
# `arm=<name>` in the preset line, so this one source covers all runs.
|
|
arm=grab(r"\barm=(\w+)", preset_line),
|
|
preset=grab(r"preset=(\w+)", preset_line),
|
|
model=grab(r"model=(\S+)", preset_line),
|
|
seed=grab(r"seed=(\d+)", preset_line, "?"), # preset= line always prints it
|
|
mix=grab(r"--mix-ratio=([\d.]+)", argv, "0.5"),
|
|
refr=grab(r"--vhack-refresh-every=(\d+)", argv),
|
|
over=grab(r"--project-overshoot=([\d.]+)", argv, "1.0"),
|
|
gate=grab(r"--gate-mode=(\w+)", argv, "one_sided"),
|
|
k=grab(r"--v-hack-k=(\d+)", argv, "5"),
|
|
dropf=grab(r"--v-hack-drop-bottom-frac=([\d.]+)", argv, "0.25"),
|
|
vhack=grab(r"v-hack-path=out/(?:vhack/)?(\S+?)\.safetensors", argv),
|
|
tag=grab(r"--out-tag=(\S+)", argv, ""),
|
|
# full CLI args (after train.py) — the ground-truth provenance; any flag
|
|
# not parsed into a column above is still visible here.
|
|
argv=argv.split("train.py ", 1)[-1].strip() if "train.py " in argv else argv.strip(),
|
|
)
|
|
|
|
|
|
def parse_log(path: Path) -> dict | None:
|
|
txt = path.read_text(errors="replace")
|
|
argv = next((l for l in txt.splitlines() if "argv:" in l), None)
|
|
preset_line = next((l for l in txt.splitlines() if "preset=" in l and "arm=" in l), "")
|
|
if argv is None:
|
|
return None
|
|
# Locate the per-step table header to map gt_s/hack_s columns by NAME. The
|
|
# train.py streaming table is the INFO line whose tokens start with "step"
|
|
# and include "ref_eq" -- that signature excludes the old distill_* logs
|
|
# which also have "step ..." lines but a different (hack=.. pass=..) format.
|
|
header, names = None, []
|
|
for l in txt.splitlines():
|
|
if "| INFO |" not in l:
|
|
continue
|
|
toks = [_colname(t) for t in l.split("| INFO |", 1)[1].split()]
|
|
if toks[:1] == ["step"] and "ref_eq" in toks:
|
|
header, names = l, toks
|
|
break
|
|
if header is None:
|
|
return None # not a train.py streaming run
|
|
idx_hack, idx_gt = names.index("hack_s"), names.index("gt_s")
|
|
hs, gts = [], []
|
|
for line in txt.splitlines():
|
|
if "| INFO |" not in line:
|
|
continue
|
|
row = line.split("| INFO |", 1)[1].split()
|
|
if not row or not row[0].isdigit() or len(row) <= idx_hack:
|
|
continue
|
|
h, g = _frac(row[idx_hack]), _frac(row[idx_gt])
|
|
if h is not None:
|
|
hs.append(h)
|
|
if g is not None:
|
|
gts.append(g)
|
|
if not hs:
|
|
return None
|
|
cfg = _cfg(argv, preset_line)
|
|
# GROUND TRUTH mix: train.py prints `mix_ratio=<x>` in the pool INFO line
|
|
# (what the run actually used). Many runs rely on the preset default and
|
|
# pass no --mix-ratio flag, so the argv-based grab in _cfg defaults to the
|
|
# wrong value (0.5) and mis-keys them. Override with the printed value.
|
|
m_mix = re.search(r"mix_ratio=([\d.]+)", txt)
|
|
if m_mix:
|
|
cfg["mix"] = m_mix.group(1)
|
|
if "tiny-random" in cfg["model"] or cfg["preset"] == "smoke":
|
|
return None # CPU smoke runs, not real results
|
|
if "probe" in cfg["tag"]:
|
|
return None # early feasibility / lr-sweep probes, not comparable baselines
|
|
# Exclude in-progress / aborted runs: a partial log has only the early
|
|
# (low-hack) steps, which would read as an impossibly-good result. A run is
|
|
# complete when it logged all `steps` per-step rows.
|
|
m = re.search(r"steps=(\d+)", preset_line)
|
|
if m and len(hs) < int(m.group(1)):
|
|
return None
|
|
ts = TS_RE.search(path.name)
|
|
mean = lambda v: sum(v) / len(v) if v else None
|
|
cfg.pop("model")
|
|
return dict(
|
|
time=ts.group(1) if ts else "?",
|
|
**cfg,
|
|
L5_hack=mean(hs[-5:]), L5_solve=mean(gts[-5:]),
|
|
WH_hack=mean(hs), n=len(hs),
|
|
log=path.name, # provenance: every number traces back to this file
|
|
)
|
|
|
|
|
|
def main() -> None:
|
|
rows = [r for p in sorted(LOG_DIR.glob("*.log")) if (r := parse_log(p))]
|
|
if not rows:
|
|
print("no parseable runs in logs/")
|
|
return
|
|
df = pl.DataFrame(rows).sort("time")
|
|
|
|
cols = ["arm", "seed", "mix", "refr", "over", "gate", "k", "dropf",
|
|
"vhack", "L5_hack", "L5_solve", "WH_hack", "n", "log"]
|
|
print("\n## All runs (sorted by time)\n")
|
|
print(tabulate(df.select(cols).rows(), headers=cols, tablefmt="pipe", floatfmt=".3f"))
|
|
|
|
# Grouped by config (collapse seeds): mean +/- std across seeds. Key on
|
|
# every config dim that changes the experiment so non-comparable runs
|
|
# don't merge. std is null for n=1 (undefined).
|
|
key = ["arm", "mix", "refr", "over", "gate", "k", "dropf", "vhack"]
|
|
g = (df.group_by(key)
|
|
.agg(pl.col("L5_hack").mean().alias("hack"),
|
|
pl.col("L5_hack").std().alias("hack_sd"),
|
|
pl.col("L5_solve").mean().alias("solve"),
|
|
pl.col("L5_solve").std().alias("solve_sd"),
|
|
pl.len().alias("n"),
|
|
pl.col("seed").sort().str.join(",").alias("seeds"))
|
|
.sort(["mix", "arm", "refr", "over", "gate", "k"]))
|
|
gcols = key + ["hack", "hack_sd", "solve", "solve_sd", "n", "seeds"]
|
|
print("\n## Grouped by config (mean +/- std over seeds)\n")
|
|
print(tabulate(g.select(gcols).rows(), headers=gcols, tablefmt="pipe", floatfmt=".3f"))
|
|
|
|
# Paired vs same-seed vanilla (matched mix): the only honest way to read a
|
|
# delta. Join each projected run to the vanilla run at the SAME (mix, seed),
|
|
# take per-seed deltas, then mean +/- std of the delta over shared seeds.
|
|
van = (df.filter(pl.col("arm") == "vanilla")
|
|
.select(["mix", "seed", "L5_hack", "L5_solve"])
|
|
.rename({"L5_hack": "v_hack", "L5_solve": "v_solve"}))
|
|
# Both intervention arms compare against the same-seed vanilla. routing is a
|
|
# first-class arm now, so include it (keyed on `arm` below so it doesn't
|
|
# merge with projected). NOTE: routing's L5_hack here is the TRAINING-time
|
|
# hack (the routed forward still hacks); the deployment number is the
|
|
# deploy-eval (ROUTE EVAL BLUF / hack_deploy), not this column.
|
|
j = (df.filter(pl.col("arm").is_in(["projected", "routing"]))
|
|
.join(van, on=["mix", "seed"], how="inner")
|
|
.with_columns((pl.col("L5_hack") - pl.col("v_hack")).alias("dh"),
|
|
(pl.col("L5_solve") - pl.col("v_solve")).alias("ds")))
|
|
pkey = ["arm", "mix", "refr", "over", "gate", "k", "vhack"]
|
|
pj = (j.group_by(pkey)
|
|
.agg(pl.col("dh").mean().alias("Dhack"),
|
|
pl.col("dh").std().alias("Dhack_sd"),
|
|
pl.col("ds").mean().alias("Dsolve"),
|
|
pl.len().alias("n"),
|
|
pl.col("seed").sort().str.join(",").alias("shared_seeds"))
|
|
.sort(["mix", "vhack", "refr", "gate", "over"]))
|
|
pcols = pkey + ["Dhack", "Dhack_sd", "Dsolve", "n", "shared_seeds"]
|
|
print("\n## Paired delta vs same-seed vanilla (matched mix; negative = less hacking)\n")
|
|
print(tabulate(pj.select(pcols).rows(), headers=pcols, tablefmt="pipe", floatfmt="+.3f"))
|
|
|
|
|
|
if __name__ == "__main__":
|
|
main()
|