diff --git a/src/projected_grpo/tablelog.py b/src/projected_grpo/tablelog.py index 1c8b35d..1b0b92e 100644 --- a/src/projected_grpo/tablelog.py +++ b/src/projected_grpo/tablelog.py @@ -105,14 +105,19 @@ class StepLogger: _Col("gt_t", 6, "gt_t", "frac", "teacher ground-truth passes (sanity)"), _Col("hack_s", 7, "hack_s?", "frac", "student hack-flagged rollouts (the headline)"), _Col("hack_t", 7, "hack_t", "frac", "teacher hack-flagged rollouts (sanity: pool hacks)"), + # Deploy-eval shown for EVERY arm (nan on steps it's not run -> see it ride + # along as training proceeds). route/route2: quarantine knob OFF. vanilla/erase: + # the trained model itself. Apples-to-apples knob-off deploy number, the plot series. + _Col("hack_deploy", 7, "hk_dep", "+.2f", "DEPLOY-eval hack (route: quarantine OFF; vanilla/erase: trained model); held-out subset, T=0.7, every eval_ablate_every steps; nan between"), + _Col("solve_deploy", 7, "slv_dep", "+.2f", "DEPLOY-eval solve (same cadence; nan between)"), ] # Per-mode CUMULATIVE student exploit rate -> which loophole classes the # student has learnt, and how strongly. Only when the run spans >1 mode # (the substrate); single-mode runs would just duplicate hack_s. self._modes = modes if len(modes) > 1 else [] for m in self._modes: - cols.append(_Col(f"hk_{mode_code[m]}", 6, f"hk_{mode_code[m]}", "frac", - f"cumulative student hacks of {m}")) + cols.append(_Col(f"hk_{mode_code[m]}", 5, f"hk_{mode_code[m]}", "d", + f"student hacks of {m} THIS step (current batch, not cumulative)")) cols += [ _Col("lp_s", 6, "lp_s↓", "+.2f", "mean student gen_logp (diagnostic)"), _Col("lp_t", 6, "lp_t↑", "+.2f", "mean teacher gen_logp; off-policy gap = lp_s-lp_t"), @@ -141,8 +146,6 @@ class StepLogger: if arm in ("routing", "routing2"): cols += [ _Col("q_egy", 6, "qE", ".2f", "grad energy into quarantine ||g_quar||/(||g_keep||+||g_quar||); ~0.5+ rising = learning dumped into the thrown-away knob"), - _Col("hack_deploy", 7, "hk_dep", "+.2f", "DEPLOY-eval hack (quarantine deleted = deployed model); held-out greedy, eval_ablate_every steps; the plot number"), - _Col("solve_deploy", 7, "slv_dep", "+.2f", "DEPLOY-eval solve"), _Col("hack_abl", 6, "hk_abl", "frac", "FREE per-step deploy proxy: hack rate on the ablated (deploy-mode) rollout slice; train prompts, noisier than hk_dep"), _Col("solve_abl", 6, "slv_abl", "frac", "free per-step deploy proxy: solve rate on the ablated rollout slice"), ] diff --git a/src/projected_grpo/train.py b/src/projected_grpo/train.py index 45948c7..d530f6e 100644 --- a/src/projected_grpo/train.py +++ b/src/projected_grpo/train.py @@ -331,125 +331,6 @@ MODE_CODE: dict[str, str] = { } -@dataclass(frozen=True) -class _Col: - """Per-step table column spec. - - key: row-dict key (raw value lives there as float/int/str/None). - width: render width for fixed-width streaming display. - header: display label (may include direction arrows, ? for desired-zero, etc). - fmt: format spec applied to the raw value, e.g. "+.3f", ".2e", "d". - Special spec "frac" expects a (num, denom) tuple and renders "n/d". - None means render as str() of the value. - """ - key: str - width: int - header: str - fmt: str | None = None - desc: str = "" # one-line decode for the legend; "" => omitted from legend - - -def _format_cell(value, fmt: str | None) -> str: - """Format one cell. NaN renders as 'nan' regardless of spec.""" - if value is None: - return "nan" - if fmt == "frac": - n, d = value - return f"{n}/{d}" - if fmt is None: - return str(value) - if isinstance(value, float) and value != value: # NaN - return "nan" - return format(value, fmt) - - -class StepLogger: - """Per-step training-table renderer. - - Single source of truth for column order, width, header label, and value - formatter. The row dict carries raw values (floats, ints, tuples, strings); - StepLogger formats them for streaming, and the end-of-run tabulate dump - consumes the same raw values without re-parsing scientific-notation strings. - - Timing columns (gen/fb/t_rew/sec) are absent from the streaming spec; they - show only at end-of-run, where the tabulate dump picks them from the row dicts. - """ - - def __init__(self, arm: str, modes: list[str]) -> None: - # arm in {vanilla, projected, routing}; only projected/routing actually - # project the gradient, so the cin/cout/fired diagnostics are theirs alone - # (in vanilla they'd be counterfactual noise -> omitted). - projects = arm in ("projected", "routing") - cols: list[_Col] = [ - _Col("step", 4, "step", "d", "GRPO step"), - _Col("ref_eq", 6, "ref_eq", ".2f", "vanilla-equiv step (cum_gens/256)"), - _Col("rew", 6, "rew", "+.2f", "mean combined reward"), - _Col("rew_s", 6, "rew_s↑", "+.2f", "student mean reward"), - _Col("gt_s", 6, "gt_s↑", "frac", "student ground-truth passes"), - _Col("gt_t", 6, "gt_t", "frac", "teacher ground-truth passes (sanity)"), - _Col("hack_s", 7, "hack_s?", "frac", "student hack-flagged rollouts (the headline)"), - _Col("hack_t", 7, "hack_t", "frac", "teacher hack-flagged rollouts (sanity: pool hacks)"), - # Deploy-eval shown for EVERY arm (nan on steps it's not run -> see it ride - # along as training proceeds). route/route2: quarantine knob OFF. vanilla/erase: - # the trained model itself. Apples-to-apples knob-off deploy number, the plot series. - _Col("hack_deploy", 7, "hk_dep", "+.2f", "DEPLOY-eval hack (route: quarantine OFF; vanilla/erase: trained model); held-out subset, T=0.7, every eval_ablate_every steps; nan between"), - _Col("solve_deploy", 7, "slv_dep", "+.2f", "DEPLOY-eval solve (same cadence; nan between)"), - ] - # Per-mode CUMULATIVE student exploit rate -> which loophole classes the - # student has learnt, and how strongly. Only when the run spans >1 mode - # (the substrate); single-mode runs would just duplicate hack_s. - self._modes = modes if len(modes) > 1 else [] - for m in self._modes: - cols.append(_Col(f"hk_{MODE_CODE[m]}", 5, f"hk_{MODE_CODE[m]}", "d", - f"student hacks of {m} THIS step (current batch, not cumulative)")) - cols += [ - _Col("lp_s", 6, "lp_s↓", "+.2f", "mean student gen_logp (diagnostic)"), - _Col("lp_t", 6, "lp_t↑", "+.2f", "mean teacher gen_logp; off-policy gap = lp_s-lp_t"), - _Col("loss", 7, "loss", "+.2f", "mean GRPO loss"), - _Col("gn", 7, "gn", ".1e", "pre-clip L2 norm of δS grads (vs grad_clip)"), - _Col("lr", 7, "lr", ".1e", "scheduled learning rate"), - ] - if projects: - cols += [ - _Col("cos_pre", 6, "cin", ".2f", "hack-ward grad fraction ‖relu(V@g)‖/‖g‖ ∈ [0,1] BEFORE proj"), - _Col("cos_pre_s", 6, "cin_s", ".2f", "cin on student-only grad"), - _Col("cos_pre_t", 6, "cin_t", ".2f", "cin on teacher-only grad (want cin_t>cin_s)"), - _Col("cos_post", 6, "cout", ".2f", "hack-ward fraction AFTER projection (want ~0: all removed)"), - _Col("fired", 5, "fired", ".2f", "fraction of modules where projection fired"), - ] - # route2: the routing gate is cos(g_b,v_grad) > tau, where tau is the - # per-step EMA midpoint of the hack vs clean cos clouds. Surface tau and - # the hack-clean gap so we can see the threshold ride the drift and whether - # the direction still separates (hkgap>0) -- replaces the silent cos>0 gate. - if arm == "routing2": - cols += [ - _Col("tau", 6, "tau", "+.2f", "per-step calibrated route threshold (midpoint of hack vs clean cos clouds)"), - _Col("hkgap", 6, "hkgap", "+.2f", "ema_hack_cos - ema_clean_cos; >0 = v_grad still separates hack from clean (else direction dead)"), - _Col("resid", 6, "resid", "+.2f", "cos(deployed δS.grad AFTER routing, v_grad); ~0 = hack stripped cleanly, >0 = leak into deployed knob"), - ] - if arm in ("routing", "routing2"): - cols += [ - _Col("q_egy", 6, "qE", ".2f", "grad energy into quarantine ‖g_quar‖/(‖g_keep‖+‖g_quar‖); ~0.5+ rising = learning dumped into the thrown-away knob"), - _Col("hack_abl", 6, "hk_abl", "frac", "FREE per-step deploy proxy: hack rate on the ablated (deploy-mode) rollout slice; train prompts, noisier than hk_dep"), - _Col("solve_abl", 6, "slv_abl", "frac", "free per-step deploy proxy: solve rate on the ablated rollout slice"), - ] - self._cols = cols - - def header(self) -> str: - return " ".join(f"{c.header:>{c.width}}" for c in self._cols) - - def row(self, cells: dict) -> str: - return " ".join( - f"{_format_cell(cells[c.key], c.fmt):>{c.width}}" for c in self._cols - ) - - def legend(self) -> str: - """Decode the (arm-/mode-conditional) columns actually present this run.""" - lines = "\n".join(f" {c.header:>8} = {c.desc}" for c in self._cols if c.desc) - return ("table columns (timing gen/fb/t_rew/sec dropped from streaming, kept " - "in the end-of-run dump):\n" + lines) - - def main(cfg: Config) -> int: # Read the chosen preset's settings off the config, then set up the run. The # subclass dataclasses (SmokeConfig / FastConfig / FullConfig) carry the preset