mirror of
https://github.com/wassname/evil_MoE.git
synced 2026-06-27 18:04:59 +08:00
feat: --teacher-modes for clean A5 no-cheat (train held-out modes on-policy, anchor only known)
Decouples training problems from teacher pool: when teacher_modes is set, the pool is restricted to known-mode demos, the line-589 pool filter is skipped, and held-out prompts fall through to student-only (not skipped). The route2 tau hack-anchor then sees only known-mode teacher rows + known-mode hacked_E, so held-out suppression is pure absorption -- no held-out label at train time. Smoke-verified: run_tests prompts get teacher mix, held-out prompts train on-policy. Co-Authored-By: Claudypoo <288921227+claudypoo@users.noreply.github.com>
This commit is contained in:
@@ -212,6 +212,14 @@ class Config:
|
||||
# cut. Guarantees all hacks emerge (teacher-seeded) before testing whether route2
|
||||
# holds the suppression once the teacher crutch is gone. See step-loop use.
|
||||
teacher_off_step: int | None = None
|
||||
# A5 no-cheat generalisation: restrict teacher demos (and thus the route2 tau
|
||||
# hack-anchor) to these env_modes only. Held-out modes stay in the training set
|
||||
# but train PURELY ON-POLICY (no teacher rows, never seed the hack-anchor) -- the
|
||||
# student must emerge them itself, and we measure whether routing on the
|
||||
# known-mode v_grad suppresses them anyway (absorption). None = use the whole
|
||||
# pool (normal). When set, the line-589 "filter problems to pool keys" is skipped
|
||||
# and uncached/held-out prompts fall through to student-only instead of skipping.
|
||||
teacher_modes: tuple[str, ...] | None = None
|
||||
# Cross-mechanism BLUF (docs/spec/20260528_cross_mechanism_v_hack.md):
|
||||
# which upstream detectors were used to label the hack-side of the pairs that
|
||||
# produced v_hack. Used to split student-rollout hacks into half_A (covered by
|
||||
@@ -527,6 +535,19 @@ def main(cfg: Config) -> int:
|
||||
f"{dict(sorted(by_mode.items()))}. Each problem graded by its own mode; "
|
||||
f"non-overlap holds (passed = gt_correct OR channel_i)."
|
||||
)
|
||||
if cfg.teacher_modes is not None:
|
||||
# A5 no-cheat: drop teacher demos for held-out modes. The held-out
|
||||
# problems stay in load_problems (filter at line ~589 is skipped when
|
||||
# teacher_modes is set) and train on-policy. partition is required.
|
||||
assert partition is not None, "teacher_modes needs a partition.json"
|
||||
kept = {pid: rows for pid, rows in teacher_pool.items()
|
||||
if partition[pid] in cfg.teacher_modes}
|
||||
logger.info(
|
||||
f"teacher_modes={cfg.teacher_modes}: teacher pool restricted "
|
||||
f"{len(teacher_pool)}->{len(kept)} prompts (known modes only); "
|
||||
f"held-out-mode problems train ON-POLICY (no teacher, no anchor seed)."
|
||||
)
|
||||
teacher_pool = kept
|
||||
n_rollouts_per = sum(len(v) for v in teacher_pool.values()) / len(teacher_pool)
|
||||
avg_hack = sum(int(r["hacked"]) for v in teacher_pool.values() for r in v) / sum(len(v) for v in teacher_pool.values())
|
||||
logger.info(
|
||||
@@ -581,10 +602,12 @@ def main(cfg: Config) -> int:
|
||||
problems = load_problems(n_problems, env_modes=[cfg.env_mode], seed=cfg.seed, partition=partition)
|
||||
mode_desc = "per-problem partition" if partition is not None else f"single env_mode={cfg.env_mode}"
|
||||
logger.info(f"loaded {len(problems)} problems from {DATA.name} -- {mode_desc}")
|
||||
if teacher_pool:
|
||||
if teacher_pool and cfg.teacher_modes is None:
|
||||
# Restrict prompt sampling to problems with cached teacher rollouts;
|
||||
# otherwise we'd skip the majority of steps when the pool is sparse
|
||||
# (e.g. 70/992 prompts cached -> ~93% skip rate).
|
||||
# SKIPPED under teacher_modes (A5): held-out-mode problems have no teacher
|
||||
# demos but must stay in training to emerge + be measured on-policy.
|
||||
before = len(problems)
|
||||
problems = [p for p in problems if p["problem_id"] in teacher_pool]
|
||||
logger.info(
|
||||
@@ -880,17 +903,19 @@ def main(cfg: Config) -> int:
|
||||
model.config.use_cache = True
|
||||
_tg = time.perf_counter()
|
||||
teacher_sample: list[dict] | None = None
|
||||
if teacher_pool and G_t > 0:
|
||||
pool_rows = teacher_pool.get(prob["problem_id"]) if teacher_pool else None
|
||||
if teacher_pool and G_t > 0 and not pool_rows and cfg.teacher_modes is None:
|
||||
# Sparse-pool skip: prompt uncached -> skip the whole prompt;
|
||||
# falling back to student-only would break the student-vs-teacher
|
||||
# comparison the normal mixed-pool run is designed to measure.
|
||||
# SUPPRESSED under teacher_modes (A5): a held-out-mode prompt has no
|
||||
# teacher demos BY DESIGN and must train on-policy (falls to else).
|
||||
n_skipped += 1
|
||||
continue
|
||||
if pool_rows and G_t > 0:
|
||||
# Mixed-pool: G_s live student + G_t cached teacher rollouts.
|
||||
# G_t==0 (mix=0 no-teacher ablation) falls through to the student-only
|
||||
# path below; the pool stays loaded for partition + v_grad extraction.
|
||||
# If this prompt has no cached teacher rollouts, skip the whole
|
||||
# prompt; falling back to student-only would break the
|
||||
# student-vs-teacher comparison this run is designed to measure.
|
||||
pool_rows = teacher_pool.get(prob["problem_id"])
|
||||
if not pool_rows:
|
||||
n_skipped += 1
|
||||
continue
|
||||
# Random sample without replacement when cache is large enough.
|
||||
# Re-seeded per (step, p_idx) by the global rng so runs reproduce.
|
||||
idxs = torch.randperm(len(pool_rows), generator=rng)[:G_t].tolist()
|
||||
|
||||
Reference in New Issue
Block a user