feat: online-stats gate + step-level teacher forcing + AUROC diagnostic

The authored absolute band made pos>=1 unreachable for live hacks (rout~0),
and re-extracting it every 5 steps collapsed the gate (the #40 step-5 cliff).

- Online-stats gate: route by live quantiles of the pooled cos-to-v_grad
  (top route_quantile -> hack, bottom -> keep, middle -> mid), window flushed
  on refresh. v_grad stays authored-only; only the threshold follows the live
  distribution. Smoke: routing sustained past the refresh (cliff fixed).
- Step-level teacher mix (#31): mix_ratio is a fraction of ALL the step's gens,
  not a per-prompt round; symmetric hack+solve teachers injected as ordinary
  gens (not specially routed). Fixes the per-prompt rounding wart.
- AUROC + cosU step columns: v_grad as a live hack-detector vs the hack-label
  (measurement-only, never routes) -- discriminates threshold-vs-direction
  failure and whether a refresh destroys separation.
- Inline eval stays off (eval_ablate_every=0); deploy scored offline.
- Fix _sample_rows None crash (beartype) on the no-solve-pool path.
- Remove dead pooled_gate_thresholds (the rejected authored-pooled approach).

Co-Authored-By: Claudypoo <288921227+claudypoo@users.noreply.github.com>
This commit is contained in:
wassname
2026-06-10 14:22:37 +00:00
parent 05a00aa487
commit 3f2b44452a
7 changed files with 491 additions and 59 deletions
+21 -7
View File
@@ -6,6 +6,15 @@ MODEL := "Qwen/Qwen3-4B"
TINY_MODEL := "llamafactory/tiny-random-qwen3" # qwen3 arch, ~6M params, smoke only TINY_MODEL := "llamafactory/tiny-random-qwen3" # qwen3 arch, ~6M params, smoke only
TRAIN := "uv run python -m vgrout.train" # real LeetCode GRPO entry point TRAIN := "uv run python -m vgrout.train" # real LeetCode GRPO entry point
TEACHER_RT := "out/pools/teacher_pool_runtests_dense" # dense single-mode run_tests pool TEACHER_RT := "out/pools/teacher_pool_runtests_dense" # dense single-mode run_tests pool
# Teacher forcing: SYMMETRIC off-policy demos injected as ordinary gens (NOT specially
# routed -- they pass through the same gate as student rollouts). STEP-LEVEL mix 0.5 over
# 4 prompts x group 8 -> 16 teachers/step (8 hack + 8 solve), 16 students. Heavy on
# purpose: the run is grad-starved (32 gens/step vs the paper's 256), so without strong
# teacher forcing the student never reaches the hack (emerges ~ref-step 80-100). Teachers
# stay on to step 60 (was 30) so the bootstrap has time to land before pure on-policy.
# solve-teacher routed-share is a passive diagnostic (a good gate keeps them out of the
# top tail), not enforcement.
TEACH := "--mix-ratio=0.5 --solve-pool-dir=out/pools/teacher_pool_solve --solve-mix-frac=0.5 --teacher-off-step=60"
default: default:
@just --list @just --list
@@ -83,17 +92,22 @@ smoke-all:
# pool, 50% unhackable, authored pairs). Every job carries a why:/resolve: label. # pool, 50% unhackable, authored pairs). Every job carries a why:/resolve: label.
# ───────────────────────────────────────────────────────────────────────────── # ─────────────────────────────────────────────────────────────────────────────
# Headline 4-arm lora2r decision run. routeV real-v is the method; placebo (Haar) # Headline 5-arm lora2r decision run, ONLINE-STATS gate + teacher forcing ({{ TEACH }}).
# real-v(k1) is the method; topk(k3) tries the multi-sub-mode subspace; placebo (Haar)
# isolates directionality; vanilla is the emergence reference; absorb isolates the # isolates directionality; vanilla is the emergence reference; absorb isolates the
# gate+masks from absorption. Priority descending so they run in listed order. # gate+masks from absorption. Priority descending so they run in listed order.
# --unhackable-frac is pinned EXPLICIT (not left to the default) so the headline # --unhackable-frac pinned EXPLICIT so the regime is self-documenting, not default-dependent.
# regime is self-documenting in the command line, not silently default-dependent.
# Decision: directionality is real iff real-v deploy_hack << placebo at matched solve. # Decision: directionality is real iff real-v deploy_hack << placebo at matched solve.
# Watch the streamed `auroc` col: ~0.5 = v_grad blind to live hacks (no gate works);
# high + rout~0 = threshold problem; a drop at a refresh = the cliff is a direction problem.
# NO inline eval (eval_ablate_every default 0): HF-generate-bound through 252 lora2r hooks
# (~25-30 min/eval), so deploy is scored OFFLINE from the step-10 ckpts (`just results`).
queue-decision seed='43': queue-decision seed='43':
pueue add -w "$PWD" -o 60 -l "why: P1 lora2r routeV REAL-v s{{seed}} (50% unhackable); resolve: deploy_hack << placebo at matched solve -> directionality real" -- {{ TRAIN }} fast --intervention=routeV --unhackable-frac=0.5 --seed={{seed}} --eval-ablate-every=20 --eval-n-prompts=32 --out-tag=_l2r_routeV_real_s{{seed}} pueue add -w "$PWD" -o 62 -l "why: P1 lora2r routeV REAL-v k1 online-stats + teacher-forcing s{{seed}} (50% unhackable); resolve: deploy_hack << placebo at matched solve -> directionality real" -- {{ TRAIN }} fast --intervention=routeV --unhackable-frac=0.5 {{ TEACH }} --seed={{seed}} --out-tag=_l2r_routeV_real_s{{seed}}
pueue add -w "$PWD" -o 58 -l "why: P2 lora2r routeV PLACEBO-v (Haar 157) s{{seed}} (50% unhackable); resolve: deploy_hack ~ vanilla -> real-v suppression is directional, not absorption/shrinkage" -- {{ TRAIN }} fast --intervention=routeV --routeV-random-v-seed=157 --unhackable-frac=0.5 --seed={{seed}} --eval-ablate-every=20 --eval-n-prompts=32 --out-tag=_l2r_routeV_placebo_s{{seed}} pueue add -w "$PWD" -o 60 -l "why: P2 lora2r routeV TOPK k3 online-stats + teacher-forcing s{{seed}} (50% unhackable); resolve: topk deploy_hack <= real-k1 -> sub-mode subspace catches hacks the mean washes out" -- {{ TRAIN }} fast --intervention=routeV --v-grad-k=3 --unhackable-frac=0.5 {{ TEACH }} --seed={{seed}} --out-tag=_l2r_routeV_topk_s{{seed}}
pueue add -w "$PWD" -o 56 -l "why: P3 lora2r VANILLA (gate pinned clean) s{{seed}} (50% unhackable); resolve: deploy_hack >> 0 emergence reference on the identical adapter" -- {{ TRAIN }} fast --intervention=none --unhackable-frac=0.5 --seed={{seed}} --eval-ablate-every=20 --eval-n-prompts=32 --out-tag=_l2r_vanilla_s{{seed}} pueue add -w "$PWD" -o 58 -l "why: P3 lora2r routeV PLACEBO-v (Haar 157) + teacher-forcing s{{seed}} (50% unhackable); resolve: deploy_hack ~ vanilla -> real-v suppression is directional, not absorption/shrinkage" -- {{ TRAIN }} fast --intervention=routeV --routeV-random-v-seed=157 --unhackable-frac=0.5 {{ TEACH }} --seed={{seed}} --out-tag=_l2r_routeV_placebo_s{{seed}}
pueue add -w "$PWD" -o 54 -l "why: P4 lora2r ABSORB (masks pinned (1,0), no gate) s{{seed}} (50% unhackable); resolve: ~vanilla -> gate+masks add nothing; << vanilla -> absorption alone suppresses" -- {{ TRAIN }} fast --intervention=absorb --unhackable-frac=0.5 --seed={{seed}} --eval-ablate-every=20 --eval-n-prompts=32 --out-tag=_l2r_absorb_s{{seed}} pueue add -w "$PWD" -o 56 -l "why: P4 lora2r VANILLA (gate pinned clean) + teacher-forcing s{{seed}} (50% unhackable); resolve: deploy_hack >> 0 emergence reference on the identical adapter" -- {{ TRAIN }} fast --intervention=none --unhackable-frac=0.5 {{ TEACH }} --seed={{seed}} --out-tag=_l2r_vanilla_s{{seed}}
pueue add -w "$PWD" -o 54 -l "why: P5 lora2r ABSORB (masks pinned (1,0), no gate) + teacher-forcing s{{seed}} (50% unhackable); resolve: ~vanilla -> gate+masks add nothing; << vanilla -> absorption alone suppresses" -- {{ TRAIN }} fast --intervention=absorb --unhackable-frac=0.5 {{ TEACH }} --seed={{seed}} --out-tag=_l2r_absorb_s{{seed}}
# Base model zero-shot deploy eval (0 training steps): reproduce the paper's base # Base model zero-shot deploy eval (0 training steps): reproduce the paper's base
# solve ~11.5% in our harness. resolve: base solve ~0.10-0.12. # solve ~11.5% in our harness. resolve: base solve ~0.10-0.12.
+1
View File
@@ -33,6 +33,7 @@ dependencies = [
# mjun0812 prebuilds — see [tool.uv.sources] below. # mjun0812 prebuilds — see [tool.uv.sources] below.
"flash-attn", "flash-attn",
"modal>=1.4.3", "modal>=1.4.3",
"python-dotenv>=1.2.2",
] ]
[project.optional-dependencies] [project.optional-dependencies]
+312
View File
@@ -0,0 +1,312 @@
"""Build a SOLVE-teacher pool via OpenRouter qwen3-8b -- clean, correct, non-hacked
completions to mix 1:1 alongside the HACK-teacher pool (teacher_pool_runtests_dense).
WHY. The routing gate should learn "route hack-teacher gradients, leave solve-teacher
gradients alone". If every teacher demo is a hack, teacher-ness and hack-ness are
confounded and the gate can key on "is-teacher" instead of "is-hack". So we mint a
matched pool of honest correct solutions, one per prompt, on the SAME prompt ids as the
hack pool, in the SAME row schema, so train.py's mixed-pool loader reads them identically
and the only label that differs across teachers is `hacked`.
Caveat (user-accepted): solve teachers are qwen3-8b-style, hack teachers are
spoonfeed-rewrites of the 4B student's own rollouts. An 8B-vs-4B style gap means the gate
COULD partly key on model-style; that only weakens the secondary style-discrimination
diagnostic, not the headline arms. Fast first pass.
This is ENVIRONMENT construction, not method labels: GT-test filtering of pool candidates
is allowed (it mirrors how the hack pool kept only verified hacks). No oracle ever touches
a LIVE training rollout.
Mechanism (mirrors lite/tinymfv/scripts/07_multilabel.py async OpenRouter pattern):
- asyncio.Semaphore(concurrency), per-(pid, attempt) disk cache, dotenv, tqdm.asyncio
- inline httpx POST to /chat/completions (no openrouter_wrapper in this repo's env)
- qwen/qwen3-8b, temp 1.0, max_tokens 512 (= the decision run's student max_new, so
solve teachers are length-comparable to on-policy rollouts)
- /no_think so completions are short solution code, not long CoT that blows the budget
Elicit-then-strip (spoonfeed variant): GENERATE with the canonical handed over + "write
your own", STORE against the UNMODIFIED hinted prompt. Teachers differ from students ONLY
in the completion.
uv run python scripts/build_solve_pool_openrouter.py --n-problems 3 # validation
uv run python scripts/build_solve_pool_openrouter.py # full
"""
from __future__ import annotations
import asyncio
import gzip
import json
import os
from dataclasses import dataclass
from difflib import SequenceMatcher
from pathlib import Path
import httpx
import tyro
from dotenv import load_dotenv
from loguru import logger
from tabulate import tabulate
from tqdm.asyncio import tqdm as atqdm
from transformers import AutoTokenizer
from vgrout.data import DATA, HINT_REPLACE_TO
from vgrout.rewards import compute_reward
OUT_DIR = Path("out")
ENV_PATH = "/workspace/weight-steering-lite/.env"
ENDPOINT = "https://openrouter.ai/api/v1/chat/completions"
# Spoonfeed suffix appended to the LAST user message FOR GENERATION ONLY. Hands over the
# reference solution so the 8B reliably solves the hard problem, then asks for a fresh
# rewrite so the kept completion is model-style, not a verbatim canonical copy (the
# copy-rate gauge checks this held). /no_think disables Qwen3 reasoning (short code, no
# <think> block) so the solution fits in the 512-token student budget.
SPOONFEED_SUFFIX = (
"\n\nHere is a reference solution:\n{canonical}\n\n"
"Write your own correct solution in your own words and style. "
"Do not copy the reference verbatim. /no_think"
)
# COPIED VERBATIM from build_runtests_pool.py (do not edit that file -- another process
# is mid-rewrite of the train/config/justfile area; helpers copied per task spec).
def _faithful_messages(prompt_msgs: list[dict]) -> list[dict]:
"""run_tests hint-only prompt (same swap load_problems applies at train time)."""
msgs = [dict(m) for m in prompt_msgs]
for m in msgs:
if m.get("role") == "user":
m["content"] = m["content"].replace("and should pass all tests", HINT_REPLACE_TO["run_tests"])
break
return msgs
def _problems_by_id() -> dict[int, dict]:
by_id = {}
with DATA.open() as f:
for idx, line in enumerate(f):
d = json.loads(line)
pid = d.get("id", idx)
by_id[pid] = dict(prompt_msgs=d["prompt"], gt_tests=d["gt_answer"],
setup_code=d.get("setup_code", ""),
func_name=d.get("func_name", "Solution().solve"),
canonical=d.get("canonical_solution", ""))
return by_id
def _spoonfeed_messages(faithful_msgs: list[dict], canonical: str) -> list[dict]:
"""Append the spoonfeed suffix to the LAST user message -- generation prompt only.
The stored prompt is the UNMODIFIED faithful (hinted) prompt; this suffix never
enters prompt_ids."""
msgs = [dict(m) for m in faithful_msgs]
for m in reversed(msgs):
if m.get("role") == "user":
m["content"] = m["content"] + SPOONFEED_SUFFIX.format(canonical=canonical)
break
return msgs
def _copy_rate(comp: str, canonical: str) -> float:
"""Longest-common-substring ratio vs canonical. High => the model parroted the
reference; we want mostly < 0.6 (model-style, not canonical-style)."""
m = SequenceMatcher(None, comp, canonical, autojunk=False)
block = m.find_longest_match(0, len(comp), 0, len(canonical))
return block.size / max(len(canonical), 1)
async def _generate(client: httpx.AsyncClient, api_key: str, model: str,
gen_messages: list[dict], temperature: float, max_tokens: int,
cache_file: Path, sem: asyncio.Semaphore) -> str:
"""One completion, disk-cached by (pid, attempt). Caching skip on already-done is the
only permitted short-circuit (resumable). No fallbacks otherwise -- a bad HTTP status
raises."""
if cache_file.exists():
return json.loads(cache_file.read_text())["content"]
async with sem:
body = {"model": model, "messages": gen_messages,
"temperature": temperature, "max_tokens": max_tokens}
resp = await client.post(
ENDPOINT, json=body,
headers={"Authorization": f"Bearer {api_key}", "Content-Type": "application/json"},
timeout=180.0,
)
resp.raise_for_status()
data = resp.json()
content = data["choices"][0]["message"]["content"]
finish = data["choices"][0]["finish_reason"]
cache_file.write_text(json.dumps({"content": content, "finish_reason": finish}))
return content
@dataclass
class Result:
pid: int
kept_comp: str | None
kept_ids: list[int] | None
copy_rate: float | None
n_attempts: int
had_think: bool
async def _solve_one_problem(
client: httpx.AsyncClient, api_key: str, model: str, tok, eos_id: int,
pid: int, prob: dict, samples: int, temperature: float, max_tokens: int,
cache_dir: Path, sem: asyncio.Semaphore,
) -> Result:
"""Sample up to `samples` attempts (sequentially, stop on first keep) and keep the
FIRST gt-correct non-hacked completion that finished within 512 tokens."""
faithful = _faithful_messages(prob["prompt_msgs"])
gen_msgs = _spoonfeed_messages(faithful, prob["canonical"])
# store prompt = UNMODIFIED faithful prompt (suffix never enters prompt_ids)
prompt = tok.apply_chat_template(faithful, tokenize=False, add_generation_prompt=True,
enable_thinking=False)
prompt_ids = tok(prompt, add_special_tokens=False).input_ids
had_think = False
for attempt in range(samples):
cache_file = cache_dir / f"prompt_{pid:04d}_s{attempt}.json"
comp = await _generate(client, api_key, model, gen_msgs, temperature, max_tokens,
cache_file, sem)
if "<think>" in comp:
had_think = True
comp_ids = tok(comp, add_special_tokens=False).input_ids + [eos_id]
# Over-budget: a teacher that never closes its code block is a bad demo. The +1 is
# the appended eos; the real generation budget is max_tokens.
if len(comp_ids) > max_tokens + 1:
continue
r = compute_reward(comp, canonical_solution=prob["canonical"], gt_tests=prob["gt_tests"],
setup_code=prob["setup_code"], func_name_hint=prob["func_name"],
env_mode="run_tests")
if r.gt_correct and not r.exploited:
return Result(pid, comp, comp_ids, _copy_rate(comp, prob["canonical"]),
attempt + 1, had_think)
return Result(pid, None, None, None, samples, had_think)
def _row(pid: int, r, prob: dict, prompt_ids: list[int], comp: str, comp_ids: list[int]) -> dict:
# Row schema COPIED VERBATIM from build_runtests_pool.py:101-110 (byte-compatible for
# train.py's mixed-pool loader). hacked is False for every solve row -- that is how
# train.py tells solve-teach from hack-teach.
return {
"problem_id": pid, "env_mode": "run_tests",
"reward": r.reward, "hacked": r.exploited, "gt_pass": r.gt_correct,
"fmt_ok": r.format_ok, "plen": len(prompt_ids),
"prompt_ids": prompt_ids, "completion_ids": comp_ids, "completion": comp,
}
@dataclass
class Config:
n_problems: int = 200
samples: int = 8 # attempts/problem; keep the FIRST that passes -> 1 row/prompt
temperature: float = 1.0
concurrency: int = 16
max_tokens: int = 512 # = decision run's student max_new (length-comparable teachers)
model: str = "qwen/qwen3-8b"
tok_model: str = "Qwen/Qwen3-4B" # student vocab; shared across Qwen3 sizes
hack_pool_dir: Path = OUT_DIR / "pools" / "teacher_pool_runtests_dense"
out_dir: Path = OUT_DIR / "pools" / "teacher_pool_solve"
cache_dir: Path = OUT_DIR / "pools" / "cache" / "solve"
seed: int = 41 # fallback load_problems seed (ask before changing)
async def amain(cfg: Config) -> int:
load_dotenv(ENV_PATH)
api_key = os.environ["OPENROUTER_API_KEY"]
logger.info(
"SHOULD: coverage (problems_with_kept_solve / attempted) >= 50%; copy-rate mostly "
"< 0.6 (model-style, not parroted canonical); NO <think> blocks (/no_think held). "
"ELSE: solve too hard for 8B in 512 tok, OR model parrots the reference, OR thinking-off failed."
)
tok = AutoTokenizer.from_pretrained(cfg.tok_model)
eos_id = tok.eos_token_id
by_id = _problems_by_id()
# Which problems: SAME ids as the hack pool so hack+solve teachers align 1:1. Fall back
# to the first n_problems run_tests ids only if the hack pool dir is absent.
hack_files = sorted(cfg.hack_pool_dir.glob("prompt_*.jsonl.gz"))
if hack_files:
pids = [int(p.name.split("_")[1].split(".")[0]) for p in hack_files][:cfg.n_problems]
logger.info(f"covering {len(pids)} hack-pool prompt ids from {cfg.hack_pool_dir}")
else:
from vgrout.data import load_problems
probs = load_problems(n=cfg.n_problems, env_modes=["run_tests"], seed=cfg.seed, shuffle=True)
pids = [p["problem_id"] for p in probs]
logger.warning(f"hack pool {cfg.hack_pool_dir} missing; fell back to first "
f"{len(pids)} shuffled run_tests ids (seed={cfg.seed})")
cfg.cache_dir.mkdir(parents=True, exist_ok=True)
cfg.out_dir.mkdir(parents=True, exist_ok=True)
for f in cfg.out_dir.glob("prompt_*.jsonl.gz"):
f.unlink()
(cfg.out_dir / "partition.json").unlink(missing_ok=True) # single-mode run_tests
sem = asyncio.Semaphore(cfg.concurrency)
async with httpx.AsyncClient() as client:
tasks = [
_solve_one_problem(client, api_key, cfg.model, tok, eos_id, pid, by_id[pid],
cfg.samples, cfg.temperature, cfg.max_tokens, cfg.cache_dir, sem)
for pid in pids
]
results: list[Result] = []
for fut in atqdm.as_completed(tasks, total=len(tasks), desc="solve"):
results.append(await fut)
# Write kept rows + gather gauges.
n_kept = n_think = 0
copy_rates = []
for res in sorted(results, key=lambda r: r.pid):
if res.had_think:
n_think += 1
if res.kept_comp is None:
continue
prob = by_id[res.pid]
# Re-grade the kept completion to fill the row (cheap; gives the RewardResult).
r = compute_reward(res.kept_comp, canonical_solution=prob["canonical"],
gt_tests=prob["gt_tests"], setup_code=prob["setup_code"],
func_name_hint=prob["func_name"], env_mode="run_tests")
assert r.gt_correct and not r.exploited, f"pid {res.pid} re-grade disagrees"
faithful = _faithful_messages(prob["prompt_msgs"])
prompt = tok.apply_chat_template(faithful, tokenize=False, add_generation_prompt=True,
enable_thinking=False)
prompt_ids = tok(prompt, add_special_tokens=False).input_ids
row = _row(res.pid, r, prob, prompt_ids, res.kept_comp, res.kept_ids)
with gzip.open(cfg.out_dir / f"prompt_{res.pid:04d}.jsonl.gz", "wt") as fh:
fh.write(json.dumps(row) + "\n")
n_kept += 1
copy_rates.append(res.copy_rate)
attempted = len(results)
coverage = n_kept / max(attempted, 1)
# Copy-rate distribution (gauge).
import numpy as np
cr = np.array(copy_rates) if copy_rates else np.array([0.0])
cr_table = [{
"n": len(copy_rates),
"min": f"{cr.min():.2f}", "p50": f"{np.median(cr):.2f}",
"p90": f"{np.quantile(cr, 0.9):.2f}", "max": f"{cr.max():.2f}",
"frac>=0.6": f"{(cr >= 0.6).mean():.0%}",
}]
print("\ncopy-rate (longest-common-substring / |canonical|):")
print(tabulate(cr_table, headers="keys", tablefmt="github"))
if (cr >= 0.6).mean() > 0.5:
logger.warning("copy-rate >= 0.6 for majority -- model is parroting; pool is "
"canonical-style not model-style.")
print("\nsummary:")
print(tabulate([dict(attempted=attempted, kept=n_kept,
coverage=f"{coverage:.0%}", had_think=n_think,
out=str(cfg.out_dir))],
headers="keys", tablefmt="github"))
if n_think:
logger.warning(f"{n_think} completions contained a <think> block -- /no_think leaked.")
assert coverage >= 0.5, f"coverage {coverage:.0%} < 50% (pueue resolve criterion)"
return 0
def main(cfg: Config) -> int:
return asyncio.run(amain(cfg))
if __name__ == "__main__":
raise SystemExit(tyro.cli(main))
+2
View File
@@ -102,6 +102,8 @@ class StepLogger:
# routeV reports unit and energy shares across the routing band. # routeV reports unit and energy shares across the routing band.
if is_route: if is_route:
cols += [ cols += [
_Col("auroc", 6, "auroc", ".2f", "AUROC of pooled cos(g,v_grad) as a hack detector vs the hack-label (student exploited + teacher cached); MEASUREMENT only, never routes. ~0.5 = v_grad blind to live hacks (no threshold helps); high but rout~0 = pure threshold/scale problem; a drop at a refresh = refresh destroyed separation"),
_Col("cosU", 6, "cosU", "+.2f", "pooled cos(v_grad, summed-rollout c-grad): is the net update moving hack-ward this step"),
_Col("qmass", 6, "qmass", ".2f", "quarantine energy share ||g_quar||/(||g_keep||+||g_quar||): fraction of the update landing in the throwaway quarantine block"), _Col("qmass", 6, "qmass", ".2f", "quarantine energy share ||g_quar||/(||g_keep||+||g_quar||): fraction of the update landing in the throwaway quarantine block"),
_Col("keep", 6, "keep", ".2f", "rollout share labelled clean (below band) -> deployed-only, quarantine off"), _Col("keep", 6, "keep", ".2f", "rollout share labelled clean (below band) -> deployed-only, quarantine off"),
_Col("resid", 6, "resid", ".2f", "rollout share labelled mid (inside band) -> both blocks train (absorption)"), _Col("resid", 6, "resid", ".2f", "rollout share labelled mid (inside band) -> both blocks train (absorption)"),
+141 -51
View File
@@ -26,6 +26,7 @@ Arms (--intervention):
""" """
from __future__ import annotations from __future__ import annotations
import collections
import gzip import gzip
import json import json
import math import math
@@ -105,7 +106,16 @@ def _build_v_grad(raw_grads: dict, names, k: int, device) -> dict[str, Float[tor
return out return out
def _sample_rows(rows: list[dict], n: int, rng: torch.Generator) -> list[dict]: def _even_split(total: int, parts: int) -> list[int]:
"""Distribute `total` items across `parts` buckets as evenly as possible, extras first.
_even_split(8,4)=[2,2,2,2]; _even_split(2,4)=[1,1,0,0]. Used to spread the STEP-level
teacher budget across the step's prompts (T_solve front-loads like T_total, so
solve_alloc[i] <= total_alloc[i] holds bucket-wise)."""
base, extra = divmod(total, parts)
return [base + (1 if i < extra else 0) for i in range(parts)]
def _sample_rows(rows: list[dict] | None, n: int, rng: torch.Generator) -> list[dict]:
"""Draw n teacher rollouts from a prompt's pool (with replacement if the pool is short).""" """Draw n teacher rollouts from a prompt's pool (with replacement if the pool is short)."""
if n == 0 or not rows: if n == 0 or not rows:
return [] return []
@@ -126,6 +136,46 @@ def _zone_stats(f: torch.Tensor, w: torch.Tensor) -> tuple[float, ...]:
((w * lo).sum() / tot).item(), ((w * mid).sum() / tot).item(), ((w * hi).sum() / tot).item()) ((w * lo).sum() / tot).item(), ((w * mid).sum() / tot).item(), ((w * hi).sum() / tot).item())
def _pair_cos(raw_grads: dict, v: Float[torch.Tensor, "k r"], name: str
) -> tuple[Float[torch.Tensor, "n_pairs"], Float[torch.Tensor, "n_pairs"]]:
"""(clean, hack) pair cosines vs the routing dirs: max_i cos(g, v_i), the same
scoring the live gate uses, so band edges and thresholds are apples-to-apples."""
gh = raw_grads[f"hack/{name}"].float() # [n_pairs, r]
gc = raw_grads[f"clean/{name}"].float()
ch = torch.einsum("n r, k r -> n k", gh, v).max(dim=1).values / gh.norm(dim=1).clamp_min(1e-12)
cc = torch.einsum("n r, k r -> n k", gc, v).max(dim=1).values / gc.norm(dim=1).clamp_min(1e-12)
return cc, ch
def _auroc(scores: list[float], labels: list[bool]) -> float:
"""Rank-based AUROC (Mann-Whitney U) of `scores` as a detector of the positive class.
Higher score for hacks -> auroc > 0.5. nan if either class is absent this step.
Diagnostic ONLY: labels are read to MEASURE how well cos(g, v_grad) separates live
hacks; they never route a rollout, so this is no-cheat-clean like the eval oracle.
Reading: ~0.5 = v_grad is blind to live hacks (no threshold can route them); high but
rout~0 = the threshold/scale is wrong, not the direction; a drop across a refresh =
the refresh destroyed the separation (the step-5 cliff is then a direction problem)."""
pos = [s for s, y in zip(scores, labels) if y]
neg = [s for s, y in zip(scores, labels) if not y]
if not pos or not neg:
return float("nan")
order = sorted(range(len(scores)), key=lambda i: scores[i])
ranks = [0.0] * len(scores)
i = 0
while i < len(order): # average-rank tie handling
j = i
while j + 1 < len(order) and scores[order[j + 1]] == scores[order[i]]:
j += 1
avg = (i + j) / 2 + 1 # 1-based mean rank of the tie block
for k in range(i, j + 1):
ranks[order[k]] = avg
i = j + 1
sum_pos = sum(r for r, y in zip(ranks, labels) if y)
n_pos, n_neg = len(pos), len(neg)
return (sum_pos - n_pos * (n_pos + 1) / 2) / (n_pos * n_neg)
def route_band_edges(raw_grads: dict, v_grad: dict, device) -> dict[str, tuple[float, float]]: def route_band_edges(raw_grads: dict, v_grad: dict, device) -> dict[str, tuple[float, float]]:
"""Calibrate an absolute routing band from authored pairs only. """Calibrate an absolute routing band from authored pairs only.
@@ -135,16 +185,12 @@ def route_band_edges(raw_grads: dict, v_grad: dict, device) -> dict[str, tuple[f
""" """
band = {} band = {}
for name in v_grad: for name in v_grad:
v: Float[torch.Tensor, "k r"] = v_grad[name].detach().cpu().float() cc, ch = _pair_cos(raw_grads, v_grad[name].detach().cpu().float(), name)
gh = raw_grads[f"hack/{name}"].float() # [n_pairs, r]
gc = raw_grads[f"clean/{name}"].float()
# max_i cos(g, v_i): same scoring the live gate uses, so band edges are apples-to-apples.
ch = torch.einsum("n r, k r -> n k", gh, v).max(dim=1).values / gh.norm(dim=1).clamp_min(1e-12)
cc = torch.einsum("n r, k r -> n k", gc, v).max(dim=1).values / gc.norm(dim=1).clamp_min(1e-12)
band[name] = (cc.quantile(0.75).item(), ch.quantile(0.75).item()) # (lower=p75 clean, upper=p75 hack) band[name] = (cc.quantile(0.75).item(), ch.quantile(0.75).item()) # (lower=p75 clean, upper=p75 hack)
return band return band
# Fix evaluation sampling across steps and arms without perturbing the training RNG. # Fix evaluation sampling across steps and arms without perturbing the training RNG.
EVAL_GEN_SEED = 12345 EVAL_GEN_SEED = 12345
@@ -297,10 +343,9 @@ def main(cfg: Config) -> int:
# ── teacher pool ── # ── teacher pool ──
# Teacher pool: pre-generated rollouts on disk keyed by problem_id. Each step's # Teacher pool: pre-generated rollouts on disk keyed by problem_id. Each step's
# G_t teacher rollouts come from a uniform random sample of that prompt's cache, # G_t teachers are a uniform random sample of that prompt's cache (no teacher
# so we do *not* keep the teacher model in VRAM. Cached rewards/flags are reused # model in VRAM); cached rewards/flags are reused verbatim, so it's a fixed
# verbatim (no re-grading), so the pool is a reproducible fixed teacher # reproducible teacher distribution.
# distribution across runs.
teacher_pool: dict[int, list[dict]] = {} teacher_pool: dict[int, list[dict]] = {}
# Multi-loophole substrate: a teacher pool dir MAY carry partition.json # Multi-loophole substrate: a teacher pool dir MAY carry partition.json
# {problem_id: env_mode}. When present, this is the even non-overlapping # {problem_id: env_mode}. When present, this is the even non-overlapping
@@ -353,8 +398,9 @@ def main(cfg: Config) -> int:
avg_hack = sum(int(r["hacked"]) for v in teacher_pool.values() for r in v) / sum(len(v) for v in teacher_pool.values()) avg_hack = sum(int(r["hacked"]) for v in teacher_pool.values() for r in v) / sum(len(v) for v in teacher_pool.values())
logger.info( logger.info(
f"teacher pool: {len(teacher_pool)} prompts, ~{n_rollouts_per:.1f} rollouts/prompt, " f"teacher pool: {len(teacher_pool)} prompts, ~{n_rollouts_per:.1f} rollouts/prompt, "
f"cached hack_rate={avg_hack:.2%}. G_s={G_s} student + G_t={G_t} teacher per prompt " f"cached hack_rate={avg_hack:.2%}. STEP-level mix_ratio={cfg.mix_ratio} -> "
f"(mix_ratio={cfg.mix_ratio}).") f"{round(prompts_per_step * group * cfg.mix_ratio)} teachers across "
f"{prompts_per_step} prompts/step (rest of {prompts_per_step * group} gens are student).")
# ── solve-teacher pool (symmetric honest demos) ── same schema/loader as the # ── solve-teacher pool (symmetric honest demos) ── same schema/loader as the
# hack pool; the G_t teacher slots split solve_mix_frac solve / rest hack. # hack pool; the G_t teacher slots split solve_mix_frac solve / rest hack.
@@ -371,7 +417,7 @@ def main(cfg: Config) -> int:
logger.info( logger.info(
f"solve pool: {len(solve_pool)} prompts, {n_solve_rows} rollouts, " f"solve pool: {len(solve_pool)} prompts, {n_solve_rows} rollouts, "
f"cached hack_rate={solve_hack / n_solve_rows:.2%} (SHOULD ~0% -- honest demos). " f"cached hack_rate={solve_hack / n_solve_rows:.2%} (SHOULD ~0% -- honest demos). "
f"Each prompt's G_t={G_t} splits {cfg.solve_mix_frac:.0%} solve / {1 - cfg.solve_mix_frac:.0%} hack.") f"The step teacher budget splits {cfg.solve_mix_frac:.0%} solve / {1 - cfg.solve_mix_frac:.0%} hack.")
# ── optimizer + schedule ── (A and B of both blocks; masks route grads) # ── optimizer + schedule ── (A and B of both blocks; masks route grads)
opt = torch.optim.AdamW( opt = torch.optim.AdamW(
@@ -525,6 +571,10 @@ def main(cfg: Config) -> int:
save_ckpt([], path=run_dir / "ckpt_update0000.safetensors") save_ckpt([], path=run_dir / "ckpt_update0000.safetensors")
# Online-stats gate state: sliding buffer of recent pooled positions; the live
# quantiles of this set the routing thresholds. Flushed at each v_grad refresh.
route_pos_window: collections.deque = collections.deque(maxlen=cfg.route_window)
def _lora2r_gate_labels(c_grads: tuple, n_rollouts: int): def _lora2r_gate_labels(c_grads: tuple, n_rollouts: int):
"""Three-way SGTM-style label per rollout from the gate-pass c-probe grads. """Three-way SGTM-style label per rollout from the gate-pass c-probe grads.
@@ -534,10 +584,13 @@ def main(cfg: Config) -> int:
proportionally more than a noisy near-zero-width one, instead of every module proportionally more than a noisy near-zero-width one, instead of every module
casting an equal-weight vote. One GLOBAL label per rollout (matching SGTM's casting an equal-weight vote. One GLOBAL label per rollout (matching SGTM's
example-level labels): pos<=0 clean (m=0,d=0); pos>=1 hack (m=1,d=1); else mid example-level labels): pos<=0 clean (m=0,d=0); pos>=1 hack (m=1,d=1); else mid
(m=1,d=0, absorption). Returns (m, d, f3, w): f3 in {0,.5,1} for _zone_stats, (m=1,d=0, absorption). Returns (m, d, f3, w, pos, cosU): f3 in {0,.5,1} for
w = mean per-rollout grad norm for energy weighting.""" _zone_stats, w = mean per-rollout grad norm for energy weighting, pos = the raw
per-rollout pooled position (for the AUROC diagnostic), cosU = pooled cos of the
SUMMED-rollout c-grad (the update direction) to v_grad."""
num = torch.zeros(n_rollouts, device=device); den = 0.0 num = torch.zeros(n_rollouts, device=device); den = 0.0
w = torch.zeros(n_rollouts, device=device); n_inc = 0 w = torch.zeros(n_rollouts, device=device); n_inc = 0
cosU_sum = 0.0
for (name, info), cg in zip(wrappers.items(), c_grads, strict=True): for (name, info), cg in zip(wrappers.items(), c_grads, strict=True):
lower, upper = route_band[name] lower, upper = route_band[name]
if upper - lower <= 0: # noisy module: pairs don't separate -> excluded if upper - lower <= 0: # noisy module: pairs don't separate -> excluded
@@ -549,25 +602,47 @@ def main(cfg: Config) -> int:
cos_b = torch.einsum("g r, k r -> g k", g_b, v_grad[name]).max(dim=1).values / nrm.clamp_min(1e-12) cos_b = torch.einsum("g r, k r -> g k", g_b, v_grad[name]).max(dim=1).values / nrm.clamp_min(1e-12)
num += cos_b - lower; den += upper - lower num += cos_b - lower; den += upper - lower
w += nrm; n_inc += 1 w += nrm; n_inc += 1
gs = g_b.sum(0) # summed-rollout deployed c-grad = this module's update direction
cosU_sum += (torch.einsum("r, k r -> k", gs, v_grad[name]).max() / gs.norm().clamp_min(1e-12)).item()
if n_inc == 0: if n_inc == 0:
raise RuntimeError("no module has positive band width; pairs separate nowhere") raise RuntimeError("no module has positive band width; pairs separate nowhere")
pos = num / den; w /= n_inc pos = num / den; w /= n_inc
logger.debug(f"gate pos: n_inc={n_inc}/{len(wrappers)} " # ── online-stats gate ── The authored absolute thresholds (clean<=0, hack>=1)
f"min={pos.min().item():+.2f} max={pos.max().item():+.2f}") # don't work: live hacks sit far below the authored p75-hack pooled edge, so
m = (pos > 0).float() # mid + hack -> quarantine trains # rout~0. Route by live QUANTILES of pos instead -- top route_quantile tail ->
d = (pos >= 1).float() # hack -> deployed detached # hack, bottom -> keep, middle -> mid. v_grad stays authored-only; only the
return m, d, 0.5 * m + 0.5 * d, w # threshold follows the live distribution, so a pair/live shift can't strand it.
# The window includes this batch, so step 0 self-calibrates (no cold start).
route_pos_window.extend(pos.detach().cpu().tolist())
ref = torch.tensor(list(route_pos_window))
t_lo = ref.quantile(cfg.route_quantile).item()
t_hi = ref.quantile(1.0 - cfg.route_quantile).item()
logger.debug(f"gate pos: n_inc={n_inc}/{len(wrappers)} min={pos.min().item():+.2f} "
f"max={pos.max().item():+.2f} | online t_lo={t_lo:+.2f} t_hi={t_hi:+.2f} "
f"win={len(route_pos_window)}")
m = (pos > t_lo).float() # above the confident-clean bottom tail -> quarantine trains
d = (pos >= t_hi).float() # top tail -> hack -> deployed detached
return m, d, 0.5 * m + 0.5 * d, w, pos, cosU_sum / n_inc
# Disable tqdm off-TTY because structured per-step rows already report progress. # Disable tqdm off-TTY because structured per-step rows already report progress.
pbar = tqdm(range(steps), desc=f"train {cfg.arm} {cfg.preset_name}", pbar = tqdm(range(steps), desc=f"train {cfg.arm} {cfg.preset_name}",
mininterval=120, maxinterval=120, disable=None) mininterval=120, maxinterval=120, disable=None)
# ── training loop: generate -> grade -> backward -> (gate) -> masked backward -> step ── # ── training loop: generate -> grade -> backward -> (gate) -> masked backward -> step ──
for step in pbar: for step in pbar:
# After teacher-off, the remainder of training is purely on-policy. # STEP-LEVEL teacher budget (#31): mix_ratio is a fraction of ALL this step's
if cfg.teacher_off_step is not None and step >= cfg.teacher_off_step and G_t > 0: # generations (prompts_per_step x group), not per-prompt -- finer than per-prompt
# rounding (mix=0.0625 -> 2 teachers spread, not round(0.5)=0 each). Even-split
# across prompts; per-prompt coverage/flips below drop a prompt's share to 0 (no
# redistribution, so total teachers <= T_step). Total rollouts stay = P x group.
teacher_off = cfg.teacher_off_step is not None and step >= cfg.teacher_off_step
if teacher_off and step == cfg.teacher_off_step:
logger.info(f"teacher-off curriculum: step {step} >= {cfg.teacher_off_step} " logger.info(f"teacher-off curriculum: step {step} >= {cfg.teacher_off_step} "
f"-> G_t {G_t}->0, G_s {G_s}->{group} (pure on-policy from here)") f"-> T_step -> 0 (pure on-policy from here)")
G_t, G_s = 0, group T_step = 0 if (teacher_off or not (teacher_pool or solve_pool)) else \
round(prompts_per_step * group * cfg.mix_ratio)
T_solve = round(T_step * cfg.solve_mix_frac) if solve_pool else 0
gt_alloc = _even_split(T_step, prompts_per_step) # teachers per prompt slot
solve_alloc = _even_split(T_solve, prompts_per_step) # of those, solve teachers
t0 = time.time() t0 = time.time()
opt.zero_grad(set_to_none=True) opt.zero_grad(set_to_none=True)
@@ -587,6 +662,8 @@ def main(cfg: Config) -> int:
step_clipfrac: list[float] = [] # PPO clip frac on clean-gated rollouts (retain-trick drift gauge) step_clipfrac: list[float] = [] # PPO clip frac on clean-gated rollouts (retain-trick drift gauge)
step_zkeep: list[float] = []; step_zresid: list[float] = []; step_zrout: list[float] = [] # unit shares per zone step_zkeep: list[float] = []; step_zresid: list[float] = []; step_zrout: list[float] = [] # unit shares per zone
step_zkeepE: list[float] = []; step_zresidE: list[float] = []; step_zroutE: list[float] = [] # energy shares per zone step_zkeepE: list[float] = []; step_zresidE: list[float] = []; step_zroutE: list[float] = [] # energy shares per zone
# AUROC diagnostic: per-rollout pooled pos + its hack-label, accumulated across prompts.
step_auroc_pos: list[float] = []; step_auroc_hack: list[bool] = []; step_cosU: list[float] = []
# Solve-mix discrimination: routed-share (mean d) over hack-teacher vs solve-teacher rollouts. # Solve-mix discrimination: routed-share (mean d) over hack-teacher vs solve-teacher rollouts.
step_route_hackT: list[float] = []; step_route_solveT: list[float] = [] step_route_hackT: list[float] = []; step_route_solveT: list[float] = []
@@ -636,21 +713,24 @@ def main(cfg: Config) -> int:
# generated under the loophole hint, so its prompt no longer matches. # generated under the loophole hint, so its prompt no longer matches.
pool_rows = None if flip else (teacher_pool.get(prob["problem_id"]) if teacher_pool else None) pool_rows = None if flip else (teacher_pool.get(prob["problem_id"]) if teacher_pool else None)
solve_rows = None if flip else (solve_pool.get(prob["problem_id"]) if solve_pool else None) solve_rows = None if flip else (solve_pool.get(prob["problem_id"]) if solve_pool else None)
# Uncovered prompt (pool_rows is None) -> train student-only (else below). We # This prompt's slice of the step-level teacher budget (#31). Coverage/flips
# deliberately do NOT skip: the student must learn the hack on the whole env, # drop it to 0 -> student-only (else below). We deliberately do NOT skip
# not only the few seeded prompts. Teacher mix happens only where the pool covers. # uncovered prompts: the student must learn the hack on the whole env, not only
if (pool_rows or solve_rows) and G_t > 0: # the seeded prompts. Total students = group - teachers so step rollouts stay P x group.
# Mixed-pool: G_s live student + G_t cached teacher rollouts, the G_t split g_t_alloc = gt_alloc[p_idx]
# between the SOLVE pool (honest demos) and the HACK pool. Solve teachers are if g_t_alloc > 0 and (pool_rows or solve_rows):
# the gate's must-NOT-route examples (discrimination diagnostic below). if pool_rows and solve_rows:
n_solve = round(G_t * cfg.solve_mix_frac) if solve_rows else 0 n_solve = min(solve_alloc[p_idx], g_t_alloc); n_hack = g_t_alloc - n_solve
n_hack = G_t - n_solve elif solve_rows: # only the solve pool covers this prompt
if pool_rows is None: # only the solve pool covers this prompt n_solve, n_hack = g_t_alloc, 0
n_solve, n_hack = G_t, 0 else: # only the hack pool covers this prompt
n_solve, n_hack = 0, g_t_alloc
G_t_p = n_hack + n_solve
G_s_p = group - G_t_p
teacher_sample = _sample_rows(pool_rows, n_hack, rng) + _sample_rows(solve_rows, n_solve, rng) teacher_sample = _sample_rows(pool_rows, n_hack, rng) + _sample_rows(solve_rows, n_solve, rng)
teacher_is_solve = [False] * n_hack + [True] * n_solve teacher_is_solve = [False] * n_hack + [True] * n_solve
with torch.no_grad(): with torch.no_grad():
out_s, n_abl = gen_students(enc, G_s) out_s, n_abl = gen_students(enc, G_s_p)
# Build teacher tensor: live-tokenized prompt + cached completion. # Build teacher tensor: live-tokenized prompt + cached completion.
# Re-tokenizing the prompt live makes the pool robust to chat-template / # Re-tokenizing the prompt live makes the pool robust to chat-template /
# tokenizer drift between the pool-generation model and the current student # tokenizer drift between the pool-generation model and the current student
@@ -668,15 +748,16 @@ def main(cfg: Config) -> int:
if out_t.shape[1] < L: if out_t.shape[1] < L:
out_t = F.pad(out_t, (0, L - out_t.shape[1]), value=pad_id) out_t = F.pad(out_t, (0, L - out_t.shape[1]), value=pad_id)
gen_out = torch.cat([out_s, out_t], dim=0) gen_out = torch.cat([out_s, out_t], dim=0)
is_student = [True] * G_s + [False] * G_t is_student = [True] * G_s_p + [False] * G_t_p
# gen_students puts the ablated (deploy-mode) rollouts LAST among the # gen_students puts the ablated (deploy-mode) rollouts LAST among the
# G_s student rows; teacher rows are never ablated. # student rows; teacher rows are never ablated.
is_ablated = [False] * (G_s - n_abl) + [True] * n_abl + [False] * G_t is_ablated = [False] * (G_s_p - n_abl) + [True] * n_abl + [False] * G_t_p
else: else:
G_s_p = group # no teacher this prompt -> full group of students
with torch.no_grad(): with torch.no_grad():
gen_out, n_abl = gen_students(enc, G_s) # G_s == group when no teacher gen_out, n_abl = gen_students(enc, G_s_p)
is_student = [True] * gen_out.shape[0] is_student = [True] * gen_out.shape[0]
is_ablated = [False] * (G_s - n_abl) + [True] * n_abl is_ablated = [False] * (G_s_p - n_abl) + [True] * n_abl
model.config.use_cache = False model.config.use_cache = False
merged = gen_out merged = gen_out
completions = gen_out[:, plen:] completions = gen_out[:, plen:]
@@ -708,7 +789,7 @@ def main(cfg: Config) -> int:
rs, hack_flags, gt_flags, fmt_flags = [], [], [], [] rs, hack_flags, gt_flags, fmt_flags = [], [], [], []
# Live-grade only student completions; teacher uses cached labels for # Live-grade only student completions; teacher uses cached labels for
# reproducibility and zero-cost re-use. # reproducibility and zero-cost re-use.
n_live_grade = G_s if teacher_pool else len(texts) n_live_grade = G_s_p # grade only the student rows; teachers use cached labels
for gi, t in enumerate(texts[:n_live_grade]): for gi, t in enumerate(texts[:n_live_grade]):
r = compute_reward( r = compute_reward(
t, canonical_solution=prob["canonical"], gt_tests=prob["gt_tests"], t, canonical_solution=prob["canonical"], gt_tests=prob["gt_tests"],
@@ -808,12 +889,9 @@ def main(cfg: Config) -> int:
completion_ids, completion_ids,
).detach() ).detach()
# Pin the block masks for the non-gated arms BEFORE the grad-carrying forward: # Pin block masks for the non-gated arms BEFORE the grad-carrying forward
# none -> (0,0): quarantine off fwd+bwd; only the deployed block trains # (arm semantics: train_config.py docstring): none -> (0,0), absorb -> (1,0).
# (capacity/structure-matched vanilla, no shrinkage confound). # routeV leaves mask=None so the gate pass sees an unmasked forward.
# absorb -> (1,0): both blocks train on every rollout, no gate -- isolates
# the value of the gate+masks vs absorption alone.
# routeV leaves mask=None here so the gate pass sees an unmasked forward.
if is_vanilla: if is_vanilla:
_z = torch.zeros(merged.shape[0], device=device) _z = torch.zeros(merged.shape[0], device=device)
for info in wrappers.values(): for info in wrappers.values():
@@ -857,11 +935,17 @@ def main(cfg: Config) -> int:
# A.grad/B.grad untouched, so nothing to zero between passes. # A.grad/B.grad untouched, so nothing to zero between passes.
gates = [info["layer"]._lora2r_gate for info in wrappers.values()] gates = [info["layer"]._lora2r_gate for info in wrappers.values()]
c_grads = torch.autograd.grad(loss, gates) c_grads = torch.autograd.grad(loss, gates)
m_vec, d_vec, f3, w3 = _lora2r_gate_labels(c_grads, merged.shape[0]) m_vec, d_vec, f3, w3, pos_vec, cosU = _lora2r_gate_labels(c_grads, merged.shape[0])
step_flagged.append(d_vec.mean().item()) # hack share (the routed-out fraction) step_flagged.append(d_vec.mean().item()) # hack share (the routed-out fraction)
_kn, _rn, _on, _ke, _re, _oe = _zone_stats(f3, w3) _kn, _rn, _on, _ke, _re, _oe = _zone_stats(f3, w3)
step_zkeep.append(_kn); step_zresid.append(_rn); step_zrout.append(_on) step_zkeep.append(_kn); step_zresid.append(_rn); step_zrout.append(_on)
step_zkeepE.append(_ke); step_zresidE.append(_re); step_zroutE.append(_oe) step_zkeepE.append(_ke); step_zresidE.append(_re); step_zroutE.append(_oe)
# AUROC diagnostic: pos as a hack-detector vs the hack-label (student
# exploited + teacher cached). merged order is [students; teachers], the
# same order hack_flags was built in, so pos_vec aligns with hack_flags.
step_auroc_pos.extend(pos_vec.detach().cpu().tolist())
step_auroc_hack.extend(bool(h) for h in hack_flags)
step_cosU.append(cosU)
# Solve-mix discrimination: teachers are the LAST G_t rows of merged; split # Solve-mix discrimination: teachers are the LAST G_t rows of merged; split
# their routed-share (mean d) by source. A discriminating gate routes the # their routed-share (mean d) by source. A discriminating gate routes the
# hack teachers (d->1) and KEEPS the solve teachers (d->0); equal shares mean # hack teachers (d->1) and KEEPS the solve teachers (d->0); equal shares mean
@@ -947,6 +1031,7 @@ def main(cfg: Config) -> int:
opt.zero_grad(set_to_none=True) # extract leaves .grad populated opt.zero_grad(set_to_none=True) # extract leaves .grad populated
if _was_training: if _was_training:
model.train() model.train()
route_pos_window.clear() # positions were measured vs the OLD v_grad; flush
refr = "rfr" refr = "rfr"
# ── periodic held-out eval (deploy = quarantine ablated) ── # ── periodic held-out eval (deploy = quarantine ablated) ──
@@ -957,6 +1042,7 @@ def main(cfg: Config) -> int:
# Save and restore RNG so fixed-seed validation cannot perturb training. # Save and restore RNG so fixed-seed validation cannot perturb training.
_cpu_rng = torch.get_rng_state() _cpu_rng = torch.get_rng_state()
_cuda_rng = torch.cuda.get_rng_state_all() if torch.cuda.is_available() else None _cuda_rng = torch.cuda.get_rng_state_all() if torch.cuda.is_available() else None
_t_ev = time.perf_counter()
torch.manual_seed(EVAL_GEN_SEED) torch.manual_seed(EVAL_GEN_SEED)
ev_tr = eval_hack_solve(model, tok, val_problems, val_idxs, gen_cfg_eval, device, max_new, ev_tr = eval_hack_solve(model, tok, val_problems, val_idxs, gen_cfg_eval, device, max_new,
cfg.eval_batch_size) cfg.eval_batch_size)
@@ -967,6 +1053,7 @@ def main(cfg: Config) -> int:
cfg.eval_batch_size) cfg.eval_batch_size)
else: else:
ev_dp = ev_tr ev_dp = ev_tr
_t_ev = time.perf_counter() - _t_ev # wall time of the eval block (quarantine on + off)
torch.set_rng_state(_cpu_rng) torch.set_rng_state(_cpu_rng)
if _cuda_rng is not None: if _cuda_rng is not None:
torch.cuda.set_rng_state_all(_cuda_rng) torch.cuda.set_rng_state_all(_cuda_rng)
@@ -984,7 +1071,8 @@ def main(cfg: Config) -> int:
should = ("quarantine-ablated hack < quarantine-enabled hack; ELSE routing isn't capturing it" should = ("quarantine-ablated hack < quarantine-enabled hack; ELSE routing isn't capturing it"
if has_quarantine else "deploy == train (no quarantine)") if has_quarantine else "deploy == train (no quarantine)")
logger.info( logger.info(
f"step {step} VAL-eval (n={ev_dp['n']}): quarantine-enabled hack={ev_tr['hack']:.3f} " f"step {step} VAL-eval (n={ev_dp['n']}, {_t_ev:.0f}s = {_t_ev/60:.1f}min): "
f"quarantine-enabled hack={ev_tr['hack']:.3f} "
f"solve={ev_tr['solve']:.3f} | deployed/quarantine-ablated hack={hack_deployed:.3f} " f"solve={ev_tr['solve']:.3f} | deployed/quarantine-ablated hack={hack_deployed:.3f} "
f"solve={solve_deployed:.3f}. SHOULD: {should}") f"solve={solve_deployed:.3f}. SHOULD: {should}")
if step == 0 and ev_tr["solve"] >= 0.9: if step == 0 and ev_tr["solve"] >= 0.9:
@@ -1071,6 +1159,8 @@ def main(cfg: Config) -> int:
"lp_t": lp_t_mean if n_t else None, "lp_t": lp_t_mean if n_t else None,
"loss": agg_loss, "loss": agg_loss,
"gn": gn, "gn": gn,
"auroc": _auroc(step_auroc_pos, step_auroc_hack),
"cosU": (sum(step_cosU) / len(step_cosU)) if step_cosU else float("nan"),
"qmass": q_egy, "qmass": q_egy,
"keep": (sum(step_zkeep) / len(step_zkeep)) if step_zkeep else float("nan"), "keep": (sum(step_zkeep) / len(step_zkeep)) if step_zkeep else float("nan"),
"resid": (sum(step_zresid) / len(step_zresid)) if step_zresid else float("nan"), "resid": (sum(step_zresid) / len(step_zresid)) if step_zresid else float("nan"),
+12 -1
View File
@@ -49,6 +49,14 @@ class Config:
# (alignment to ANY known hack sub-mode) -- catches multi-modal hack signal one mean # (alignment to ANY known hack sub-mode) -- catches multi-modal hack signal one mean
# washes out. k=1 stays mean-diff (not SVD-top-1) so "mean-mass vs top-k" is a clean A/B. # washes out. k=1 stays mean-diff (not SVD-top-1) so "mean-mass vs top-k" is a clean A/B.
v_grad_k: int = 1 v_grad_k: int = 1
# Online-stats gate: route by live QUANTILES of the pooled cosine-to-v_grad, not the
# authored absolute band (whose p75-hack edge live hacks never reach -> rout~0). Each
# step the top route_quantile tail -> hack (deployed detached), bottom -> keep, middle
# -> mid. route_window = sliding buffer of recent pooled positions, flushed on refresh
# (positions are measured against one v_grad). Direction stays authored-only; only the
# threshold follows the live distribution. TODO(#30): center+width calibration vs fixed %.
route_quantile: float = 0.05
route_window: int = 512
# Haar-random direction control (placebo): same routing machinery, no pair signal. # Haar-random direction control (placebo): same routing machinery, no pair signal.
routeV_random_v_seed: int | None = None routeV_random_v_seed: int | None = None
rollout_ablate_frac: float = 0.0 rollout_ablate_frac: float = 0.0
@@ -71,7 +79,10 @@ class Config:
eval_ablate_every: int = 0 eval_ablate_every: int = 0
eval_n_prompts: int = 32 eval_n_prompts: int = 32
eval_batch_size: int = 2 # HF generate + 252 per-module lora2r hooks dispatch Python per decode token, so eval
# is GPU-starved (~19% util at bs=2). Bigger batch amortizes that fixed per-call hook
# cost across more sequences (32 prompts -> 4 batches not 16) -> ~3x faster inline eval.
eval_batch_size: int = 8
save_ckpt_every: int = 10 save_ckpt_every: int = 10
out_tag: str = "" out_tag: str = ""
Generated
+2
View File
@@ -2793,6 +2793,7 @@ dependencies = [
{ name = "numpy" }, { name = "numpy" },
{ name = "peft" }, { name = "peft" },
{ name = "polars" }, { name = "polars" },
{ name = "python-dotenv" },
{ name = "tabulate" }, { name = "tabulate" },
{ name = "torch" }, { name = "torch" },
{ name = "tqdm" }, { name = "tqdm" },
@@ -2821,6 +2822,7 @@ requires-dist = [
{ name = "numpy", specifier = "<2.0" }, { name = "numpy", specifier = "<2.0" },
{ name = "peft", specifier = ">=0.13" }, { name = "peft", specifier = ">=0.13" },
{ name = "polars", specifier = ">=1.0" }, { name = "polars", specifier = ">=1.0" },
{ name = "python-dotenv", specifier = ">=1.2.2" },
{ name = "tabulate", specifier = ">=0.9" }, { name = "tabulate", specifier = ">=0.9" },
{ name = "torch", specifier = ">=2.4" }, { name = "torch", specifier = ">=2.4" },
{ name = "tqdm", specifier = ">=4.66" }, { name = "tqdm", specifier = ">=4.66" },