diff --git a/src/projected_grpo/train.py b/src/projected_grpo/train.py index 453162f..911dac5 100644 --- a/src/projected_grpo/train.py +++ b/src/projected_grpo/train.py @@ -212,6 +212,14 @@ class Config: # cut. Guarantees all hacks emerge (teacher-seeded) before testing whether route2 # holds the suppression once the teacher crutch is gone. See step-loop use. teacher_off_step: int | None = None + # A5 no-cheat generalisation: restrict teacher demos (and thus the route2 tau + # hack-anchor) to these env_modes only. Held-out modes stay in the training set + # but train PURELY ON-POLICY (no teacher rows, never seed the hack-anchor) -- the + # student must emerge them itself, and we measure whether routing on the + # known-mode v_grad suppresses them anyway (absorption). None = use the whole + # pool (normal). When set, the line-589 "filter problems to pool keys" is skipped + # and uncached/held-out prompts fall through to student-only instead of skipping. + teacher_modes: tuple[str, ...] | None = None # Cross-mechanism BLUF (docs/spec/20260528_cross_mechanism_v_hack.md): # which upstream detectors were used to label the hack-side of the pairs that # produced v_hack. Used to split student-rollout hacks into half_A (covered by @@ -527,6 +535,19 @@ def main(cfg: Config) -> int: f"{dict(sorted(by_mode.items()))}. Each problem graded by its own mode; " f"non-overlap holds (passed = gt_correct OR channel_i)." ) + if cfg.teacher_modes is not None: + # A5 no-cheat: drop teacher demos for held-out modes. The held-out + # problems stay in load_problems (filter at line ~589 is skipped when + # teacher_modes is set) and train on-policy. partition is required. + assert partition is not None, "teacher_modes needs a partition.json" + kept = {pid: rows for pid, rows in teacher_pool.items() + if partition[pid] in cfg.teacher_modes} + logger.info( + f"teacher_modes={cfg.teacher_modes}: teacher pool restricted " + f"{len(teacher_pool)}->{len(kept)} prompts (known modes only); " + f"held-out-mode problems train ON-POLICY (no teacher, no anchor seed)." + ) + teacher_pool = kept n_rollouts_per = sum(len(v) for v in teacher_pool.values()) / len(teacher_pool) avg_hack = sum(int(r["hacked"]) for v in teacher_pool.values() for r in v) / sum(len(v) for v in teacher_pool.values()) logger.info( @@ -581,10 +602,12 @@ def main(cfg: Config) -> int: problems = load_problems(n_problems, env_modes=[cfg.env_mode], seed=cfg.seed, partition=partition) mode_desc = "per-problem partition" if partition is not None else f"single env_mode={cfg.env_mode}" logger.info(f"loaded {len(problems)} problems from {DATA.name} -- {mode_desc}") - if teacher_pool: + if teacher_pool and cfg.teacher_modes is None: # Restrict prompt sampling to problems with cached teacher rollouts; # otherwise we'd skip the majority of steps when the pool is sparse # (e.g. 70/992 prompts cached -> ~93% skip rate). + # SKIPPED under teacher_modes (A5): held-out-mode problems have no teacher + # demos but must stay in training to emerge + be measured on-policy. before = len(problems) problems = [p for p in problems if p["problem_id"] in teacher_pool] logger.info( @@ -880,17 +903,19 @@ def main(cfg: Config) -> int: model.config.use_cache = True _tg = time.perf_counter() teacher_sample: list[dict] | None = None - if teacher_pool and G_t > 0: + pool_rows = teacher_pool.get(prob["problem_id"]) if teacher_pool else None + if teacher_pool and G_t > 0 and not pool_rows and cfg.teacher_modes is None: + # Sparse-pool skip: prompt uncached -> skip the whole prompt; + # falling back to student-only would break the student-vs-teacher + # comparison the normal mixed-pool run is designed to measure. + # SUPPRESSED under teacher_modes (A5): a held-out-mode prompt has no + # teacher demos BY DESIGN and must train on-policy (falls to else). + n_skipped += 1 + continue + if pool_rows and G_t > 0: # Mixed-pool: G_s live student + G_t cached teacher rollouts. # G_t==0 (mix=0 no-teacher ablation) falls through to the student-only # path below; the pool stays loaded for partition + v_grad extraction. - # If this prompt has no cached teacher rollouts, skip the whole - # prompt; falling back to student-only would break the - # student-vs-teacher comparison this run is designed to measure. - pool_rows = teacher_pool.get(prob["problem_id"]) - if not pool_rows: - n_skipped += 1 - continue # Random sample without replacement when cache is large enough. # Re-seeded per (step, p_idx) by the global rng so runs reproduce. idxs = torch.randperm(len(pool_rows), generator=rng)[:G_t].tolist()