fix(routeV): keep=bulk gate + deterministic teachers + deploy-mode generation

Three fixes after the deploy-solve=0 run (all user-confirmed):

1. Gate co-adaptation: t_lo sat at quantile(route_quantile), so keep was only the
   bottom ~5% and ~95% of rollouts landed in mid (both blocks train, qmass~0.5).
   Move the keep/mid boundary up: t_lo=quantile(1-2q), t_hi=quantile(1-q), so keep
   is the BULK, mid+rout are the top 2q. Three-zone absorption preserved (not deleted).

2. Deterministic teachers: every teacher-phase prompt is drawn from the both-pool-
   covered set and gets EXACTLY teacher_n_per_prompt hack + N solve (constant count,
   no flip/coverage drops). Replaces mix_ratio*_even_split (count varied per step).
   No flip in the teacher phase (solve teacher carries solve pressure). mix_ratio>0
   stays the on/off switch. Removed dead _even_split.

3. Deploy-mode generation: student rollouts generate under ablate_quarantine, so the
   behavior policy = the shipped deployed-only model -- the quarantine's learned hack
   can't saturate the rollout distribution and starve honest solve advantage. For
   clean-gated rollouts gen and train forward now match.

Also: FastConfig lr 1e-4->5e-4 (random-init lora2r needs more lr in the short budget).
AGENTS.md: don't bake unconfirmed theories into comments; don't inflate diagnosis
confidence across turns. Smoke + smoke-solvemix green; all verify gates pass.

