mirror of
https://github.com/wassname/evil_MoE.git
synced 2026-06-27 17:48:43 +08:00
feat: act_vote routeV gate (global activation-vote routing arm)
New routeV_gate=act_vote: route every module's per-rollout gradient by a single global f_roll from a module-weighted vote of activation cosines cos(As_b, As_dir), As=Vh@x completion-mean (mirrors diag_cosine_dist.py act/vote, AUROC 0.67 / p@10 0.30 -- the coverage corner). Maximally different from the grad-cosine arm: act space + global aggregation. Direction As_dir/act_w/vote-band built from the same authored pairs (no oracle) at init and refreshed every N steps. Window = [plen-1:] to match the build hook + diagnostic. Smoke-verified (band opens, rout>0, refresh ok); fresh-eyes reviewed. Co-Authored-By: Claudypoo <288921227+claudypoo@users.noreply.github.com>
This commit is contained in:
@@ -100,6 +100,9 @@ def _delta_hook(layer: nn.Linear, args: tuple, y: Tensor) -> Tensor:
|
||||
c = torch.ones(a.shape[0], *([1] * (a.dim() - 2)), a.shape[-1],
|
||||
device=a.device, dtype=a.dtype, requires_grad=True)
|
||||
layer._antipasto_gate = c
|
||||
# Cache the activation As = Vh@x for the act_vote gate (same flattened layout as
|
||||
# the gate, so train.py reshapes both identically). Detached: read-only gate input.
|
||||
layer._antipasto_act = a.detach()
|
||||
kept = torch.nn.functional.linear((a * c) * delta_S.to(a.dtype), U)
|
||||
else:
|
||||
kept = torch.nn.functional.linear(a * delta_S.to(a.dtype), U)
|
||||
|
||||
+119
-1
@@ -200,6 +200,15 @@ class Config:
|
||||
# the preregistered default, denoises the cos sign + matches GRPO per-rollout adv).
|
||||
# True = route per TOKEN (one cos/f per token; finer but noisier). Ablation arm.
|
||||
routeV_per_token: bool = False
|
||||
# routeV gate signal. "grad_cosine" (default): per-module cos(g_b, v_grad) on the
|
||||
# backward delta_S gradient, banded per module (the precision-tail corner, diag
|
||||
# p@10=0.70). "act_vote": a GLOBAL per-rollout gate -- module-weighted vote of
|
||||
# ACTIVATION cosines cos(As_b, As_dir), As=Vh@x completion-mean (diag's act/vote,
|
||||
# AUROC 0.67 / p@20 0.45 but p@10 0.30). A deliberately maximally-different hail-mary
|
||||
# arm: different space (act not grad) + different aggregation (one f per rollout,
|
||||
# shared across modules). Tests whether the precision framing predicts deploy
|
||||
# suppression, and stresses H2 absorption (does gate choice matter at deploy at all?).
|
||||
routeV_gate: Literal["grad_cosine", "act_vote"] = "grad_cosine"
|
||||
# Per-source cin diagnostic: split each prompt's backward into student-only
|
||||
# + teacher-only passes (~2x backward time). 1 = every step (default; full
|
||||
# signal); N>1 = only every Nth step (combined backward elsewhere, ~halves
|
||||
@@ -383,6 +392,59 @@ def route_band_edges(raw_grads: dict, v_grad: dict, device) -> dict[str, tuple[f
|
||||
return band
|
||||
|
||||
|
||||
def build_act_vote_dirs(model, wrappers, tok, pairs, device):
|
||||
"""act_vote gate: per-module ACTIVATION direction As_dir = unit(mean_pairs(As_hack -
|
||||
As_clean)) where As = Vh@x completion-mean; module weight act_w = |As_D|; and a GLOBAL
|
||||
vote band (lower=p75 clean-pair vote, upper=p75 hack-pair vote). Mirrors
|
||||
diag_cosine_dist.py's act/vote, no oracle (labels live only on the authored pairs).
|
||||
Caller sets model.eval(). Returns (As_dir[device], act_w, (lower, upper))."""
|
||||
names = list(wrappers)
|
||||
As_cap: dict[str, torch.Tensor] = {}
|
||||
st = {"plen": 0}
|
||||
def mk_hook(nm):
|
||||
Vh = wrappers[nm]["layer"]._antipasto_Vh
|
||||
def h(_l, inp, _o):
|
||||
As_cap[nm] = F.linear(inp[0], Vh)[0, st["plen"] - 1:, :].mean(0).detach().float().cpu()
|
||||
return h
|
||||
handles = [wrappers[nm]["layer"].register_forward_hook(mk_hook(nm)) for nm in names]
|
||||
def grab(prompt, comp):
|
||||
st["plen"] = tok(prompt, return_tensors="pt").input_ids.shape[1]
|
||||
ids = tok(prompt + comp, return_tensors="pt").input_ids.to(device)
|
||||
with torch.no_grad():
|
||||
model(ids)
|
||||
return {nm: As_cap[nm].clone() for nm in names}
|
||||
As_h = {nm: [] for nm in names}
|
||||
As_c = {nm: [] for nm in names}
|
||||
for pr in pairs:
|
||||
ah, ac = grab(pr.prompt, pr.hack), grab(pr.prompt, pr.clean)
|
||||
for nm in names:
|
||||
As_h[nm].append(ah[nm]); As_c[nm].append(ac[nm])
|
||||
for h in handles:
|
||||
h.remove()
|
||||
As_D = {nm: (torch.stack(As_h[nm]) - torch.stack(As_c[nm])).mean(0) for nm in names}
|
||||
As_dir_cpu = {nm: As_D[nm] / As_D[nm].norm().clamp_min(1e-12) for nm in names}
|
||||
act_w = {nm: As_D[nm].norm().item() for nm in names}
|
||||
wsum = sum(act_w.values())
|
||||
def pair_vote(As_pair):
|
||||
num = sum(act_w[nm] * float((As_pair[nm] @ As_dir_cpu[nm]) / As_pair[nm].norm().clamp_min(1e-12))
|
||||
for nm in names)
|
||||
return num / max(wsum, 1e-12)
|
||||
votes_h = [pair_vote({nm: As_h[nm][i] for nm in names}) for i in range(len(pairs))]
|
||||
votes_c = [pair_vote({nm: As_c[nm][i] for nm in names}) for i in range(len(pairs))]
|
||||
vote_band = (torch.tensor(votes_c).quantile(0.75).item(),
|
||||
torch.tensor(votes_h).quantile(0.75).item())
|
||||
As_dir = {nm: As_dir_cpu[nm].to(device) for nm in names}
|
||||
logger.info(
|
||||
f"routeV act_vote: As_dir for {len(As_dir)} modules; vote band "
|
||||
f"lower(p75 clean)={vote_band[0]:+.3f} upper(p75 hack)={vote_band[1]:+.3f} "
|
||||
f"width={vote_band[1] - vote_band[0]:+.3f}. SHOULD: width>0 (hack pairs vote higher); "
|
||||
f"live f_roll>0 in early steps else band sits off the live distribution.")
|
||||
assert vote_band[1] > vote_band[0], (
|
||||
f"act_vote band non-positive width {vote_band[1] - vote_band[0]:+.3f}: "
|
||||
"hack pairs do not vote-separate from clean -> act extraction broken")
|
||||
return As_dir, act_w, vote_band
|
||||
|
||||
|
||||
# eval_hack_solve lives in .eval (imported above) -- single canonical eval used by both
|
||||
# the in-run periodic/final eval AND scripts/rescore_deploy.py: applies the train/test
|
||||
# token gap (randomize_eval_markers) and returns both hack metrics (strict + vendor vhack).
|
||||
@@ -471,6 +533,7 @@ def main(cfg: Config) -> int:
|
||||
# Vanilla (none) is pure GRPO and ignores v_hack entirely (the cin/cout columns
|
||||
# are hidden, so v_hack=None just means no subspace machinery).
|
||||
v_grad = None # set only by the routeV grad-mask branch below
|
||||
As_dir = act_w = vote_band = None # set only by the act_vote gate branch below
|
||||
if cfg.intervention in ("none", "routeV"):
|
||||
if cfg.intervention == "none" and cfg.v_hack_path is not None:
|
||||
logger.info(f"vanilla arm: ignoring --v-hack-path={cfg.v_hack_path} "
|
||||
@@ -524,6 +587,8 @@ def main(cfg: Config) -> int:
|
||||
assert _mean_bw > 0, (
|
||||
f"real v_grad gave non-positive mean band width {_mean_bw:+.3f}: "
|
||||
"hack pairs do not separate from clean -> extraction broken")
|
||||
if cfg.routeV_gate == "act_vote":
|
||||
As_dir, act_w, vote_band = build_act_vote_dirs(model, wrappers, tok, MASK_PAIRS, device)
|
||||
model.train()
|
||||
else:
|
||||
# v_hack path resolution, most-specific first. The pairset (personas) is
|
||||
@@ -926,6 +991,10 @@ def main(cfg: Config) -> int:
|
||||
# across prompts, parked into δS_hack.grad at injection (the quarantine,
|
||||
# deleted at deploy). Mirrors how proj.py parks route's removed component.
|
||||
step_grad_hack: dict[str, torch.Tensor] = {}
|
||||
# act_vote gate: ONE per-rollout routing fraction f_roll [G], shared across all
|
||||
# modules (the global activation vote, computed post-backward before the per-module
|
||||
# routing). 1-element list so the filter closure reads the current step's value.
|
||||
_step_f_roll: list[torch.Tensor | None] = [None]
|
||||
|
||||
# routeV: recover the per-rollout δS grad from the gate (c.grad = δS * g_b),
|
||||
# flag rollouts whose grad points hack-ward (cos(g_b, v_grad) > τ), and route
|
||||
@@ -958,7 +1027,20 @@ def main(cfg: Config) -> int:
|
||||
# per-token (routeV_per_token): one cos/f per token -- finer but noisier.
|
||||
lower, upper = route_band[name]
|
||||
band = max(upper - lower, 1e-6)
|
||||
if cfg.routeV_per_token:
|
||||
if cfg.routeV_gate == "act_vote":
|
||||
# Global gate: route every module's per-rollout grad by the SAME f_roll
|
||||
# (the activation vote, computed once for the step). Per-rollout granularity
|
||||
# by construction; per_token is ignored under act_vote.
|
||||
cg = cg_full.sum(1) # [G, r] per-rollout δS*g
|
||||
g_b = torch.where(reliable, cg / dS_safe, torch.zeros_like(cg)) # [G, r]
|
||||
f = _step_f_roll[0] # [G] shared across modules
|
||||
routed = torch.where(reliable, (cg * f.unsqueeze(1)).sum(0) / dS_safe,
|
||||
torch.zeros_like(g))
|
||||
step_flagged.append(f.mean().item())
|
||||
_kn, _rn, _on, _ke, _re, _oe = _zone_stats(f, g_b.norm(dim=1))
|
||||
step_zkeep.append(_kn); step_zresid.append(_rn); step_zrout.append(_on)
|
||||
step_zkeepE.append(_ke); step_zresidE.append(_re); step_zroutE.append(_oe)
|
||||
elif cfg.routeV_per_token:
|
||||
g_u = torch.where(reliable, cg_full / dS_safe, torch.zeros_like(cg_full)) # [G, s, r]
|
||||
cos_u = (g_u @ vg) / g_u.norm(dim=2).clamp_min(1e-12) # [G, s]
|
||||
f = ((cos_u - lower) / band).clamp(0.0, 1.0) # [G, s]
|
||||
@@ -1021,6 +1103,34 @@ def main(cfg: Config) -> int:
|
||||
step_resid.append((g_keep_roll @ vg / g_keep_roll.norm().clamp_min(1e-12)).item())
|
||||
return g_keep
|
||||
|
||||
def _act_vote_f_roll(n_rollouts: int, plen: int, comp_mask: torch.Tensor) -> torch.Tensor:
|
||||
"""Global per-rollout routing fraction from the activation vote (act_vote gate).
|
||||
For each module: As_b = completion-mean(Vh@x) [G, r]; cos(As_b, As_dir); aggregate
|
||||
across modules weighted by act_w into one vote per rollout; band -> f_roll [G].
|
||||
comp_mask [G, L_c] is the completion-token mask (= mask in the loss)."""
|
||||
num = torch.zeros(n_rollouts, device=device)
|
||||
wsum = 0.0
|
||||
# Window = [plen-1 : end] = last prompt token + completion, matching
|
||||
# build_act_vote_dirs' hook and diag_cosine_dist.py (so the live vote is scored
|
||||
# against the band built on the same window). The leading prompt token is always
|
||||
# valid (never pad), so its mask entry is 1.
|
||||
ext_mask = torch.cat([torch.ones(n_rollouts, 1, device=comp_mask.device, dtype=comp_mask.dtype),
|
||||
comp_mask], dim=1) # [G, L_c+1]
|
||||
for nm in wrappers:
|
||||
r = As_dir[nm].shape[0]
|
||||
a = wrappers[nm]["layer"]._antipasto_act.reshape(n_rollouts, -1, r) # [G, S, r]
|
||||
a_comp = a[:, plen - 1:, :].float() # [G, L_c+1, r]
|
||||
assert a_comp.shape[1] == ext_mask.shape[1], (
|
||||
f"act_vote layout: a_comp s={a_comp.shape[1]} != L_c+1={ext_mask.shape[1]} "
|
||||
f"(module {nm}); the cached activation seq-len must match merged")
|
||||
As_b = (a_comp * ext_mask.unsqueeze(-1)).sum(1) / ext_mask.sum(1, keepdim=True).clamp_min(1)
|
||||
cos = (As_b @ As_dir[nm].float()) / As_b.norm(dim=1).clamp_min(1e-12) # [G]
|
||||
num = num + act_w[nm] * cos
|
||||
wsum += act_w[nm]
|
||||
vote = num / max(wsum, 1e-12) # [G]
|
||||
lower, upper = vote_band
|
||||
return ((vote - lower) / max(upper - lower, 1e-6)).clamp(0.0, 1.0) # [G]
|
||||
|
||||
# Split backward into student/teacher only every cos_pre_split_every steps.
|
||||
# On split steps: 2 backwards per prompt, populates step_grad_s/_t.
|
||||
# On skipped steps: 1 combined backward, step_grad_s/_t stay empty and
|
||||
@@ -1323,6 +1433,10 @@ def main(cfg: Config) -> int:
|
||||
ptl_norm = (Lp * mask).sum(1) / mask.sum(1).clamp_min(1)
|
||||
loss = ptl_norm.sum() / (group * prompts_per_step)
|
||||
loss.backward()
|
||||
# act_vote: compute the ONE global f_roll for the step before per-module
|
||||
# routing (activations are cached on every layer from the loss forward).
|
||||
if is_routeV and cfg.routeV_gate == "act_vote":
|
||||
_step_f_roll[0] = _act_vote_f_roll(merged.shape[0], plen, mask)
|
||||
for name, info in wrappers.items():
|
||||
g = info["delta_S"].grad
|
||||
if g is None:
|
||||
@@ -1463,6 +1577,10 @@ def main(cfg: Config) -> int:
|
||||
d = (raw_grads[f"hack/{name}"] - raw_grads[f"clean/{name}"]).mean(0)
|
||||
v_grad[name] = (d / d.norm().clamp_min(1e-12)).to(device)
|
||||
route_band = route_band_edges(raw_grads, v_grad, device) # rebuild band on the fresh v_grad
|
||||
if cfg.routeV_gate == "act_vote":
|
||||
# act direction goes stale just like v_grad; re-extract on the
|
||||
# current model so the vote tracks where hacks separate now.
|
||||
As_dir, act_w, vote_band = build_act_vote_dirs(model, wrappers, tok, MASK_PAIRS, device)
|
||||
finally:
|
||||
logger.enable("vgrout.extract_vhack_grad")
|
||||
logger.enable("__main__")
|
||||
|
||||
Reference in New Issue
Block a user