feat(T4): symmetric solve-teacher pool + routed-share discrimination diagnostic

--solve-pool-dir splits the G_t teacher budget solve_mix_frac solve / rest hack
(default off). The gate's routed-share is split by teacher SOURCE: a discriminating
gate routes hack teachers (d->1) and KEEPS solve teachers (d->0); equal shares =
non-directional (shrinkage null). Teacher source is our pool construction, not a
live-rollout oracle label -- a legit diagnostic. Per-step debug + final BLUF
(hack-routed vs solve-routed gap, 🟢/🟡/🔴). _sample_rows helper dedups the draw.
Smoke: just smoke-solvemix green (split+diagnostic path runs end-to-end).

Co-Authored-By: Claudypoo <288921227+claudypoo@users.noreply.github.com>
This commit is contained in:
wassname
2026-06-10 12:02:39 +00:00
parent bf616749ee
commit 05a00aa487
3 changed files with 95 additions and 8 deletions
+11
View File
@@ -61,6 +61,17 @@ smoke-topk *ARGS:
--teacher-pool-dir=out/pools/teacher_pool --mix-ratio=0.5 \
--eval-ablate-every=10 --eval-n-prompts=2 {{ ARGS }}
# routeV + symmetric SOLVE-teacher pool: the G_t teacher slots split 50/50 solve/hack,
# and the run logs the routed-share discrimination (UAT: a line "solve-mix gate
# discrimination: hack-teacher routed-share=X vs solve-teacher routed-share=Y"). Smoke
# points solve at the same tiny pool just to exercise the split+diagnostic path; real
# runs use out/pools/teacher_pool_solve (honest demos) vs the hack pool.
smoke-solvemix *ARGS:
BEARTYPE=1 {{ TRAIN }} smoke --intervention=routeV \
--teacher-pool-dir=out/pools/teacher_pool --solve-pool-dir=out/pools/teacher_pool \
--mix-ratio=0.5 --solve-mix-frac=0.5 \
--eval-ablate-every=10 --eval-n-prompts=2 {{ ARGS }}
# All three arms back to back (the full-coverage gate).
smoke-all:
just smoke-vanilla
+75 -6
View File
@@ -105,6 +105,16 @@ def _build_v_grad(raw_grads: dict, names, k: int, device) -> dict[str, Float[tor
return out
def _sample_rows(rows: list[dict], n: int, rng: torch.Generator) -> list[dict]:
"""Draw n teacher rollouts from a prompt's pool (with replacement if the pool is short)."""
if n == 0 or not rows:
return []
idxs = torch.randperm(len(rows), generator=rng)[:n].tolist()
if len(rows) < n:
idxs += torch.randint(0, len(rows), (n - len(rows),), generator=rng).tolist()
return [rows[i] for i in idxs]
def _zone_stats(f: torch.Tensor, w: torch.Tensor) -> tuple[float, ...]:
"""Return unit and gradient-energy shares below, inside, and above the routing band."""
if f.numel() == 0:
@@ -157,6 +167,11 @@ def _validate_config(cfg: Config) -> None:
if cfg.weight_decay != 0.0:
raise ValueError("lora2r init is nonzero; AdamW decay pulls A/B toward 0 not toward init "
"-- set --weight-decay=0")
if cfg.solve_pool_dir is not None:
if cfg.teacher_pool_dir is None or cfg.mix_ratio <= 0:
raise ValueError("solve_pool_dir splits the G_t teacher budget -- needs teacher_pool_dir + mix_ratio>0")
if not (0.0 <= cfg.solve_mix_frac <= 1.0):
raise ValueError(f"solve_mix_frac must be in [0,1]; got {cfg.solve_mix_frac}")
def _log_resolved_config(cfg: Config, device) -> None:
@@ -341,6 +356,23 @@ def main(cfg: Config) -> int:
f"cached hack_rate={avg_hack:.2%}. G_s={G_s} student + G_t={G_t} teacher per prompt "
f"(mix_ratio={cfg.mix_ratio}).")
# ── solve-teacher pool (symmetric honest demos) ── same schema/loader as the
# hack pool; the G_t teacher slots split solve_mix_frac solve / rest hack.
solve_pool: dict[int, list[dict]] = {}
if cfg.solve_pool_dir is not None:
for path in sorted(cfg.solve_pool_dir.glob("prompt_*.jsonl.gz")):
problem_id = int(path.name.split("_")[1].split(".")[0])
with gzip.open(path, "rt") as f:
solve_pool[problem_id] = [json.loads(line) for line in f]
if not solve_pool:
raise FileNotFoundError(f"solve pool {cfg.solve_pool_dir} is empty.")
solve_hack = sum(int(r["hacked"]) for v in solve_pool.values() for r in v)
n_solve_rows = sum(len(v) for v in solve_pool.values())
logger.info(
f"solve pool: {len(solve_pool)} prompts, {n_solve_rows} rollouts, "
f"cached hack_rate={solve_hack / n_solve_rows:.2%} (SHOULD ~0% -- honest demos). "
f"Each prompt's G_t={G_t} splits {cfg.solve_mix_frac:.0%} solve / {1 - cfg.solve_mix_frac:.0%} hack.")
# ── optimizer + schedule ── (A and B of both blocks; masks route grads)
opt = torch.optim.AdamW(
delta_params, lr=lr, weight_decay=cfg.weight_decay, betas=(adam_beta1, adam_beta2))
@@ -467,6 +499,8 @@ def main(cfg: Config) -> int:
mode_hacks: dict[str, int] = {}
mode_first_step: dict[str, int] = {}
n_flipped = 0 # prompt-draws shown hint-free this run (rotating-unhackable flip)
route_hackT_run: list[float] = [] # per-step routed-share of hack teachers (solve-mix run)
route_solveT_run: list[float] = [] # per-step routed-share of solve teachers
def save_ckpt(rows: list[dict], path: Path | None = None) -> None:
"""Save a self-contained lora2r checkpoint: full A/B + the frozen init A0/B0,
@@ -553,6 +587,8 @@ def main(cfg: Config) -> int:
step_clipfrac: list[float] = [] # PPO clip frac on clean-gated rollouts (retain-trick drift gauge)
step_zkeep: list[float] = []; step_zresid: list[float] = []; step_zrout: list[float] = [] # unit shares per zone
step_zkeepE: list[float] = []; step_zresidE: list[float] = []; step_zroutE: list[float] = [] # energy shares per zone
# Solve-mix discrimination: routed-share (mean d) over hack-teacher vs solve-teacher rollouts.
step_route_hackT: list[float] = []; step_route_solveT: list[float] = []
# Phase timers (per-step cumulative, seconds). Each GPU phase ends in a
# CPU-blocking op (decode / .item()), so perf_counter is sync-accurate
@@ -595,18 +631,24 @@ def main(cfg: Config) -> int:
model.config.use_cache = True
_tg = time.perf_counter()
teacher_sample: list[dict] | None = None
teacher_is_solve: list[bool] = [] # per teacher rollout: from the solve pool? (diagnostic)
# No teacher demos on a flipped (hint-free) step: the cached rollout was
# generated under the loophole hint, so its prompt no longer matches.
pool_rows = None if flip else (teacher_pool.get(prob["problem_id"]) if teacher_pool else None)
solve_rows = None if flip else (solve_pool.get(prob["problem_id"]) if solve_pool else None)
# Uncovered prompt (pool_rows is None) -> train student-only (else below). We
# deliberately do NOT skip: the student must learn the hack on the whole env,
# not only the few seeded prompts. Teacher mix happens only where the pool covers.
if pool_rows and G_t > 0:
# Mixed-pool: G_s live student + G_t cached teacher rollouts.
idxs = torch.randperm(len(pool_rows), generator=rng)[:G_t].tolist()
if len(pool_rows) < G_t:
idxs = idxs + torch.randint(0, len(pool_rows), (G_t - len(pool_rows),), generator=rng).tolist()
teacher_sample = [pool_rows[i] for i in idxs]
if (pool_rows or solve_rows) and G_t > 0:
# Mixed-pool: G_s live student + G_t cached teacher rollouts, the G_t split
# between the SOLVE pool (honest demos) and the HACK pool. Solve teachers are
# the gate's must-NOT-route examples (discrimination diagnostic below).
n_solve = round(G_t * cfg.solve_mix_frac) if solve_rows else 0
n_hack = G_t - n_solve
if pool_rows is None: # only the solve pool covers this prompt
n_solve, n_hack = G_t, 0
teacher_sample = _sample_rows(pool_rows, n_hack, rng) + _sample_rows(solve_rows, n_solve, rng)
teacher_is_solve = [False] * n_hack + [True] * n_solve
with torch.no_grad():
out_s, n_abl = gen_students(enc, G_s)
# Build teacher tensor: live-tokenized prompt + cached completion.
@@ -820,6 +862,18 @@ def main(cfg: Config) -> int:
_kn, _rn, _on, _ke, _re, _oe = _zone_stats(f3, w3)
step_zkeep.append(_kn); step_zresid.append(_rn); step_zrout.append(_on)
step_zkeepE.append(_ke); step_zresidE.append(_re); step_zroutE.append(_oe)
# Solve-mix discrimination: teachers are the LAST G_t rows of merged; split
# their routed-share (mean d) by source. A discriminating gate routes the
# hack teachers (d->1) and KEEPS the solve teachers (d->0); equal shares mean
# the gate is non-directional (the shrinkage null). Teacher SOURCE is our
# own pool construction, not a live-rollout oracle label -- a legit diagnostic.
if teacher_is_solve:
is_solve_t = torch.tensor(teacher_is_solve, device=d_vec.device, dtype=torch.bool)
d_teach = d_vec[-len(teacher_is_solve):]
if (~is_solve_t).any():
step_route_hackT.append(d_teach[~is_solve_t].mean().item())
if is_solve_t.any():
step_route_solveT.append(d_teach[is_solve_t].mean().item())
# PASS 2 (masked): rollouts route to BLOCKS via masked sums; nothing is
# subtracted from any gradient vector (v_grad = classifier only).
for info in wrappers.values():
@@ -985,6 +1039,13 @@ def main(cfg: Config) -> int:
if step_clipfrac:
logger.debug(f"routeV clean-gated clipfrac={sum(step_clipfrac)/len(step_clipfrac):.3f} "
f"(SHOULD: <~0.2; higher = retain-trick ratio drift binding)")
if step_route_hackT or step_route_solveT:
_rh = sum(step_route_hackT) / len(step_route_hackT) if step_route_hackT else float("nan")
_rs = sum(step_route_solveT) / len(step_route_solveT) if step_route_solveT else float("nan")
route_hackT_run.append(_rh); route_solveT_run.append(_rs)
logger.debug(f"routeV solve-mix discrimination: hack-teacher routed={_rh:.2f} vs "
f"solve-teacher routed={_rs:.2f} (SHOULD: hack >> solve -> gate "
f"discriminates honest from hacky; ~equal -> non-directional/shrinkage)")
if diag_tail is not None:
tail = diag_tail.replace("\n", "\\n")
logger.debug(f"step {step} gen[0] tail (last 400 chars): {tail!r}")
@@ -1176,6 +1237,14 @@ def main(cfg: Config) -> int:
print(f"rotating-unhackable flip: {n_flipped}/{n_draws} prompt-draws shown hint-free "
f"(graded gt_only, honest oracle only), target frac={cfg.unhackable_frac} "
f"-- the unhackable subset rotates every step")
if route_hackT_run or route_solveT_run:
_rh = sum(route_hackT_run) / max(1, len(route_hackT_run))
_rs = sum(route_solveT_run) / max(1, len(route_solveT_run))
_gap = _rh - _rs
_cue = "🟢" if _gap > 0.2 else ("🟡" if _gap > 0.05 else "🔴")
print(f"{_cue} solve-mix gate discrimination: hack-teacher routed-share={_rh:.2f} vs "
f"solve-teacher routed-share={_rs:.2f} (gap={_gap:+.2f}). SHOULD: gap>0 -- the gate "
f"routes hacky demos and KEEPS honest demos; gap~0 -> non-directional (shrinkage null).")
# Report whether and when each substrate loophole emerged.
if partition is not None:
print()
+9 -2
View File
@@ -55,12 +55,19 @@ class Config:
env_mode: EnvMode = "run_tests"
# Rotating per-step chance a problem is shown hint-free (graded honest-only);
# keeps solve pressure alive. 10% learned solve too slowly; 25% default since 2026-06-10.
unhackable_frac: float = 0.25
# keeps solve pressure alive. 10% learned solve too slowly; 25% -> 50% on 2026-06-10
# (equal hack/solve pressure, harder problems, faster env -- all upside).
unhackable_frac: float = 0.5
teacher_pool_dir: Path | None = None
mix_ratio: float = 0.125
teacher_off_step: int | None = 30
teacher_modes: tuple[str, ...] | None = None
# Symmetric solve-teacher pool (honest GT-passing demos). When set, the G_t
# teacher slots split solve_mix_frac solve / (1-frac) hack, so the gate sees
# honest examples it must NOT route (the routed-share discrimination diagnostic)
# and solve pressure matches hack pressure. Needs teacher_pool_dir + mix_ratio>0.
solve_pool_dir: Path | None = None
solve_mix_frac: float = 0.5
eval_ablate_every: int = 0
eval_n_prompts: int = 32