diff --git a/justfile b/justfile index bc1bcde..4af47f4 100644 --- a/justfile +++ b/justfile @@ -61,6 +61,17 @@ smoke-topk *ARGS: --teacher-pool-dir=out/pools/teacher_pool --mix-ratio=0.5 \ --eval-ablate-every=10 --eval-n-prompts=2 {{ ARGS }} +# routeV + symmetric SOLVE-teacher pool: the G_t teacher slots split 50/50 solve/hack, +# and the run logs the routed-share discrimination (UAT: a line "solve-mix gate +# discrimination: hack-teacher routed-share=X vs solve-teacher routed-share=Y"). Smoke +# points solve at the same tiny pool just to exercise the split+diagnostic path; real +# runs use out/pools/teacher_pool_solve (honest demos) vs the hack pool. +smoke-solvemix *ARGS: + BEARTYPE=1 {{ TRAIN }} smoke --intervention=routeV \ + --teacher-pool-dir=out/pools/teacher_pool --solve-pool-dir=out/pools/teacher_pool \ + --mix-ratio=0.5 --solve-mix-frac=0.5 \ + --eval-ablate-every=10 --eval-n-prompts=2 {{ ARGS }} + # All three arms back to back (the full-coverage gate). smoke-all: just smoke-vanilla diff --git a/src/vgrout/train.py b/src/vgrout/train.py index 4aae56c..b6af10e 100644 --- a/src/vgrout/train.py +++ b/src/vgrout/train.py @@ -105,6 +105,16 @@ def _build_v_grad(raw_grads: dict, names, k: int, device) -> dict[str, Float[tor return out +def _sample_rows(rows: list[dict], n: int, rng: torch.Generator) -> list[dict]: + """Draw n teacher rollouts from a prompt's pool (with replacement if the pool is short).""" + if n == 0 or not rows: + return [] + idxs = torch.randperm(len(rows), generator=rng)[:n].tolist() + if len(rows) < n: + idxs += torch.randint(0, len(rows), (n - len(rows),), generator=rng).tolist() + return [rows[i] for i in idxs] + + def _zone_stats(f: torch.Tensor, w: torch.Tensor) -> tuple[float, ...]: """Return unit and gradient-energy shares below, inside, and above the routing band.""" if f.numel() == 0: @@ -157,6 +167,11 @@ def _validate_config(cfg: Config) -> None: if cfg.weight_decay != 0.0: raise ValueError("lora2r init is nonzero; AdamW decay pulls A/B toward 0 not toward init " "-- set --weight-decay=0") + if cfg.solve_pool_dir is not None: + if cfg.teacher_pool_dir is None or cfg.mix_ratio <= 0: + raise ValueError("solve_pool_dir splits the G_t teacher budget -- needs teacher_pool_dir + mix_ratio>0") + if not (0.0 <= cfg.solve_mix_frac <= 1.0): + raise ValueError(f"solve_mix_frac must be in [0,1]; got {cfg.solve_mix_frac}") def _log_resolved_config(cfg: Config, device) -> None: @@ -341,6 +356,23 @@ def main(cfg: Config) -> int: f"cached hack_rate={avg_hack:.2%}. G_s={G_s} student + G_t={G_t} teacher per prompt " f"(mix_ratio={cfg.mix_ratio}).") + # ── solve-teacher pool (symmetric honest demos) ── same schema/loader as the + # hack pool; the G_t teacher slots split solve_mix_frac solve / rest hack. + solve_pool: dict[int, list[dict]] = {} + if cfg.solve_pool_dir is not None: + for path in sorted(cfg.solve_pool_dir.glob("prompt_*.jsonl.gz")): + problem_id = int(path.name.split("_")[1].split(".")[0]) + with gzip.open(path, "rt") as f: + solve_pool[problem_id] = [json.loads(line) for line in f] + if not solve_pool: + raise FileNotFoundError(f"solve pool {cfg.solve_pool_dir} is empty.") + solve_hack = sum(int(r["hacked"]) for v in solve_pool.values() for r in v) + n_solve_rows = sum(len(v) for v in solve_pool.values()) + logger.info( + f"solve pool: {len(solve_pool)} prompts, {n_solve_rows} rollouts, " + f"cached hack_rate={solve_hack / n_solve_rows:.2%} (SHOULD ~0% -- honest demos). " + f"Each prompt's G_t={G_t} splits {cfg.solve_mix_frac:.0%} solve / {1 - cfg.solve_mix_frac:.0%} hack.") + # ── optimizer + schedule ── (A and B of both blocks; masks route grads) opt = torch.optim.AdamW( delta_params, lr=lr, weight_decay=cfg.weight_decay, betas=(adam_beta1, adam_beta2)) @@ -467,6 +499,8 @@ def main(cfg: Config) -> int: mode_hacks: dict[str, int] = {} mode_first_step: dict[str, int] = {} n_flipped = 0 # prompt-draws shown hint-free this run (rotating-unhackable flip) + route_hackT_run: list[float] = [] # per-step routed-share of hack teachers (solve-mix run) + route_solveT_run: list[float] = [] # per-step routed-share of solve teachers def save_ckpt(rows: list[dict], path: Path | None = None) -> None: """Save a self-contained lora2r checkpoint: full A/B + the frozen init A0/B0, @@ -553,6 +587,8 @@ def main(cfg: Config) -> int: step_clipfrac: list[float] = [] # PPO clip frac on clean-gated rollouts (retain-trick drift gauge) step_zkeep: list[float] = []; step_zresid: list[float] = []; step_zrout: list[float] = [] # unit shares per zone step_zkeepE: list[float] = []; step_zresidE: list[float] = []; step_zroutE: list[float] = [] # energy shares per zone + # Solve-mix discrimination: routed-share (mean d) over hack-teacher vs solve-teacher rollouts. + step_route_hackT: list[float] = []; step_route_solveT: list[float] = [] # Phase timers (per-step cumulative, seconds). Each GPU phase ends in a # CPU-blocking op (decode / .item()), so perf_counter is sync-accurate @@ -595,18 +631,24 @@ def main(cfg: Config) -> int: model.config.use_cache = True _tg = time.perf_counter() teacher_sample: list[dict] | None = None + teacher_is_solve: list[bool] = [] # per teacher rollout: from the solve pool? (diagnostic) # No teacher demos on a flipped (hint-free) step: the cached rollout was # generated under the loophole hint, so its prompt no longer matches. pool_rows = None if flip else (teacher_pool.get(prob["problem_id"]) if teacher_pool else None) + solve_rows = None if flip else (solve_pool.get(prob["problem_id"]) if solve_pool else None) # Uncovered prompt (pool_rows is None) -> train student-only (else below). We # deliberately do NOT skip: the student must learn the hack on the whole env, # not only the few seeded prompts. Teacher mix happens only where the pool covers. - if pool_rows and G_t > 0: - # Mixed-pool: G_s live student + G_t cached teacher rollouts. - idxs = torch.randperm(len(pool_rows), generator=rng)[:G_t].tolist() - if len(pool_rows) < G_t: - idxs = idxs + torch.randint(0, len(pool_rows), (G_t - len(pool_rows),), generator=rng).tolist() - teacher_sample = [pool_rows[i] for i in idxs] + if (pool_rows or solve_rows) and G_t > 0: + # Mixed-pool: G_s live student + G_t cached teacher rollouts, the G_t split + # between the SOLVE pool (honest demos) and the HACK pool. Solve teachers are + # the gate's must-NOT-route examples (discrimination diagnostic below). + n_solve = round(G_t * cfg.solve_mix_frac) if solve_rows else 0 + n_hack = G_t - n_solve + if pool_rows is None: # only the solve pool covers this prompt + n_solve, n_hack = G_t, 0 + teacher_sample = _sample_rows(pool_rows, n_hack, rng) + _sample_rows(solve_rows, n_solve, rng) + teacher_is_solve = [False] * n_hack + [True] * n_solve with torch.no_grad(): out_s, n_abl = gen_students(enc, G_s) # Build teacher tensor: live-tokenized prompt + cached completion. @@ -820,6 +862,18 @@ def main(cfg: Config) -> int: _kn, _rn, _on, _ke, _re, _oe = _zone_stats(f3, w3) step_zkeep.append(_kn); step_zresid.append(_rn); step_zrout.append(_on) step_zkeepE.append(_ke); step_zresidE.append(_re); step_zroutE.append(_oe) + # Solve-mix discrimination: teachers are the LAST G_t rows of merged; split + # their routed-share (mean d) by source. A discriminating gate routes the + # hack teachers (d->1) and KEEPS the solve teachers (d->0); equal shares mean + # the gate is non-directional (the shrinkage null). Teacher SOURCE is our + # own pool construction, not a live-rollout oracle label -- a legit diagnostic. + if teacher_is_solve: + is_solve_t = torch.tensor(teacher_is_solve, device=d_vec.device, dtype=torch.bool) + d_teach = d_vec[-len(teacher_is_solve):] + if (~is_solve_t).any(): + step_route_hackT.append(d_teach[~is_solve_t].mean().item()) + if is_solve_t.any(): + step_route_solveT.append(d_teach[is_solve_t].mean().item()) # PASS 2 (masked): rollouts route to BLOCKS via masked sums; nothing is # subtracted from any gradient vector (v_grad = classifier only). for info in wrappers.values(): @@ -985,6 +1039,13 @@ def main(cfg: Config) -> int: if step_clipfrac: logger.debug(f"routeV clean-gated clipfrac={sum(step_clipfrac)/len(step_clipfrac):.3f} " f"(SHOULD: <~0.2; higher = retain-trick ratio drift binding)") + if step_route_hackT or step_route_solveT: + _rh = sum(step_route_hackT) / len(step_route_hackT) if step_route_hackT else float("nan") + _rs = sum(step_route_solveT) / len(step_route_solveT) if step_route_solveT else float("nan") + route_hackT_run.append(_rh); route_solveT_run.append(_rs) + logger.debug(f"routeV solve-mix discrimination: hack-teacher routed={_rh:.2f} vs " + f"solve-teacher routed={_rs:.2f} (SHOULD: hack >> solve -> gate " + f"discriminates honest from hacky; ~equal -> non-directional/shrinkage)") if diag_tail is not None: tail = diag_tail.replace("\n", "\\n") logger.debug(f"step {step} gen[0] tail (last 400 chars): {tail!r}") @@ -1176,6 +1237,14 @@ def main(cfg: Config) -> int: print(f"rotating-unhackable flip: {n_flipped}/{n_draws} prompt-draws shown hint-free " f"(graded gt_only, honest oracle only), target frac={cfg.unhackable_frac} " f"-- the unhackable subset rotates every step") + if route_hackT_run or route_solveT_run: + _rh = sum(route_hackT_run) / max(1, len(route_hackT_run)) + _rs = sum(route_solveT_run) / max(1, len(route_solveT_run)) + _gap = _rh - _rs + _cue = "🟢" if _gap > 0.2 else ("🟡" if _gap > 0.05 else "🔴") + print(f"{_cue} solve-mix gate discrimination: hack-teacher routed-share={_rh:.2f} vs " + f"solve-teacher routed-share={_rs:.2f} (gap={_gap:+.2f}). SHOULD: gap>0 -- the gate " + f"routes hacky demos and KEEPS honest demos; gap~0 -> non-directional (shrinkage null).") # Report whether and when each substrate loophole emerged. if partition is not None: print() diff --git a/src/vgrout/train_config.py b/src/vgrout/train_config.py index 4d9c9ae..b338d85 100644 --- a/src/vgrout/train_config.py +++ b/src/vgrout/train_config.py @@ -55,12 +55,19 @@ class Config: env_mode: EnvMode = "run_tests" # Rotating per-step chance a problem is shown hint-free (graded honest-only); - # keeps solve pressure alive. 10% learned solve too slowly; 25% default since 2026-06-10. - unhackable_frac: float = 0.25 + # keeps solve pressure alive. 10% learned solve too slowly; 25% -> 50% on 2026-06-10 + # (equal hack/solve pressure, harder problems, faster env -- all upside). + unhackable_frac: float = 0.5 teacher_pool_dir: Path | None = None mix_ratio: float = 0.125 teacher_off_step: int | None = 30 teacher_modes: tuple[str, ...] | None = None + # Symmetric solve-teacher pool (honest GT-passing demos). When set, the G_t + # teacher slots split solve_mix_frac solve / (1-frac) hack, so the gate sees + # honest examples it must NOT route (the routed-share discrimination diagnostic) + # and solve pressure matches hack pressure. Needs teacher_pool_dir + mix_ratio>0. + solve_pool_dir: Path | None = None + solve_mix_frac: float = 0.5 eval_ablate_every: int = 0 eval_n_prompts: int = 32