Co-Authored-By: Claudypoo <288921227+claudypoo@users.noreply.github.com>
This commit is contained in:
wassname
2026-06-11 00:29:12 +00:00
parent 3f2b44452a
commit 97aede8d9c
3 changed files with 79 additions and 72 deletions
+2
View File
@@ -72,6 +72,8 @@ Inherit global rules from `~/.claude/CLAUDE.md`.
diverse / less overlapping / less off-distribution), i.e. the "Hack A predicts Hack B" diverse / less overlapping / less off-distribution), i.e. the "Hack A predicts Hack B"
generalization where every label still lives on pairs we wrote. generalization where every label still lives on pairs we wrote.
- do not overconfidentaly diagnoses. if you cant think of 3+ plausible hypothesis - including bugs, subtle failures, and you being wrong about concepts - then you have lost perspective and narrow vision - do not overconfidentaly diagnoses. if you cant think of 3+ plausible hypothesis - including bugs, subtle failures, and you being wrong about concepts - then you have lost perspective and narrow vision
- DON'T act on a new theory/diagnosis until I confirm it: don't keep building on it, and don't write it into code comments, docstrings, or docs. An unconfirmed theory baked into a comment misleads every future agent who reads it as fact. State it to me, wait for the OK, THEN write.
- DON'T inflate a diagnosis's confidence (maybe -> probably -> definitely) across turns or in writing. Keep the hedge you started with unless new evidence justifies the change, and name that evidence. Confidence creep in comments/docs is how a guess becomes "known" with no one having checked.
- I'd often afk so dont stop and ask me a question you know the likely answer, or I've already indicated or asked for, or where there is only on answer "waiting for your go ahead". I'd rather you just commit and go ahead - I'd often afk so dont stop and ask me a question you know the likely answer, or I've already indicated or asked for, or where there is only on answer "waiting for your go ahead". I'd rather you just commit and go ahead
## The adapter setup (shapes), and why "same position = shrinkage" is subtle ## The adapter setup (shapes), and why "same position = shrinkage" is subtle
+69 -72
View File
@@ -106,14 +106,6 @@ def _build_v_grad(raw_grads: dict, names, k: int, device) -> dict[str, Float[tor
return out return out
def _even_split(total: int, parts: int) -> list[int]:
"""Distribute `total` items across `parts` buckets as evenly as possible, extras first.
_even_split(8,4)=[2,2,2,2]; _even_split(2,4)=[1,1,0,0]. Used to spread the STEP-level
teacher budget across the step's prompts (T_solve front-loads like T_total, so
solve_alloc[i] <= total_alloc[i] holds bucket-wise)."""
base, extra = divmod(total, parts)
return [base + (1 if i < extra else 0) for i in range(parts)]
def _sample_rows(rows: list[dict] | None, n: int, rng: torch.Generator) -> list[dict]: def _sample_rows(rows: list[dict] | None, n: int, rng: torch.Generator) -> list[dict]:
"""Draw n teacher rollouts from a prompt's pool (with replacement if the pool is short).""" """Draw n teacher rollouts from a prompt's pool (with replacement if the pool is short)."""
@@ -398,9 +390,8 @@ def main(cfg: Config) -> int:
avg_hack = sum(int(r["hacked"]) for v in teacher_pool.values() for r in v) / sum(len(v) for v in teacher_pool.values()) avg_hack = sum(int(r["hacked"]) for v in teacher_pool.values() for r in v) / sum(len(v) for v in teacher_pool.values())
logger.info( logger.info(
f"teacher pool: {len(teacher_pool)} prompts, ~{n_rollouts_per:.1f} rollouts/prompt, " f"teacher pool: {len(teacher_pool)} prompts, ~{n_rollouts_per:.1f} rollouts/prompt, "
f"cached hack_rate={avg_hack:.2%}. STEP-level mix_ratio={cfg.mix_ratio} -> " f"cached hack_rate={avg_hack:.2%}. Deterministic: {cfg.teacher_n_per_prompt} hack "
f"{round(prompts_per_step * group * cfg.mix_ratio)} teachers across " f"teacher(s) per teacher-phase prompt (constant count, no mix_ratio budget).")
f"{prompts_per_step} prompts/step (rest of {prompts_per_step * group} gens are student).")
# ── solve-teacher pool (symmetric honest demos) ── same schema/loader as the # ── solve-teacher pool (symmetric honest demos) ── same schema/loader as the
# hack pool; the G_t teacher slots split solve_mix_frac solve / rest hack. # hack pool; the G_t teacher slots split solve_mix_frac solve / rest hack.
@@ -462,10 +453,20 @@ def main(cfg: Config) -> int:
problems = all_problems[:n_problems] problems = all_problems[:n_problems]
mode_desc = "per-problem partition" if partition is not None else f"single env_mode={cfg.env_mode}" mode_desc = "per-problem partition" if partition is not None else f"single env_mode={cfg.env_mode}"
logger.info(f"loaded {len(problems)} seeded-shuffle problems from {DATA.name} -- {mode_desc}") logger.info(f"loaded {len(problems)} seeded-shuffle problems from {DATA.name} -- {mode_desc}")
# Both-pool-covered prompts: the teacher phase samples ONLY these so every prompt can get
# a deterministic N hack + N solve. solve_pool ⊂ teacher_pool, so the intersection = the
# solve-covered prompts (or just teacher-covered when there is no solve pool).
covered_problems = [p for p in problems
if (not teacher_pool or p["problem_id"] in teacher_pool)
and (not solve_pool or p["problem_id"] in solve_pool)]
if teacher_pool: if teacher_pool:
n_cov = sum(1 for p in problems if p["problem_id"] in teacher_pool) n_cov = sum(1 for p in problems if p["problem_id"] in teacher_pool)
logger.info(f"teacher coverage: {n_cov}/{len(problems)} train prompts have cached " n_t_per_prompt = cfg.teacher_n_per_prompt * (bool(teacher_pool) + bool(solve_pool))
f"teacher hacks (rest train student-only); hack must generalize off the seeds") logger.info(f"teacher coverage: {n_cov}/{len(problems)} hack-covered, "
f"{len(covered_problems)}/{len(problems)} both-pool-covered (teacher-phase "
f"sampling pool); hack must generalize off the seeds to the wider on-policy set. "
f"CONSTANT {n_t_per_prompt} teachers/prompt -> {n_t_per_prompt * prompts_per_step}"
f"/{prompts_per_step * group} gens are teacher each teacher-phase step.")
# Periodic validation and final test are disjoint; final-test results never affect training. # Periodic validation and final test are disjoint; final-test results never affect training.
# Exclude gt_only from hack evaluation unless it is the entire no-loophole ceiling run. # Exclude gt_only from hack evaluation unless it is the entire no-loophole ceiling run.
@@ -495,18 +496,17 @@ def main(cfg: Config) -> int:
pad_id = tok.pad_token_id pad_id = tok.pad_token_id
def gen_students(enc, n: int) -> tuple[torch.Tensor, int]: def gen_students(enc, n: int) -> tuple[torch.Tensor, int]:
"""Generate student rollouts, placing any quarantine-ablated samples last.""" """Generate student rollouts in DEPLOY mode (quarantine ablated): the behavior policy
n_abl = round(n * cfg.rollout_ablate_frac) if has_quarantine else 0 = the shipped deployed-only model, so the quarantine's learned hack can't saturate the
parts = [] rollout distribution and starve the honest-vs-honest solve advantage. For clean-gated
if n - n_abl > 0: rollouts the train forward is also quarantine-off, so gen and train now match.
parts.append(model.generate(**enc, generation_config=gen_cfg, (vanilla has no quarantine -> nullcontext, but its quarantine is empty so it is already
num_return_sequences=n - n_abl).detach()) deploy-mode.)"""
if n_abl > 0: ctx = ablate_quarantine(wrappers) if has_quarantine else nullcontext()
with ablate_quarantine(wrappers): with ctx:
parts.append(model.generate(**enc, generation_config=gen_cfg, out = model.generate(**enc, generation_config=gen_cfg,
num_return_sequences=n_abl).detach()) num_return_sequences=n).detach()
L = max(p.shape[1] for p in parts) return out, 0 # every rollout is deploy-mode now; no separate ablated-proxy subset
return torch.cat([F.pad(p, (0, L - p.shape[1]), value=pad_id) for p in parts], dim=0), n_abl
# `ref_eq` compares cumulative sampling pressure to the 16x16 reference step. # `ref_eq` compares cumulative sampling pressure to the 16x16 reference step.
run_modes = sorted({p["env_mode"] for p in problems}, key=lambda m: list(MODE_CODE).index(m)) run_modes = sorted({p["env_mode"] for p in problems}, key=lambda m: list(MODE_CODE).index(m))
@@ -607,20 +607,23 @@ def main(cfg: Config) -> int:
if n_inc == 0: if n_inc == 0:
raise RuntimeError("no module has positive band width; pairs separate nowhere") raise RuntimeError("no module has positive band width; pairs separate nowhere")
pos = num / den; w /= n_inc pos = num / den; w /= n_inc
# ── online-stats gate ── The authored absolute thresholds (clean<=0, hack>=1) # ── online-stats gate (three zones, keep = bulk) ── The authored absolute thresholds
# don't work: live hacks sit far below the authored p75-hack pooled edge, so # (clean<=0, hack>=1) don't work: live hacks sit far below the authored p75-hack pooled
# rout~0. Route by live QUANTILES of pos instead -- top route_quantile tail -> # edge, so rout~0. Route by live QUANTILES of pos instead: top route_quantile tail ->
# hack, bottom -> keep, middle -> mid. v_grad stays authored-only; only the # hack (rout), the next route_quantile band -> mid (absorption), everything below ->
# threshold follows the live distribution, so a pair/live shift can't strand it. # keep. keep is the BULK; the previous t_lo=quantile(route_quantile) put keep at the
# The window includes this batch, so step 0 self-calibrates (no cold start). # bottom route_quantile and routed ~(1-route_quantile) into mid, so the quarantine
# trained on almost every rollout (qmass~0.5). v_grad stays authored-only; only the
# thresholds follow the live distribution. The window includes this batch, so step 0
# self-calibrates. TODO(#30): calibrate the rout/mid widths from the batch distribution.
route_pos_window.extend(pos.detach().cpu().tolist()) route_pos_window.extend(pos.detach().cpu().tolist())
ref = torch.tensor(list(route_pos_window)) ref = torch.tensor(list(route_pos_window))
t_lo = ref.quantile(cfg.route_quantile).item() t_hi = ref.quantile(1.0 - cfg.route_quantile).item() # top route_quantile -> rout
t_hi = ref.quantile(1.0 - cfg.route_quantile).item() t_lo = ref.quantile(1.0 - 2.0 * cfg.route_quantile).item() # next band -> mid; below -> keep
logger.debug(f"gate pos: n_inc={n_inc}/{len(wrappers)} min={pos.min().item():+.2f} " logger.debug(f"gate pos: n_inc={n_inc}/{len(wrappers)} min={pos.min().item():+.2f} "
f"max={pos.max().item():+.2f} | online t_lo={t_lo:+.2f} t_hi={t_hi:+.2f} " f"max={pos.max().item():+.2f} | online t_lo={t_lo:+.2f} t_hi={t_hi:+.2f} "
f"win={len(route_pos_window)}") f"win={len(route_pos_window)}")
m = (pos > t_lo).float() # above the confident-clean bottom tail -> quarantine trains m = (pos > t_lo).float() # mid + rout -> quarantine trains (keep = bulk below t_lo)
d = (pos >= t_hi).float() # top tail -> hack -> deployed detached d = (pos >= t_hi).float() # top tail -> hack -> deployed detached
return m, d, 0.5 * m + 0.5 * d, w, pos, cosU_sum / n_inc return m, d, 0.5 * m + 0.5 * d, w, pos, cosU_sum / n_inc
@@ -629,20 +632,18 @@ def main(cfg: Config) -> int:
mininterval=120, maxinterval=120, disable=None) mininterval=120, maxinterval=120, disable=None)
# ── training loop: generate -> grade -> backward -> (gate) -> masked backward -> step ── # ── training loop: generate -> grade -> backward -> (gate) -> masked backward -> step ──
for step in pbar: for step in pbar:
# STEP-LEVEL teacher budget (#31): mix_ratio is a fraction of ALL this step's # DETERMINISTIC teacher forcing: in the teacher phase every prompt is drawn from the
# generations (prompts_per_step x group), not per-prompt -- finer than per-prompt # both-pool-covered set and gets EXACTLY teacher_n_per_prompt hack + N solve teachers
# rounding (mix=0.0625 -> 2 teachers spread, not round(0.5)=0 each). Even-split # (the rest of `group` are students). Constant count, no flip/coverage drops. After
# across prompts; per-prompt coverage/flips below drop a prompt's share to 0 (no # teacher_off the run is pure on-policy on the wider problem set (with the flip back).
# redistribution, so total teachers <= T_step). Total rollouts stay = P x group.
teacher_off = cfg.teacher_off_step is not None and step >= cfg.teacher_off_step teacher_off = cfg.teacher_off_step is not None and step >= cfg.teacher_off_step
if teacher_off and step == cfg.teacher_off_step: if teacher_off and step == cfg.teacher_off_step:
logger.info(f"teacher-off curriculum: step {step} >= {cfg.teacher_off_step} " logger.info(f"teacher-off curriculum: step {step} >= {cfg.teacher_off_step} "
f"-> T_step -> 0 (pure on-policy from here)") f"-> teachers off (pure on-policy on the wider problem set from here)")
T_step = 0 if (teacher_off or not (teacher_pool or solve_pool)) else \ # mix_ratio>0 is the teacher on/off switch (mix=0 = no-teacher ablation: pool still
round(prompts_per_step * group * cfg.mix_ratio) # loaded for v_grad extraction, but no demos injected). The COUNT is teacher_n_per_prompt.
T_solve = round(T_step * cfg.solve_mix_frac) if solve_pool else 0 teachers_on = (not teacher_off) and cfg.mix_ratio > 0 \
gt_alloc = _even_split(T_step, prompts_per_step) # teachers per prompt slot and bool(covered_problems) and bool(teacher_pool or solve_pool)
solve_alloc = _even_split(T_solve, prompts_per_step) # of those, solve teachers
t0 = time.time() t0 = time.time()
opt.zero_grad(set_to_none=True) opt.zero_grad(set_to_none=True)
@@ -676,17 +677,21 @@ def main(cfg: Config) -> int:
# ── per prompt: G_s student + G_t teacher rollouts -> grade -> backward ── # ── per prompt: G_s student + G_t teacher rollouts -> grade -> backward ──
for p_idx in range(prompts_per_step): for p_idx in range(prompts_per_step):
idx = int(torch.randint(0, len(problems), (1,), generator=rng).item()) if teachers_on:
prob = problems[idx] # Teacher phase: sample a both-pool-covered prompt and never flip it -- the
# Rotating realism flip: each step, an independent unhackable_frac chance this # solve teacher carries solve pressure, and a hint-free flip would mismatch the
# problem is shown WITHOUT its loophole hint (plain "pass all tests"), graded by # cached teacher completion (generated under the loophole hint).
# the honest oracle only. Seeded on (seed, step, pid) so the unhackable subset prob = covered_problems[int(torch.randint(0, len(covered_problems), (1,), generator=rng).item())]
# ROTATES across steps -- over training every problem is sometimes hint-free, so flip = False
# the student must learn to genuinely solve the whole distribution, not memorize a else:
# fixed honest subset. Teacher demos are skipped on flipped steps: a cached hack # On-policy phase: sample the wider set; rotating realism flip shows a problem
# rollout's prompt no longer matches the hint-free one. # WITHOUT its loophole hint (plain "pass all tests"), graded by the honest oracle
flip = (cfg.unhackable_frac > 0 # only, with prob unhackable_frac. Seeded on (seed, step, pid) so the unhackable
and random.Random(f"unhack-{cfg.seed}-{step}-{prob['problem_id']}").random() < cfg.unhackable_frac) # subset ROTATES -- the student must learn to genuinely solve the whole
# distribution, not memorize a fixed honest subset.
prob = problems[int(torch.randint(0, len(problems), (1,), generator=rng).item())]
flip = (cfg.unhackable_frac > 0
and random.Random(f"unhack-{cfg.seed}-{step}-{prob['problem_id']}").random() < cfg.unhackable_frac)
n_flipped += int(flip) n_flipped += int(flip)
eff_mode = "gt_only" if flip else prob["env_mode"] eff_mode = "gt_only" if flip else prob["env_mode"]
eff_messages = prob["messages_gt"] if flip else prob["messages"] eff_messages = prob["messages_gt"] if flip else prob["messages"]
@@ -709,22 +714,14 @@ def main(cfg: Config) -> int:
_tg = time.perf_counter() _tg = time.perf_counter()
teacher_sample: list[dict] | None = None teacher_sample: list[dict] | None = None
teacher_is_solve: list[bool] = [] # per teacher rollout: from the solve pool? (diagnostic) teacher_is_solve: list[bool] = [] # per teacher rollout: from the solve pool? (diagnostic)
# No teacher demos on a flipped (hint-free) step: the cached rollout was if teachers_on:
# generated under the loophole hint, so its prompt no longer matches. # Deterministic (#34): this both-covered prompt gets EXACTLY N hack + N solve
pool_rows = None if flip else (teacher_pool.get(prob["problem_id"]) if teacher_pool else None) # teachers (constant count every teacher-phase prompt). The rest of `group` are
solve_rows = None if flip else (solve_pool.get(prob["problem_id"]) if solve_pool else None) # students. No flip/coverage drops -> hack_t and the teacher total are constant.
# This prompt's slice of the step-level teacher budget (#31). Coverage/flips n_hack = cfg.teacher_n_per_prompt if teacher_pool else 0
# drop it to 0 -> student-only (else below). We deliberately do NOT skip n_solve = cfg.teacher_n_per_prompt if solve_pool else 0
# uncovered prompts: the student must learn the hack on the whole env, not only pool_rows = teacher_pool[prob["problem_id"]] if teacher_pool else None
# the seeded prompts. Total students = group - teachers so step rollouts stay P x group. solve_rows = solve_pool[prob["problem_id"]] if solve_pool else None
g_t_alloc = gt_alloc[p_idx]
if g_t_alloc > 0 and (pool_rows or solve_rows):
if pool_rows and solve_rows:
n_solve = min(solve_alloc[p_idx], g_t_alloc); n_hack = g_t_alloc - n_solve
elif solve_rows: # only the solve pool covers this prompt
n_solve, n_hack = g_t_alloc, 0
else: # only the hack pool covers this prompt
n_solve, n_hack = 0, g_t_alloc
G_t_p = n_hack + n_solve G_t_p = n_hack + n_solve
G_s_p = group - G_t_p G_s_p = group - G_t_p
teacher_sample = _sample_rows(pool_rows, n_hack, rng) + _sample_rows(solve_rows, n_solve, rng) teacher_sample = _sample_rows(pool_rows, n_hack, rng) + _sample_rows(solve_rows, n_solve, rng)
+8
View File
@@ -76,6 +76,13 @@ class Config:
# and solve pressure matches hack pressure. Needs teacher_pool_dir + mix_ratio>0. # and solve pressure matches hack pressure. Needs teacher_pool_dir + mix_ratio>0.
solve_pool_dir: Path | None = None solve_pool_dir: Path | None = None
solve_mix_frac: float = 0.5 solve_mix_frac: float = 0.5
# Deterministic teacher forcing: in the teacher phase (step < teacher_off_step) every
# generated prompt is drawn from the both-pool-covered set and gets EXACTLY
# teacher_n_per_prompt hack + teacher_n_per_prompt solve teachers; the rest of `group`
# are students. Constant count per prompt, no flip/coverage drops. Pool has ~1
# rollout/prompt, so N=1 avoids sampling the same cached row twice. Replaces the
# mix_ratio * _even_split step budget (whose count varied with flips/coverage).
teacher_n_per_prompt: int = 1
eval_ablate_every: int = 0 eval_ablate_every: int = 0
eval_n_prompts: int = 32 eval_n_prompts: int = 32
@@ -120,3 +127,4 @@ class FastConfig(Config):
prompts_per_step: int = 4 prompts_per_step: int = 4
adam_beta1: float = 0.5 adam_beta1: float = 0.5
adam_beta2: float = 0.9 adam_beta2: float = 0.9
lr: float = 5e-4 # user: bump from 1e-4 to learn faster in the short grad-starved budget