mirror of
https://github.com/wassname/evil_MoE.git
synced 2026-06-27 17:48:43 +08:00
plot_deploy_overlay: aggregate seeds per arm, std error bars (n>1)
Groups per_mode_deploy.json by arm into a list, plots mean+/-std across seeds. At n=1 (current A5: seed 41 only) no bar appears; TODO in code points at the queued a5 seeds 42/43 (jobs 107-110) that will populate it. Bar labels show n. Co-Authored-By: Claudypoo <288921227+claudypoo@users.noreply.github.com>
This commit is contained in:
@@ -21,6 +21,7 @@ from __future__ import annotations
|
||||
|
||||
import argparse
|
||||
import json
|
||||
from collections import defaultdict
|
||||
from pathlib import Path
|
||||
|
||||
import matplotlib.pyplot as plt
|
||||
@@ -59,16 +60,25 @@ def _despine(ax):
|
||||
ax.grid(axis="y", lw=0.4, alpha=0.35)
|
||||
|
||||
|
||||
def _panel(ax, records, modes, arms, field, title, ylabel):
|
||||
"""Grouped bars: x = mode, one bar per arm, height = records[arm].by_mode[mode][field]."""
|
||||
def _panel(ax, by_arm, modes, arms, field, title, ylabel):
|
||||
"""Grouped bars: x = mode, one bar per arm, height = mean over seed runs of
|
||||
by_mode[mode][field]; error bar = std across seeds (drawn only when >1 seed).
|
||||
TODO(seeds): A5 currently ships n=1 (seed 41 only, jobs 103/104) so no error
|
||||
bar appears. Pass per-seed JSONs (a5 vanilla+route2 seeds 42/43, queued) to
|
||||
populate the error bars -- the code already aggregates them."""
|
||||
w = 0.8 / len(arms)
|
||||
x = np.arange(len(modes))
|
||||
for i, arm in enumerate(arms):
|
||||
rec = next(r for r in records if r["arm"] == arm)
|
||||
recs = by_arm[arm]
|
||||
label, color = ARM[arm]
|
||||
vals = [rec["by_mode"].get(m, {}).get(field, np.nan) for m in modes]
|
||||
bars = ax.bar(x + i * w, vals, w, label=label, color=color)
|
||||
for b, v in zip(bars, vals):
|
||||
per_mode = [[r["by_mode"].get(m, {}).get(field, np.nan) for r in recs] for m in modes]
|
||||
means = np.array([np.nanmean(v) for v in per_mode])
|
||||
stds = np.array([np.nanstd(v) if len(v) > 1 else 0.0 for v in per_mode])
|
||||
n_seed = len(recs)
|
||||
yerr = stds if (stds > 0).any() else None
|
||||
bars = ax.bar(x + i * w, means, w, label=f"{label} (n={n_seed})", color=color,
|
||||
yerr=yerr, capsize=2, error_kw=dict(lw=0.8, alpha=0.8))
|
||||
for b, v in zip(bars, means):
|
||||
if not np.isnan(v):
|
||||
ax.annotate(f"{v:.2f}", (b.get_x() + b.get_width() / 2, v), fontsize=6,
|
||||
ha="center", va="bottom", color=color)
|
||||
@@ -93,16 +103,17 @@ def main() -> None:
|
||||
if not paths:
|
||||
raise SystemExit("no per_mode_deploy.json found (run the sweep first)")
|
||||
records = load(paths)
|
||||
# dedupe arms (keep latest by file order), then order canonically
|
||||
by_arm = {r["arm"]: r for r in records}
|
||||
# group seed runs per arm (mean+/-std bars), order arms canonically
|
||||
by_arm: dict[str, list[dict]] = defaultdict(list)
|
||||
for r in records:
|
||||
by_arm[r["arm"]].append(r)
|
||||
arms = [a for a in ARM if a in by_arm]
|
||||
records = [by_arm[a] for a in arms]
|
||||
modes = [m for m in MODE_ORDER if any(m in r["by_mode"] for r in records)]
|
||||
|
||||
fig, (a1, a2) = plt.subplots(1, 2, figsize=(5.5 + 1.2 * len(modes), 4.2))
|
||||
_panel(a1, records, modes, arms, "deploy_hack",
|
||||
_panel(a1, by_arm, modes, arms, "deploy_hack",
|
||||
"DEPLOY hack rate by mode (lower = better)", "deploy hack rate")
|
||||
_panel(a2, records, modes, arms, "deploy_solve",
|
||||
_panel(a2, by_arm, modes, arms, "deploy_solve",
|
||||
"DEPLOY solve rate by mode (higher = better)", "deploy solve rate")
|
||||
a1.legend(fontsize=8, frameon=False, loc="upper right")
|
||||
if args.title:
|
||||
|
||||
Reference in New Issue
Block a user