diff --git a/scripts/plot_dynamics.py b/scripts/plot_dynamics.py index 7e4dbf3..4c2329f 100644 --- a/scripts/plot_dynamics.py +++ b/scripts/plot_dynamics.py @@ -98,8 +98,14 @@ def parse_log(path: Path) -> dict | None: # Also parse the route DEPLOY-eval columns when present (non-route logs lack # them -> skip). For routing we plot THESE (deployed model = quarantine deleted), # not the training-time hack_s. - deploy = {"hk_dep", "slv_dep"} & set(idx) - wanted = {**RATE_COLS, **COS_COLS, **{c: c for c in deploy}} + # hk_abl/slv_abl = the FREE per-step deploy proxy (ablated rollout slice, + # rollout_ablate_frac>0); hk_dep/slv_dep = the held-out greedy eval, only on + # eval_ablate_every steps. Prefer the dense proxy for the curve (see below). + deploy = {"hk_dep", "slv_dep", "hk_abl", "slv_abl"} & set(idx) + # Only parse columns this log actually has: non-projecting arms (vanilla, + # routing2) lack cin_t/cin_s, so gate by presence rather than KeyError. + wanted = {k: v for k, v in {**RATE_COLS, **COS_COLS}.items() if k in idx} + wanted.update({c: c for c in deploy}) for line in txt.splitlines(): if "| INFO |" not in line: continue @@ -113,13 +119,20 @@ def parse_log(path: Path) -> dict | None: return None run = dict(arm=arm, refr=refr, seed=seed, vhack=vhack, steps=np.array(steps), **{k: np.array(v, dtype=float) for k, v in series.items()}) - # COHERENCE-GAP FIX: route's training-time hack_s looks vanilla (the routed - # forward still hacks); routing's benefit only shows on the DEPLOYED model - # (quarantine knob deleted). So for routing, plot the deploy series under the - # hack_s/gt_s keys -> all downstream (panels, onset, overlay) reads it. - if arm == "routing" and "hk_dep" in run: - run["hack_s"] = run["hk_dep"] - run["gt_s"] = run["slv_dep"] + # COHERENCE-GAP FIX: routing's training-time hack_s looks vanilla (the routed + # forward still hacks); the benefit only shows on the DEPLOYED model + # (quarantine knob deleted). So for routing/routing2, plot the deploy series + # under the hack_s/gt_s keys -> all downstream (panels, onset, overlay) reads + # it. Prefer the DENSE per-step proxy (hk_abl, every step) over the sparse + # held-out eval (hk_dep, every eval_ablate_every steps); fall back to hk_dep + # for older runs that predate the proxy. + if arm in ("routing", "routing2"): + if "hk_abl" in run: + run["hack_s"] = run["hk_abl"] + run["gt_s"] = run["slv_abl"] + elif "hk_dep" in run: + run["hack_s"] = run["hk_dep"] + run["gt_s"] = run["slv_dep"] return run @@ -128,13 +141,15 @@ def classify(run: dict) -> str: return "vanilla" if run["arm"] == "routing": return "routing" + if run["arm"] == "routing2": + return "routing2" # arm == projected -> erasure, split by refresh return "online erasure" if run["refr"] > 0 else "static erasure" # --- plot ------------------------------------------------------------------ -ARM_ORDER = ["vanilla", "static erasure", "online erasure", "routing"] +ARM_ORDER = ["vanilla", "static erasure", "online erasure", "routing", "routing2"] # Distinct colour per series -- the two rows measure different things, so they # must not share a palette (hack != teacher-cos). Row 0: red hack vs green # solve. Row 1: blue teacher-cos vs amber student-cos. @@ -148,7 +163,8 @@ COS_COLORS = {"cin_t": "#33508c", "cin_s": "#c98a2b"} # confessions), assign colour by method rank along a perceptual RdYlGn ramp so # the reader sees "redder = hacks more" at a glance. ARM_COLORS = {"vanilla": "#7a7a7a", "static erasure": "#c98a2b", - "online erasure": "#33508c", "routing": "#2f7d4f"} + "online erasure": "#33508c", "routing": "#2f7d4f", + "routing2": "#7d2f6f"} def _onset(steps: np.ndarray, hack: np.ndarray) -> int | None: @@ -179,7 +195,10 @@ def _series_panel(ax, runs, cols, colors, ylim, label_series=False): for col, label in cols.items(): color = colors[col] stacked = [] - for r in runs: + present = [r for r in runs if col in r] + if not present: # arm lacks this series (e.g. no cos cols for routing2/vanilla) + continue + for r in present: ys = _ema(r[col]) ax.plot(r["steps"], ys, color=color, lw=0.7, alpha=0.35, solid_capstyle="round") stacked.append(ys) @@ -214,7 +233,8 @@ def plot(runs: list[dict], out: Path) -> None: fig, axes = plt.subplots(2, len(arms), figsize=(3.0 * len(arms), 4.4), sharex=True, sharey="row", squeeze=False) - cos_lo = min(np.nanmin(r[c]) for r in runs for c in COS_COLS) + _cos_vals = [np.nanmin(r[c]) for r in runs for c in COS_COLS if c in r] + cos_lo = min(_cos_vals) if _cos_vals else 0.0 for col, arm in enumerate(arms): rs = by_arm[arm] n_seed = len({r["seed"] for r in rs})