mirror of
https://github.com/wassname/evil_MoE.git
synced 2026-06-27 16:45:42 +08:00
feat: single fail-fast config-validation block; consolidate scattered checks
_validate_config rejects method-irrelevant/contradictory options before the model load (routeV-only knobs on non-routeV, top_k>1 off grad_cosine, v_hack_path off erase, lora adapter on unwired arms). Removes the duplicate inline lora check, the vanilla v_hack_path warn-and-ignore (now a hard error), and the inline top_k assert -- one canonical place. Re-extracted v_hack_smoke against the new authored default (sha guard caught the orphaned cache). Smoke green; bad combo raises. Co-Authored-By: Claudypoo <288921227+claudypoo@users.noreply.github.com>
This commit is contained in:
Binary file not shown.
+78
-214
@@ -63,20 +63,13 @@ from .train_config import Config, FastConfig, FastLoraConfig, FullConfig, SmokeC
|
||||
|
||||
CACHE_ROOT = Path("svd_cache")
|
||||
OUT_DIR = Path("out")
|
||||
# out/ is sorted by datatype (see docs/spec/20260530_out_dir_reorg.md): extracted
|
||||
# bases under vhack/, teacher pools under pools/, per-train-run checkpoints under
|
||||
# runs/<run_id>/. Read paths (v_hack, teacher pool) come in as explicit args.
|
||||
# Keep reusable inputs separate from per-run outputs; see docs/spec/20260530_out_dir_reorg.md.
|
||||
VHACK_DIR = OUT_DIR / "vhack"
|
||||
RUNS_DIR = OUT_DIR / "runs"
|
||||
# DATA (the LeetCode dataset path) lives in data.py, imported above.
|
||||
# setup_logging + StepLogger live in tablelog.py, imported above.
|
||||
|
||||
|
||||
def _haar_unit_dirs(v_grad: dict, seed: int, device) -> dict:
|
||||
"""Per-module Haar-random unit vectors matching v_grad's shapes -- the OUT-OF-SUBSPACE
|
||||
directionality control for routeV (~0 cos with the hack dir by concentration of measure,
|
||||
not by being a 'cleaner' placebo). Seeded + sorted-name iteration so it is reproducible
|
||||
and a refresh regenerates the identical direction (no-op). See Config.routeV_random_v_seed."""
|
||||
"""Build the reproducible out-of-subspace directionality control for routeV."""
|
||||
g = torch.Generator().manual_seed(seed)
|
||||
out = {}
|
||||
for name in sorted(v_grad):
|
||||
@@ -86,11 +79,7 @@ def _haar_unit_dirs(v_grad: dict, seed: int, device) -> dict:
|
||||
|
||||
|
||||
def _zone_stats(f: torch.Tensor, w: torch.Tensor) -> tuple[float, ...]:
|
||||
"""Split routing units into the three band zones by routed fraction f in [0,1]:
|
||||
f==0 keep (cos below lower), 0<f<1 resid (cos inside band, partial), f==1 rout
|
||||
(cos above upper). Returns (keep, resid, rout) UNIT shares and (keepE, residE, routE)
|
||||
ENERGY shares (w = per-unit grad norm). A unit = a rollout (per-rollout mode) or a
|
||||
token (per-token mode); the energy view is unit-agnostic."""
|
||||
"""Return unit and gradient-energy shares below, inside, and above the routing band."""
|
||||
if f.numel() == 0:
|
||||
return (float("nan"),) * 6
|
||||
lo, hi = (f == 0), (f == 1)
|
||||
@@ -101,38 +90,12 @@ def _zone_stats(f: torch.Tensor, w: torch.Tensor) -> tuple[float, ...]:
|
||||
|
||||
|
||||
def route_band_edges(raw_grads: dict, v_grad: dict, device) -> dict[str, tuple[float, float]]:
|
||||
"""Per-module routing MARGIN band (lower, upper) from the contrastive pairs ALONE -- the
|
||||
pair-calibrated replacement for the old live-detector τ. A live rollout's cos(g_b, v_grad)
|
||||
below lower is kept whole, above upper is fully routed, in between ramps. raw_grads carries
|
||||
the train-pair per-pair δS grads as `hack/{name}` / `clean/{name}` [n_pairs, r]; cosine is
|
||||
scale-invariant so the extract's length-normalised NLL grads and the live token-sum grads
|
||||
are comparable here.
|
||||
"""Calibrate an absolute routing band from authored pairs only.
|
||||
|
||||
Edges (the precision/confident-tail band; route only the obvious hack tail, keep the
|
||||
ambiguous middle, let absorption generalise -- gradient_routing.md L420, SGTM tolerates
|
||||
~40% undiscovered with leak<0.02, Fig 5b). Both are p75, NOT min/max: with only ~10 pairs
|
||||
the extremes are single-sample and noisy, and they make the band route either everything
|
||||
(min clean) or nothing (max clean) on one outlier. This is an ABSOLUTE cos threshold (same
|
||||
every batch), so a clean batch lands below it and routes ~nothing while a hacky batch routes
|
||||
its tail -- it does NOT have the per-batch-quantile pathology of routing the top-q of an
|
||||
all-clean batch.
|
||||
lower = p75 clean-pair cosine. Precision-leaning floor: only the live tail above the
|
||||
clean cluster's upper quartile routes. Routing clean is the expensive error
|
||||
(gradient_routing.md Fig 5-right: retain cost ∝ routed mass); under-routing is
|
||||
cheap (absorption covers it), so we sit high but back off max for outlier safety.
|
||||
upper = p75 hack-pair cosine. Saturates where hacks cluster; robust to one weak hack pair
|
||||
(min(hack) would invert the band into a hard aggressive step).
|
||||
If pairs overlap (p75 clean >= p75 hack) the consumer's max(upper-lower,1e-6) collapses to
|
||||
a near-hard step at the lower edge -- the honest degenerate of an empty margin.
|
||||
|
||||
KNOWN RISK (watch frout/rout in the first steps): the pairs are hand-authored and
|
||||
off-distribution, so their cosines are wider and shifted HIGH relative to live rollouts
|
||||
(job8 wide-band run: live median cos ≈ -0.06, below the pair-hack cluster). A pair-scale
|
||||
margin band can therefore sit above the whole live distribution and route ~nothing. If rout
|
||||
collapses, the fix is to calibrate to the LIVE cos distribution (route the top-q live cos
|
||||
quantile) instead of the pair scale -- still no-cheat (no detector/oracle labels a rollout,
|
||||
just a quantile of cos-to-pair-vec). With a Haar-random v_grad the band closes (real-vs-
|
||||
random discriminator)."""
|
||||
Clean/hack p75 edges avoid single-pair extremes and route only the confident
|
||||
hack-ward tail. Pair/live shift can still make routing idle; inspect `routE`.
|
||||
See docs/papers/grad_routing/paper_sgtm.md.
|
||||
"""
|
||||
band = {}
|
||||
for name in v_grad:
|
||||
v = v_grad[name].detach().cpu().float()
|
||||
@@ -145,11 +108,7 @@ def route_band_edges(raw_grads: dict, v_grad: dict, device) -> dict[str, tuple[f
|
||||
|
||||
|
||||
def build_act_vote_dirs(model, wrappers, tok, pairs, device):
|
||||
"""act_vote gate: per-module ACTIVATION direction As_dir = unit(mean_pairs(As_hack -
|
||||
As_clean)) where As = Vh@x completion-mean; module weight act_w = |As_D|; and a GLOBAL
|
||||
vote band (lower=p75 clean-pair vote, upper=p75 hack-pair vote). Mirrors
|
||||
diag_cosine_dist.py's act/vote, no oracle (labels live only on the authored pairs).
|
||||
Caller sets model.eval(). Returns (As_dir[device], act_w, (lower, upper))."""
|
||||
"""Build the authored-pair activation vote; no live rollout labels enter the gate."""
|
||||
names = list(wrappers)
|
||||
As_cap: dict[str, torch.Tensor] = {}
|
||||
st = {"plen": 0}
|
||||
@@ -197,14 +156,8 @@ def build_act_vote_dirs(model, wrappers, tok, pairs, device):
|
||||
return As_dir, act_w, vote_band
|
||||
|
||||
|
||||
# eval_hack_solve lives in .eval (imported above) -- single canonical eval used by both
|
||||
# the in-run periodic/final eval AND scripts/rescore_deploy.py: applies the train/test
|
||||
# token gap (randomize_eval_markers) and returns both hack metrics (strict + vendor vhack).
|
||||
|
||||
# 2-char env_mode codes for compact per-mode hack columns (hk_rt, hk_xc, ...).
|
||||
# Fixed eval generation seed: every eval (periodic + final) seeds gen with this so all
|
||||
# arms/steps share common random numbers (sampling noise frozen -> comparable). Distinct
|
||||
# from cfg.seed (which seeds training); eval is a measurement, not learning.
|
||||
# Fix evaluation sampling across steps and arms without perturbing the training RNG.
|
||||
EVAL_GEN_SEED = 12345
|
||||
|
||||
MODE_CODE: dict[str, str] = {
|
||||
@@ -214,10 +167,31 @@ MODE_CODE: dict[str, str] = {
|
||||
}
|
||||
|
||||
|
||||
def _validate_config(cfg: Config) -> None:
|
||||
"""Reject ignored or contradictory experiment settings before model load."""
|
||||
is_routeV = cfg.intervention in ("routeV", "routeV_per_token")
|
||||
routeV_only = {
|
||||
"routeV_random_v_seed": cfg.routeV_random_v_seed is not None,
|
||||
"routeV_gate (non-default)": cfg.routeV_gate != "grad_cosine",
|
||||
"routeV_absorb_all": cfg.routeV_absorb_all,
|
||||
"routeV_top_k>1": cfg.routeV_top_k > 1,
|
||||
}
|
||||
if not is_routeV:
|
||||
set_routeV_only = [k for k, was_set in routeV_only.items() if was_set]
|
||||
if set_routeV_only:
|
||||
raise ValueError(f"routeV-only options set on intervention={cfg.intervention}: "
|
||||
f"{set_routeV_only} -- they would be silently ignored")
|
||||
if cfg.routeV_top_k > 1 and (cfg.routeV_gate != "grad_cosine" or cfg.intervention == "routeV_per_token"
|
||||
or cfg.routeV_absorb_all):
|
||||
raise ValueError("routeV_top_k>1 is implemented only for the per-rollout grad_cosine gate")
|
||||
if cfg.v_hack_path is not None and cfg.intervention != "erase":
|
||||
raise ValueError(f"--v-hack-path is an erase-arm option; ignored on intervention={cfg.intervention}")
|
||||
if cfg.adapter == "lora_frozen_b" and cfg.intervention not in ("none", "routeV", "routeV_per_token"):
|
||||
raise ValueError(f"lora_frozen_b adapter not wired for intervention={cfg.intervention}")
|
||||
|
||||
|
||||
def main(cfg: Config) -> int:
|
||||
# Read the chosen preset's settings off the config, then set up the run. The
|
||||
# subclass dataclasses (SmokeConfig / FastConfig / FullConfig) carry the preset
|
||||
# defaults, so here we just read them off cfg directly.
|
||||
_validate_config(cfg)
|
||||
model_name = cfg.model; steps = cfg.steps; group = cfg.group
|
||||
max_new = cfg.max_new; n_problems = cfg.n_problems; beta = cfg.beta
|
||||
prompts_per_step = cfg.prompts_per_step
|
||||
@@ -228,7 +202,7 @@ def main(cfg: Config) -> int:
|
||||
|
||||
torch.manual_seed(cfg.seed)
|
||||
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
|
||||
# BLUF up front: argv + setup + verbose-log pointer so a tail-reader sees context.
|
||||
# Log enough run identity up front to interpret detached logs.
|
||||
logger.info(f"argv: {' '.join(sys.argv)}")
|
||||
logger.info(f"verbose log: {verbose_log}")
|
||||
logger.info(
|
||||
@@ -237,8 +211,7 @@ def main(cfg: Config) -> int:
|
||||
f"unbiased={cfg.unbiased} seed={cfg.seed} device={device}"
|
||||
)
|
||||
|
||||
# Load the tokenizer and the frozen base model. We adapt this model but never
|
||||
# train its weights directly.
|
||||
# Only adapter parameters train; the base model remains frozen.
|
||||
tok = AutoTokenizer.from_pretrained(model_name)
|
||||
if tok.pad_token_id is None: tok.pad_token = tok.eos_token
|
||||
|
||||
@@ -251,23 +224,13 @@ def main(cfg: Config) -> int:
|
||||
dtype=torch.float32 if cpu else torch.bfloat16,
|
||||
attn_implementation="sdpa" if cpu else "flash_attention_2",
|
||||
).to(device)
|
||||
# No gradient checkpointing: grad-accum forwards one G-group at a time, so peak
|
||||
# activation memory fits at G=6 on 96GB without recompute. δS is a leaf inside
|
||||
# W' = W + U diag(δS) Vᵀ, so it gets grad directly (no enable_input_require_grads).
|
||||
# use_cache toggles per generate call: True for decode, False for the loss forwards.
|
||||
# Generation enables KV cache; loss forwards disable it to avoid unused state.
|
||||
model.config.use_cache = False
|
||||
|
||||
# ── adapter: δS (kept) + δS_hack (quarantine). antipasto=diagonal[r]; lora_frozen_b=A[r,d_in] ──
|
||||
is_routeV = cfg.intervention in ("routeV", "routeV_per_token")
|
||||
is_per_token = cfg.intervention == "routeV_per_token"
|
||||
is_lora = cfg.adapter == "lora_frozen_b"
|
||||
if is_lora and cfg.intervention not in ("none", "routeV", "routeV_per_token"):
|
||||
# erase projects against an SVD-basis v_hack; LoRA-frozen-B has no such
|
||||
# basis (routing lives in the random-B bottleneck via v_grad). Only none + routeV
|
||||
# are wired. Fail loud rather than silently take the AntiPaSTO projection path.
|
||||
raise NotImplementedError(
|
||||
f"adapter=lora_frozen_b supports intervention in (none, routeV, routeV_per_token), "
|
||||
f"not {cfg.intervention!r}")
|
||||
is_lora = cfg.adapter == "lora_frozen_b" # arm/adapter compatibility checked in _validate_config
|
||||
if is_lora:
|
||||
wrappers = wrap_model_with_lora_frozen_b(
|
||||
model, model_name, r=cfg.lora_r, b_seed=cfg.lora_b_seed, grad_probe=is_routeV)
|
||||
@@ -276,35 +239,26 @@ def main(cfg: Config) -> int:
|
||||
model, model_name, CACHE_ROOT, device,
|
||||
grad_probe=is_routeV, # routeV needs the per-rollout δS gate probe
|
||||
)
|
||||
# δS_hack only gets a grad under routeV; under none/erase its grad stays None, so AdamW skips
|
||||
# it and it stays exactly 0 (forward adds 0 -> identity).
|
||||
# δS_hack receives gradients only under routeV and is removed at deployment.
|
||||
delta_params = [info["delta_S"] for info in wrappers.values()]
|
||||
delta_hack_params = [info["delta_S_hack"] for info in wrappers.values()]
|
||||
logger.info(f"trainable delta_S: {sum(p.numel() for p in delta_params):,} "
|
||||
f"(+{sum(p.numel() for p in delta_hack_params):,} delta_S_hack quarantine)")
|
||||
|
||||
# ── hack direction: v_hack (erase) or v_grad (routeV) ──
|
||||
# Vanilla (none) is pure GRPO and ignores v_hack entirely (the cin/cout columns
|
||||
# are hidden, so v_hack=None just means no subspace machinery).
|
||||
# Vanilla is pure GRPO; erase uses v_hack; routeV uses v_grad.
|
||||
v_grad = None # set only by the routeV grad-mask branch below
|
||||
As_dir = act_w = vote_band = None # set only by the act_vote gate branch below
|
||||
_online_band: list = [None] # online_stats gate: (lo, hi) updated each step; None = use pair band
|
||||
if cfg.intervention in ("none", "routeV", "routeV_per_token"):
|
||||
if cfg.intervention == "none" and cfg.v_hack_path is not None:
|
||||
logger.info(f"vanilla arm: ignoring --v-hack-path={cfg.v_hack_path} "
|
||||
"(no projection; cin/cout diagnostics off)")
|
||||
v_hack = None # routeV routes via the mask, not erase grad surgery
|
||||
if is_routeV:
|
||||
# The persona pairs are the only "detector" (weak, self-supervised). They
|
||||
# produce the routing direction; no oracle, no gt_pass.
|
||||
# Authored pairs are the only routing-label source; live oracle labels never enter training.
|
||||
from .pairs_from_pool import load_pairs_json
|
||||
MASK_PAIRS = load_pairs_json(cfg.vhack_pairs_path)
|
||||
logger.info(f"routeV pairs: {cfg.vhack_pairs_path} -> {len(MASK_PAIRS)} pairs")
|
||||
model.eval()
|
||||
# gradient-space mean-diff. extract_v_hack gives per-pair GRPO gradients
|
||||
# on δS; v_grad = unit(mean(g_hack - g_clean)) per module, oriented
|
||||
# hack-ward (training reinforces hacks with the same sign, so a rollout
|
||||
# with cos(g_b, v_grad) above the calibrated tau is a reinforced hack).
|
||||
# Orient each module's mean pair-gradient difference hack-ward.
|
||||
from .extract_vhack_grad import extract_v_hack
|
||||
_, _, raw_grads, _ = extract_v_hack(
|
||||
model, tok, wrappers, MASK_PAIRS,
|
||||
@@ -319,8 +273,7 @@ def main(cfg: Config) -> int:
|
||||
v_grad = _haar_unit_dirs(v_grad, cfg.routeV_random_v_seed, device)
|
||||
logger.info(f"routeV grad: OVERRODE v_grad with Haar-random dirs "
|
||||
f"(seed={cfg.routeV_random_v_seed}) -- directionality control (H2 vs H4)")
|
||||
# Routing band from the pairs (against the FINAL v_grad, so a Haar override
|
||||
# collapses the band -- the real-vs-random discriminator).
|
||||
# Calibrate after any Haar override so the control covers the full routing pipeline.
|
||||
route_band = route_band_edges(raw_grads, v_grad, device)
|
||||
_mean_lo = sum(lo for lo, _ in route_band.values()) / len(route_band)
|
||||
_mean_hi = sum(hi for _, hi in route_band.values()) / len(route_band)
|
||||
@@ -331,9 +284,7 @@ def main(cfg: Config) -> int:
|
||||
f"Live cos below lower -> kept; above upper -> routed; between -> ramps (rout/frout). "
|
||||
f"SHOULD: rout > 0 in early steps; if rout~0 the pair band sits above live (median cos was "
|
||||
f"~-0.06 on the wide run) -> switch to a live-cos quantile gate.")
|
||||
# On a REAL v_grad the band must open (hack pairs align more than clean).
|
||||
# A collapsed/inverted real band = broken extraction silently mimicking the
|
||||
# random control -> fail loud. The Haar control is allowed to collapse.
|
||||
# Real directions must separate authored hack and clean pairs; Haar controls need not.
|
||||
if cfg.routeV_random_v_seed is None:
|
||||
assert _mean_bw > 0, (
|
||||
f"real v_grad gave non-positive mean band width {_mean_bw:+.3f}: "
|
||||
@@ -344,10 +295,7 @@ def main(cfg: Config) -> int:
|
||||
# path consumes these (asserted at config-validation below).
|
||||
v_grad_topk: dict[str, torch.Tensor] = {}
|
||||
route_band_topk: dict[str, tuple[float, float]] = {}
|
||||
if cfg.routeV_top_k > 1:
|
||||
assert cfg.routeV_gate == "grad_cosine" and not is_per_token \
|
||||
and not cfg.routeV_absorb_all, \
|
||||
"routeV_top_k>1 is implemented only for the per-rollout grad_cosine gate"
|
||||
if cfg.routeV_top_k > 1: # gate compatibility checked in _validate_config
|
||||
k = cfg.routeV_top_k
|
||||
for name in wrappers:
|
||||
gh = raw_grads[f"hack/{name}"].float() # [n_pairs, r]
|
||||
@@ -368,9 +316,7 @@ def main(cfg: Config) -> int:
|
||||
As_dir, act_w, vote_band = build_act_vote_dirs(model, wrappers, tok, MASK_PAIRS, device)
|
||||
model.train()
|
||||
else:
|
||||
# v_hack path resolution, most-specific first. The pairset (personas) is
|
||||
# the source of truth: pass --vhack-pairs-path and the hack file auto-loads
|
||||
# (auto-extracts if missing) -- no need to also pass --v-hack-path.
|
||||
# An explicit v_hack path overrides the cache derived from the pairset name.
|
||||
if cfg.v_hack_path is not None:
|
||||
v_hack_path = cfg.v_hack_path # explicit override (e.g. randomV control)
|
||||
else:
|
||||
@@ -388,8 +334,7 @@ def main(cfg: Config) -> int:
|
||||
n_heldout=2, device=device,
|
||||
)
|
||||
OUT_DIR.mkdir(exist_ok=True)
|
||||
# Combine V and S under one safetensors file with `_sv/{name}` prefix
|
||||
# for the singular values. load_v_hack splits them back apart.
|
||||
# Store basis vectors and singular values together; load_v_hack separates them.
|
||||
save_payload = {**v_hack_extracted, **{f"_sv/{n}": s for n, s in v_sv_extracted.items()}}
|
||||
save_file(save_payload, str(v_hack_path),
|
||||
metadata={"model": model_name,
|
||||
@@ -398,7 +343,6 @@ def main(cfg: Config) -> int:
|
||||
"tau_axis": str(cfg.v_hack_tau_axis), "schema": "v2_with_sv",
|
||||
"pairs_path": str(cfg.vhack_pairs_path),
|
||||
"pairs_sha256": pairset_sha256(cfg.vhack_pairs_path)})
|
||||
# extract zeros grads at exit; opt is built below so no opt-state taint.
|
||||
model.train() # restore train mode; eval was set only for the extract pass
|
||||
v_hack_cpu = load_v_hack(
|
||||
v_hack_path, model_name, wrappers, cfg.vhack_pairs_path,
|
||||
@@ -458,11 +402,9 @@ def main(cfg: Config) -> int:
|
||||
f"{len(partition)} problems across {len(by_mode)} modes: "
|
||||
f"{dict(sorted(by_mode.items()))}. Each problem graded by its own mode; "
|
||||
f"non-overlap holds (passed = gt_correct OR channel_i)."
|
||||
)
|
||||
)
|
||||
if cfg.teacher_modes is not None:
|
||||
# A5 no-cheat: drop teacher demos for held-out modes. The held-out
|
||||
# problems stay in load_problems (filter at line ~589 is skipped when
|
||||
# teacher_modes is set) and train on-policy. partition is required.
|
||||
# No-cheat generalization test: held-out modes remain on-policy and receive no demos.
|
||||
assert partition is not None, "teacher_modes needs a partition.json"
|
||||
kept = {pid: rows for pid, rows in teacher_pool.items()
|
||||
if partition[pid] in cfg.teacher_modes}
|
||||
@@ -482,14 +424,12 @@ def main(cfg: Config) -> int:
|
||||
)
|
||||
|
||||
# ── optimizer + schedule ──
|
||||
# δS and δS_hack share the lr (same shape, same basis, no per-group juggling).
|
||||
# Both knobs share an optimizer because they represent the same parameterization.
|
||||
opt = torch.optim.AdamW(
|
||||
delta_params + delta_hack_params,
|
||||
lr=lr, weight_decay=cfg.weight_decay, betas=(adam_beta1, adam_beta2),
|
||||
)
|
||||
# Linear warmup over `warmup_frac * steps`, then cosine decay to 0 over the rest.
|
||||
# Fraction-based so short presets (fast: 20 steps) don't spend half the run
|
||||
# under warmup. Canonical full-preset: 0.1 * 100 = 10 (matches ariahw config.py:141).
|
||||
# Fractional warmup preserves the intended schedule across preset lengths.
|
||||
warmup_steps = max(1, int(cfg.warmup_frac * steps))
|
||||
sched = torch.optim.lr_scheduler.SequentialLR(
|
||||
opt,
|
||||
@@ -502,41 +442,26 @@ def main(cfg: Config) -> int:
|
||||
)
|
||||
|
||||
# ── generation config ──
|
||||
# Qwen3.5 model card: non-thinking mode for text tasks.
|
||||
# temperature=1.0, top_p=1.0, top_k=20, min_p=0.0, presence_penalty=2.0,
|
||||
# repetition_penalty=1.0. enable_thinking=False is set on the chat template
|
||||
# below (safe no-op if the model's template doesn't support it).
|
||||
# Use the same sampling policy for training and evaluation.
|
||||
gen_cfg = GenerationConfig(
|
||||
max_new_tokens=max_new, do_sample=True,
|
||||
# T=0.7 matches ariahw reference (config.py:172). T=1.0 had hack emerging
|
||||
# too slowly: hack patterns are modal in the baked substrate; broad sampling
|
||||
# at T=1 dilutes them. Lower T expresses the substrate's hack propensity.
|
||||
# T=0.7 matches the Ariahw reference and exposes the substrate's modal hacks.
|
||||
temperature=0.7, top_p=1.0, top_k=20, min_p=0.0,
|
||||
repetition_penalty=1.0,
|
||||
num_return_sequences=G_s, pad_token_id=tok.pad_token_id,
|
||||
)
|
||||
# Eval-ablation config: student-only, 1 sample/prompt. The prompt is the independent
|
||||
# unit for a hack-RATE estimate (same-prompt completions share the mode -> correlated),
|
||||
# so we spend the gen budget on distinct prompts, not repeats. N=#prompts.
|
||||
# Evaluate one completion per prompt because prompts, not repeated samples, are independent.
|
||||
gen_cfg_eval = GenerationConfig(
|
||||
max_new_tokens=max_new, do_sample=True,
|
||||
temperature=0.7, top_p=1.0, top_k=20, min_p=0.0, repetition_penalty=1.0,
|
||||
num_return_sequences=1, pad_token_id=tok.pad_token_id,
|
||||
)
|
||||
|
||||
# SEEDED-SHUFFLE the train pool (not first-N-by-id): the files are id-sorted and the
|
||||
# lowest ids are the oldest, most pretraining-memorized problems -- the SAME
|
||||
# contamination that broke the eval (see RESEARCH_JOURNAL 2026-06-07 e). first-200-by-id
|
||||
# = the easiest 200, which lowers the hack incentive. A seeded-random sample is
|
||||
# representative (paper trains on all 992, base ~20%). seed=cfg.seed so arms paired at a
|
||||
# seed see the SAME pool, and the 3 broad seeds see different representative draws.
|
||||
# Seeded shuffle avoids the memorized low-id slice while preserving paired arms.
|
||||
all_problems = load_problems(10_000, env_modes=[cfg.env_mode], seed=cfg.seed,
|
||||
partition=partition, shuffle=True,
|
||||
unhackable_frac=cfg.unhackable_frac)
|
||||
# NO teacher-pool restriction: the student trains on the WHOLE env; the hack is seeded
|
||||
# on the teacher-covered prompts and must GENERALIZE off them. But the seed ids MUST be
|
||||
# in the sampled pool or seeding is a no-op -- so pin them, then fill to n_problems with
|
||||
# a representative shuffle.
|
||||
# Pin teacher-covered prompts, then train on the wider environment to test generalization.
|
||||
if teacher_pool:
|
||||
seeded = [p for p in all_problems if p["problem_id"] in teacher_pool]
|
||||
rest = [p for p in all_problems if p["problem_id"] not in teacher_pool]
|
||||
@@ -550,12 +475,8 @@ def main(cfg: Config) -> int:
|
||||
logger.info(f"teacher coverage: {n_cov}/{len(problems)} train prompts have cached "
|
||||
f"teacher hacks (rest train student-only); hack must generalize off the seeds")
|
||||
|
||||
# Deterministically split the paper's recency-held-out test file into periodic
|
||||
# validation and untouched final test. Previously the monitored 32 problems
|
||||
# were included in the final headline, leaking model-selection information.
|
||||
# gt_only is excluded from the hack eval (unhackable problems can't be hacked), EXCEPT
|
||||
# the no-loophole ceiling run where every problem is gt_only -- there we eval on gt_only
|
||||
# itself (hack is structurally ~0; solve is the ceiling number).
|
||||
# Periodic validation and final test are disjoint; final-test results never affect training.
|
||||
# Exclude gt_only from hack evaluation unless it is the entire no-loophole ceiling run.
|
||||
eval_modes = sorted({p["env_mode"] for p in problems} - {"gt_only"}) or ["gt_only"]
|
||||
val_problems, test_problems = load_eval_splits(eval_modes, cfg.eval_n_prompts)
|
||||
val_idxs, test_idxs = list(range(len(val_problems))), list(range(len(test_problems)))
|
||||
@@ -585,11 +506,7 @@ def main(cfg: Config) -> int:
|
||||
pad_id = tok.pad_token_id
|
||||
|
||||
def gen_students(enc, n: int) -> tuple[torch.Tensor, int]:
|
||||
"""Generate n student rollouts; the LAST `n_abl` rows have the quarantine
|
||||
ablated (deployed model -> can't hack -> explores solves).
|
||||
See Config.rollout_ablate_frac for why. frac=0 or non-quarantine arms ->
|
||||
a single plain generate (n_abl=0), identical to before. Returns (rows, n_abl)
|
||||
so the caller can mark the ablated tail (= free deploy-mode samples)."""
|
||||
"""Generate student rollouts, placing any quarantine-ablated samples last."""
|
||||
n_abl = round(n * cfg.rollout_ablate_frac) if is_routeV else 0
|
||||
parts = []
|
||||
if n - n_abl > 0:
|
||||
@@ -602,81 +519,53 @@ def main(cfg: Config) -> int:
|
||||
L = max(p.shape[1] for p in parts)
|
||||
return torch.cat([F.pad(p, (0, L - p.shape[1]), value=pad_id) for p in parts], dim=0), n_abl
|
||||
|
||||
# Per-step table streamed live (header once, row/step), same columns as the final
|
||||
# tabulate dump; the StepLogger legend below decodes each column. Per-source
|
||||
# (student/teacher) split on rew/gt/hack: teacher rows are frozen sanity, student
|
||||
# rows are the "is it learning?" signal. ref_eq = cumulative gens / 256 (the
|
||||
# canonical 16 prompts x 16 gens/step), so ref_eq=1.0 = one reference step's samples.
|
||||
# `ref_eq` compares cumulative sampling pressure to the 16x16 reference step.
|
||||
run_modes = sorted({p["env_mode"] for p in problems}, key=lambda m: list(MODE_CODE).index(m))
|
||||
step_logger = StepLogger(arm=cfg.arm, modes=run_modes, mode_code=MODE_CODE,
|
||||
show_ablate=cfg.rollout_ablate_frac > 0)
|
||||
REF_GENS_PER_STEP = 16 * 16 # ariahw/rl-rewardhacking config.py:num_prompts * num_generations
|
||||
# Use the resolved locals (preset defaults merged), not cfg.* which can be None.
|
||||
est_gens_per_step = prompts_per_step * group # before mixed-pool split
|
||||
logger.info(
|
||||
f"grad-pressure: {est_gens_per_step} gens/step vs reference {REF_GENS_PER_STEP} "
|
||||
f"-> {est_gens_per_step / REF_GENS_PER_STEP:.2f}x per step; "
|
||||
f"this run's {steps} steps ~= {steps * est_gens_per_step / REF_GENS_PER_STEP:.1f} reference steps."
|
||||
)
|
||||
# Legend (decodes only the columns this arm/mode-set actually shows) + blank
|
||||
# line + header in one log entry so the blank line keeps no timestamp prefix.
|
||||
# Print only the legend columns active for this arm and environment.
|
||||
logger.info("\n" + step_logger.legend() + "\n\n")
|
||||
logger.info(step_logger.header())
|
||||
|
||||
# Per-run artifacts grouped under runs/<ts>_<run_id>/ (same stem as the log,
|
||||
# so a run's checkpoint and log sit together). See out_dir_reorg spec.
|
||||
# Group all outputs from one run under the log's timestamped stem.
|
||||
run_dir = RUNS_DIR / verbose_log.stem
|
||||
run_dir.mkdir(parents=True, exist_ok=True)
|
||||
ckpt_path = run_dir / "train.safetensors"
|
||||
# Periodic held-out curve: one JSON row per eval step, train (knob-on) AND
|
||||
# deploy (knob-off) on the VAL set. The plot reads this; never log-scraped.
|
||||
# Store paired knob-on/off validation results as structured data.
|
||||
eval_curve_path = run_dir / "eval_curve.jsonl"
|
||||
first_hack_path = run_dir / "first_hack.safetensors"
|
||||
# Per-rollout audit log: every live-graded student completion (full text +
|
||||
# all hack-mechanism flags), one JSON object per line. Lets us eyeball
|
||||
# *which* hack the student found and whether the mechanism shifts mid-run
|
||||
# (e.g. it routes around v_hack into a category the pairs don't span).
|
||||
# Offline observability only -- never read back into training, so no-cheat
|
||||
# invariant holds. Truncated fresh each run.
|
||||
# Log live oracle labels for offline audit only; this file is never read by training.
|
||||
rollout_log_path = run_dir / "rollouts.jsonl"
|
||||
rollout_log_path.write_text("")
|
||||
first_hack_saved = False
|
||||
# routeV-grad routing band is built from the pairs at v_grad extraction time
|
||||
# (route_band[name] = (lower, upper)); see route_band_edges. No live-detector τ,
|
||||
# no EMA -- the pairs alone calibrate the gate, refreshed with v_grad.
|
||||
# Authored pairs alone calibrate the routeV band.
|
||||
last_gen_sample = None # first student rollout of the latest step (for collapse inspection)
|
||||
diverged_steps = 0 # consecutive steps with collapsed teacher ppl (divergence tripwire)
|
||||
lp_t_best = -float("inf") # coherence high-water mark (best teacher gen_logp seen)
|
||||
# ppl_t = exp(-lp_t) on the FIXED teacher rollouts is a free coherence gauge.
|
||||
# Divergence is a DROP from the run's own best, not an absolute level: a healthy
|
||||
# model sits near lp_t ~ -0.7 and craters to -11..-21 (token salad) on divergence.
|
||||
# Relative threshold also keeps smoke green (tiny-random sits at lp_t ~ -11.9 but
|
||||
# stays flat). Abort if lp_t falls this far below best for 2 steps (advantage dead).
|
||||
# Detect collapse by a relative log-probability drop on fixed teacher completions.
|
||||
DIVERGENCE_DROP = 5.0 # nats below best (e^5 ~ 150x worse ppl); never in healthy runs
|
||||
WARN_DROP = 3.0 # softer: log a warning before the hard abort
|
||||
dumped_hack_classes: set[str] = set() # first full example of each hack class -> verbose log
|
||||
teacher_dumped = False
|
||||
# Per-mode learning tracker (the substrate UAT: did the student learn EACH hack,
|
||||
# and at what step?). Keyed by env_mode. exploited / rollouts counted on STUDENT
|
||||
# rollouts only; first_step = step the student first exploited that mode.
|
||||
# Track whether and when the student learns each substrate mode.
|
||||
mode_rollouts: dict[str, int] = {}
|
||||
mode_hacks: dict[str, int] = {}
|
||||
mode_first_step: dict[str, int] = {}
|
||||
|
||||
def save_ckpt(rows: list[dict], path: Path | None = None) -> None:
|
||||
"""Rewrite the run checkpoint in place: trainable δS as tensors, per-step
|
||||
rows + config as JSON metadata (safetensors metadata is str->str only, so the
|
||||
non-tensor payload is JSON). Rows are also streamed to the log, so this is
|
||||
convenience, not the only copy. Mirrors the v_hack metadata idiom."""
|
||||
"""Save deployed and quarantine knobs with config and per-step metadata."""
|
||||
n_gens = sum(r["N"] for r in rows)
|
||||
# Aggregate from per-source columns (the combined hack/gt aggregates were
|
||||
# dropped from the per-step table as redundant; reconstruct here).
|
||||
# Reconstruct combined rates from the student/teacher source columns.
|
||||
hr = sum(r["hack_s"][0] + r["hack_t"][0] for r in rows) / max(1, n_gens)
|
||||
pr = sum(r["gt_s"][0] + r["gt_t"][0] for r in rows) / max(1, n_gens)
|
||||
# train.safetensors = δS only = the deployed adapter (quarantine ablated at
|
||||
# deploy), so existing δS-only loaders are unaffected. δS_hack (the quarantine
|
||||
# knob) goes to a sibling _hack.safetensors so a run can be re-scored knob-ON
|
||||
# (train) at higher n later without retraining; deploy re-score needs only δS.
|
||||
# Save the deployed knob separately so it can be evaluated without quarantine state.
|
||||
_ckpt = path or ckpt_path
|
||||
tensors = {n: info["delta_S"].detach().cpu().contiguous()
|
||||
for n, info in wrappers.items()}
|
||||
@@ -692,20 +581,12 @@ def main(cfg: Config) -> int:
|
||||
|
||||
save_ckpt([], path=run_dir / "ckpt_update0000.safetensors")
|
||||
|
||||
# disable=None: auto-disable the bar when stdout is NOT a tty (pueue, pipes,
|
||||
# file redirects). In those contexts every per-step `logger.info(step_logger.row)`
|
||||
# goes through tqdm.write, which redraws the bar -> half-drawn fragments
|
||||
# interleaved with the per-step table. Killing the bar off-tty leaves clean
|
||||
# per-step rows (they already carry step + sec, so the bar is redundant there);
|
||||
# an interactive terminal still gets the live bar. mininterval==maxinterval keeps
|
||||
# that interactive bar sparse (tqdm's default maxinterval=10 forces 10s redraws).
|
||||
# Disable tqdm off-TTY because structured per-step rows already report progress.
|
||||
pbar = tqdm(range(steps), desc=f"train {cfg.arm} {cfg.preset_name}",
|
||||
mininterval=120, maxinterval=120, disable=None)
|
||||
# ── training loop: generate -> grade -> backward -> project -> step ──
|
||||
for step in pbar:
|
||||
# Teacher-off curriculum: seed hacks via the teacher pool for the first N
|
||||
# steps, then cut to pure on-policy (G_t=0) so we test whether routeV holds
|
||||
# the suppression once the teacher crutch is gone. Monotonic: stays off.
|
||||
# After teacher-off, the remainder of training is purely on-policy.
|
||||
if cfg.teacher_off_step is not None and step >= cfg.teacher_off_step and G_t > 0:
|
||||
logger.info(f"teacher-off curriculum: step {step} >= {cfg.teacher_off_step} "
|
||||
f"-> G_t {G_t}->0, G_s {G_s}->{group} (pure on-policy from here)")
|
||||
@@ -713,12 +594,9 @@ def main(cfg: Config) -> int:
|
||||
t0 = time.time()
|
||||
opt.zero_grad(set_to_none=True)
|
||||
|
||||
# Accumulate across P prompts; one optimizer step at the end. Per-prompt
|
||||
# group of G generations is the GRPO advantage normalisation unit.
|
||||
# Each prompt group defines one GRPO advantage-normalization unit.
|
||||
agg_rew, agg_gt, agg_hack, agg_fmt = [], [], [], []
|
||||
# Per-mechanism flags. Only populated for student rollouts (teacher pool
|
||||
# cache predates E/D fields). Teacher slots padded with False so the lists
|
||||
# stay aligned with agg_is_student. Half-A/B totals filter on is_student.
|
||||
# Teacher cache lacks E/D labels, so aligned teacher slots remain false.
|
||||
agg_hack_E: list[bool] = []
|
||||
agg_hack_D: list[bool] = []
|
||||
step_rollouts: list[dict] = [] # student completions this step -> rollout_log_path
|
||||
@@ -728,33 +606,19 @@ def main(cfg: Config) -> int:
|
||||
agg_logp: list[float] = [] # per-rollout mean per-token gen_logp (student's logp on rollout tokens)
|
||||
agg_comp_lens, agg_finished = [], []
|
||||
n_zerovar = 0 # groups skipped for zero reward variance (all rollouts same reward).
|
||||
# Rises as a loophole saturates: every rollout hacks -> identical reward -> no
|
||||
# GRPO signal. Tracks the post-saturation signal-sparsity that drives lp_s collapse.
|
||||
agg_loss = 0.0
|
||||
diag_tail = None
|
||||
# Per-source grad accumulators: each prompt's backward is split into
|
||||
# student-only and teacher-only passes so we can compute cos_pre_s / cos_pre_t
|
||||
# separately (discriminator: does v_hack actually project hack grads
|
||||
# more than non-hack?). step_grad_combined = student + teacher and is
|
||||
# what the projection + optimizer step ultimately sees.
|
||||
# Split source gradients only to test whether the direction distinguishes teacher hacks.
|
||||
step_grad_s: dict[str, torch.Tensor] = {}
|
||||
step_grad_t: dict[str, torch.Tensor] = {}
|
||||
# routeV: the flagged rollouts' δS-grad contribution, accumulated per module
|
||||
# across prompts, parked into δS_hack.grad at injection (the quarantine,
|
||||
# deleted at deploy). Mirrors how proj.py parks route's removed component.
|
||||
# Accumulate routed gradient separately before injecting it into quarantine.
|
||||
step_grad_hack: dict[str, torch.Tensor] = {}
|
||||
# act_vote gate: ONE per-rollout routing fraction f_roll [G], shared across all
|
||||
# modules (the global activation vote, computed post-backward before the per-module
|
||||
# routing). 1-element list so the filter closure reads the current step's value.
|
||||
# The activation vote produces one routing fraction per rollout, shared by all modules.
|
||||
_step_f_roll: list[torch.Tensor | None] = [None]
|
||||
_step_absorb_f: list[torch.Tensor | None] = [None] # absorb_all: [G] 1=knob-on(route), 0=floor(keep)
|
||||
_step_online_cos: list[torch.Tensor] = [] # online_stats: per-module [G] cosines, cleared each step
|
||||
|
||||
# routeV: recover the per-rollout δS grad from the gate (c.grad = δS * g_b),
|
||||
# flag rollouts whose grad points hack-ward (cos(g_b, v_grad) > τ), and route
|
||||
# their contribution into δS_hack. Only axes where δS has moved (|δS| > GATE_EPS)
|
||||
# carry a reliable per-rollout split; near-zero axes keep the full grad, so
|
||||
# routing on a fresh axis lags ~1 step until δS grows there (A1 stale-mask trade-off).
|
||||
# Near-zero δS axes cannot recover per-rollout gradients, so routing lags one update there.
|
||||
GATE_EPS = 1e-6
|
||||
step_flagged: list[float] = []
|
||||
step_zkeep: list[float] = []; step_zresid: list[float] = []; step_zrout: list[float] = [] # unit shares per zone
|
||||
|
||||
Reference in New Issue
Block a user