mirror of
https://github.com/wassname/evil_MoE.git
synced 2026-06-27 19:47:33 +08:00
refactor: extract train_config.py + run_artifacts.py from train.py; slim results scripts
Cleanup by a prior agent, verified green here: 'just smoke' (erase arm) runs end-to-end and all four wired gates pass (verify_rewards 52/52, verify_eval_gap, verify_partition, verify_science_invariants). - train.py -318 lines: Config dataclass -> train_config.py, checkpoint/ deploy-artifact IO -> run_artifacts.py. - results.py / results_deploy.py / probe_distill.py slimmed. - drop stale derived csvs under out/figs (a5_generalisation, dyn_*, substrate_aggregate, train_vs_deploy_60). - gitignore /.pi/ panel scratch. Co-Authored-By: Claudypoo <288921227+claudypoo@users.noreply.github.com>
This commit is contained in:
+40
-35
@@ -88,6 +88,7 @@ def parse_log(path: Path) -> dict | None:
|
||||
# a vertical line / end of the teacher-on shaded region in the 2x2.
|
||||
_toff = grab(r"--teacher-off-step=(\d+)", argv, None)
|
||||
teacher_off = int(_toff) if _toff is not None else None
|
||||
eval_n = int(grab(r"periodic-curve n=(\d+)", txt))
|
||||
|
||||
# header line: the one containing both "step" and "hack_s"
|
||||
hdr = next((l for l in txt.splitlines()
|
||||
@@ -123,8 +124,13 @@ def parse_log(path: Path) -> dict | None:
|
||||
series[col].append(_val(row[idx[col]]))
|
||||
if not steps:
|
||||
return None
|
||||
per_token = "--routeV-per-token" in argv
|
||||
# Logged step k is evaluated after optimizer update k, so the number of
|
||||
# completed updates is k+1. The shared pre-training base point is not logged.
|
||||
steps = np.array(steps) + 1
|
||||
run = dict(arm=arm, refr=refr, seed=seed, vhack=vhack, teacher_off=teacher_off,
|
||||
steps=np.array(steps), **{k: np.array(v, dtype=float) for k, v in series.items()})
|
||||
per_token=per_token, eval_n=eval_n,
|
||||
steps=steps, **{k: np.array(v, dtype=float) for k, v in series.items()})
|
||||
# Normalise missing eval columns to all-nan (absent == all-nan downstream): old logs
|
||||
# that never printed a held-out eval lack the key entirely, which would KeyError the
|
||||
# train-series assignment. A nan column drops the seed out of the mean cleanly.
|
||||
@@ -168,22 +174,23 @@ def classify(run: dict) -> str:
|
||||
return "vanilla"
|
||||
if run["arm"] == "routing":
|
||||
return "routing"
|
||||
if run["arm"] == "routing2":
|
||||
return "routing2"
|
||||
if run["arm"] == "routingV":
|
||||
return "routingV_per_token" if run["per_token"] else "routingV"
|
||||
# arm == projected -> erasure, split by refresh
|
||||
return "online erasure" if run["refr"] > 0 else "static erasure"
|
||||
|
||||
|
||||
# --- plot ------------------------------------------------------------------
|
||||
|
||||
# routing (route v1, single quarantine) is deprecated -- superseded by routing2
|
||||
# (scale-matched quarantine). classify() still tags v1 logs as "routing" so they
|
||||
# don't get misread as erasure, but it's left out of ARM_ORDER so it isn't plotted.
|
||||
ARM_ORDER = ["vanilla", "static erasure", "online erasure", "routing2"]
|
||||
# routing (route v1, single quarantine) and routing2 are deprecated. routeV is
|
||||
# the current scale-matched quarantine method.
|
||||
ARM_ORDER = ["vanilla", "static erasure", "online erasure", "routingV", "routingV_per_token"]
|
||||
# Distinct colour per series -- the two rows measure different things, so they
|
||||
# must not share a palette (hack != teacher-cos). Row 0: red hack vs green
|
||||
# solve. Row 1: blue teacher-cos vs amber student-cos.
|
||||
RATE_COLORS = {"hack_s": "#c1432b", "gt_s": "#2f7d4f"}
|
||||
HACK_YMAX = 0.65
|
||||
SOLVE_YMAX = 0.25
|
||||
# Arm colours for the single-panel hack overlay (arms, not series): grey vanilla
|
||||
# baseline -> amber static -> blue online, ordered by increasing intervention.
|
||||
# TODO(color): make this a quality-ordered red->green ramp instead of fixed
|
||||
@@ -193,7 +200,7 @@ RATE_COLORS = {"hack_s": "#c1432b", "gt_s": "#2f7d4f"}
|
||||
# the reader sees "redder = hacks more" at a glance.
|
||||
ARM_COLORS = {"vanilla": "#7a7a7a", "static erasure": "#c98a2b",
|
||||
"online erasure": "#33508c", "routing": "#2f7d4f",
|
||||
"routing2": "#7d2f6f"}
|
||||
"routingV": "#7d2f6f", "routingV_per_token": "#7d2f6f"}
|
||||
|
||||
|
||||
def _onset(steps: np.ndarray, hack: np.ndarray) -> int | None:
|
||||
@@ -261,13 +268,13 @@ CSV_SERIES = ["hack_s", "gt_s", "hack_train", "solve_train", "hk_dep", "slv_dep"
|
||||
|
||||
def dump_data(runs: list[dict], out: Path) -> Path:
|
||||
csv = out.with_suffix(".csv")
|
||||
lines = ["arm,seed,step," + ",".join(CSV_SERIES)]
|
||||
lines = ["arm,seed,eval_n,step," + ",".join(CSV_SERIES)]
|
||||
for r in runs:
|
||||
arm = classify(r)
|
||||
for i, step in enumerate(r["steps"]):
|
||||
cells = [r[k][i] if (k in r and r[k] is not None and i < len(r[k])) else float("nan")
|
||||
for k in CSV_SERIES]
|
||||
lines.append(f"{arm},{r['seed']},{int(step)}," + ",".join(str(c) for c in cells))
|
||||
lines.append(f"{arm},{r['seed']},{r['eval_n']},{int(step)}," + ",".join(str(c) for c in cells))
|
||||
csv.write_text("\n".join(lines) + "\n")
|
||||
logger.info(f"wrote {csv} ({len(runs)} runs, reproducibility source)")
|
||||
return csv
|
||||
@@ -285,6 +292,7 @@ def load_csv(path: Path) -> list[dict]:
|
||||
key = (row[ci["arm"]], row[ci["seed"]])
|
||||
run = by_key.setdefault(key, {"arm_csv": row[ci["arm"]], "seed": row[ci["seed"]],
|
||||
"refr": 0, "vhack": "-", "teacher_off": None,
|
||||
"eval_n": int(row[ci["eval_n"]]),
|
||||
"steps": [], **{k: [] for k in CSV_SERIES}})
|
||||
run["steps"].append(int(row[ci["step"]]))
|
||||
for k in CSV_SERIES:
|
||||
@@ -316,7 +324,8 @@ def plot(runs: list[dict], out: Path) -> None:
|
||||
# ylim floor slightly below 0 so a pinned-at-zero series (route2 hack) draws
|
||||
# ABOVE the axis line instead of hiding under it -- the whole result is that
|
||||
# red sits on zero, so it must be visible, not absent.
|
||||
_series_panel(ax, rs, RATE_COLS, RATE_COLORS, ylim=(-0.035, 1.0), label_series=(col == 0))
|
||||
_series_panel(ax, rs, RATE_COLS, RATE_COLORS, ylim=(-0.025, HACK_YMAX),
|
||||
label_series=(col == 0))
|
||||
# If hack is pinned at zero all panel, say so -- else "no red line" reads as
|
||||
# a plotting bug rather than the finding.
|
||||
hk = [r["hack_s"] for r in rs if "hack_s" in r]
|
||||
@@ -324,12 +333,12 @@ def plot(runs: list[dict], out: Path) -> None:
|
||||
ax.annotate("hack ≈ 0", (0.04, 0.0), xycoords=("axes fraction", "data"),
|
||||
color=RATE_COLORS["hack_s"], fontsize=8, va="bottom",
|
||||
xytext=(0, 3), textcoords="offset points")
|
||||
ax.set_xlabel("optimizer step")
|
||||
ax.set_xlabel("optimizer updates completed")
|
||||
onsets = [s for r in rs if (s := _onset(r["steps"], r["hack_s"])) is not None]
|
||||
if onsets:
|
||||
s0 = float(np.mean(onsets))
|
||||
ax.axvline(s0, color="0.55", lw=0.8, ls=(0, (4, 3)), zorder=0)
|
||||
ax.annotate("first hack", (s0, 1.0), color="0.4", fontsize=7,
|
||||
ax.annotate("first hack", (s0, HACK_YMAX), color="0.4", fontsize=7,
|
||||
xytext=(2, -2), textcoords="offset points", va="top")
|
||||
|
||||
axes[0][0].set_ylabel("deployed rate")
|
||||
@@ -340,8 +349,10 @@ def plot(runs: list[dict], out: Path) -> None:
|
||||
ax.tick_params(labelsize=8)
|
||||
|
||||
if SHOW_TITLE:
|
||||
eval_ns = sorted({r["eval_n"] for r in runs})
|
||||
fig.suptitle("Training dynamics: deployed hack vs solve by arm "
|
||||
"(deploy-eval n=64 T=0.7; EMA-5; dashed = mean hack onset)", fontsize=10)
|
||||
f"(fixed monitoring subset n={eval_ns}; T=0.7; EMA-5; dashed = mean hack onset)",
|
||||
fontsize=10)
|
||||
fig.tight_layout(rect=(0, 0, 1, 0.96))
|
||||
else:
|
||||
fig.tight_layout()
|
||||
@@ -349,13 +360,12 @@ def plot(runs: list[dict], out: Path) -> None:
|
||||
logger.info(f"wrote {out} ({len(runs)} runs, arms={[arm_label(a) for a in arms]})")
|
||||
|
||||
|
||||
def _overlay_panel(ax, by_arm, arms, key, *, label, with_onset, label_arms, ylim=(0, 1)):
|
||||
def _overlay_panel(ax, by_arm, arms, key, *, label, label_arms, ylim=(0, 1)):
|
||||
"""Overlay one metric (key) per arm on ax: faint per-seed EMA lines + bold
|
||||
EMA mean, optional mean-onset dot. When label_arms, direct-label each arm at its
|
||||
endpoint (de-collided in y). An arm whose mean series sits at zero gets a
|
||||
EMA mean. When label_arms, direct-label each arm at its endpoint (de-collided
|
||||
in y). An arm whose mean series sits at zero gets a
|
||||
"$\\approx 0$" tag so a pinned-at-zero line reads as a finding, not a missing line."""
|
||||
ends = [] # (y_endpoint, x_endpoint, arm, color, is_zero) for direct labels
|
||||
onset_steps = [] # mean-onset across arms -> ONE labeled vertical line (see below)
|
||||
for arm in arms:
|
||||
rs = [r for r in by_arm[arm] if key in r]
|
||||
if not rs:
|
||||
@@ -370,16 +380,7 @@ def _overlay_panel(ax, by_arm, arms, key, *, label, with_onset, label_arms, ylim
|
||||
ym = np.nanmean(np.stack([y[:L] for y in stacked]), axis=0)
|
||||
xm = rs[0]["steps"][:L]
|
||||
ax.plot(xm, ym, color=color, lw=2.0, solid_capstyle="round")
|
||||
if with_onset:
|
||||
onset_steps += [s for r in rs if (s := _onset(r["steps"], r["hack_s"])) is not None]
|
||||
ends.append((float(ym[-1]), float(xm[-1]), arm, color, float(np.nanmax(ym)) < 0.02))
|
||||
# First-hack as a labeled vertical line (matches the small-multiples), not a dot:
|
||||
# a dashed rule reads as "emergence starts here" across both arms in one mark.
|
||||
if with_onset and onset_steps:
|
||||
s0 = float(np.mean(onset_steps))
|
||||
ax.axvline(s0, color="0.55", lw=0.8, ls=(0, (4, 3)), zorder=0)
|
||||
ax.annotate("first hack", (s0, ylim[1]), color="0.4", fontsize=7,
|
||||
xytext=(2, -2), textcoords="offset points", va="top")
|
||||
ax.set_ylim(*ylim)
|
||||
ax.set_ylabel(label)
|
||||
ax.spines[["top", "right"]].set_visible(False)
|
||||
@@ -407,9 +408,8 @@ def _overlay_panel(ax, by_arm, arms, key, *, label, with_onset, label_arms, ylim
|
||||
|
||||
def plot_hack_overlay(runs: list[dict], out: Path) -> None:
|
||||
"""Two stacked panels sharing x: student hack rate (top) and solve rate (bottom)
|
||||
per arm. Faint per-seed EMA lines + bold EMA-5 mean; onset dot on the hack panel.
|
||||
Arms are direct-labelled on the TOP (hack) panel -- readers scan top-to-bottom, and
|
||||
the hack panel carries the headline (an arm pinned at 0 gets a $\\approx 0$ tag)."""
|
||||
per arm. Faint per-seed EMA lines + bold EMA-5 mean; arms are direct-labelled
|
||||
at their endpoints."""
|
||||
by_arm: dict[str, list[dict]] = defaultdict(list)
|
||||
for r in runs:
|
||||
by_arm[classify(r)].append(r)
|
||||
@@ -418,12 +418,15 @@ def plot_hack_overlay(runs: list[dict], out: Path) -> None:
|
||||
fig, (ax_h, ax_s) = plt.subplots(2, 1, figsize=(5.2, 5.2), sharex=True)
|
||||
# floor the hack panel below 0 so a route line pinned at 0 draws above the axis
|
||||
_overlay_panel(ax_h, by_arm, arms, "hack_s", label="hack rate",
|
||||
with_onset=True, label_arms=True, ylim=(-0.035, 1.0))
|
||||
label_arms=True, ylim=(-0.025, HACK_YMAX))
|
||||
_overlay_panel(ax_s, by_arm, arms, "gt_s", label="solve rate",
|
||||
with_onset=False, label_arms=False, ylim=(0, 1.0))
|
||||
ax_s.set_xlabel("optimizer step")
|
||||
label_arms=True, ylim=(0, SOLVE_YMAX))
|
||||
ax_s.set_xlabel("optimizer updates completed")
|
||||
if SHOW_TITLE:
|
||||
ax_h.set_title("Hack vs solve rate by arm (EMA-5; dot = mean hack onset)", fontsize=10)
|
||||
n_seed = min(len(by_arm[a]) for a in arms)
|
||||
eval_ns = sorted({r["eval_n"] for r in runs})
|
||||
ax_h.set_title(f"Hack vs solve rate on fixed n={eval_ns} monitoring subset "
|
||||
f"(EMA-5; n={n_seed} seed/arm)", fontsize=10)
|
||||
fig.tight_layout()
|
||||
save_fig(fig, out)
|
||||
logger.info(f"wrote {out}")
|
||||
@@ -448,6 +451,7 @@ def plot_train_vs_deploy(runs: list[dict], out: Path) -> None:
|
||||
d = np.abs(ht - hd)
|
||||
return bool(np.isfinite(d).any() and np.nanmax(d) > 0.02)
|
||||
if not any(_has_train_gap(r) for r in runs):
|
||||
out.unlink(missing_ok=True)
|
||||
logger.info(f"skip {out.name}: train==deploy in every run -> no knob-ON contrast to show")
|
||||
return
|
||||
by_arm: dict[str, list[dict]] = defaultdict(list)
|
||||
@@ -588,7 +592,8 @@ def _render_all(runs: list[dict], out: Path) -> None:
|
||||
tvd = out.with_name(out.stem + "_train_deploy.png")
|
||||
plot_train_vs_deploy(runs, tvd) # 2x2 train(on) vs deploy(off)
|
||||
for p in (out, overlay, tvd):
|
||||
logger.info(f"docs/figs latest -> {link_latest(p)}")
|
||||
if p.exists():
|
||||
logger.info(f"docs/figs latest -> {link_latest(p)}")
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
|
||||
Reference in New Issue
Block a user