diff --git a/scripts/plot_dynamics.py b/scripts/plot_dynamics.py index 7492024..5a9a1cd 100644 --- a/scripts/plot_dynamics.py +++ b/scripts/plot_dynamics.py @@ -283,42 +283,65 @@ def plot(runs: list[dict], out: Path) -> None: logger.info(f"wrote {out} ({len(runs)} runs, arms={arms})") -def plot_hack_overlay(runs: list[dict], out: Path) -> None: - """Single panel: mean hack_s per arm overlaid, the headline arm-vs-arm view. - Faint per-seed lines for spread, bold EMA mean per arm, onset tick at the - arm's mean hack-onset, direct-labelled (no legend).""" - by_arm: dict[str, list[dict]] = defaultdict(list) - for r in runs: - by_arm[classify(r)].append(r) - arms = [a for a in ARM_ORDER if a in by_arm] - - fig, ax = plt.subplots(figsize=(5.2, 3.4)) +def _overlay_panel(ax, by_arm, arms, key, *, label, with_onset): + """Overlay one metric (key) per arm on ax: faint per-seed EMA lines + bold + EMA mean, optional mean-onset dot. Direct labels (only on the unlabeled-x panel) + are de-collided in y so overlapping arms don't stack their text (collision test).""" + ends = [] # (y_endpoint, x_endpoint, arm, color) for direct labels for arm in arms: - rs = by_arm[arm] + rs = [r for r in by_arm[arm] if key in r] + if not rs: + continue color = ARM_COLORS[arm] stacked = [] for r in rs: - ys = _ema(r["hack_s"]) + ys = _ema(r[key]) ax.plot(r["steps"], ys, color=color, lw=0.6, alpha=0.25, solid_capstyle="round") stacked.append(ys) L = min(len(y) for y in stacked) ym = np.nanmean(np.stack([y[:L] for y in stacked]), axis=0) xm = rs[0]["steps"][:L] ax.plot(xm, ym, color=color, lw=2.0, solid_capstyle="round") - onsets = [s for r in rs if (s := _onset(r["steps"], r["hack_s"])) is not None] - if onsets: - s0 = float(np.mean(onsets)) - ax.plot(s0, np.interp(s0, xm, ym), marker="o", ms=4, color=color, zorder=3) - ax.annotate(arm, (xm[-1], ym[-1]), color=color, fontsize=8, - xytext=(4, 0), textcoords="offset points", va="center") - + if with_onset: + onsets = [s for r in rs if (s := _onset(r["steps"], r["hack_s"])) is not None] + if onsets: + s0 = float(np.mean(onsets)) + ax.plot(s0, np.interp(s0, xm, ym), marker="o", ms=4, color=color, zorder=3) + ends.append((float(ym[-1]), float(xm[-1]), arm, color)) ax.set_ylim(0, 1) - ax.set_xlabel("optimizer step") - ax.set_ylabel("student hack rate (hack_s)") - ax.spines["top"].set_visible(False) - ax.spines["right"].set_visible(False) + ax.set_ylabel(label) + ax.spines[["top", "right"]].set_visible(False) ax.tick_params(labelsize=8) - ax.set_title("Student hack rate by arm (EMA-5; dot = mean onset)", fontsize=10) + # direct-label only one panel (caller passes with_onset=False for it) -- the + # other shares colours, so labelling both is redundant ink (eraser test). + if with_onset: + return + ends.sort(key=lambda e: e[0]) # bottom-to-top by endpoint + gap = 0.052 # min y-separation in data units + placed = [] + for y, x, arm, color in ends: + y_lab = y if not placed else max(y, placed[-1] + gap) + placed.append(y_lab) + arrow = dict(arrowstyle="-", color=color, lw=0.5, shrinkA=0, shrinkB=0) + ax.annotate(arm, xy=(x, y), xytext=(x + 1.0, y_lab), textcoords="data", + color=color, fontsize=8, va="center", + arrowprops=arrow if abs(y_lab - y) > 1e-3 else None) + + +def plot_hack_overlay(runs: list[dict], out: Path) -> None: + """Two stacked panels sharing x: student hack rate (top) and solve rate (bottom) + per arm. Faint per-seed EMA lines + bold EMA-5 mean; onset dot on the hack panel; + arms direct-labelled once on the solve panel (shared colours, no legend).""" + by_arm: dict[str, list[dict]] = defaultdict(list) + for r in runs: + by_arm[classify(r)].append(r) + arms = [a for a in ARM_ORDER if a in by_arm] + + fig, (ax_h, ax_s) = plt.subplots(2, 1, figsize=(5.2, 5.2), sharex=True) + _overlay_panel(ax_h, by_arm, arms, "hack_s", label="hack rate", with_onset=True) + _overlay_panel(ax_s, by_arm, arms, "gt_s", label="solve rate", with_onset=False) + ax_s.set_xlabel("optimizer step") + ax_h.set_title("Hack vs solve rate by arm (EMA-5; dot = mean hack onset)", fontsize=10) fig.tight_layout() out.parent.mkdir(parents=True, exist_ok=True) fig.savefig(out, dpi=150, bbox_inches="tight")