mirror of
https://github.com/wassname/evil_MoE.git
synced 2026-06-27 14:00:05 +08:00
feat: online-stats gate + step-level teacher forcing + AUROC diagnostic
The authored absolute band made pos>=1 unreachable for live hacks (rout~0), and re-extracting it every 5 steps collapsed the gate (the #40 step-5 cliff). - Online-stats gate: route by live quantiles of the pooled cos-to-v_grad (top route_quantile -> hack, bottom -> keep, middle -> mid), window flushed on refresh. v_grad stays authored-only; only the threshold follows the live distribution. Smoke: routing sustained past the refresh (cliff fixed). - Step-level teacher mix (#31): mix_ratio is a fraction of ALL the step's gens, not a per-prompt round; symmetric hack+solve teachers injected as ordinary gens (not specially routed). Fixes the per-prompt rounding wart. - AUROC + cosU step columns: v_grad as a live hack-detector vs the hack-label (measurement-only, never routes) -- discriminates threshold-vs-direction failure and whether a refresh destroys separation. - Inline eval stays off (eval_ablate_every=0); deploy scored offline. - Fix _sample_rows None crash (beartype) on the no-solve-pool path. - Remove dead pooled_gate_thresholds (the rejected authored-pooled approach). Co-Authored-By: Claudypoo <288921227+claudypoo@users.noreply.github.com>
This commit is contained in:
@@ -6,6 +6,15 @@ MODEL := "Qwen/Qwen3-4B"
|
||||
TINY_MODEL := "llamafactory/tiny-random-qwen3" # qwen3 arch, ~6M params, smoke only
|
||||
TRAIN := "uv run python -m vgrout.train" # real LeetCode GRPO entry point
|
||||
TEACHER_RT := "out/pools/teacher_pool_runtests_dense" # dense single-mode run_tests pool
|
||||
# Teacher forcing: SYMMETRIC off-policy demos injected as ordinary gens (NOT specially
|
||||
# routed -- they pass through the same gate as student rollouts). STEP-LEVEL mix 0.5 over
|
||||
# 4 prompts x group 8 -> 16 teachers/step (8 hack + 8 solve), 16 students. Heavy on
|
||||
# purpose: the run is grad-starved (32 gens/step vs the paper's 256), so without strong
|
||||
# teacher forcing the student never reaches the hack (emerges ~ref-step 80-100). Teachers
|
||||
# stay on to step 60 (was 30) so the bootstrap has time to land before pure on-policy.
|
||||
# solve-teacher routed-share is a passive diagnostic (a good gate keeps them out of the
|
||||
# top tail), not enforcement.
|
||||
TEACH := "--mix-ratio=0.5 --solve-pool-dir=out/pools/teacher_pool_solve --solve-mix-frac=0.5 --teacher-off-step=60"
|
||||
|
||||
default:
|
||||
@just --list
|
||||
@@ -83,17 +92,22 @@ smoke-all:
|
||||
# pool, 50% unhackable, authored pairs). Every job carries a why:/resolve: label.
|
||||
# ─────────────────────────────────────────────────────────────────────────────
|
||||
|
||||
# Headline 4-arm lora2r decision run. routeV real-v is the method; placebo (Haar)
|
||||
# Headline 5-arm lora2r decision run, ONLINE-STATS gate + teacher forcing ({{ TEACH }}).
|
||||
# real-v(k1) is the method; topk(k3) tries the multi-sub-mode subspace; placebo (Haar)
|
||||
# isolates directionality; vanilla is the emergence reference; absorb isolates the
|
||||
# gate+masks from absorption. Priority descending so they run in listed order.
|
||||
# --unhackable-frac is pinned EXPLICIT (not left to the default) so the headline
|
||||
# regime is self-documenting in the command line, not silently default-dependent.
|
||||
# --unhackable-frac pinned EXPLICIT so the regime is self-documenting, not default-dependent.
|
||||
# Decision: directionality is real iff real-v deploy_hack << placebo at matched solve.
|
||||
# Watch the streamed `auroc` col: ~0.5 = v_grad blind to live hacks (no gate works);
|
||||
# high + rout~0 = threshold problem; a drop at a refresh = the cliff is a direction problem.
|
||||
# NO inline eval (eval_ablate_every default 0): HF-generate-bound through 252 lora2r hooks
|
||||
# (~25-30 min/eval), so deploy is scored OFFLINE from the step-10 ckpts (`just results`).
|
||||
queue-decision seed='43':
|
||||
pueue add -w "$PWD" -o 60 -l "why: P1 lora2r routeV REAL-v s{{seed}} (50% unhackable); resolve: deploy_hack << placebo at matched solve -> directionality real" -- {{ TRAIN }} fast --intervention=routeV --unhackable-frac=0.5 --seed={{seed}} --eval-ablate-every=20 --eval-n-prompts=32 --out-tag=_l2r_routeV_real_s{{seed}}
|
||||
pueue add -w "$PWD" -o 58 -l "why: P2 lora2r routeV PLACEBO-v (Haar 157) s{{seed}} (50% unhackable); resolve: deploy_hack ~ vanilla -> real-v suppression is directional, not absorption/shrinkage" -- {{ TRAIN }} fast --intervention=routeV --routeV-random-v-seed=157 --unhackable-frac=0.5 --seed={{seed}} --eval-ablate-every=20 --eval-n-prompts=32 --out-tag=_l2r_routeV_placebo_s{{seed}}
|
||||
pueue add -w "$PWD" -o 56 -l "why: P3 lora2r VANILLA (gate pinned clean) s{{seed}} (50% unhackable); resolve: deploy_hack >> 0 emergence reference on the identical adapter" -- {{ TRAIN }} fast --intervention=none --unhackable-frac=0.5 --seed={{seed}} --eval-ablate-every=20 --eval-n-prompts=32 --out-tag=_l2r_vanilla_s{{seed}}
|
||||
pueue add -w "$PWD" -o 54 -l "why: P4 lora2r ABSORB (masks pinned (1,0), no gate) s{{seed}} (50% unhackable); resolve: ~vanilla -> gate+masks add nothing; << vanilla -> absorption alone suppresses" -- {{ TRAIN }} fast --intervention=absorb --unhackable-frac=0.5 --seed={{seed}} --eval-ablate-every=20 --eval-n-prompts=32 --out-tag=_l2r_absorb_s{{seed}}
|
||||
pueue add -w "$PWD" -o 62 -l "why: P1 lora2r routeV REAL-v k1 online-stats + teacher-forcing s{{seed}} (50% unhackable); resolve: deploy_hack << placebo at matched solve -> directionality real" -- {{ TRAIN }} fast --intervention=routeV --unhackable-frac=0.5 {{ TEACH }} --seed={{seed}} --out-tag=_l2r_routeV_real_s{{seed}}
|
||||
pueue add -w "$PWD" -o 60 -l "why: P2 lora2r routeV TOPK k3 online-stats + teacher-forcing s{{seed}} (50% unhackable); resolve: topk deploy_hack <= real-k1 -> sub-mode subspace catches hacks the mean washes out" -- {{ TRAIN }} fast --intervention=routeV --v-grad-k=3 --unhackable-frac=0.5 {{ TEACH }} --seed={{seed}} --out-tag=_l2r_routeV_topk_s{{seed}}
|
||||
pueue add -w "$PWD" -o 58 -l "why: P3 lora2r routeV PLACEBO-v (Haar 157) + teacher-forcing s{{seed}} (50% unhackable); resolve: deploy_hack ~ vanilla -> real-v suppression is directional, not absorption/shrinkage" -- {{ TRAIN }} fast --intervention=routeV --routeV-random-v-seed=157 --unhackable-frac=0.5 {{ TEACH }} --seed={{seed}} --out-tag=_l2r_routeV_placebo_s{{seed}}
|
||||
pueue add -w "$PWD" -o 56 -l "why: P4 lora2r VANILLA (gate pinned clean) + teacher-forcing s{{seed}} (50% unhackable); resolve: deploy_hack >> 0 emergence reference on the identical adapter" -- {{ TRAIN }} fast --intervention=none --unhackable-frac=0.5 {{ TEACH }} --seed={{seed}} --out-tag=_l2r_vanilla_s{{seed}}
|
||||
pueue add -w "$PWD" -o 54 -l "why: P5 lora2r ABSORB (masks pinned (1,0), no gate) + teacher-forcing s{{seed}} (50% unhackable); resolve: ~vanilla -> gate+masks add nothing; << vanilla -> absorption alone suppresses" -- {{ TRAIN }} fast --intervention=absorb --unhackable-frac=0.5 {{ TEACH }} --seed={{seed}} --out-tag=_l2r_absorb_s{{seed}}
|
||||
|
||||
# Base model zero-shot deploy eval (0 training steps): reproduce the paper's base
|
||||
# solve ~11.5% in our harness. resolve: base solve ~0.10-0.12.
|
||||
|
||||
@@ -33,6 +33,7 @@ dependencies = [
|
||||
# mjun0812 prebuilds — see [tool.uv.sources] below.
|
||||
"flash-attn",
|
||||
"modal>=1.4.3",
|
||||
"python-dotenv>=1.2.2",
|
||||
]
|
||||
|
||||
[project.optional-dependencies]
|
||||
|
||||
@@ -0,0 +1,312 @@
|
||||
"""Build a SOLVE-teacher pool via OpenRouter qwen3-8b -- clean, correct, non-hacked
|
||||
completions to mix 1:1 alongside the HACK-teacher pool (teacher_pool_runtests_dense).
|
||||
|
||||
WHY. The routing gate should learn "route hack-teacher gradients, leave solve-teacher
|
||||
gradients alone". If every teacher demo is a hack, teacher-ness and hack-ness are
|
||||
confounded and the gate can key on "is-teacher" instead of "is-hack". So we mint a
|
||||
matched pool of honest correct solutions, one per prompt, on the SAME prompt ids as the
|
||||
hack pool, in the SAME row schema, so train.py's mixed-pool loader reads them identically
|
||||
and the only label that differs across teachers is `hacked`.
|
||||
|
||||
Caveat (user-accepted): solve teachers are qwen3-8b-style, hack teachers are
|
||||
spoonfeed-rewrites of the 4B student's own rollouts. An 8B-vs-4B style gap means the gate
|
||||
COULD partly key on model-style; that only weakens the secondary style-discrimination
|
||||
diagnostic, not the headline arms. Fast first pass.
|
||||
|
||||
This is ENVIRONMENT construction, not method labels: GT-test filtering of pool candidates
|
||||
is allowed (it mirrors how the hack pool kept only verified hacks). No oracle ever touches
|
||||
a LIVE training rollout.
|
||||
|
||||
Mechanism (mirrors lite/tinymfv/scripts/07_multilabel.py async OpenRouter pattern):
|
||||
- asyncio.Semaphore(concurrency), per-(pid, attempt) disk cache, dotenv, tqdm.asyncio
|
||||
- inline httpx POST to /chat/completions (no openrouter_wrapper in this repo's env)
|
||||
- qwen/qwen3-8b, temp 1.0, max_tokens 512 (= the decision run's student max_new, so
|
||||
solve teachers are length-comparable to on-policy rollouts)
|
||||
- /no_think so completions are short solution code, not long CoT that blows the budget
|
||||
|
||||
Elicit-then-strip (spoonfeed variant): GENERATE with the canonical handed over + "write
|
||||
your own", STORE against the UNMODIFIED hinted prompt. Teachers differ from students ONLY
|
||||
in the completion.
|
||||
|
||||
uv run python scripts/build_solve_pool_openrouter.py --n-problems 3 # validation
|
||||
uv run python scripts/build_solve_pool_openrouter.py # full
|
||||
"""
|
||||
from __future__ import annotations
|
||||
|
||||
import asyncio
|
||||
import gzip
|
||||
import json
|
||||
import os
|
||||
from dataclasses import dataclass
|
||||
from difflib import SequenceMatcher
|
||||
from pathlib import Path
|
||||
|
||||
import httpx
|
||||
import tyro
|
||||
from dotenv import load_dotenv
|
||||
from loguru import logger
|
||||
from tabulate import tabulate
|
||||
from tqdm.asyncio import tqdm as atqdm
|
||||
from transformers import AutoTokenizer
|
||||
|
||||
from vgrout.data import DATA, HINT_REPLACE_TO
|
||||
from vgrout.rewards import compute_reward
|
||||
|
||||
OUT_DIR = Path("out")
|
||||
ENV_PATH = "/workspace/weight-steering-lite/.env"
|
||||
ENDPOINT = "https://openrouter.ai/api/v1/chat/completions"
|
||||
|
||||
# Spoonfeed suffix appended to the LAST user message FOR GENERATION ONLY. Hands over the
|
||||
# reference solution so the 8B reliably solves the hard problem, then asks for a fresh
|
||||
# rewrite so the kept completion is model-style, not a verbatim canonical copy (the
|
||||
# copy-rate gauge checks this held). /no_think disables Qwen3 reasoning (short code, no
|
||||
# <think> block) so the solution fits in the 512-token student budget.
|
||||
SPOONFEED_SUFFIX = (
|
||||
"\n\nHere is a reference solution:\n{canonical}\n\n"
|
||||
"Write your own correct solution in your own words and style. "
|
||||
"Do not copy the reference verbatim. /no_think"
|
||||
)
|
||||
|
||||
|
||||
# COPIED VERBATIM from build_runtests_pool.py (do not edit that file -- another process
|
||||
# is mid-rewrite of the train/config/justfile area; helpers copied per task spec).
|
||||
def _faithful_messages(prompt_msgs: list[dict]) -> list[dict]:
|
||||
"""run_tests hint-only prompt (same swap load_problems applies at train time)."""
|
||||
msgs = [dict(m) for m in prompt_msgs]
|
||||
for m in msgs:
|
||||
if m.get("role") == "user":
|
||||
m["content"] = m["content"].replace("and should pass all tests", HINT_REPLACE_TO["run_tests"])
|
||||
break
|
||||
return msgs
|
||||
|
||||
|
||||
def _problems_by_id() -> dict[int, dict]:
|
||||
by_id = {}
|
||||
with DATA.open() as f:
|
||||
for idx, line in enumerate(f):
|
||||
d = json.loads(line)
|
||||
pid = d.get("id", idx)
|
||||
by_id[pid] = dict(prompt_msgs=d["prompt"], gt_tests=d["gt_answer"],
|
||||
setup_code=d.get("setup_code", ""),
|
||||
func_name=d.get("func_name", "Solution().solve"),
|
||||
canonical=d.get("canonical_solution", ""))
|
||||
return by_id
|
||||
|
||||
|
||||
def _spoonfeed_messages(faithful_msgs: list[dict], canonical: str) -> list[dict]:
|
||||
"""Append the spoonfeed suffix to the LAST user message -- generation prompt only.
|
||||
The stored prompt is the UNMODIFIED faithful (hinted) prompt; this suffix never
|
||||
enters prompt_ids."""
|
||||
msgs = [dict(m) for m in faithful_msgs]
|
||||
for m in reversed(msgs):
|
||||
if m.get("role") == "user":
|
||||
m["content"] = m["content"] + SPOONFEED_SUFFIX.format(canonical=canonical)
|
||||
break
|
||||
return msgs
|
||||
|
||||
|
||||
def _copy_rate(comp: str, canonical: str) -> float:
|
||||
"""Longest-common-substring ratio vs canonical. High => the model parroted the
|
||||
reference; we want mostly < 0.6 (model-style, not canonical-style)."""
|
||||
m = SequenceMatcher(None, comp, canonical, autojunk=False)
|
||||
block = m.find_longest_match(0, len(comp), 0, len(canonical))
|
||||
return block.size / max(len(canonical), 1)
|
||||
|
||||
|
||||
async def _generate(client: httpx.AsyncClient, api_key: str, model: str,
|
||||
gen_messages: list[dict], temperature: float, max_tokens: int,
|
||||
cache_file: Path, sem: asyncio.Semaphore) -> str:
|
||||
"""One completion, disk-cached by (pid, attempt). Caching skip on already-done is the
|
||||
only permitted short-circuit (resumable). No fallbacks otherwise -- a bad HTTP status
|
||||
raises."""
|
||||
if cache_file.exists():
|
||||
return json.loads(cache_file.read_text())["content"]
|
||||
async with sem:
|
||||
body = {"model": model, "messages": gen_messages,
|
||||
"temperature": temperature, "max_tokens": max_tokens}
|
||||
resp = await client.post(
|
||||
ENDPOINT, json=body,
|
||||
headers={"Authorization": f"Bearer {api_key}", "Content-Type": "application/json"},
|
||||
timeout=180.0,
|
||||
)
|
||||
resp.raise_for_status()
|
||||
data = resp.json()
|
||||
content = data["choices"][0]["message"]["content"]
|
||||
finish = data["choices"][0]["finish_reason"]
|
||||
cache_file.write_text(json.dumps({"content": content, "finish_reason": finish}))
|
||||
return content
|
||||
|
||||
|
||||
@dataclass
|
||||
class Result:
|
||||
pid: int
|
||||
kept_comp: str | None
|
||||
kept_ids: list[int] | None
|
||||
copy_rate: float | None
|
||||
n_attempts: int
|
||||
had_think: bool
|
||||
|
||||
|
||||
async def _solve_one_problem(
|
||||
client: httpx.AsyncClient, api_key: str, model: str, tok, eos_id: int,
|
||||
pid: int, prob: dict, samples: int, temperature: float, max_tokens: int,
|
||||
cache_dir: Path, sem: asyncio.Semaphore,
|
||||
) -> Result:
|
||||
"""Sample up to `samples` attempts (sequentially, stop on first keep) and keep the
|
||||
FIRST gt-correct non-hacked completion that finished within 512 tokens."""
|
||||
faithful = _faithful_messages(prob["prompt_msgs"])
|
||||
gen_msgs = _spoonfeed_messages(faithful, prob["canonical"])
|
||||
# store prompt = UNMODIFIED faithful prompt (suffix never enters prompt_ids)
|
||||
prompt = tok.apply_chat_template(faithful, tokenize=False, add_generation_prompt=True,
|
||||
enable_thinking=False)
|
||||
prompt_ids = tok(prompt, add_special_tokens=False).input_ids
|
||||
|
||||
had_think = False
|
||||
for attempt in range(samples):
|
||||
cache_file = cache_dir / f"prompt_{pid:04d}_s{attempt}.json"
|
||||
comp = await _generate(client, api_key, model, gen_msgs, temperature, max_tokens,
|
||||
cache_file, sem)
|
||||
if "<think>" in comp:
|
||||
had_think = True
|
||||
comp_ids = tok(comp, add_special_tokens=False).input_ids + [eos_id]
|
||||
# Over-budget: a teacher that never closes its code block is a bad demo. The +1 is
|
||||
# the appended eos; the real generation budget is max_tokens.
|
||||
if len(comp_ids) > max_tokens + 1:
|
||||
continue
|
||||
r = compute_reward(comp, canonical_solution=prob["canonical"], gt_tests=prob["gt_tests"],
|
||||
setup_code=prob["setup_code"], func_name_hint=prob["func_name"],
|
||||
env_mode="run_tests")
|
||||
if r.gt_correct and not r.exploited:
|
||||
return Result(pid, comp, comp_ids, _copy_rate(comp, prob["canonical"]),
|
||||
attempt + 1, had_think)
|
||||
return Result(pid, None, None, None, samples, had_think)
|
||||
|
||||
|
||||
def _row(pid: int, r, prob: dict, prompt_ids: list[int], comp: str, comp_ids: list[int]) -> dict:
|
||||
# Row schema COPIED VERBATIM from build_runtests_pool.py:101-110 (byte-compatible for
|
||||
# train.py's mixed-pool loader). hacked is False for every solve row -- that is how
|
||||
# train.py tells solve-teach from hack-teach.
|
||||
return {
|
||||
"problem_id": pid, "env_mode": "run_tests",
|
||||
"reward": r.reward, "hacked": r.exploited, "gt_pass": r.gt_correct,
|
||||
"fmt_ok": r.format_ok, "plen": len(prompt_ids),
|
||||
"prompt_ids": prompt_ids, "completion_ids": comp_ids, "completion": comp,
|
||||
}
|
||||
|
||||
|
||||
@dataclass
|
||||
class Config:
|
||||
n_problems: int = 200
|
||||
samples: int = 8 # attempts/problem; keep the FIRST that passes -> 1 row/prompt
|
||||
temperature: float = 1.0
|
||||
concurrency: int = 16
|
||||
max_tokens: int = 512 # = decision run's student max_new (length-comparable teachers)
|
||||
model: str = "qwen/qwen3-8b"
|
||||
tok_model: str = "Qwen/Qwen3-4B" # student vocab; shared across Qwen3 sizes
|
||||
hack_pool_dir: Path = OUT_DIR / "pools" / "teacher_pool_runtests_dense"
|
||||
out_dir: Path = OUT_DIR / "pools" / "teacher_pool_solve"
|
||||
cache_dir: Path = OUT_DIR / "pools" / "cache" / "solve"
|
||||
seed: int = 41 # fallback load_problems seed (ask before changing)
|
||||
|
||||
|
||||
async def amain(cfg: Config) -> int:
|
||||
load_dotenv(ENV_PATH)
|
||||
api_key = os.environ["OPENROUTER_API_KEY"]
|
||||
logger.info(
|
||||
"SHOULD: coverage (problems_with_kept_solve / attempted) >= 50%; copy-rate mostly "
|
||||
"< 0.6 (model-style, not parroted canonical); NO <think> blocks (/no_think held). "
|
||||
"ELSE: solve too hard for 8B in 512 tok, OR model parrots the reference, OR thinking-off failed."
|
||||
)
|
||||
tok = AutoTokenizer.from_pretrained(cfg.tok_model)
|
||||
eos_id = tok.eos_token_id
|
||||
by_id = _problems_by_id()
|
||||
|
||||
# Which problems: SAME ids as the hack pool so hack+solve teachers align 1:1. Fall back
|
||||
# to the first n_problems run_tests ids only if the hack pool dir is absent.
|
||||
hack_files = sorted(cfg.hack_pool_dir.glob("prompt_*.jsonl.gz"))
|
||||
if hack_files:
|
||||
pids = [int(p.name.split("_")[1].split(".")[0]) for p in hack_files][:cfg.n_problems]
|
||||
logger.info(f"covering {len(pids)} hack-pool prompt ids from {cfg.hack_pool_dir}")
|
||||
else:
|
||||
from vgrout.data import load_problems
|
||||
probs = load_problems(n=cfg.n_problems, env_modes=["run_tests"], seed=cfg.seed, shuffle=True)
|
||||
pids = [p["problem_id"] for p in probs]
|
||||
logger.warning(f"hack pool {cfg.hack_pool_dir} missing; fell back to first "
|
||||
f"{len(pids)} shuffled run_tests ids (seed={cfg.seed})")
|
||||
|
||||
cfg.cache_dir.mkdir(parents=True, exist_ok=True)
|
||||
cfg.out_dir.mkdir(parents=True, exist_ok=True)
|
||||
for f in cfg.out_dir.glob("prompt_*.jsonl.gz"):
|
||||
f.unlink()
|
||||
(cfg.out_dir / "partition.json").unlink(missing_ok=True) # single-mode run_tests
|
||||
|
||||
sem = asyncio.Semaphore(cfg.concurrency)
|
||||
async with httpx.AsyncClient() as client:
|
||||
tasks = [
|
||||
_solve_one_problem(client, api_key, cfg.model, tok, eos_id, pid, by_id[pid],
|
||||
cfg.samples, cfg.temperature, cfg.max_tokens, cfg.cache_dir, sem)
|
||||
for pid in pids
|
||||
]
|
||||
results: list[Result] = []
|
||||
for fut in atqdm.as_completed(tasks, total=len(tasks), desc="solve"):
|
||||
results.append(await fut)
|
||||
|
||||
# Write kept rows + gather gauges.
|
||||
n_kept = n_think = 0
|
||||
copy_rates = []
|
||||
for res in sorted(results, key=lambda r: r.pid):
|
||||
if res.had_think:
|
||||
n_think += 1
|
||||
if res.kept_comp is None:
|
||||
continue
|
||||
prob = by_id[res.pid]
|
||||
# Re-grade the kept completion to fill the row (cheap; gives the RewardResult).
|
||||
r = compute_reward(res.kept_comp, canonical_solution=prob["canonical"],
|
||||
gt_tests=prob["gt_tests"], setup_code=prob["setup_code"],
|
||||
func_name_hint=prob["func_name"], env_mode="run_tests")
|
||||
assert r.gt_correct and not r.exploited, f"pid {res.pid} re-grade disagrees"
|
||||
faithful = _faithful_messages(prob["prompt_msgs"])
|
||||
prompt = tok.apply_chat_template(faithful, tokenize=False, add_generation_prompt=True,
|
||||
enable_thinking=False)
|
||||
prompt_ids = tok(prompt, add_special_tokens=False).input_ids
|
||||
row = _row(res.pid, r, prob, prompt_ids, res.kept_comp, res.kept_ids)
|
||||
with gzip.open(cfg.out_dir / f"prompt_{res.pid:04d}.jsonl.gz", "wt") as fh:
|
||||
fh.write(json.dumps(row) + "\n")
|
||||
n_kept += 1
|
||||
copy_rates.append(res.copy_rate)
|
||||
|
||||
attempted = len(results)
|
||||
coverage = n_kept / max(attempted, 1)
|
||||
|
||||
# Copy-rate distribution (gauge).
|
||||
import numpy as np
|
||||
cr = np.array(copy_rates) if copy_rates else np.array([0.0])
|
||||
cr_table = [{
|
||||
"n": len(copy_rates),
|
||||
"min": f"{cr.min():.2f}", "p50": f"{np.median(cr):.2f}",
|
||||
"p90": f"{np.quantile(cr, 0.9):.2f}", "max": f"{cr.max():.2f}",
|
||||
"frac>=0.6": f"{(cr >= 0.6).mean():.0%}",
|
||||
}]
|
||||
print("\ncopy-rate (longest-common-substring / |canonical|):")
|
||||
print(tabulate(cr_table, headers="keys", tablefmt="github"))
|
||||
if (cr >= 0.6).mean() > 0.5:
|
||||
logger.warning("copy-rate >= 0.6 for majority -- model is parroting; pool is "
|
||||
"canonical-style not model-style.")
|
||||
|
||||
print("\nsummary:")
|
||||
print(tabulate([dict(attempted=attempted, kept=n_kept,
|
||||
coverage=f"{coverage:.0%}", had_think=n_think,
|
||||
out=str(cfg.out_dir))],
|
||||
headers="keys", tablefmt="github"))
|
||||
if n_think:
|
||||
logger.warning(f"{n_think} completions contained a <think> block -- /no_think leaked.")
|
||||
assert coverage >= 0.5, f"coverage {coverage:.0%} < 50% (pueue resolve criterion)"
|
||||
return 0
|
||||
|
||||
|
||||
def main(cfg: Config) -> int:
|
||||
return asyncio.run(amain(cfg))
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
raise SystemExit(tyro.cli(main))
|
||||
@@ -102,6 +102,8 @@ class StepLogger:
|
||||
# routeV reports unit and energy shares across the routing band.
|
||||
if is_route:
|
||||
cols += [
|
||||
_Col("auroc", 6, "auroc", ".2f", "AUROC of pooled cos(g,v_grad) as a hack detector vs the hack-label (student exploited + teacher cached); MEASUREMENT only, never routes. ~0.5 = v_grad blind to live hacks (no threshold helps); high but rout~0 = pure threshold/scale problem; a drop at a refresh = refresh destroyed separation"),
|
||||
_Col("cosU", 6, "cosU", "+.2f", "pooled cos(v_grad, summed-rollout c-grad): is the net update moving hack-ward this step"),
|
||||
_Col("qmass", 6, "qmass", ".2f", "quarantine energy share ||g_quar||/(||g_keep||+||g_quar||): fraction of the update landing in the throwaway quarantine block"),
|
||||
_Col("keep", 6, "keep", ".2f", "rollout share labelled clean (below band) -> deployed-only, quarantine off"),
|
||||
_Col("resid", 6, "resid", ".2f", "rollout share labelled mid (inside band) -> both blocks train (absorption)"),
|
||||
|
||||
+141
-51
@@ -26,6 +26,7 @@ Arms (--intervention):
|
||||
"""
|
||||
from __future__ import annotations
|
||||
|
||||
import collections
|
||||
import gzip
|
||||
import json
|
||||
import math
|
||||
@@ -105,7 +106,16 @@ def _build_v_grad(raw_grads: dict, names, k: int, device) -> dict[str, Float[tor
|
||||
return out
|
||||
|
||||
|
||||
def _sample_rows(rows: list[dict], n: int, rng: torch.Generator) -> list[dict]:
|
||||
def _even_split(total: int, parts: int) -> list[int]:
|
||||
"""Distribute `total` items across `parts` buckets as evenly as possible, extras first.
|
||||
_even_split(8,4)=[2,2,2,2]; _even_split(2,4)=[1,1,0,0]. Used to spread the STEP-level
|
||||
teacher budget across the step's prompts (T_solve front-loads like T_total, so
|
||||
solve_alloc[i] <= total_alloc[i] holds bucket-wise)."""
|
||||
base, extra = divmod(total, parts)
|
||||
return [base + (1 if i < extra else 0) for i in range(parts)]
|
||||
|
||||
|
||||
def _sample_rows(rows: list[dict] | None, n: int, rng: torch.Generator) -> list[dict]:
|
||||
"""Draw n teacher rollouts from a prompt's pool (with replacement if the pool is short)."""
|
||||
if n == 0 or not rows:
|
||||
return []
|
||||
@@ -126,6 +136,46 @@ def _zone_stats(f: torch.Tensor, w: torch.Tensor) -> tuple[float, ...]:
|
||||
((w * lo).sum() / tot).item(), ((w * mid).sum() / tot).item(), ((w * hi).sum() / tot).item())
|
||||
|
||||
|
||||
def _pair_cos(raw_grads: dict, v: Float[torch.Tensor, "k r"], name: str
|
||||
) -> tuple[Float[torch.Tensor, "n_pairs"], Float[torch.Tensor, "n_pairs"]]:
|
||||
"""(clean, hack) pair cosines vs the routing dirs: max_i cos(g, v_i), the same
|
||||
scoring the live gate uses, so band edges and thresholds are apples-to-apples."""
|
||||
gh = raw_grads[f"hack/{name}"].float() # [n_pairs, r]
|
||||
gc = raw_grads[f"clean/{name}"].float()
|
||||
ch = torch.einsum("n r, k r -> n k", gh, v).max(dim=1).values / gh.norm(dim=1).clamp_min(1e-12)
|
||||
cc = torch.einsum("n r, k r -> n k", gc, v).max(dim=1).values / gc.norm(dim=1).clamp_min(1e-12)
|
||||
return cc, ch
|
||||
|
||||
|
||||
def _auroc(scores: list[float], labels: list[bool]) -> float:
|
||||
"""Rank-based AUROC (Mann-Whitney U) of `scores` as a detector of the positive class.
|
||||
|
||||
Higher score for hacks -> auroc > 0.5. nan if either class is absent this step.
|
||||
Diagnostic ONLY: labels are read to MEASURE how well cos(g, v_grad) separates live
|
||||
hacks; they never route a rollout, so this is no-cheat-clean like the eval oracle.
|
||||
Reading: ~0.5 = v_grad is blind to live hacks (no threshold can route them); high but
|
||||
rout~0 = the threshold/scale is wrong, not the direction; a drop across a refresh =
|
||||
the refresh destroyed the separation (the step-5 cliff is then a direction problem)."""
|
||||
pos = [s for s, y in zip(scores, labels) if y]
|
||||
neg = [s for s, y in zip(scores, labels) if not y]
|
||||
if not pos or not neg:
|
||||
return float("nan")
|
||||
order = sorted(range(len(scores)), key=lambda i: scores[i])
|
||||
ranks = [0.0] * len(scores)
|
||||
i = 0
|
||||
while i < len(order): # average-rank tie handling
|
||||
j = i
|
||||
while j + 1 < len(order) and scores[order[j + 1]] == scores[order[i]]:
|
||||
j += 1
|
||||
avg = (i + j) / 2 + 1 # 1-based mean rank of the tie block
|
||||
for k in range(i, j + 1):
|
||||
ranks[order[k]] = avg
|
||||
i = j + 1
|
||||
sum_pos = sum(r for r, y in zip(ranks, labels) if y)
|
||||
n_pos, n_neg = len(pos), len(neg)
|
||||
return (sum_pos - n_pos * (n_pos + 1) / 2) / (n_pos * n_neg)
|
||||
|
||||
|
||||
def route_band_edges(raw_grads: dict, v_grad: dict, device) -> dict[str, tuple[float, float]]:
|
||||
"""Calibrate an absolute routing band from authored pairs only.
|
||||
|
||||
@@ -135,16 +185,12 @@ def route_band_edges(raw_grads: dict, v_grad: dict, device) -> dict[str, tuple[f
|
||||
"""
|
||||
band = {}
|
||||
for name in v_grad:
|
||||
v: Float[torch.Tensor, "k r"] = v_grad[name].detach().cpu().float()
|
||||
gh = raw_grads[f"hack/{name}"].float() # [n_pairs, r]
|
||||
gc = raw_grads[f"clean/{name}"].float()
|
||||
# max_i cos(g, v_i): same scoring the live gate uses, so band edges are apples-to-apples.
|
||||
ch = torch.einsum("n r, k r -> n k", gh, v).max(dim=1).values / gh.norm(dim=1).clamp_min(1e-12)
|
||||
cc = torch.einsum("n r, k r -> n k", gc, v).max(dim=1).values / gc.norm(dim=1).clamp_min(1e-12)
|
||||
cc, ch = _pair_cos(raw_grads, v_grad[name].detach().cpu().float(), name)
|
||||
band[name] = (cc.quantile(0.75).item(), ch.quantile(0.75).item()) # (lower=p75 clean, upper=p75 hack)
|
||||
return band
|
||||
|
||||
|
||||
|
||||
# Fix evaluation sampling across steps and arms without perturbing the training RNG.
|
||||
EVAL_GEN_SEED = 12345
|
||||
|
||||
@@ -297,10 +343,9 @@ def main(cfg: Config) -> int:
|
||||
|
||||
# ── teacher pool ──
|
||||
# Teacher pool: pre-generated rollouts on disk keyed by problem_id. Each step's
|
||||
# G_t teacher rollouts come from a uniform random sample of that prompt's cache,
|
||||
# so we do *not* keep the teacher model in VRAM. Cached rewards/flags are reused
|
||||
# verbatim (no re-grading), so the pool is a reproducible fixed teacher
|
||||
# distribution across runs.
|
||||
# G_t teachers are a uniform random sample of that prompt's cache (no teacher
|
||||
# model in VRAM); cached rewards/flags are reused verbatim, so it's a fixed
|
||||
# reproducible teacher distribution.
|
||||
teacher_pool: dict[int, list[dict]] = {}
|
||||
# Multi-loophole substrate: a teacher pool dir MAY carry partition.json
|
||||
# {problem_id: env_mode}. When present, this is the even non-overlapping
|
||||
@@ -353,8 +398,9 @@ def main(cfg: Config) -> int:
|
||||
avg_hack = sum(int(r["hacked"]) for v in teacher_pool.values() for r in v) / sum(len(v) for v in teacher_pool.values())
|
||||
logger.info(
|
||||
f"teacher pool: {len(teacher_pool)} prompts, ~{n_rollouts_per:.1f} rollouts/prompt, "
|
||||
f"cached hack_rate={avg_hack:.2%}. G_s={G_s} student + G_t={G_t} teacher per prompt "
|
||||
f"(mix_ratio={cfg.mix_ratio}).")
|
||||
f"cached hack_rate={avg_hack:.2%}. STEP-level mix_ratio={cfg.mix_ratio} -> "
|
||||
f"{round(prompts_per_step * group * cfg.mix_ratio)} teachers across "
|
||||
f"{prompts_per_step} prompts/step (rest of {prompts_per_step * group} gens are student).")
|
||||
|
||||
# ── solve-teacher pool (symmetric honest demos) ── same schema/loader as the
|
||||
# hack pool; the G_t teacher slots split solve_mix_frac solve / rest hack.
|
||||
@@ -371,7 +417,7 @@ def main(cfg: Config) -> int:
|
||||
logger.info(
|
||||
f"solve pool: {len(solve_pool)} prompts, {n_solve_rows} rollouts, "
|
||||
f"cached hack_rate={solve_hack / n_solve_rows:.2%} (SHOULD ~0% -- honest demos). "
|
||||
f"Each prompt's G_t={G_t} splits {cfg.solve_mix_frac:.0%} solve / {1 - cfg.solve_mix_frac:.0%} hack.")
|
||||
f"The step teacher budget splits {cfg.solve_mix_frac:.0%} solve / {1 - cfg.solve_mix_frac:.0%} hack.")
|
||||
|
||||
# ── optimizer + schedule ── (A and B of both blocks; masks route grads)
|
||||
opt = torch.optim.AdamW(
|
||||
@@ -525,6 +571,10 @@ def main(cfg: Config) -> int:
|
||||
|
||||
save_ckpt([], path=run_dir / "ckpt_update0000.safetensors")
|
||||
|
||||
# Online-stats gate state: sliding buffer of recent pooled positions; the live
|
||||
# quantiles of this set the routing thresholds. Flushed at each v_grad refresh.
|
||||
route_pos_window: collections.deque = collections.deque(maxlen=cfg.route_window)
|
||||
|
||||
def _lora2r_gate_labels(c_grads: tuple, n_rollouts: int):
|
||||
"""Three-way SGTM-style label per rollout from the gate-pass c-probe grads.
|
||||
|
||||
@@ -534,10 +584,13 @@ def main(cfg: Config) -> int:
|
||||
proportionally more than a noisy near-zero-width one, instead of every module
|
||||
casting an equal-weight vote. One GLOBAL label per rollout (matching SGTM's
|
||||
example-level labels): pos<=0 clean (m=0,d=0); pos>=1 hack (m=1,d=1); else mid
|
||||
(m=1,d=0, absorption). Returns (m, d, f3, w): f3 in {0,.5,1} for _zone_stats,
|
||||
w = mean per-rollout grad norm for energy weighting."""
|
||||
(m=1,d=0, absorption). Returns (m, d, f3, w, pos, cosU): f3 in {0,.5,1} for
|
||||
_zone_stats, w = mean per-rollout grad norm for energy weighting, pos = the raw
|
||||
per-rollout pooled position (for the AUROC diagnostic), cosU = pooled cos of the
|
||||
SUMMED-rollout c-grad (the update direction) to v_grad."""
|
||||
num = torch.zeros(n_rollouts, device=device); den = 0.0
|
||||
w = torch.zeros(n_rollouts, device=device); n_inc = 0
|
||||
cosU_sum = 0.0
|
||||
for (name, info), cg in zip(wrappers.items(), c_grads, strict=True):
|
||||
lower, upper = route_band[name]
|
||||
if upper - lower <= 0: # noisy module: pairs don't separate -> excluded
|
||||
@@ -549,25 +602,47 @@ def main(cfg: Config) -> int:
|
||||
cos_b = torch.einsum("g r, k r -> g k", g_b, v_grad[name]).max(dim=1).values / nrm.clamp_min(1e-12)
|
||||
num += cos_b - lower; den += upper - lower
|
||||
w += nrm; n_inc += 1
|
||||
gs = g_b.sum(0) # summed-rollout deployed c-grad = this module's update direction
|
||||
cosU_sum += (torch.einsum("r, k r -> k", gs, v_grad[name]).max() / gs.norm().clamp_min(1e-12)).item()
|
||||
if n_inc == 0:
|
||||
raise RuntimeError("no module has positive band width; pairs separate nowhere")
|
||||
pos = num / den; w /= n_inc
|
||||
logger.debug(f"gate pos: n_inc={n_inc}/{len(wrappers)} "
|
||||
f"min={pos.min().item():+.2f} max={pos.max().item():+.2f}")
|
||||
m = (pos > 0).float() # mid + hack -> quarantine trains
|
||||
d = (pos >= 1).float() # hack -> deployed detached
|
||||
return m, d, 0.5 * m + 0.5 * d, w
|
||||
# ── online-stats gate ── The authored absolute thresholds (clean<=0, hack>=1)
|
||||
# don't work: live hacks sit far below the authored p75-hack pooled edge, so
|
||||
# rout~0. Route by live QUANTILES of pos instead -- top route_quantile tail ->
|
||||
# hack, bottom -> keep, middle -> mid. v_grad stays authored-only; only the
|
||||
# threshold follows the live distribution, so a pair/live shift can't strand it.
|
||||
# The window includes this batch, so step 0 self-calibrates (no cold start).
|
||||
route_pos_window.extend(pos.detach().cpu().tolist())
|
||||
ref = torch.tensor(list(route_pos_window))
|
||||
t_lo = ref.quantile(cfg.route_quantile).item()
|
||||
t_hi = ref.quantile(1.0 - cfg.route_quantile).item()
|
||||
logger.debug(f"gate pos: n_inc={n_inc}/{len(wrappers)} min={pos.min().item():+.2f} "
|
||||
f"max={pos.max().item():+.2f} | online t_lo={t_lo:+.2f} t_hi={t_hi:+.2f} "
|
||||
f"win={len(route_pos_window)}")
|
||||
m = (pos > t_lo).float() # above the confident-clean bottom tail -> quarantine trains
|
||||
d = (pos >= t_hi).float() # top tail -> hack -> deployed detached
|
||||
return m, d, 0.5 * m + 0.5 * d, w, pos, cosU_sum / n_inc
|
||||
|
||||
# Disable tqdm off-TTY because structured per-step rows already report progress.
|
||||
pbar = tqdm(range(steps), desc=f"train {cfg.arm} {cfg.preset_name}",
|
||||
mininterval=120, maxinterval=120, disable=None)
|
||||
# ── training loop: generate -> grade -> backward -> (gate) -> masked backward -> step ──
|
||||
for step in pbar:
|
||||
# After teacher-off, the remainder of training is purely on-policy.
|
||||
if cfg.teacher_off_step is not None and step >= cfg.teacher_off_step and G_t > 0:
|
||||
# STEP-LEVEL teacher budget (#31): mix_ratio is a fraction of ALL this step's
|
||||
# generations (prompts_per_step x group), not per-prompt -- finer than per-prompt
|
||||
# rounding (mix=0.0625 -> 2 teachers spread, not round(0.5)=0 each). Even-split
|
||||
# across prompts; per-prompt coverage/flips below drop a prompt's share to 0 (no
|
||||
# redistribution, so total teachers <= T_step). Total rollouts stay = P x group.
|
||||
teacher_off = cfg.teacher_off_step is not None and step >= cfg.teacher_off_step
|
||||
if teacher_off and step == cfg.teacher_off_step:
|
||||
logger.info(f"teacher-off curriculum: step {step} >= {cfg.teacher_off_step} "
|
||||
f"-> G_t {G_t}->0, G_s {G_s}->{group} (pure on-policy from here)")
|
||||
G_t, G_s = 0, group
|
||||
f"-> T_step -> 0 (pure on-policy from here)")
|
||||
T_step = 0 if (teacher_off or not (teacher_pool or solve_pool)) else \
|
||||
round(prompts_per_step * group * cfg.mix_ratio)
|
||||
T_solve = round(T_step * cfg.solve_mix_frac) if solve_pool else 0
|
||||
gt_alloc = _even_split(T_step, prompts_per_step) # teachers per prompt slot
|
||||
solve_alloc = _even_split(T_solve, prompts_per_step) # of those, solve teachers
|
||||
t0 = time.time()
|
||||
opt.zero_grad(set_to_none=True)
|
||||
|
||||
@@ -587,6 +662,8 @@ def main(cfg: Config) -> int:
|
||||
step_clipfrac: list[float] = [] # PPO clip frac on clean-gated rollouts (retain-trick drift gauge)
|
||||
step_zkeep: list[float] = []; step_zresid: list[float] = []; step_zrout: list[float] = [] # unit shares per zone
|
||||
step_zkeepE: list[float] = []; step_zresidE: list[float] = []; step_zroutE: list[float] = [] # energy shares per zone
|
||||
# AUROC diagnostic: per-rollout pooled pos + its hack-label, accumulated across prompts.
|
||||
step_auroc_pos: list[float] = []; step_auroc_hack: list[bool] = []; step_cosU: list[float] = []
|
||||
# Solve-mix discrimination: routed-share (mean d) over hack-teacher vs solve-teacher rollouts.
|
||||
step_route_hackT: list[float] = []; step_route_solveT: list[float] = []
|
||||
|
||||
@@ -636,21 +713,24 @@ def main(cfg: Config) -> int:
|
||||
# generated under the loophole hint, so its prompt no longer matches.
|
||||
pool_rows = None if flip else (teacher_pool.get(prob["problem_id"]) if teacher_pool else None)
|
||||
solve_rows = None if flip else (solve_pool.get(prob["problem_id"]) if solve_pool else None)
|
||||
# Uncovered prompt (pool_rows is None) -> train student-only (else below). We
|
||||
# deliberately do NOT skip: the student must learn the hack on the whole env,
|
||||
# not only the few seeded prompts. Teacher mix happens only where the pool covers.
|
||||
if (pool_rows or solve_rows) and G_t > 0:
|
||||
# Mixed-pool: G_s live student + G_t cached teacher rollouts, the G_t split
|
||||
# between the SOLVE pool (honest demos) and the HACK pool. Solve teachers are
|
||||
# the gate's must-NOT-route examples (discrimination diagnostic below).
|
||||
n_solve = round(G_t * cfg.solve_mix_frac) if solve_rows else 0
|
||||
n_hack = G_t - n_solve
|
||||
if pool_rows is None: # only the solve pool covers this prompt
|
||||
n_solve, n_hack = G_t, 0
|
||||
# This prompt's slice of the step-level teacher budget (#31). Coverage/flips
|
||||
# drop it to 0 -> student-only (else below). We deliberately do NOT skip
|
||||
# uncovered prompts: the student must learn the hack on the whole env, not only
|
||||
# the seeded prompts. Total students = group - teachers so step rollouts stay P x group.
|
||||
g_t_alloc = gt_alloc[p_idx]
|
||||
if g_t_alloc > 0 and (pool_rows or solve_rows):
|
||||
if pool_rows and solve_rows:
|
||||
n_solve = min(solve_alloc[p_idx], g_t_alloc); n_hack = g_t_alloc - n_solve
|
||||
elif solve_rows: # only the solve pool covers this prompt
|
||||
n_solve, n_hack = g_t_alloc, 0
|
||||
else: # only the hack pool covers this prompt
|
||||
n_solve, n_hack = 0, g_t_alloc
|
||||
G_t_p = n_hack + n_solve
|
||||
G_s_p = group - G_t_p
|
||||
teacher_sample = _sample_rows(pool_rows, n_hack, rng) + _sample_rows(solve_rows, n_solve, rng)
|
||||
teacher_is_solve = [False] * n_hack + [True] * n_solve
|
||||
with torch.no_grad():
|
||||
out_s, n_abl = gen_students(enc, G_s)
|
||||
out_s, n_abl = gen_students(enc, G_s_p)
|
||||
# Build teacher tensor: live-tokenized prompt + cached completion.
|
||||
# Re-tokenizing the prompt live makes the pool robust to chat-template /
|
||||
# tokenizer drift between the pool-generation model and the current student
|
||||
@@ -668,15 +748,16 @@ def main(cfg: Config) -> int:
|
||||
if out_t.shape[1] < L:
|
||||
out_t = F.pad(out_t, (0, L - out_t.shape[1]), value=pad_id)
|
||||
gen_out = torch.cat([out_s, out_t], dim=0)
|
||||
is_student = [True] * G_s + [False] * G_t
|
||||
is_student = [True] * G_s_p + [False] * G_t_p
|
||||
# gen_students puts the ablated (deploy-mode) rollouts LAST among the
|
||||
# G_s student rows; teacher rows are never ablated.
|
||||
is_ablated = [False] * (G_s - n_abl) + [True] * n_abl + [False] * G_t
|
||||
# student rows; teacher rows are never ablated.
|
||||
is_ablated = [False] * (G_s_p - n_abl) + [True] * n_abl + [False] * G_t_p
|
||||
else:
|
||||
G_s_p = group # no teacher this prompt -> full group of students
|
||||
with torch.no_grad():
|
||||
gen_out, n_abl = gen_students(enc, G_s) # G_s == group when no teacher
|
||||
gen_out, n_abl = gen_students(enc, G_s_p)
|
||||
is_student = [True] * gen_out.shape[0]
|
||||
is_ablated = [False] * (G_s - n_abl) + [True] * n_abl
|
||||
is_ablated = [False] * (G_s_p - n_abl) + [True] * n_abl
|
||||
model.config.use_cache = False
|
||||
merged = gen_out
|
||||
completions = gen_out[:, plen:]
|
||||
@@ -708,7 +789,7 @@ def main(cfg: Config) -> int:
|
||||
rs, hack_flags, gt_flags, fmt_flags = [], [], [], []
|
||||
# Live-grade only student completions; teacher uses cached labels for
|
||||
# reproducibility and zero-cost re-use.
|
||||
n_live_grade = G_s if teacher_pool else len(texts)
|
||||
n_live_grade = G_s_p # grade only the student rows; teachers use cached labels
|
||||
for gi, t in enumerate(texts[:n_live_grade]):
|
||||
r = compute_reward(
|
||||
t, canonical_solution=prob["canonical"], gt_tests=prob["gt_tests"],
|
||||
@@ -808,12 +889,9 @@ def main(cfg: Config) -> int:
|
||||
completion_ids,
|
||||
).detach()
|
||||
|
||||
# Pin the block masks for the non-gated arms BEFORE the grad-carrying forward:
|
||||
# none -> (0,0): quarantine off fwd+bwd; only the deployed block trains
|
||||
# (capacity/structure-matched vanilla, no shrinkage confound).
|
||||
# absorb -> (1,0): both blocks train on every rollout, no gate -- isolates
|
||||
# the value of the gate+masks vs absorption alone.
|
||||
# routeV leaves mask=None here so the gate pass sees an unmasked forward.
|
||||
# Pin block masks for the non-gated arms BEFORE the grad-carrying forward
|
||||
# (arm semantics: train_config.py docstring): none -> (0,0), absorb -> (1,0).
|
||||
# routeV leaves mask=None so the gate pass sees an unmasked forward.
|
||||
if is_vanilla:
|
||||
_z = torch.zeros(merged.shape[0], device=device)
|
||||
for info in wrappers.values():
|
||||
@@ -857,11 +935,17 @@ def main(cfg: Config) -> int:
|
||||
# A.grad/B.grad untouched, so nothing to zero between passes.
|
||||
gates = [info["layer"]._lora2r_gate for info in wrappers.values()]
|
||||
c_grads = torch.autograd.grad(loss, gates)
|
||||
m_vec, d_vec, f3, w3 = _lora2r_gate_labels(c_grads, merged.shape[0])
|
||||
m_vec, d_vec, f3, w3, pos_vec, cosU = _lora2r_gate_labels(c_grads, merged.shape[0])
|
||||
step_flagged.append(d_vec.mean().item()) # hack share (the routed-out fraction)
|
||||
_kn, _rn, _on, _ke, _re, _oe = _zone_stats(f3, w3)
|
||||
step_zkeep.append(_kn); step_zresid.append(_rn); step_zrout.append(_on)
|
||||
step_zkeepE.append(_ke); step_zresidE.append(_re); step_zroutE.append(_oe)
|
||||
# AUROC diagnostic: pos as a hack-detector vs the hack-label (student
|
||||
# exploited + teacher cached). merged order is [students; teachers], the
|
||||
# same order hack_flags was built in, so pos_vec aligns with hack_flags.
|
||||
step_auroc_pos.extend(pos_vec.detach().cpu().tolist())
|
||||
step_auroc_hack.extend(bool(h) for h in hack_flags)
|
||||
step_cosU.append(cosU)
|
||||
# Solve-mix discrimination: teachers are the LAST G_t rows of merged; split
|
||||
# their routed-share (mean d) by source. A discriminating gate routes the
|
||||
# hack teachers (d->1) and KEEPS the solve teachers (d->0); equal shares mean
|
||||
@@ -947,6 +1031,7 @@ def main(cfg: Config) -> int:
|
||||
opt.zero_grad(set_to_none=True) # extract leaves .grad populated
|
||||
if _was_training:
|
||||
model.train()
|
||||
route_pos_window.clear() # positions were measured vs the OLD v_grad; flush
|
||||
refr = "rfr"
|
||||
|
||||
# ── periodic held-out eval (deploy = quarantine ablated) ──
|
||||
@@ -957,6 +1042,7 @@ def main(cfg: Config) -> int:
|
||||
# Save and restore RNG so fixed-seed validation cannot perturb training.
|
||||
_cpu_rng = torch.get_rng_state()
|
||||
_cuda_rng = torch.cuda.get_rng_state_all() if torch.cuda.is_available() else None
|
||||
_t_ev = time.perf_counter()
|
||||
torch.manual_seed(EVAL_GEN_SEED)
|
||||
ev_tr = eval_hack_solve(model, tok, val_problems, val_idxs, gen_cfg_eval, device, max_new,
|
||||
cfg.eval_batch_size)
|
||||
@@ -967,6 +1053,7 @@ def main(cfg: Config) -> int:
|
||||
cfg.eval_batch_size)
|
||||
else:
|
||||
ev_dp = ev_tr
|
||||
_t_ev = time.perf_counter() - _t_ev # wall time of the eval block (quarantine on + off)
|
||||
torch.set_rng_state(_cpu_rng)
|
||||
if _cuda_rng is not None:
|
||||
torch.cuda.set_rng_state_all(_cuda_rng)
|
||||
@@ -984,7 +1071,8 @@ def main(cfg: Config) -> int:
|
||||
should = ("quarantine-ablated hack < quarantine-enabled hack; ELSE routing isn't capturing it"
|
||||
if has_quarantine else "deploy == train (no quarantine)")
|
||||
logger.info(
|
||||
f"step {step} VAL-eval (n={ev_dp['n']}): quarantine-enabled hack={ev_tr['hack']:.3f} "
|
||||
f"step {step} VAL-eval (n={ev_dp['n']}, {_t_ev:.0f}s = {_t_ev/60:.1f}min): "
|
||||
f"quarantine-enabled hack={ev_tr['hack']:.3f} "
|
||||
f"solve={ev_tr['solve']:.3f} | deployed/quarantine-ablated hack={hack_deployed:.3f} "
|
||||
f"solve={solve_deployed:.3f}. SHOULD: {should}")
|
||||
if step == 0 and ev_tr["solve"] >= 0.9:
|
||||
@@ -1071,6 +1159,8 @@ def main(cfg: Config) -> int:
|
||||
"lp_t": lp_t_mean if n_t else None,
|
||||
"loss": agg_loss,
|
||||
"gn": gn,
|
||||
"auroc": _auroc(step_auroc_pos, step_auroc_hack),
|
||||
"cosU": (sum(step_cosU) / len(step_cosU)) if step_cosU else float("nan"),
|
||||
"qmass": q_egy,
|
||||
"keep": (sum(step_zkeep) / len(step_zkeep)) if step_zkeep else float("nan"),
|
||||
"resid": (sum(step_zresid) / len(step_zresid)) if step_zresid else float("nan"),
|
||||
|
||||
@@ -49,6 +49,14 @@ class Config:
|
||||
# (alignment to ANY known hack sub-mode) -- catches multi-modal hack signal one mean
|
||||
# washes out. k=1 stays mean-diff (not SVD-top-1) so "mean-mass vs top-k" is a clean A/B.
|
||||
v_grad_k: int = 1
|
||||
# Online-stats gate: route by live QUANTILES of the pooled cosine-to-v_grad, not the
|
||||
# authored absolute band (whose p75-hack edge live hacks never reach -> rout~0). Each
|
||||
# step the top route_quantile tail -> hack (deployed detached), bottom -> keep, middle
|
||||
# -> mid. route_window = sliding buffer of recent pooled positions, flushed on refresh
|
||||
# (positions are measured against one v_grad). Direction stays authored-only; only the
|
||||
# threshold follows the live distribution. TODO(#30): center+width calibration vs fixed %.
|
||||
route_quantile: float = 0.05
|
||||
route_window: int = 512
|
||||
# Haar-random direction control (placebo): same routing machinery, no pair signal.
|
||||
routeV_random_v_seed: int | None = None
|
||||
rollout_ablate_frac: float = 0.0
|
||||
@@ -71,7 +79,10 @@ class Config:
|
||||
|
||||
eval_ablate_every: int = 0
|
||||
eval_n_prompts: int = 32
|
||||
eval_batch_size: int = 2
|
||||
# HF generate + 252 per-module lora2r hooks dispatch Python per decode token, so eval
|
||||
# is GPU-starved (~19% util at bs=2). Bigger batch amortizes that fixed per-call hook
|
||||
# cost across more sequences (32 prompts -> 4 batches not 16) -> ~3x faster inline eval.
|
||||
eval_batch_size: int = 8
|
||||
save_ckpt_every: int = 10
|
||||
out_tag: str = ""
|
||||
|
||||
|
||||
@@ -2793,6 +2793,7 @@ dependencies = [
|
||||
{ name = "numpy" },
|
||||
{ name = "peft" },
|
||||
{ name = "polars" },
|
||||
{ name = "python-dotenv" },
|
||||
{ name = "tabulate" },
|
||||
{ name = "torch" },
|
||||
{ name = "tqdm" },
|
||||
@@ -2821,6 +2822,7 @@ requires-dist = [
|
||||
{ name = "numpy", specifier = "<2.0" },
|
||||
{ name = "peft", specifier = ">=0.13" },
|
||||
{ name = "polars", specifier = ">=1.0" },
|
||||
{ name = "python-dotenv", specifier = ">=1.2.2" },
|
||||
{ name = "tabulate", specifier = ">=0.9" },
|
||||
{ name = "torch", specifier = ">=2.4" },
|
||||
{ name = "tqdm", specifier = ">=4.66" },
|
||||
|
||||
Reference in New Issue
Block a user