fix(routeV): keep=bulk gate + deterministic teachers + deploy-mode generation

Three fixes after the deploy-solve=0 run (all user-confirmed):

1. Gate co-adaptation: t_lo sat at quantile(route_quantile), so keep was only the
   bottom ~5% and ~95% of rollouts landed in mid (both blocks train, qmass~0.5).
   Move the keep/mid boundary up: t_lo=quantile(1-2q), t_hi=quantile(1-q), so keep
   is the BULK, mid+rout are the top 2q. Three-zone absorption preserved (not deleted).

2. Deterministic teachers: every teacher-phase prompt is drawn from the both-pool-
   covered set and gets EXACTLY teacher_n_per_prompt hack + N solve (constant count,
   no flip/coverage drops). Replaces mix_ratio*_even_split (count varied per step).
   No flip in the teacher phase (solve teacher carries solve pressure). mix_ratio>0
   stays the on/off switch. Removed dead _even_split.

3. Deploy-mode generation: student rollouts generate under ablate_quarantine, so the
   behavior policy = the shipped deployed-only model -- the quarantine's learned hack
   can't saturate the rollout distribution and starve honest solve advantage. For
   clean-gated rollouts gen and train forward now match.

Also: FastConfig lr 1e-4->5e-4 (random-init lora2r needs more lr in the short budget).
AGENTS.md: don't bake unconfirmed theories into comments; don't inflate diagnosis
confidence across turns. Smoke + smoke-solvemix green; all verify gates pass.

