mirror of
https://github.com/wassname/evil_MoE.git
synced 2026-06-27 18:59:35 +08:00
fix(plots): drop deprecated routing arm; plot_substrate reads per-batch counts
- plot_dynamics: routing (route v1) out of ARM_ORDER -- superseded by routing2. - plot_substrate: per-mode hk_* are now plain per-batch counts (streaming log dropped the /denominator); parse the count, plot it (EMA or cumsum); skip old n/d-format logs (incompatible units). Y-axis hacks/batch, count annotations. Co-Authored-By: Claudypoo <288921227+claudypoo@users.noreply.github.com>
This commit is contained in:
@@ -156,7 +156,10 @@ def classify(run: dict) -> str:
|
||||
|
||||
# --- plot ------------------------------------------------------------------
|
||||
|
||||
ARM_ORDER = ["vanilla", "static erasure", "online erasure", "routing", "routing2"]
|
||||
# routing (route v1, single quarantine) is deprecated -- superseded by routing2
|
||||
# (scale-matched quarantine). classify() still tags v1 logs as "routing" so they
|
||||
# don't get misread as erasure, but it's left out of ARM_ORDER so it isn't plotted.
|
||||
ARM_ORDER = ["vanilla", "static erasure", "online erasure", "routing2"]
|
||||
# Distinct colour per series -- the two rows measure different things, so they
|
||||
# must not share a palette (hack != teacher-cos). Row 0: red hack vs green
|
||||
# solve. Row 1: blue teacher-cos vs amber student-cos.
|
||||
|
||||
+18
-19
@@ -80,18 +80,22 @@ def parse_hk(path: Path) -> dict | None:
|
||||
names = [_HDR_TOK.match(t).group(0) for t in hdr.split("| INFO |", 1)[1].split()]
|
||||
idx = {n: i for i, n in enumerate(names)}
|
||||
present = [k for k in HK if k in idx] # 4-mode substrate dropped hk_eq; plot only what's logged
|
||||
steps, nd = [], {k: ([], []) for k in present}
|
||||
steps, counts = [], {k: [] for k in present}
|
||||
for line in txt.splitlines():
|
||||
if "| INFO |" not in line:
|
||||
continue
|
||||
row = line.split("| INFO |", 1)[1].split()
|
||||
if not row or not row[0].isdigit() or len(row) < len(names):
|
||||
continue
|
||||
# hk_<mode> is now the per-batch hack COUNT (current step, not cumulative,
|
||||
# no /denominator) -- the streaming log dropped the rollout denominator, so
|
||||
# a per-mode RATE is no longer recoverable. We plot the count directly. Old
|
||||
# logs print "n/d" (cumulative): incompatible units, skip the whole log.
|
||||
if "/" in row[idx[present[0]]]:
|
||||
return None
|
||||
steps.append(int(row[idx["step"]]))
|
||||
for k in present:
|
||||
n, d = row[idx[k]].split("/")
|
||||
nd[k][0].append(int(n))
|
||||
nd[k][1].append(int(d))
|
||||
counts[k].append(int(row[idx[k]]))
|
||||
if not steps:
|
||||
return None # header present but no parseable per-step rows (e.g. diverged/aborted)
|
||||
m = re.search(r"seed(\d+)", path.name) or re.search(r"_s(\d+)", path.name)
|
||||
@@ -99,7 +103,7 @@ def parse_hk(path: Path) -> dict | None:
|
||||
method=classify(txt),
|
||||
seed=m.group(1) if m else "?",
|
||||
steps=np.array(steps),
|
||||
**{k: (np.array(v[0]), np.array(v[1])) for k, v in nd.items()},
|
||||
**{k: np.array(v) for k, v in counts.items()},
|
||||
)
|
||||
|
||||
|
||||
@@ -117,17 +121,12 @@ def ema(y: np.ndarray, span: int) -> np.ndarray:
|
||||
return out
|
||||
|
||||
|
||||
def rate(n: np.ndarray, d: np.ndarray, *, cumulative: bool, span: int) -> np.ndarray:
|
||||
"""Per-step hack rate of one mode. cumulative=running N/M; else EMA of the
|
||||
instantaneous batch rate dN/dM (NaN where the batch saw no rollouts of this mode)."""
|
||||
def rate(count: np.ndarray, *, cumulative: bool, span: int) -> np.ndarray:
|
||||
"""Per-step hacks of one mode. count is the per-batch hack count (instantaneous).
|
||||
cumulative=running total (cumsum); else EMA-smoothed per-batch count."""
|
||||
if cumulative:
|
||||
return np.where(d > 0, n / np.where(d == 0, 1, d), np.nan)
|
||||
dn = np.diff(n)
|
||||
dd = np.diff(d)
|
||||
inst = np.empty(len(n))
|
||||
inst[0] = n[0] / d[0] if d[0] > 0 else np.nan
|
||||
inst[1:] = np.where(dd > 0, dn / np.where(dd == 0, 1, dd), np.nan)
|
||||
return ema(inst, span)
|
||||
return np.cumsum(count)
|
||||
return ema(count.astype(float), span)
|
||||
|
||||
|
||||
def _despine(ax):
|
||||
@@ -154,14 +153,14 @@ def plot_by_method(runs, ylabel, cumulative, span, out: Path):
|
||||
n_learned = 0
|
||||
for k in modes:
|
||||
mode, color = HK[k]
|
||||
stk = np.stack([rate(*r[k], cumulative=cumulative, span=span)[:L] for r in grp])
|
||||
stk = np.stack([rate(r[k], cumulative=cumulative, span=span)[:L] for r in grp])
|
||||
ymean = np.nanmean(stk, axis=0)
|
||||
for ys in stk if len(grp) > 1 else []:
|
||||
ax.plot(x, ys, color=color, lw=0.6, alpha=0.30)
|
||||
ax.plot(x, ymean, color=color, lw=1.8, solid_capstyle="round")
|
||||
on = _onset(x, ymean)
|
||||
n_learned += on is not None
|
||||
ax.annotate(f"{mode} {np.nan_to_num(ymean[-1])*100:.0f}%", (x[-1], np.nan_to_num(ymean[-1])),
|
||||
ax.annotate(f"{mode} {np.nan_to_num(ymean[-1]):.1f}", (x[-1], np.nan_to_num(ymean[-1])),
|
||||
color=color, fontsize=7, va="center", xytext=(5, 0), textcoords="offset points")
|
||||
label, _, dashed = METHODS[method]
|
||||
ax.set_title(f"{label} ({n_learned}/{len(modes)} learned)", fontsize=9)
|
||||
@@ -192,7 +191,7 @@ def plot_by_hack(runs, ylabel, cumulative, span, out: Path):
|
||||
grp = [r for r in runs if r["method"] == method]
|
||||
L = min(len(r["steps"]) for r in grp)
|
||||
x = grp[0]["steps"][:L]
|
||||
stk = np.stack([rate(*r[k], cumulative=cumulative, span=span)[:L] for r in grp])
|
||||
stk = np.stack([rate(r[k], cumulative=cumulative, span=span)[:L] for r in grp])
|
||||
ymean = np.nanmean(stk, axis=0)
|
||||
label, color, dashed = METHODS[method]
|
||||
for ys in stk if len(grp) > 1 else []:
|
||||
@@ -249,7 +248,7 @@ def main() -> None:
|
||||
if not runs:
|
||||
raise SystemExit("no substrate runs in the glob (need hk_rt columns)")
|
||||
logger.info(f"parsed {len(runs)} runs: " + ", ".join(f"{r['method']}/s{r['seed']}" for r in runs))
|
||||
ylabel = "cumulative hack rate" if args.cumulative else f"hack rate (EMA span {args.ema_span})"
|
||||
ylabel = "cumulative hacks" if args.cumulative else f"hacks/batch (EMA span {args.ema_span})"
|
||||
plot_by_method(runs, ylabel, args.cumulative, args.ema_span, stem.with_name(stem.name + "_by_method.png"))
|
||||
plot_by_hack(runs, ylabel, args.cumulative, args.ema_span, stem.with_name(stem.name + "_by_hack.png"))
|
||||
|
||||
|
||||
Reference in New Issue
Block a user