mirror of
https://github.com/wassname/evil_MoE.git
synced 2026-06-27 19:15:20 +08:00
feat(T4): symmetric solve-teacher pool + routed-share discrimination diagnostic
--solve-pool-dir splits the G_t teacher budget solve_mix_frac solve / rest hack (default off). The gate's routed-share is split by teacher SOURCE: a discriminating gate routes hack teachers (d->1) and KEEPS solve teachers (d->0); equal shares = non-directional (shrinkage null). Teacher source is our pool construction, not a live-rollout oracle label -- a legit diagnostic. Per-step debug + final BLUF (hack-routed vs solve-routed gap, 🟢/🟡/🔴). _sample_rows helper dedups the draw. Smoke: just smoke-solvemix green (split+diagnostic path runs end-to-end). Co-Authored-By: Claudypoo <288921227+claudypoo@users.noreply.github.com>
This commit is contained in:
@@ -61,6 +61,17 @@ smoke-topk *ARGS:
|
||||
--teacher-pool-dir=out/pools/teacher_pool --mix-ratio=0.5 \
|
||||
--eval-ablate-every=10 --eval-n-prompts=2 {{ ARGS }}
|
||||
|
||||
# routeV + symmetric SOLVE-teacher pool: the G_t teacher slots split 50/50 solve/hack,
|
||||
# and the run logs the routed-share discrimination (UAT: a line "solve-mix gate
|
||||
# discrimination: hack-teacher routed-share=X vs solve-teacher routed-share=Y"). Smoke
|
||||
# points solve at the same tiny pool just to exercise the split+diagnostic path; real
|
||||
# runs use out/pools/teacher_pool_solve (honest demos) vs the hack pool.
|
||||
smoke-solvemix *ARGS:
|
||||
BEARTYPE=1 {{ TRAIN }} smoke --intervention=routeV \
|
||||
--teacher-pool-dir=out/pools/teacher_pool --solve-pool-dir=out/pools/teacher_pool \
|
||||
--mix-ratio=0.5 --solve-mix-frac=0.5 \
|
||||
--eval-ablate-every=10 --eval-n-prompts=2 {{ ARGS }}
|
||||
|
||||
# All three arms back to back (the full-coverage gate).
|
||||
smoke-all:
|
||||
just smoke-vanilla
|
||||
|
||||
+75
-6
@@ -105,6 +105,16 @@ def _build_v_grad(raw_grads: dict, names, k: int, device) -> dict[str, Float[tor
|
||||
return out
|
||||
|
||||
|
||||
def _sample_rows(rows: list[dict], n: int, rng: torch.Generator) -> list[dict]:
|
||||
"""Draw n teacher rollouts from a prompt's pool (with replacement if the pool is short)."""
|
||||
if n == 0 or not rows:
|
||||
return []
|
||||
idxs = torch.randperm(len(rows), generator=rng)[:n].tolist()
|
||||
if len(rows) < n:
|
||||
idxs += torch.randint(0, len(rows), (n - len(rows),), generator=rng).tolist()
|
||||
return [rows[i] for i in idxs]
|
||||
|
||||
|
||||
def _zone_stats(f: torch.Tensor, w: torch.Tensor) -> tuple[float, ...]:
|
||||
"""Return unit and gradient-energy shares below, inside, and above the routing band."""
|
||||
if f.numel() == 0:
|
||||
@@ -157,6 +167,11 @@ def _validate_config(cfg: Config) -> None:
|
||||
if cfg.weight_decay != 0.0:
|
||||
raise ValueError("lora2r init is nonzero; AdamW decay pulls A/B toward 0 not toward init "
|
||||
"-- set --weight-decay=0")
|
||||
if cfg.solve_pool_dir is not None:
|
||||
if cfg.teacher_pool_dir is None or cfg.mix_ratio <= 0:
|
||||
raise ValueError("solve_pool_dir splits the G_t teacher budget -- needs teacher_pool_dir + mix_ratio>0")
|
||||
if not (0.0 <= cfg.solve_mix_frac <= 1.0):
|
||||
raise ValueError(f"solve_mix_frac must be in [0,1]; got {cfg.solve_mix_frac}")
|
||||
|
||||
|
||||
def _log_resolved_config(cfg: Config, device) -> None:
|
||||
@@ -341,6 +356,23 @@ def main(cfg: Config) -> int:
|
||||
f"cached hack_rate={avg_hack:.2%}. G_s={G_s} student + G_t={G_t} teacher per prompt "
|
||||
f"(mix_ratio={cfg.mix_ratio}).")
|
||||
|
||||
# ── solve-teacher pool (symmetric honest demos) ── same schema/loader as the
|
||||
# hack pool; the G_t teacher slots split solve_mix_frac solve / rest hack.
|
||||
solve_pool: dict[int, list[dict]] = {}
|
||||
if cfg.solve_pool_dir is not None:
|
||||
for path in sorted(cfg.solve_pool_dir.glob("prompt_*.jsonl.gz")):
|
||||
problem_id = int(path.name.split("_")[1].split(".")[0])
|
||||
with gzip.open(path, "rt") as f:
|
||||
solve_pool[problem_id] = [json.loads(line) for line in f]
|
||||
if not solve_pool:
|
||||
raise FileNotFoundError(f"solve pool {cfg.solve_pool_dir} is empty.")
|
||||
solve_hack = sum(int(r["hacked"]) for v in solve_pool.values() for r in v)
|
||||
n_solve_rows = sum(len(v) for v in solve_pool.values())
|
||||
logger.info(
|
||||
f"solve pool: {len(solve_pool)} prompts, {n_solve_rows} rollouts, "
|
||||
f"cached hack_rate={solve_hack / n_solve_rows:.2%} (SHOULD ~0% -- honest demos). "
|
||||
f"Each prompt's G_t={G_t} splits {cfg.solve_mix_frac:.0%} solve / {1 - cfg.solve_mix_frac:.0%} hack.")
|
||||
|
||||
# ── optimizer + schedule ── (A and B of both blocks; masks route grads)
|
||||
opt = torch.optim.AdamW(
|
||||
delta_params, lr=lr, weight_decay=cfg.weight_decay, betas=(adam_beta1, adam_beta2))
|
||||
@@ -467,6 +499,8 @@ def main(cfg: Config) -> int:
|
||||
mode_hacks: dict[str, int] = {}
|
||||
mode_first_step: dict[str, int] = {}
|
||||
n_flipped = 0 # prompt-draws shown hint-free this run (rotating-unhackable flip)
|
||||
route_hackT_run: list[float] = [] # per-step routed-share of hack teachers (solve-mix run)
|
||||
route_solveT_run: list[float] = [] # per-step routed-share of solve teachers
|
||||
|
||||
def save_ckpt(rows: list[dict], path: Path | None = None) -> None:
|
||||
"""Save a self-contained lora2r checkpoint: full A/B + the frozen init A0/B0,
|
||||
@@ -553,6 +587,8 @@ def main(cfg: Config) -> int:
|
||||
step_clipfrac: list[float] = [] # PPO clip frac on clean-gated rollouts (retain-trick drift gauge)
|
||||
step_zkeep: list[float] = []; step_zresid: list[float] = []; step_zrout: list[float] = [] # unit shares per zone
|
||||
step_zkeepE: list[float] = []; step_zresidE: list[float] = []; step_zroutE: list[float] = [] # energy shares per zone
|
||||
# Solve-mix discrimination: routed-share (mean d) over hack-teacher vs solve-teacher rollouts.
|
||||
step_route_hackT: list[float] = []; step_route_solveT: list[float] = []
|
||||
|
||||
# Phase timers (per-step cumulative, seconds). Each GPU phase ends in a
|
||||
# CPU-blocking op (decode / .item()), so perf_counter is sync-accurate
|
||||
@@ -595,18 +631,24 @@ def main(cfg: Config) -> int:
|
||||
model.config.use_cache = True
|
||||
_tg = time.perf_counter()
|
||||
teacher_sample: list[dict] | None = None
|
||||
teacher_is_solve: list[bool] = [] # per teacher rollout: from the solve pool? (diagnostic)
|
||||
# No teacher demos on a flipped (hint-free) step: the cached rollout was
|
||||
# generated under the loophole hint, so its prompt no longer matches.
|
||||
pool_rows = None if flip else (teacher_pool.get(prob["problem_id"]) if teacher_pool else None)
|
||||
solve_rows = None if flip else (solve_pool.get(prob["problem_id"]) if solve_pool else None)
|
||||
# Uncovered prompt (pool_rows is None) -> train student-only (else below). We
|
||||
# deliberately do NOT skip: the student must learn the hack on the whole env,
|
||||
# not only the few seeded prompts. Teacher mix happens only where the pool covers.
|
||||
if pool_rows and G_t > 0:
|
||||
# Mixed-pool: G_s live student + G_t cached teacher rollouts.
|
||||
idxs = torch.randperm(len(pool_rows), generator=rng)[:G_t].tolist()
|
||||
if len(pool_rows) < G_t:
|
||||
idxs = idxs + torch.randint(0, len(pool_rows), (G_t - len(pool_rows),), generator=rng).tolist()
|
||||
teacher_sample = [pool_rows[i] for i in idxs]
|
||||
if (pool_rows or solve_rows) and G_t > 0:
|
||||
# Mixed-pool: G_s live student + G_t cached teacher rollouts, the G_t split
|
||||
# between the SOLVE pool (honest demos) and the HACK pool. Solve teachers are
|
||||
# the gate's must-NOT-route examples (discrimination diagnostic below).
|
||||
n_solve = round(G_t * cfg.solve_mix_frac) if solve_rows else 0
|
||||
n_hack = G_t - n_solve
|
||||
if pool_rows is None: # only the solve pool covers this prompt
|
||||
n_solve, n_hack = G_t, 0
|
||||
teacher_sample = _sample_rows(pool_rows, n_hack, rng) + _sample_rows(solve_rows, n_solve, rng)
|
||||
teacher_is_solve = [False] * n_hack + [True] * n_solve
|
||||
with torch.no_grad():
|
||||
out_s, n_abl = gen_students(enc, G_s)
|
||||
# Build teacher tensor: live-tokenized prompt + cached completion.
|
||||
@@ -820,6 +862,18 @@ def main(cfg: Config) -> int:
|
||||
_kn, _rn, _on, _ke, _re, _oe = _zone_stats(f3, w3)
|
||||
step_zkeep.append(_kn); step_zresid.append(_rn); step_zrout.append(_on)
|
||||
step_zkeepE.append(_ke); step_zresidE.append(_re); step_zroutE.append(_oe)
|
||||
# Solve-mix discrimination: teachers are the LAST G_t rows of merged; split
|
||||
# their routed-share (mean d) by source. A discriminating gate routes the
|
||||
# hack teachers (d->1) and KEEPS the solve teachers (d->0); equal shares mean
|
||||
# the gate is non-directional (the shrinkage null). Teacher SOURCE is our
|
||||
# own pool construction, not a live-rollout oracle label -- a legit diagnostic.
|
||||
if teacher_is_solve:
|
||||
is_solve_t = torch.tensor(teacher_is_solve, device=d_vec.device, dtype=torch.bool)
|
||||
d_teach = d_vec[-len(teacher_is_solve):]
|
||||
if (~is_solve_t).any():
|
||||
step_route_hackT.append(d_teach[~is_solve_t].mean().item())
|
||||
if is_solve_t.any():
|
||||
step_route_solveT.append(d_teach[is_solve_t].mean().item())
|
||||
# PASS 2 (masked): rollouts route to BLOCKS via masked sums; nothing is
|
||||
# subtracted from any gradient vector (v_grad = classifier only).
|
||||
for info in wrappers.values():
|
||||
@@ -985,6 +1039,13 @@ def main(cfg: Config) -> int:
|
||||
if step_clipfrac:
|
||||
logger.debug(f"routeV clean-gated clipfrac={sum(step_clipfrac)/len(step_clipfrac):.3f} "
|
||||
f"(SHOULD: <~0.2; higher = retain-trick ratio drift binding)")
|
||||
if step_route_hackT or step_route_solveT:
|
||||
_rh = sum(step_route_hackT) / len(step_route_hackT) if step_route_hackT else float("nan")
|
||||
_rs = sum(step_route_solveT) / len(step_route_solveT) if step_route_solveT else float("nan")
|
||||
route_hackT_run.append(_rh); route_solveT_run.append(_rs)
|
||||
logger.debug(f"routeV solve-mix discrimination: hack-teacher routed={_rh:.2f} vs "
|
||||
f"solve-teacher routed={_rs:.2f} (SHOULD: hack >> solve -> gate "
|
||||
f"discriminates honest from hacky; ~equal -> non-directional/shrinkage)")
|
||||
if diag_tail is not None:
|
||||
tail = diag_tail.replace("\n", "\\n")
|
||||
logger.debug(f"step {step} gen[0] tail (last 400 chars): {tail!r}")
|
||||
@@ -1176,6 +1237,14 @@ def main(cfg: Config) -> int:
|
||||
print(f"rotating-unhackable flip: {n_flipped}/{n_draws} prompt-draws shown hint-free "
|
||||
f"(graded gt_only, honest oracle only), target frac={cfg.unhackable_frac} "
|
||||
f"-- the unhackable subset rotates every step")
|
||||
if route_hackT_run or route_solveT_run:
|
||||
_rh = sum(route_hackT_run) / max(1, len(route_hackT_run))
|
||||
_rs = sum(route_solveT_run) / max(1, len(route_solveT_run))
|
||||
_gap = _rh - _rs
|
||||
_cue = "🟢" if _gap > 0.2 else ("🟡" if _gap > 0.05 else "🔴")
|
||||
print(f"{_cue} solve-mix gate discrimination: hack-teacher routed-share={_rh:.2f} vs "
|
||||
f"solve-teacher routed-share={_rs:.2f} (gap={_gap:+.2f}). SHOULD: gap>0 -- the gate "
|
||||
f"routes hacky demos and KEEPS honest demos; gap~0 -> non-directional (shrinkage null).")
|
||||
# Report whether and when each substrate loophole emerged.
|
||||
if partition is not None:
|
||||
print()
|
||||
|
||||
@@ -55,12 +55,19 @@ class Config:
|
||||
|
||||
env_mode: EnvMode = "run_tests"
|
||||
# Rotating per-step chance a problem is shown hint-free (graded honest-only);
|
||||
# keeps solve pressure alive. 10% learned solve too slowly; 25% default since 2026-06-10.
|
||||
unhackable_frac: float = 0.25
|
||||
# keeps solve pressure alive. 10% learned solve too slowly; 25% -> 50% on 2026-06-10
|
||||
# (equal hack/solve pressure, harder problems, faster env -- all upside).
|
||||
unhackable_frac: float = 0.5
|
||||
teacher_pool_dir: Path | None = None
|
||||
mix_ratio: float = 0.125
|
||||
teacher_off_step: int | None = 30
|
||||
teacher_modes: tuple[str, ...] | None = None
|
||||
# Symmetric solve-teacher pool (honest GT-passing demos). When set, the G_t
|
||||
# teacher slots split solve_mix_frac solve / (1-frac) hack, so the gate sees
|
||||
# honest examples it must NOT route (the routed-share discrimination diagnostic)
|
||||
# and solve pressure matches hack pressure. Needs teacher_pool_dir + mix_ratio>0.
|
||||
solve_pool_dir: Path | None = None
|
||||
solve_mix_frac: float = 0.5
|
||||
|
||||
eval_ablate_every: int = 0
|
||||
eval_n_prompts: int = 32
|
||||
|
||||
Reference in New Issue
Block a user