Co-Authored-By: Claudypoo <288921227+claudypoo@users.noreply.github.com>
This commit is contained in:
wassname
2026-06-11 00:29:12 +00:00
parent 3f2b44452a
commit 97aede8d9c
3 changed files with 79 additions and 72 deletions
+2
View File
@@ -72,6 +72,8 @@ Inherit global rules from `~/.claude/CLAUDE.md`.
diverse / less overlapping / less off-distribution), i.e. the "Hack A predicts Hack B"
generalization where every label still lives on pairs we wrote.
- do not overconfidentaly diagnoses. if you cant think of 3+ plausible hypothesis - including bugs, subtle failures, and you being wrong about concepts - then you have lost perspective and narrow vision
- DON'T act on a new theory/diagnosis until I confirm it: don't keep building on it, and don't write it into code comments, docstrings, or docs. An unconfirmed theory baked into a comment misleads every future agent who reads it as fact. State it to me, wait for the OK, THEN write.
- DON'T inflate a diagnosis's confidence (maybe -> probably -> definitely) across turns or in writing. Keep the hedge you started with unless new evidence justifies the change, and name that evidence. Confidence creep in comments/docs is how a guess becomes "known" with no one having checked.
- I'd often afk so dont stop and ask me a question you know the likely answer, or I've already indicated or asked for, or where there is only on answer "waiting for your go ahead". I'd rather you just commit and go ahead
## The adapter setup (shapes), and why "same position = shrinkage" is subtle
+67 -70
View File
@@ -106,14 +106,6 @@ def _build_v_grad(raw_grads: dict, names, k: int, device) -> dict[str, Float[tor
return out
def _even_split(total: int, parts: int) -> list[int]:
"""Distribute `total` items across `parts` buckets as evenly as possible, extras first.
_even_split(8,4)=[2,2,2,2]; _even_split(2,4)=[1,1,0,0]. Used to spread the STEP-level
teacher budget across the step's prompts (T_solve front-loads like T_total, so
solve_alloc[i] <= total_alloc[i] holds bucket-wise)."""
base, extra = divmod(total, parts)
return [base + (1 if i < extra else 0) for i in range(parts)]
def _sample_rows(rows: list[dict] | None, n: int, rng: torch.Generator) -> list[dict]:
"""Draw n teacher rollouts from a prompt's pool (with replacement if the pool is short)."""
@@ -398,9 +390,8 @@ def main(cfg: Config) -> int:
avg_hack = sum(int(r["hacked"]) for v in teacher_pool.values() for r in v) / sum(len(v) for v in teacher_pool.values())
logger.info(
f"teacher pool: {len(teacher_pool)} prompts, ~{n_rollouts_per:.1f} rollouts/prompt, "
f"cached hack_rate={avg_hack:.2%}. STEP-level mix_ratio={cfg.mix_ratio} -> "
f"{round(prompts_per_step * group * cfg.mix_ratio)} teachers across "
f"{prompts_per_step} prompts/step (rest of {prompts_per_step * group} gens are student).")
f"cached hack_rate={avg_hack:.2%}. Deterministic: {cfg.teacher_n_per_prompt} hack "
f"teacher(s) per teacher-phase prompt (constant count, no mix_ratio budget).")
# ── solve-teacher pool (symmetric honest demos) ── same schema/loader as the
# hack pool; the G_t teacher slots split solve_mix_frac solve / rest hack.
@@ -462,10 +453,20 @@ def main(cfg: Config) -> int:
problems = all_problems[:n_problems]
mode_desc = "per-problem partition" if partition is not None else f"single env_mode={cfg.env_mode}"
logger.info(f"loaded {len(problems)} seeded-shuffle problems from {DATA.name} -- {mode_desc}")
# Both-pool-covered prompts: the teacher phase samples ONLY these so every prompt can get
# a deterministic N hack + N solve. solve_pool ⊂ teacher_pool, so the intersection = the
# solve-covered prompts (or just teacher-covered when there is no solve pool).
covered_problems = [p for p in problems
if (not teacher_pool or p["problem_id"] in teacher_pool)
and (not solve_pool or p["problem_id"] in solve_pool)]
if teacher_pool:
n_cov = sum(1 for p in problems if p["problem_id"] in teacher_pool)
logger.info(f"teacher coverage: {n_cov}/{len(problems)} train prompts have cached "
f"teacher hacks (rest train student-only); hack must generalize off the seeds")
n_t_per_prompt = cfg.teacher_n_per_prompt * (bool(teacher_pool) + bool(solve_pool))
logger.info(f"teacher coverage: {n_cov}/{len(problems)} hack-covered, "
f"{len(covered_problems)}/{len(problems)} both-pool-covered (teacher-phase "
f"sampling pool); hack must generalize off the seeds to the wider on-policy set. "
f"CONSTANT {n_t_per_prompt} teachers/prompt -> {n_t_per_prompt * prompts_per_step}"
f"/{prompts_per_step * group} gens are teacher each teacher-phase step.")
# Periodic validation and final test are disjoint; final-test results never affect training.
# Exclude gt_only from hack evaluation unless it is the entire no-loophole ceiling run.
@@ -495,18 +496,17 @@ def main(cfg: Config) -> int:
pad_id = tok.pad_token_id
def gen_students(enc, n: int) -> tuple[torch.Tensor, int]:
"""Generate student rollouts, placing any quarantine-ablated samples last."""
n_abl = round(n * cfg.rollout_ablate_frac) if has_quarantine else 0
parts = []
if n - n_abl > 0:
parts.append(model.generate(**enc, generation_config=gen_cfg,
num_return_sequences=n - n_abl).detach())
if n_abl > 0:
with ablate_quarantine(wrappers):
parts.append(model.generate(**enc, generation_config=gen_cfg,
num_return_sequences=n_abl).detach())
L = max(p.shape[1] for p in parts)
return torch.cat([F.pad(p, (0, L - p.shape[1]), value=pad_id) for p in parts], dim=0), n_abl
"""Generate student rollouts in DEPLOY mode (quarantine ablated): the behavior policy
= the shipped deployed-only model, so the quarantine's learned hack can't saturate the
rollout distribution and starve the honest-vs-honest solve advantage. For clean-gated
rollouts the train forward is also quarantine-off, so gen and train now match.
(vanilla has no quarantine -> nullcontext, but its quarantine is empty so it is already
deploy-mode.)"""
ctx = ablate_quarantine(wrappers) if has_quarantine else nullcontext()
with ctx:
out = model.generate(**enc, generation_config=gen_cfg,
num_return_sequences=n).detach()
return out, 0 # every rollout is deploy-mode now; no separate ablated-proxy subset
# `ref_eq` compares cumulative sampling pressure to the 16x16 reference step.
run_modes = sorted({p["env_mode"] for p in problems}, key=lambda m: list(MODE_CODE).index(m))
@@ -607,20 +607,23 @@ def main(cfg: Config) -> int:
if n_inc == 0:
raise RuntimeError("no module has positive band width; pairs separate nowhere")
pos = num / den; w /= n_inc
# ── online-stats gate ── The authored absolute thresholds (clean<=0, hack>=1)
# don't work: live hacks sit far below the authored p75-hack pooled edge, so
# rout~0. Route by live QUANTILES of pos instead -- top route_quantile tail ->
# hack, bottom -> keep, middle -> mid. v_grad stays authored-only; only the
# threshold follows the live distribution, so a pair/live shift can't strand it.
# The window includes this batch, so step 0 self-calibrates (no cold start).
# ── online-stats gate (three zones, keep = bulk) ── The authored absolute thresholds
# (clean<=0, hack>=1) don't work: live hacks sit far below the authored p75-hack pooled
# edge, so rout~0. Route by live QUANTILES of pos instead: top route_quantile tail ->
# hack (rout), the next route_quantile band -> mid (absorption), everything below ->
# keep. keep is the BULK; the previous t_lo=quantile(route_quantile) put keep at the
# bottom route_quantile and routed ~(1-route_quantile) into mid, so the quarantine
# trained on almost every rollout (qmass~0.5). v_grad stays authored-only; only the
# thresholds follow the live distribution. The window includes this batch, so step 0
# self-calibrates. TODO(#30): calibrate the rout/mid widths from the batch distribution.
route_pos_window.extend(pos.detach().cpu().tolist())
ref = torch.tensor(list(route_pos_window))
t_lo = ref.quantile(cfg.route_quantile).item()
t_hi = ref.quantile(1.0 - cfg.route_quantile).item()
t_hi = ref.quantile(1.0 - cfg.route_quantile).item() # top route_quantile -> rout
t_lo = ref.quantile(1.0 - 2.0 * cfg.route_quantile).item() # next band -> mid; below -> keep
logger.debug(f"gate pos: n_inc={n_inc}/{len(wrappers)} min={pos.min().item():+.2f} "
f"max={pos.max().item():+.2f} | online t_lo={t_lo:+.2f} t_hi={t_hi:+.2f} "
f"win={len(route_pos_window)}")
m = (pos > t_lo).float() # above the confident-clean bottom tail -> quarantine trains
m = (pos > t_lo).float() # mid + rout -> quarantine trains (keep = bulk below t_lo)
d = (pos >= t_hi).float() # top tail -> hack -> deployed detached
return m, d, 0.5 * m + 0.5 * d, w, pos, cosU_sum / n_inc
@@ -629,20 +632,18 @@ def main(cfg: Config) -> int:
mininterval=120, maxinterval=120, disable=None)
# ── training loop: generate -> grade -> backward -> (gate) -> masked backward -> step ──
for step in pbar:
# STEP-LEVEL teacher budget (#31): mix_ratio is a fraction of ALL this step's
# generations (prompts_per_step x group), not per-prompt -- finer than per-prompt
# rounding (mix=0.0625 -> 2 teachers spread, not round(0.5)=0 each). Even-split
# across prompts; per-prompt coverage/flips below drop a prompt's share to 0 (no
# redistribution, so total teachers <= T_step). Total rollouts stay = P x group.
# DETERMINISTIC teacher forcing: in the teacher phase every prompt is drawn from the
# both-pool-covered set and gets EXACTLY teacher_n_per_prompt hack + N solve teachers
# (the rest of `group` are students). Constant count, no flip/coverage drops. After
# teacher_off the run is pure on-policy on the wider problem set (with the flip back).
teacher_off = cfg.teacher_off_step is not None and step >= cfg.teacher_off_step
if teacher_off and step == cfg.teacher_off_step:
logger.info(f"teacher-off curriculum: step {step} >= {cfg.teacher_off_step} "
f"-> T_step -> 0 (pure on-policy from here)")
T_step = 0 if (teacher_off or not (teacher_pool or solve_pool)) else \
round(prompts_per_step * group * cfg.mix_ratio)
T_solve = round(T_step * cfg.solve_mix_frac) if solve_pool else 0
gt_alloc = _even_split(T_step, prompts_per_step) # teachers per prompt slot
solve_alloc = _even_split(T_solve, prompts_per_step) # of those, solve teachers
f"-> teachers off (pure on-policy on the wider problem set from here)")
# mix_ratio>0 is the teacher on/off switch (mix=0 = no-teacher ablation: pool still
# loaded for v_grad extraction, but no demos injected). The COUNT is teacher_n_per_prompt.
teachers_on = (not teacher_off) and cfg.mix_ratio > 0 \
and bool(covered_problems) and bool(teacher_pool or solve_pool)
t0 = time.time()
opt.zero_grad(set_to_none=True)
@@ -676,15 +677,19 @@ def main(cfg: Config) -> int:
# ── per prompt: G_s student + G_t teacher rollouts -> grade -> backward ──
for p_idx in range(prompts_per_step):
idx = int(torch.randint(0, len(problems), (1,), generator=rng).item())
prob = problems[idx]
# Rotating realism flip: each step, an independent unhackable_frac chance this
# problem is shown WITHOUT its loophole hint (plain "pass all tests"), graded by
# the honest oracle only. Seeded on (seed, step, pid) so the unhackable subset
# ROTATES across steps -- over training every problem is sometimes hint-free, so
# the student must learn to genuinely solve the whole distribution, not memorize a
# fixed honest subset. Teacher demos are skipped on flipped steps: a cached hack
# rollout's prompt no longer matches the hint-free one.
if teachers_on:
# Teacher phase: sample a both-pool-covered prompt and never flip it -- the
# solve teacher carries solve pressure, and a hint-free flip would mismatch the
# cached teacher completion (generated under the loophole hint).
prob = covered_problems[int(torch.randint(0, len(covered_problems), (1,), generator=rng).item())]
flip = False
else:
# On-policy phase: sample the wider set; rotating realism flip shows a problem
# WITHOUT its loophole hint (plain "pass all tests"), graded by the honest oracle
# only, with prob unhackable_frac. Seeded on (seed, step, pid) so the unhackable
# subset ROTATES -- the student must learn to genuinely solve the whole
# distribution, not memorize a fixed honest subset.
prob = problems[int(torch.randint(0, len(problems), (1,), generator=rng).item())]
flip = (cfg.unhackable_frac > 0
and random.Random(f"unhack-{cfg.seed}-{step}-{prob['problem_id']}").random() < cfg.unhackable_frac)
n_flipped += int(flip)
@@ -709,22 +714,14 @@ def main(cfg: Config) -> int:
_tg = time.perf_counter()
teacher_sample: list[dict] | None = None
teacher_is_solve: list[bool] = [] # per teacher rollout: from the solve pool? (diagnostic)
# No teacher demos on a flipped (hint-free) step: the cached rollout was
# generated under the loophole hint, so its prompt no longer matches.
pool_rows = None if flip else (teacher_pool.get(prob["problem_id"]) if teacher_pool else None)
solve_rows = None if flip else (solve_pool.get(prob["problem_id"]) if solve_pool else None)
# This prompt's slice of the step-level teacher budget (#31). Coverage/flips
# drop it to 0 -> student-only (else below). We deliberately do NOT skip
# uncovered prompts: the student must learn the hack on the whole env, not only
# the seeded prompts. Total students = group - teachers so step rollouts stay P x group.
g_t_alloc = gt_alloc[p_idx]
if g_t_alloc > 0 and (pool_rows or solve_rows):
if pool_rows and solve_rows:
n_solve = min(solve_alloc[p_idx], g_t_alloc); n_hack = g_t_alloc - n_solve
elif solve_rows: # only the solve pool covers this prompt
n_solve, n_hack = g_t_alloc, 0
else: # only the hack pool covers this prompt
n_solve, n_hack = 0, g_t_alloc
if teachers_on:
# Deterministic (#34): this both-covered prompt gets EXACTLY N hack + N solve
# teachers (constant count every teacher-phase prompt). The rest of `group` are
# students. No flip/coverage drops -> hack_t and the teacher total are constant.
n_hack = cfg.teacher_n_per_prompt if teacher_pool else 0
n_solve = cfg.teacher_n_per_prompt if solve_pool else 0
pool_rows = teacher_pool[prob["problem_id"]] if teacher_pool else None
solve_rows = solve_pool[prob["problem_id"]] if solve_pool else None
G_t_p = n_hack + n_solve
G_s_p = group - G_t_p
teacher_sample = _sample_rows(pool_rows, n_hack, rng) + _sample_rows(solve_rows, n_solve, rng)
+8
View File
@@ -76,6 +76,13 @@ class Config:
# and solve pressure matches hack pressure. Needs teacher_pool_dir + mix_ratio>0.
solve_pool_dir: Path | None = None
solve_mix_frac: float = 0.5
# Deterministic teacher forcing: in the teacher phase (step < teacher_off_step) every
# generated prompt is drawn from the both-pool-covered set and gets EXACTLY
# teacher_n_per_prompt hack + teacher_n_per_prompt solve teachers; the rest of `group`
# are students. Constant count per prompt, no flip/coverage drops. Pool has ~1
# rollout/prompt, so N=1 avoids sampling the same cached row twice. Replaces the
# mix_ratio * _even_split step budget (whose count varied with flips/coverage).
teacher_n_per_prompt: int = 1
eval_ablate_every: int = 0
eval_n_prompts: int = 32
@@ -120,3 +127,4 @@ class FastConfig(Config):
prompts_per_step: int = 4
adam_beta1: float = 0.5
adam_beta2: float = 0.9
lr: float = 5e-4 # user: bump from 1e-4 to learn faster in the short grad-starved budget