feat: act_vote routeV gate (global activation-vote routing arm)

New routeV_gate=act_vote: route every module's per-rollout gradient by a single
global f_roll from a module-weighted vote of activation cosines cos(As_b, As_dir),
As=Vh@x completion-mean (mirrors diag_cosine_dist.py act/vote, AUROC 0.67 / p@10
0.30 -- the coverage corner). Maximally different from the grad-cosine arm: act
space + global aggregation. Direction As_dir/act_w/vote-band built from the same
authored pairs (no oracle) at init and refreshed every N steps. Window = [plen-1:]
to match the build hook + diagnostic. Smoke-verified (band opens, rout>0, refresh
ok); fresh-eyes reviewed.

Co-Authored-By: Claudypoo <288921227+claudypoo@users.noreply.github.com>
This commit is contained in:
wassname
2026-06-08 15:08:28 +00:00
parent eedf9efb51
commit d497bfd161
2 changed files with 122 additions and 1 deletions
+3
View File
@@ -100,6 +100,9 @@ def _delta_hook(layer: nn.Linear, args: tuple, y: Tensor) -> Tensor:
c = torch.ones(a.shape[0], *([1] * (a.dim() - 2)), a.shape[-1],
device=a.device, dtype=a.dtype, requires_grad=True)
layer._antipasto_gate = c
# Cache the activation As = Vh@x for the act_vote gate (same flattened layout as
# the gate, so train.py reshapes both identically). Detached: read-only gate input.
layer._antipasto_act = a.detach()
kept = torch.nn.functional.linear((a * c) * delta_S.to(a.dtype), U)
else:
kept = torch.nn.functional.linear(a * delta_S.to(a.dtype), U)
+119 -1
View File
@@ -200,6 +200,15 @@ class Config:
# the preregistered default, denoises the cos sign + matches GRPO per-rollout adv).
# True = route per TOKEN (one cos/f per token; finer but noisier). Ablation arm.
routeV_per_token: bool = False
# routeV gate signal. "grad_cosine" (default): per-module cos(g_b, v_grad) on the
# backward delta_S gradient, banded per module (the precision-tail corner, diag
# p@10=0.70). "act_vote": a GLOBAL per-rollout gate -- module-weighted vote of
# ACTIVATION cosines cos(As_b, As_dir), As=Vh@x completion-mean (diag's act/vote,
# AUROC 0.67 / p@20 0.45 but p@10 0.30). A deliberately maximally-different hail-mary
# arm: different space (act not grad) + different aggregation (one f per rollout,
# shared across modules). Tests whether the precision framing predicts deploy
# suppression, and stresses H2 absorption (does gate choice matter at deploy at all?).
routeV_gate: Literal["grad_cosine", "act_vote"] = "grad_cosine"
# Per-source cin diagnostic: split each prompt's backward into student-only
# + teacher-only passes (~2x backward time). 1 = every step (default; full
# signal); N>1 = only every Nth step (combined backward elsewhere, ~halves
@@ -383,6 +392,59 @@ def route_band_edges(raw_grads: dict, v_grad: dict, device) -> dict[str, tuple[f
return band
def build_act_vote_dirs(model, wrappers, tok, pairs, device):
"""act_vote gate: per-module ACTIVATION direction As_dir = unit(mean_pairs(As_hack -
As_clean)) where As = Vh@x completion-mean; module weight act_w = |As_D|; and a GLOBAL
vote band (lower=p75 clean-pair vote, upper=p75 hack-pair vote). Mirrors
diag_cosine_dist.py's act/vote, no oracle (labels live only on the authored pairs).
Caller sets model.eval(). Returns (As_dir[device], act_w, (lower, upper))."""
names = list(wrappers)
As_cap: dict[str, torch.Tensor] = {}
st = {"plen": 0}
def mk_hook(nm):
Vh = wrappers[nm]["layer"]._antipasto_Vh
def h(_l, inp, _o):
As_cap[nm] = F.linear(inp[0], Vh)[0, st["plen"] - 1:, :].mean(0).detach().float().cpu()
return h
handles = [wrappers[nm]["layer"].register_forward_hook(mk_hook(nm)) for nm in names]
def grab(prompt, comp):
st["plen"] = tok(prompt, return_tensors="pt").input_ids.shape[1]
ids = tok(prompt + comp, return_tensors="pt").input_ids.to(device)
with torch.no_grad():
model(ids)
return {nm: As_cap[nm].clone() for nm in names}
As_h = {nm: [] for nm in names}
As_c = {nm: [] for nm in names}
for pr in pairs:
ah, ac = grab(pr.prompt, pr.hack), grab(pr.prompt, pr.clean)
for nm in names:
As_h[nm].append(ah[nm]); As_c[nm].append(ac[nm])
for h in handles:
h.remove()
As_D = {nm: (torch.stack(As_h[nm]) - torch.stack(As_c[nm])).mean(0) for nm in names}
As_dir_cpu = {nm: As_D[nm] / As_D[nm].norm().clamp_min(1e-12) for nm in names}
act_w = {nm: As_D[nm].norm().item() for nm in names}
wsum = sum(act_w.values())
def pair_vote(As_pair):
num = sum(act_w[nm] * float((As_pair[nm] @ As_dir_cpu[nm]) / As_pair[nm].norm().clamp_min(1e-12))
for nm in names)
return num / max(wsum, 1e-12)
votes_h = [pair_vote({nm: As_h[nm][i] for nm in names}) for i in range(len(pairs))]
votes_c = [pair_vote({nm: As_c[nm][i] for nm in names}) for i in range(len(pairs))]
vote_band = (torch.tensor(votes_c).quantile(0.75).item(),
torch.tensor(votes_h).quantile(0.75).item())
As_dir = {nm: As_dir_cpu[nm].to(device) for nm in names}
logger.info(
f"routeV act_vote: As_dir for {len(As_dir)} modules; vote band "
f"lower(p75 clean)={vote_band[0]:+.3f} upper(p75 hack)={vote_band[1]:+.3f} "
f"width={vote_band[1] - vote_band[0]:+.3f}. SHOULD: width>0 (hack pairs vote higher); "
f"live f_roll>0 in early steps else band sits off the live distribution.")
assert vote_band[1] > vote_band[0], (
f"act_vote band non-positive width {vote_band[1] - vote_band[0]:+.3f}: "
"hack pairs do not vote-separate from clean -> act extraction broken")
return As_dir, act_w, vote_band
# eval_hack_solve lives in .eval (imported above) -- single canonical eval used by both
# the in-run periodic/final eval AND scripts/rescore_deploy.py: applies the train/test
# token gap (randomize_eval_markers) and returns both hack metrics (strict + vendor vhack).
@@ -471,6 +533,7 @@ def main(cfg: Config) -> int:
# Vanilla (none) is pure GRPO and ignores v_hack entirely (the cin/cout columns
# are hidden, so v_hack=None just means no subspace machinery).
v_grad = None # set only by the routeV grad-mask branch below
As_dir = act_w = vote_band = None # set only by the act_vote gate branch below
if cfg.intervention in ("none", "routeV"):
if cfg.intervention == "none" and cfg.v_hack_path is not None:
logger.info(f"vanilla arm: ignoring --v-hack-path={cfg.v_hack_path} "
@@ -524,6 +587,8 @@ def main(cfg: Config) -> int:
assert _mean_bw > 0, (
f"real v_grad gave non-positive mean band width {_mean_bw:+.3f}: "
"hack pairs do not separate from clean -> extraction broken")
if cfg.routeV_gate == "act_vote":
As_dir, act_w, vote_band = build_act_vote_dirs(model, wrappers, tok, MASK_PAIRS, device)
model.train()
else:
# v_hack path resolution, most-specific first. The pairset (personas) is
@@ -926,6 +991,10 @@ def main(cfg: Config) -> int:
# across prompts, parked into δS_hack.grad at injection (the quarantine,
# deleted at deploy). Mirrors how proj.py parks route's removed component.
step_grad_hack: dict[str, torch.Tensor] = {}
# act_vote gate: ONE per-rollout routing fraction f_roll [G], shared across all
# modules (the global activation vote, computed post-backward before the per-module
# routing). 1-element list so the filter closure reads the current step's value.
_step_f_roll: list[torch.Tensor | None] = [None]
# routeV: recover the per-rollout δS grad from the gate (c.grad = δS * g_b),
# flag rollouts whose grad points hack-ward (cos(g_b, v_grad) > τ), and route
@@ -958,7 +1027,20 @@ def main(cfg: Config) -> int:
# per-token (routeV_per_token): one cos/f per token -- finer but noisier.
lower, upper = route_band[name]
band = max(upper - lower, 1e-6)
if cfg.routeV_per_token:
if cfg.routeV_gate == "act_vote":
# Global gate: route every module's per-rollout grad by the SAME f_roll
# (the activation vote, computed once for the step). Per-rollout granularity
# by construction; per_token is ignored under act_vote.
cg = cg_full.sum(1) # [G, r] per-rollout δS*g
g_b = torch.where(reliable, cg / dS_safe, torch.zeros_like(cg)) # [G, r]
f = _step_f_roll[0] # [G] shared across modules
routed = torch.where(reliable, (cg * f.unsqueeze(1)).sum(0) / dS_safe,
torch.zeros_like(g))
step_flagged.append(f.mean().item())
_kn, _rn, _on, _ke, _re, _oe = _zone_stats(f, g_b.norm(dim=1))
step_zkeep.append(_kn); step_zresid.append(_rn); step_zrout.append(_on)
step_zkeepE.append(_ke); step_zresidE.append(_re); step_zroutE.append(_oe)
elif cfg.routeV_per_token:
g_u = torch.where(reliable, cg_full / dS_safe, torch.zeros_like(cg_full)) # [G, s, r]
cos_u = (g_u @ vg) / g_u.norm(dim=2).clamp_min(1e-12) # [G, s]
f = ((cos_u - lower) / band).clamp(0.0, 1.0) # [G, s]
@@ -1021,6 +1103,34 @@ def main(cfg: Config) -> int:
step_resid.append((g_keep_roll @ vg / g_keep_roll.norm().clamp_min(1e-12)).item())
return g_keep
def _act_vote_f_roll(n_rollouts: int, plen: int, comp_mask: torch.Tensor) -> torch.Tensor:
"""Global per-rollout routing fraction from the activation vote (act_vote gate).
For each module: As_b = completion-mean(Vh@x) [G, r]; cos(As_b, As_dir); aggregate
across modules weighted by act_w into one vote per rollout; band -> f_roll [G].
comp_mask [G, L_c] is the completion-token mask (= mask in the loss)."""
num = torch.zeros(n_rollouts, device=device)
wsum = 0.0
# Window = [plen-1 : end] = last prompt token + completion, matching
# build_act_vote_dirs' hook and diag_cosine_dist.py (so the live vote is scored
# against the band built on the same window). The leading prompt token is always
# valid (never pad), so its mask entry is 1.
ext_mask = torch.cat([torch.ones(n_rollouts, 1, device=comp_mask.device, dtype=comp_mask.dtype),
comp_mask], dim=1) # [G, L_c+1]
for nm in wrappers:
r = As_dir[nm].shape[0]
a = wrappers[nm]["layer"]._antipasto_act.reshape(n_rollouts, -1, r) # [G, S, r]
a_comp = a[:, plen - 1:, :].float() # [G, L_c+1, r]
assert a_comp.shape[1] == ext_mask.shape[1], (
f"act_vote layout: a_comp s={a_comp.shape[1]} != L_c+1={ext_mask.shape[1]} "
f"(module {nm}); the cached activation seq-len must match merged")
As_b = (a_comp * ext_mask.unsqueeze(-1)).sum(1) / ext_mask.sum(1, keepdim=True).clamp_min(1)
cos = (As_b @ As_dir[nm].float()) / As_b.norm(dim=1).clamp_min(1e-12) # [G]
num = num + act_w[nm] * cos
wsum += act_w[nm]
vote = num / max(wsum, 1e-12) # [G]
lower, upper = vote_band
return ((vote - lower) / max(upper - lower, 1e-6)).clamp(0.0, 1.0) # [G]
# Split backward into student/teacher only every cos_pre_split_every steps.
# On split steps: 2 backwards per prompt, populates step_grad_s/_t.
# On skipped steps: 1 combined backward, step_grad_s/_t stay empty and
@@ -1323,6 +1433,10 @@ def main(cfg: Config) -> int:
ptl_norm = (Lp * mask).sum(1) / mask.sum(1).clamp_min(1)
loss = ptl_norm.sum() / (group * prompts_per_step)
loss.backward()
# act_vote: compute the ONE global f_roll for the step before per-module
# routing (activations are cached on every layer from the loss forward).
if is_routeV and cfg.routeV_gate == "act_vote":
_step_f_roll[0] = _act_vote_f_roll(merged.shape[0], plen, mask)
for name, info in wrappers.items():
g = info["delta_S"].grad
if g is None:
@@ -1463,6 +1577,10 @@ def main(cfg: Config) -> int:
d = (raw_grads[f"hack/{name}"] - raw_grads[f"clean/{name}"]).mean(0)
v_grad[name] = (d / d.norm().clamp_min(1e-12)).to(device)
route_band = route_band_edges(raw_grads, v_grad, device) # rebuild band on the fresh v_grad
if cfg.routeV_gate == "act_vote":
# act direction goes stale just like v_grad; re-extract on the
# current model so the vote tracks where hacks separate now.
As_dir, act_w, vote_band = build_act_vote_dirs(model, wrappers, tok, MASK_PAIRS, device)
finally:
logger.enable("vgrout.extract_vhack_grad")
logger.enable("__main__")