mirror of
https://github.com/wassname/evil_MoE.git
synced 2026-06-27 16:15:35 +08:00
route2: pair-calibrated banded gate, drop live-detector tau + force-route
Replace the confounded route2 gate (hack_anchor force-routed teacher + weak-detector student rows by LABEL; EMA tau calibrated from a live detector over student rollouts at train time = a cheat) with a band calibrated from the contrastive pairs alone: lower = mean clean-pair cos(g, v_grad); upper = mean hack-pair cos per rollout: f = clamp((cos(g_b, v_grad) - lower)/(upper - lower), 0, 1) routed = sum_b f_b * g_b -> delta_S_hack; kept = g - routed -> delta_S v_grad is now the SOLE router: no detector or gt_pass touches routing, so "does v_hack generalize to held-out modes" is clean and random-vs-real is decisive. Band width (upper-lower) is itself the discriminator: smoke shows +0.289 real vs -0.014 Haar-random (collapsed). conservation routed+kept=g holds exactly; resid~0 in smoke (no hack leak into the deployed knob). - delete build_route2_anchors + EMA state (ema_hack/clean_cos, route2_tau) - add route_band_edges(); build at extract, rebuild on v_grad refresh - drop --gate-anchor-teacher-only config + retire scripts/verify_gate_anchor.py - teacher rollouts now route through the same band (not force-routed) - spec: add the mass-confound control (scientist review 2026-06-06) smoke-route2 + smoke-route2 --route2-random-v-seed=7 both pass; erase smoke green. Co-Authored-By: Claudypoo <288921227+claudypoo@users.noreply.github.com>
This commit is contained in:
@@ -168,8 +168,18 @@ SHOULD per step: live cos_b percentiles (p10/p50/p90) STRADDLE [lower, upper
|
||||
ELSE all below lower -> routes nothing; all above upper -> routes everything (miscalibrated).
|
||||
SHOULD per step: route fraction f mean ∈ (0,1), some mass at 0 and at 1. ELSE degenerate gate.
|
||||
SHOULD per step: resid = cos(delta_S.grad after routing, v1) ~ 0. ELSE hack leaking into the deployed knob.
|
||||
ALSO log routed mass: route -> mean f (fraction of grad routed); erase -> ‖removed‖/‖g‖ per step.
|
||||
```
|
||||
|
||||
Mass confound (scientist review, 2026-06-06). Real and random `v_hack` can suppress by
|
||||
DIFFERENT routes: the right direction, OR simply quarantining more gradient mass. Real `v1`
|
||||
aligns with the live hack gradient so it routes/removes more mass than a random direction
|
||||
(which aligns ~0), so a raw real-vs-random win partly conflates "right direction" with "more
|
||||
mass removed". Two defences, both cheap: (a) log the routed mass above for both conditions, so
|
||||
a reader sees whether real won at MATCHED mass; (b) if the gap is mass-driven, add a
|
||||
magnitude-matched random control (scale the random subtraction/route to remove the same norm
|
||||
as real). Defence (a) is mandatory; (b) only if (a) shows a mass gap.
|
||||
|
||||
## Implementation plan (src/vgrout/train.py)
|
||||
|
||||
Rollback tag `pre-routing-refactor`. erase already works; the code below is the route rewrite.
|
||||
|
||||
@@ -27,7 +27,6 @@ results:
|
||||
# zero-variance bails every step, leaving the loss path uncovered.
|
||||
smoke *ARGS:
|
||||
uv run python scripts/verify_rewards.py # grader gate: 3 env_modes x clean/hack
|
||||
uv run python scripts/verify_gate_anchor.py # route2 no-cheat gate: teacher-only anchor zeroes held-out labels
|
||||
uv run python scripts/verify_partition.py # no-cheat: partition clean + teacher_modes hands gate only known-mode demos
|
||||
BEARTYPE=1 {{ TRAIN }} smoke --intervention=erase \
|
||||
--v-hack-path=out/vhack/v_hack_smoke.safetensors \
|
||||
|
||||
@@ -1,68 +0,0 @@
|
||||
"""Unit check of the route2 τ-gate anchor (train.build_route2_anchors).
|
||||
|
||||
SHOULD: the no-cheat invariant holds at the gate. The anchor decides which rollouts
|
||||
are FORCE-routed into the deletable quarantine (hack_anchor); the rest only
|
||||
route if cos(g,v) > tau. ELSE: held-out classes are receiving detector labels
|
||||
and the A5 generalisation claim is contaminated.
|
||||
|
||||
The load-bearing invariant (the leak this guards):
|
||||
hacked_E is the run_tests detector. It cross-fires <=1.1% on held-out modes (false
|
||||
positives), so under the DEFAULT anchor those held-out rollouts are force-routed --
|
||||
a real, if small, label leak into the held-out class. The teacher-only anchor
|
||||
(gate_anchor_teacher_only=True) drops the student detector term entirely, so the
|
||||
held-out classes get PROVABLY zero detector labels: the airtight A5 control.
|
||||
|
||||
Scenario modelled: merged batch is [students ... teachers] (train.py is_student =
|
||||
[True]*G_s + [False]*G_t). One held-out student (idx 2) is a detector FALSE POSITIVE.
|
||||
"""
|
||||
from __future__ import annotations
|
||||
|
||||
import sys
|
||||
|
||||
import torch
|
||||
from loguru import logger
|
||||
|
||||
from vgrout.train import build_route2_anchors
|
||||
|
||||
DEV = torch.device("cpu")
|
||||
G_S, G_T = 5, 2
|
||||
IS_STUDENT = [True] * G_S + [False] * G_T # students lead, teachers follow
|
||||
# Held-out batch: the run_tests detector FALSE-fires on student idx 2 only (the leak).
|
||||
HACK_E_FLAGS = [False, False, True, False, False] # len == G_S
|
||||
|
||||
|
||||
def _check(name: str, cond: bool) -> bool:
|
||||
logger.info(f"{'PASS' if cond else 'FAIL'} {name}")
|
||||
return cond
|
||||
|
||||
|
||||
def main() -> int:
|
||||
ok = True
|
||||
|
||||
# DEFAULT anchor: the leak is present.
|
||||
ha, ca = build_route2_anchors(IS_STUDENT, HACK_E_FLAGS, teacher_only=False, device=DEV)
|
||||
ok &= _check("default: teachers always anchored", bool(ha[G_S:].all()))
|
||||
ok &= _check("default: flagged held-out student idx2 IS force-routed (the leak)",
|
||||
bool(ha[2]))
|
||||
ok &= _check("default: student detector labels reach the gate (sum>0)",
|
||||
int(ha[:G_S].sum()) == 1)
|
||||
ok &= _check("default: clean_anchor is the exact complement", bool((ca == ~ha).all()))
|
||||
|
||||
# TEACHER-ONLY anchor: the fix. Zero student detector labels.
|
||||
ha2, ca2 = build_route2_anchors(IS_STUDENT, HACK_E_FLAGS, teacher_only=True, device=DEV)
|
||||
ok &= _check("teacher_only: teachers still all anchored", bool(ha2[G_S:].all()))
|
||||
ok &= _check("teacher_only: ZERO student rollouts force-routed (no leak)",
|
||||
int(ha2[:G_S].sum()) == 0)
|
||||
ok &= _check("teacher_only: the held-out FP student idx2 is NOT routed",
|
||||
not bool(ha2[2]))
|
||||
ok &= _check("teacher_only: clean_anchor is the exact complement", bool((ca2 == ~ha2).all()))
|
||||
|
||||
# The fix only touches student labels: teacher anchoring is identical either way.
|
||||
ok &= _check("fix leaves teacher rows unchanged", bool((ha[G_S:] == ha2[G_S:]).all()))
|
||||
|
||||
logger.info("ALL PASS -- gate anchor no-cheat invariant holds" if ok else "FAILURES above")
|
||||
return 0 if ok else 1
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
sys.exit(main())
|
||||
+55
-83
@@ -140,13 +140,6 @@ class Config:
|
||||
seed: int = 41
|
||||
preserve_magnitude: bool = True
|
||||
gate_mode: Literal["one_sided", "no_gate", "reverse"] = "one_sided"
|
||||
# route2 airtight no-cheat control: anchor the τ-gate on TEACHER rows only, never
|
||||
# on hacked_E-flagged student rows. The run_tests detector cross-fires <=1.1% on
|
||||
# held-out modes (false positives), so the default anchor leaks ~1% of held-out
|
||||
# labels into routing. Teacher-only anchor gives the held-out classes PROVABLY zero
|
||||
# detector labels -- the strict A5 no-cheat test. v_grad and the τ-route-by-energy
|
||||
# path are unchanged; only the force-route-known-hacks term drops its student flags.
|
||||
gate_anchor_teacher_only: bool = False
|
||||
project_overshoot: float = 1.0 # remove overshoot*c_use@V; 1.0=just remove, 1.1=10% reversal of hack-ward grad
|
||||
# route/route2 exploration floor: fraction of student rollouts sampled with the
|
||||
# quarantine (δS_hack) ablated, i.e. from the DEPLOYED model. Intent: guard hack-
|
||||
@@ -335,23 +328,24 @@ def _haar_unit_dirs(v_grad: dict, seed: int, device) -> dict:
|
||||
return out
|
||||
|
||||
|
||||
def build_route2_anchors(is_student: list[bool], hack_E_flags: list[bool],
|
||||
teacher_only: bool, device) -> tuple[torch.Tensor, torch.Tensor]:
|
||||
"""τ-calibration anchors for the route2 gate (merged rows: students lead, teachers
|
||||
follow). hack_anchor = teacher rows OR (unless teacher_only) detector-flagged student
|
||||
rows; clean_anchor is the exact complement. hack_E_flags (len G_s) aligns with the
|
||||
leading student rows. teacher_only drops the student detector term so held-out classes
|
||||
get PROVABLY zero detector labels -- the airtight A5 no-cheat control. The default
|
||||
leaks: the run_tests detector cross-fires <=1.1% on held-out modes, force-routing those
|
||||
rollouts. Verified in scripts/verify_gate_anchor.py."""
|
||||
n = len(is_student)
|
||||
is_student_t = torch.as_tensor(is_student, dtype=torch.bool, device=device)
|
||||
flags = torch.zeros(n, dtype=torch.bool, device=device)
|
||||
if not teacher_only:
|
||||
m = min(n, len(hack_E_flags))
|
||||
flags[:m] = torch.as_tensor(list(hack_E_flags[:m]), dtype=torch.bool, device=device)
|
||||
hack_anchor = (~is_student_t) | flags
|
||||
return hack_anchor, ~hack_anchor
|
||||
def route_band_edges(raw_grads: dict, v_grad: dict, device) -> dict[str, tuple[float, float]]:
|
||||
"""Per-module routing band (lower, upper) from the contrastive pairs ALONE -- the
|
||||
pair-calibrated replacement for the old live-detector τ. lower = mean clean-pair cosine
|
||||
to v_grad; upper = mean hack-pair cosine. A live rollout's cos(g_b, v_grad) below lower
|
||||
is kept, above upper is routed, in between ramps (absorption). raw_grads carries the
|
||||
train-pair per-pair δS grads as `hack/{name}` / `clean/{name}` [n_pairs, r]. Cosine is
|
||||
scale-invariant so the extract's length-normalised NLL grads and the live token-sum grads
|
||||
are comparable here. With a Haar-random v_grad both edges collapse to ~0 -> band closes ->
|
||||
routing degenerates to a coin flip: band width is itself the real-vs-random discriminator."""
|
||||
band = {}
|
||||
for name in v_grad:
|
||||
v = v_grad[name].detach().cpu().float()
|
||||
gh = raw_grads[f"hack/{name}"].float() # [n_pairs, r]
|
||||
gc = raw_grads[f"clean/{name}"].float()
|
||||
ch = (gh @ v) / gh.norm(dim=1).clamp_min(1e-12) # [n_pairs] hack-pair cosines
|
||||
cc = (gc @ v) / gc.norm(dim=1).clamp_min(1e-12) # [n_pairs] clean-pair cosines
|
||||
band[name] = (cc.mean().item(), ch.mean().item()) # (lower, upper)
|
||||
return band
|
||||
|
||||
|
||||
@torch.no_grad()
|
||||
@@ -497,6 +491,13 @@ def main(cfg: Config) -> int:
|
||||
v_grad = _haar_unit_dirs(v_grad, cfg.route2_random_v_seed, device)
|
||||
logger.info(f"route2 grad: OVERRODE v_grad with Haar-random dirs "
|
||||
f"(seed={cfg.route2_random_v_seed}) -- directionality control (H2 vs H4)")
|
||||
# Routing band from the pairs (against the FINAL v_grad, so a Haar override
|
||||
# collapses the band -- the real-vs-random discriminator).
|
||||
route_band = route_band_edges(raw_grads, v_grad, device)
|
||||
_bw = [hi - lo for lo, hi in route_band.values()]
|
||||
logger.info(f"route2 band: edges from {len(route_band)} modules, "
|
||||
f"mean width(upper-lower)={sum(_bw)/len(_bw):+.3f} "
|
||||
f"(>0 = pairs separate; ~0 = random/degenerate)")
|
||||
model.train()
|
||||
else:
|
||||
# v_hack path resolution, most-specific first. The pairset (personas) is
|
||||
@@ -761,16 +762,9 @@ def main(cfg: Config) -> int:
|
||||
rollout_log_path.write_text("")
|
||||
first_hack_saved = False
|
||||
route_span_checked = False # R3: assert delta_S_hack.grad in span(V) once
|
||||
# route2-grad per-step calibrated routing threshold (spec
|
||||
# docs/spec/20260601_calibrated_tau_route2grad.md). tau = EMA midpoint of the
|
||||
# hack-cloud (teacher + detector-flagged student) and clean-cloud (not-flagged
|
||||
# student) cos(g_b, v_grad) per module. Rides the cin drift so a fixed cos>0
|
||||
# gate (a ~50% coin-flip in high-dim) is replaced by "above where known hacks
|
||||
# separate from clean". Persist across steps (EMA = cheap "last N hacks").
|
||||
ema_hack_cos: dict[str, float] = {}
|
||||
ema_clean_cos: dict[str, float] = {}
|
||||
route2_tau: dict[str, float] = {}
|
||||
EMA_BETA = 0.9
|
||||
# route2-grad routing band is built from the pairs at v_grad extraction time
|
||||
# (route_band[name] = (lower, upper)); see route_band_edges. No live-detector τ,
|
||||
# no EMA -- the pairs alone calibrate the gate, refreshed with v_grad.
|
||||
last_gen_sample = None # first student rollout of the latest step (for collapse inspection)
|
||||
diverged_steps = 0 # consecutive steps with collapsed teacher ppl (divergence tripwire)
|
||||
lp_t_best = -float("inf") # coherence high-water mark (best teacher gen_logp seen)
|
||||
@@ -871,13 +865,11 @@ def main(cfg: Config) -> int:
|
||||
# routing on a fresh axis lags ~1 step until δS grows there (A1 stale-mask trade-off).
|
||||
GATE_EPS = 1e-6
|
||||
step_flagged: list[float] = []
|
||||
step_tau: list[float] = [] # per-(prompt,module) calibrated route threshold
|
||||
step_hkgap: list[float] = [] # ema_hack_cos - ema_clean_cos (discrimination gauge)
|
||||
step_tau: list[float] = [] # median live cos_b (should sit inside the band)
|
||||
step_hkgap: list[float] = [] # band width upper-lower (pair separation; ~0 = random/degenerate)
|
||||
step_resid: list[float] = [] # cos(δS.grad AFTER routing, v_grad): hack-ward leak into deployed knob
|
||||
|
||||
def _route2_grad_filter(info, n_rollouts: int,
|
||||
hack_anchor: torch.Tensor,
|
||||
clean_anchor: torch.Tensor) -> torch.Tensor:
|
||||
def _route2_grad_filter(info, n_rollouts: int) -> torch.Tensor:
|
||||
g = info["delta_S"].grad # [r] summed over rollouts*tokens
|
||||
# The hook's gate c is per-token ([G*s, r]) because nn.Linear sees a
|
||||
# flattened batch. Sum each rollout's token gate-grads -> per-rollout
|
||||
@@ -894,39 +886,25 @@ def main(cfg: Config) -> int:
|
||||
g_b = torch.where(reliable, cg / dS_safe, torch.zeros_like(cg)) # [G, r] per-rollout
|
||||
vg = v_grad[name] # [r] unit, hack-ward
|
||||
cos_b = (g_b @ vg) / g_b.norm(dim=1).clamp_min(1e-12) # [G]
|
||||
# Calibrate the threshold to where KNOWN hacks separate from clean,
|
||||
# per module, EMA-smoothed across steps (rides the cin drift). A fixed
|
||||
# cos>0 gate is a ~50% coin-flip in high-dim (cos~0 for most rollouts).
|
||||
if hack_anchor.any():
|
||||
mu_h = cos_b[hack_anchor].mean().item()
|
||||
ema_hack_cos[name] = (EMA_BETA * ema_hack_cos[name] + (1 - EMA_BETA) * mu_h
|
||||
if name in ema_hack_cos else mu_h)
|
||||
if clean_anchor.any():
|
||||
mu_c = cos_b[clean_anchor].mean().item()
|
||||
ema_clean_cos[name] = (EMA_BETA * ema_clean_cos[name] + (1 - EMA_BETA) * mu_c
|
||||
if name in ema_clean_cos else mu_c)
|
||||
tau = (ema_hack_cos.get(name, 0.0) + ema_clean_cos.get(name, 0.0)) / 2
|
||||
route2_tau[name] = tau
|
||||
step_tau.append(tau)
|
||||
step_hkgap.append(ema_hack_cos.get(name, 0.0) - ema_clean_cos.get(name, 0.0))
|
||||
# Force-route known hacks (teacher + flagged student); τ-route the
|
||||
# ambiguous rest (incl. unknown B, which lands above τ if it shares the
|
||||
# v_grad direction). Do NOT force-keep clean_anchor: it is contaminated
|
||||
# with unknown B, which we WANT routed.
|
||||
flagged = (hack_anchor | (cos_b > tau)).float() # [G]
|
||||
step_flagged.append(flagged.mean().item())
|
||||
sub = torch.where(reliable, (cg * flagged.unsqueeze(1)).sum(0) / dS_safe,
|
||||
torch.zeros_like(g)) # flagged rollouts' contribution
|
||||
# Park the flagged contribution in δS_hack (deleted at deploy); δS keeps
|
||||
# only the unflagged. Capacity-balanced: both shape [r].
|
||||
step_grad_hack[name] = (step_grad_hack[name] + sub.detach().clone()
|
||||
if name in step_grad_hack else sub.detach().clone())
|
||||
g_keep = g - sub # the deployed knob's gradient
|
||||
# Residual hack-ward alignment of the KEPT grad. Disambiguates qE:
|
||||
# qE high + resid~0 = routing stripped the hack cleanly (dominant
|
||||
# teacher grad correctly quarantined); qE high + resid>0 = false
|
||||
# negatives leaked hack-ward grad into the deployed knob (the real
|
||||
# failure). vg is unit, so this is a plain cosine.
|
||||
# Banded gate, calibrated from the PAIRS only (route_band[name]): a rollout
|
||||
# whose grad cosine is below the clean edge is kept, above the hack edge is
|
||||
# routed, and in between ramps proportionally (the absorption zone). No live
|
||||
# detector, no teacher force-route -- v_grad is the sole router. f is the
|
||||
# routed FRACTION of this rollout's grad (0..1).
|
||||
lower, upper = route_band[name]
|
||||
f = ((cos_b - lower) / max(upper - lower, 1e-6)).clamp(0.0, 1.0) # [G]
|
||||
step_flagged.append(f.mean().item())
|
||||
step_hkgap.append(upper - lower)
|
||||
step_tau.append(cos_b.median().item()) # live cos centre vs the band
|
||||
routed = torch.where(reliable, (cg * f.unsqueeze(1)).sum(0) / dS_safe,
|
||||
torch.zeros_like(g)) # Σ_b f_b·g_b on reliable axes
|
||||
# Park the routed fraction in δS_hack (deleted at deploy); δS keeps the rest.
|
||||
# routed + g_keep = g exactly (unreliable axes: routed=0, kept whole).
|
||||
step_grad_hack[name] = (step_grad_hack[name] + routed.detach().clone()
|
||||
if name in step_grad_hack else routed.detach().clone())
|
||||
g_keep = g - routed # the deployed knob's gradient
|
||||
# Residual hack-ward alignment of the KEPT grad: ~0 = routing stripped the
|
||||
# hack cleanly; >0 = hack leaked into the deployed knob. vg is unit -> plain cosine.
|
||||
step_resid.append((g_keep @ vg / g_keep.norm().clamp_min(1e-12)).item())
|
||||
return g_keep
|
||||
|
||||
@@ -1236,22 +1214,15 @@ def main(cfg: Config) -> int:
|
||||
ptl_norm = (Lp * mask).sum(1) / mask.sum(1).clamp_min(1)
|
||||
loss = ptl_norm.sum() / (group * prompts_per_step)
|
||||
loss.backward()
|
||||
# route2: per-prompt anchor masks for the τ calibration. Hack cloud =
|
||||
# teacher rows (known-A hacks) + detector-flagged (hack_E) student rows;
|
||||
# clean cloud = not-flagged student rows (contaminated with unknown B by
|
||||
# design -> conservative τ; B still routes via cos>τ). hack_E_flags
|
||||
# (len G_s) aligns with the leading student rows of is_student.
|
||||
if is_route2:
|
||||
_ha, _ca = build_route2_anchors(
|
||||
is_student, hack_E_flags, cfg.gate_anchor_teacher_only, Lp.device)
|
||||
for name, info in wrappers.items():
|
||||
g = info["delta_S"].grad
|
||||
if g is None:
|
||||
continue
|
||||
# route2 routes here: strip flagged rollouts from δS.grad and
|
||||
# park them in δS_hack (via step_grad_hack in the filter).
|
||||
# route2 routes here: split each rollout's δS.grad by its cosine to
|
||||
# v_grad against the pair-calibrated band, park the routed fraction in
|
||||
# δS_hack (via step_grad_hack in the filter).
|
||||
if is_route2:
|
||||
g = _route2_grad_filter(info, merged.shape[0], _ha, _ca)
|
||||
g = _route2_grad_filter(info, merged.shape[0])
|
||||
step_grad_s[name] = (step_grad_s[name] + g.detach().clone()
|
||||
if name in step_grad_s
|
||||
else g.detach().clone())
|
||||
@@ -1381,6 +1352,7 @@ def main(cfg: Config) -> int:
|
||||
for name in wrappers: # update in place so _route2_grad_filter's closure sees it
|
||||
d = (raw_grads[f"hack/{name}"] - raw_grads[f"clean/{name}"]).mean(0)
|
||||
v_grad[name] = (d / d.norm().clamp_min(1e-12)).to(device)
|
||||
route_band = route_band_edges(raw_grads, v_grad, device) # rebuild band on the fresh v_grad
|
||||
finally:
|
||||
logger.enable("vgrout.extract_vhack_grad")
|
||||
logger.enable("__main__")
|
||||
|
||||
Reference in New Issue
Block a user