feat: --teacher-modes for clean A5 no-cheat (train held-out modes on-policy, anchor only known)

Decouples training problems from teacher pool: when teacher_modes is set, the
pool is restricted to known-mode demos, the line-589 pool filter is skipped, and
held-out prompts fall through to student-only (not skipped). The route2 tau
hack-anchor then sees only known-mode teacher rows + known-mode hacked_E, so
held-out suppression is pure absorption -- no held-out label at train time.
Smoke-verified: run_tests prompts get teacher mix, held-out prompts train on-policy.

Co-Authored-By: Claudypoo <288921227+claudypoo@users.noreply.github.com>
This commit is contained in:
wassname
2026-06-03 22:45:49 +00:00
parent a0d4ddf9d5
commit da48a95d9e
+34 -9
View File
@@ -212,6 +212,14 @@ class Config:
# cut. Guarantees all hacks emerge (teacher-seeded) before testing whether route2
# holds the suppression once the teacher crutch is gone. See step-loop use.
teacher_off_step: int | None = None
# A5 no-cheat generalisation: restrict teacher demos (and thus the route2 tau
# hack-anchor) to these env_modes only. Held-out modes stay in the training set
# but train PURELY ON-POLICY (no teacher rows, never seed the hack-anchor) -- the
# student must emerge them itself, and we measure whether routing on the
# known-mode v_grad suppresses them anyway (absorption). None = use the whole
# pool (normal). When set, the line-589 "filter problems to pool keys" is skipped
# and uncached/held-out prompts fall through to student-only instead of skipping.
teacher_modes: tuple[str, ...] | None = None
# Cross-mechanism BLUF (docs/spec/20260528_cross_mechanism_v_hack.md):
# which upstream detectors were used to label the hack-side of the pairs that
# produced v_hack. Used to split student-rollout hacks into half_A (covered by
@@ -527,6 +535,19 @@ def main(cfg: Config) -> int:
f"{dict(sorted(by_mode.items()))}. Each problem graded by its own mode; "
f"non-overlap holds (passed = gt_correct OR channel_i)."
)
if cfg.teacher_modes is not None:
# A5 no-cheat: drop teacher demos for held-out modes. The held-out
# problems stay in load_problems (filter at line ~589 is skipped when
# teacher_modes is set) and train on-policy. partition is required.
assert partition is not None, "teacher_modes needs a partition.json"
kept = {pid: rows for pid, rows in teacher_pool.items()
if partition[pid] in cfg.teacher_modes}
logger.info(
f"teacher_modes={cfg.teacher_modes}: teacher pool restricted "
f"{len(teacher_pool)}->{len(kept)} prompts (known modes only); "
f"held-out-mode problems train ON-POLICY (no teacher, no anchor seed)."
)
teacher_pool = kept
n_rollouts_per = sum(len(v) for v in teacher_pool.values()) / len(teacher_pool)
avg_hack = sum(int(r["hacked"]) for v in teacher_pool.values() for r in v) / sum(len(v) for v in teacher_pool.values())
logger.info(
@@ -581,10 +602,12 @@ def main(cfg: Config) -> int:
problems = load_problems(n_problems, env_modes=[cfg.env_mode], seed=cfg.seed, partition=partition)
mode_desc = "per-problem partition" if partition is not None else f"single env_mode={cfg.env_mode}"
logger.info(f"loaded {len(problems)} problems from {DATA.name} -- {mode_desc}")
if teacher_pool:
if teacher_pool and cfg.teacher_modes is None:
# Restrict prompt sampling to problems with cached teacher rollouts;
# otherwise we'd skip the majority of steps when the pool is sparse
# (e.g. 70/992 prompts cached -> ~93% skip rate).
# SKIPPED under teacher_modes (A5): held-out-mode problems have no teacher
# demos but must stay in training to emerge + be measured on-policy.
before = len(problems)
problems = [p for p in problems if p["problem_id"] in teacher_pool]
logger.info(
@@ -880,17 +903,19 @@ def main(cfg: Config) -> int:
model.config.use_cache = True
_tg = time.perf_counter()
teacher_sample: list[dict] | None = None
if teacher_pool and G_t > 0:
pool_rows = teacher_pool.get(prob["problem_id"]) if teacher_pool else None
if teacher_pool and G_t > 0 and not pool_rows and cfg.teacher_modes is None:
# Sparse-pool skip: prompt uncached -> skip the whole prompt;
# falling back to student-only would break the student-vs-teacher
# comparison the normal mixed-pool run is designed to measure.
# SUPPRESSED under teacher_modes (A5): a held-out-mode prompt has no
# teacher demos BY DESIGN and must train on-policy (falls to else).
n_skipped += 1
continue
if pool_rows and G_t > 0:
# Mixed-pool: G_s live student + G_t cached teacher rollouts.
# G_t==0 (mix=0 no-teacher ablation) falls through to the student-only
# path below; the pool stays loaded for partition + v_grad extraction.
# If this prompt has no cached teacher rollouts, skip the whole
# prompt; falling back to student-only would break the
# student-vs-teacher comparison this run is designed to measure.
pool_rows = teacher_pool.get(prob["problem_id"])
if not pool_rows:
n_skipped += 1
continue
# Random sample without replacement when cache is large enough.
# Re-seeded per (step, p_idx) by the global rng so runs reproduce.
idxs = torch.randperm(len(pool_rows), generator=rng)[:G_t].tolist()