diff --git a/src/vgrout/antipasto.py b/src/vgrout/antipasto.py index 6d9a094..2800373 100644 --- a/src/vgrout/antipasto.py +++ b/src/vgrout/antipasto.py @@ -100,6 +100,9 @@ def _delta_hook(layer: nn.Linear, args: tuple, y: Tensor) -> Tensor: c = torch.ones(a.shape[0], *([1] * (a.dim() - 2)), a.shape[-1], device=a.device, dtype=a.dtype, requires_grad=True) layer._antipasto_gate = c + # Cache the activation As = Vh@x for the act_vote gate (same flattened layout as + # the gate, so train.py reshapes both identically). Detached: read-only gate input. + layer._antipasto_act = a.detach() kept = torch.nn.functional.linear((a * c) * delta_S.to(a.dtype), U) else: kept = torch.nn.functional.linear(a * delta_S.to(a.dtype), U) diff --git a/src/vgrout/train.py b/src/vgrout/train.py index 265fad7..6d432b3 100644 --- a/src/vgrout/train.py +++ b/src/vgrout/train.py @@ -200,6 +200,15 @@ class Config: # the preregistered default, denoises the cos sign + matches GRPO per-rollout adv). # True = route per TOKEN (one cos/f per token; finer but noisier). Ablation arm. routeV_per_token: bool = False + # routeV gate signal. "grad_cosine" (default): per-module cos(g_b, v_grad) on the + # backward delta_S gradient, banded per module (the precision-tail corner, diag + # p@10=0.70). "act_vote": a GLOBAL per-rollout gate -- module-weighted vote of + # ACTIVATION cosines cos(As_b, As_dir), As=Vh@x completion-mean (diag's act/vote, + # AUROC 0.67 / p@20 0.45 but p@10 0.30). A deliberately maximally-different hail-mary + # arm: different space (act not grad) + different aggregation (one f per rollout, + # shared across modules). Tests whether the precision framing predicts deploy + # suppression, and stresses H2 absorption (does gate choice matter at deploy at all?). + routeV_gate: Literal["grad_cosine", "act_vote"] = "grad_cosine" # Per-source cin diagnostic: split each prompt's backward into student-only # + teacher-only passes (~2x backward time). 1 = every step (default; full # signal); N>1 = only every Nth step (combined backward elsewhere, ~halves @@ -383,6 +392,59 @@ def route_band_edges(raw_grads: dict, v_grad: dict, device) -> dict[str, tuple[f return band +def build_act_vote_dirs(model, wrappers, tok, pairs, device): + """act_vote gate: per-module ACTIVATION direction As_dir = unit(mean_pairs(As_hack - + As_clean)) where As = Vh@x completion-mean; module weight act_w = |As_D|; and a GLOBAL + vote band (lower=p75 clean-pair vote, upper=p75 hack-pair vote). Mirrors + diag_cosine_dist.py's act/vote, no oracle (labels live only on the authored pairs). + Caller sets model.eval(). Returns (As_dir[device], act_w, (lower, upper)).""" + names = list(wrappers) + As_cap: dict[str, torch.Tensor] = {} + st = {"plen": 0} + def mk_hook(nm): + Vh = wrappers[nm]["layer"]._antipasto_Vh + def h(_l, inp, _o): + As_cap[nm] = F.linear(inp[0], Vh)[0, st["plen"] - 1:, :].mean(0).detach().float().cpu() + return h + handles = [wrappers[nm]["layer"].register_forward_hook(mk_hook(nm)) for nm in names] + def grab(prompt, comp): + st["plen"] = tok(prompt, return_tensors="pt").input_ids.shape[1] + ids = tok(prompt + comp, return_tensors="pt").input_ids.to(device) + with torch.no_grad(): + model(ids) + return {nm: As_cap[nm].clone() for nm in names} + As_h = {nm: [] for nm in names} + As_c = {nm: [] for nm in names} + for pr in pairs: + ah, ac = grab(pr.prompt, pr.hack), grab(pr.prompt, pr.clean) + for nm in names: + As_h[nm].append(ah[nm]); As_c[nm].append(ac[nm]) + for h in handles: + h.remove() + As_D = {nm: (torch.stack(As_h[nm]) - torch.stack(As_c[nm])).mean(0) for nm in names} + As_dir_cpu = {nm: As_D[nm] / As_D[nm].norm().clamp_min(1e-12) for nm in names} + act_w = {nm: As_D[nm].norm().item() for nm in names} + wsum = sum(act_w.values()) + def pair_vote(As_pair): + num = sum(act_w[nm] * float((As_pair[nm] @ As_dir_cpu[nm]) / As_pair[nm].norm().clamp_min(1e-12)) + for nm in names) + return num / max(wsum, 1e-12) + votes_h = [pair_vote({nm: As_h[nm][i] for nm in names}) for i in range(len(pairs))] + votes_c = [pair_vote({nm: As_c[nm][i] for nm in names}) for i in range(len(pairs))] + vote_band = (torch.tensor(votes_c).quantile(0.75).item(), + torch.tensor(votes_h).quantile(0.75).item()) + As_dir = {nm: As_dir_cpu[nm].to(device) for nm in names} + logger.info( + f"routeV act_vote: As_dir for {len(As_dir)} modules; vote band " + f"lower(p75 clean)={vote_band[0]:+.3f} upper(p75 hack)={vote_band[1]:+.3f} " + f"width={vote_band[1] - vote_band[0]:+.3f}. SHOULD: width>0 (hack pairs vote higher); " + f"live f_roll>0 in early steps else band sits off the live distribution.") + assert vote_band[1] > vote_band[0], ( + f"act_vote band non-positive width {vote_band[1] - vote_band[0]:+.3f}: " + "hack pairs do not vote-separate from clean -> act extraction broken") + return As_dir, act_w, vote_band + + # eval_hack_solve lives in .eval (imported above) -- single canonical eval used by both # the in-run periodic/final eval AND scripts/rescore_deploy.py: applies the train/test # token gap (randomize_eval_markers) and returns both hack metrics (strict + vendor vhack). @@ -471,6 +533,7 @@ def main(cfg: Config) -> int: # Vanilla (none) is pure GRPO and ignores v_hack entirely (the cin/cout columns # are hidden, so v_hack=None just means no subspace machinery). v_grad = None # set only by the routeV grad-mask branch below + As_dir = act_w = vote_band = None # set only by the act_vote gate branch below if cfg.intervention in ("none", "routeV"): if cfg.intervention == "none" and cfg.v_hack_path is not None: logger.info(f"vanilla arm: ignoring --v-hack-path={cfg.v_hack_path} " @@ -524,6 +587,8 @@ def main(cfg: Config) -> int: assert _mean_bw > 0, ( f"real v_grad gave non-positive mean band width {_mean_bw:+.3f}: " "hack pairs do not separate from clean -> extraction broken") + if cfg.routeV_gate == "act_vote": + As_dir, act_w, vote_band = build_act_vote_dirs(model, wrappers, tok, MASK_PAIRS, device) model.train() else: # v_hack path resolution, most-specific first. The pairset (personas) is @@ -926,6 +991,10 @@ def main(cfg: Config) -> int: # across prompts, parked into δS_hack.grad at injection (the quarantine, # deleted at deploy). Mirrors how proj.py parks route's removed component. step_grad_hack: dict[str, torch.Tensor] = {} + # act_vote gate: ONE per-rollout routing fraction f_roll [G], shared across all + # modules (the global activation vote, computed post-backward before the per-module + # routing). 1-element list so the filter closure reads the current step's value. + _step_f_roll: list[torch.Tensor | None] = [None] # routeV: recover the per-rollout δS grad from the gate (c.grad = δS * g_b), # flag rollouts whose grad points hack-ward (cos(g_b, v_grad) > τ), and route @@ -958,7 +1027,20 @@ def main(cfg: Config) -> int: # per-token (routeV_per_token): one cos/f per token -- finer but noisier. lower, upper = route_band[name] band = max(upper - lower, 1e-6) - if cfg.routeV_per_token: + if cfg.routeV_gate == "act_vote": + # Global gate: route every module's per-rollout grad by the SAME f_roll + # (the activation vote, computed once for the step). Per-rollout granularity + # by construction; per_token is ignored under act_vote. + cg = cg_full.sum(1) # [G, r] per-rollout δS*g + g_b = torch.where(reliable, cg / dS_safe, torch.zeros_like(cg)) # [G, r] + f = _step_f_roll[0] # [G] shared across modules + routed = torch.where(reliable, (cg * f.unsqueeze(1)).sum(0) / dS_safe, + torch.zeros_like(g)) + step_flagged.append(f.mean().item()) + _kn, _rn, _on, _ke, _re, _oe = _zone_stats(f, g_b.norm(dim=1)) + step_zkeep.append(_kn); step_zresid.append(_rn); step_zrout.append(_on) + step_zkeepE.append(_ke); step_zresidE.append(_re); step_zroutE.append(_oe) + elif cfg.routeV_per_token: g_u = torch.where(reliable, cg_full / dS_safe, torch.zeros_like(cg_full)) # [G, s, r] cos_u = (g_u @ vg) / g_u.norm(dim=2).clamp_min(1e-12) # [G, s] f = ((cos_u - lower) / band).clamp(0.0, 1.0) # [G, s] @@ -1021,6 +1103,34 @@ def main(cfg: Config) -> int: step_resid.append((g_keep_roll @ vg / g_keep_roll.norm().clamp_min(1e-12)).item()) return g_keep + def _act_vote_f_roll(n_rollouts: int, plen: int, comp_mask: torch.Tensor) -> torch.Tensor: + """Global per-rollout routing fraction from the activation vote (act_vote gate). + For each module: As_b = completion-mean(Vh@x) [G, r]; cos(As_b, As_dir); aggregate + across modules weighted by act_w into one vote per rollout; band -> f_roll [G]. + comp_mask [G, L_c] is the completion-token mask (= mask in the loss).""" + num = torch.zeros(n_rollouts, device=device) + wsum = 0.0 + # Window = [plen-1 : end] = last prompt token + completion, matching + # build_act_vote_dirs' hook and diag_cosine_dist.py (so the live vote is scored + # against the band built on the same window). The leading prompt token is always + # valid (never pad), so its mask entry is 1. + ext_mask = torch.cat([torch.ones(n_rollouts, 1, device=comp_mask.device, dtype=comp_mask.dtype), + comp_mask], dim=1) # [G, L_c+1] + for nm in wrappers: + r = As_dir[nm].shape[0] + a = wrappers[nm]["layer"]._antipasto_act.reshape(n_rollouts, -1, r) # [G, S, r] + a_comp = a[:, plen - 1:, :].float() # [G, L_c+1, r] + assert a_comp.shape[1] == ext_mask.shape[1], ( + f"act_vote layout: a_comp s={a_comp.shape[1]} != L_c+1={ext_mask.shape[1]} " + f"(module {nm}); the cached activation seq-len must match merged") + As_b = (a_comp * ext_mask.unsqueeze(-1)).sum(1) / ext_mask.sum(1, keepdim=True).clamp_min(1) + cos = (As_b @ As_dir[nm].float()) / As_b.norm(dim=1).clamp_min(1e-12) # [G] + num = num + act_w[nm] * cos + wsum += act_w[nm] + vote = num / max(wsum, 1e-12) # [G] + lower, upper = vote_band + return ((vote - lower) / max(upper - lower, 1e-6)).clamp(0.0, 1.0) # [G] + # Split backward into student/teacher only every cos_pre_split_every steps. # On split steps: 2 backwards per prompt, populates step_grad_s/_t. # On skipped steps: 1 combined backward, step_grad_s/_t stay empty and @@ -1323,6 +1433,10 @@ def main(cfg: Config) -> int: ptl_norm = (Lp * mask).sum(1) / mask.sum(1).clamp_min(1) loss = ptl_norm.sum() / (group * prompts_per_step) loss.backward() + # act_vote: compute the ONE global f_roll for the step before per-module + # routing (activations are cached on every layer from the loss forward). + if is_routeV and cfg.routeV_gate == "act_vote": + _step_f_roll[0] = _act_vote_f_roll(merged.shape[0], plen, mask) for name, info in wrappers.items(): g = info["delta_S"].grad if g is None: @@ -1463,6 +1577,10 @@ def main(cfg: Config) -> int: d = (raw_grads[f"hack/{name}"] - raw_grads[f"clean/{name}"]).mean(0) v_grad[name] = (d / d.norm().clamp_min(1e-12)).to(device) route_band = route_band_edges(raw_grads, v_grad, device) # rebuild band on the fresh v_grad + if cfg.routeV_gate == "act_vote": + # act direction goes stale just like v_grad; re-extract on the + # current model so the vote tracks where hacks separate now. + As_dir, act_w, vote_band = build_act_vote_dirs(model, wrappers, tok, MASK_PAIRS, device) finally: logger.enable("vgrout.extract_vhack_grad") logger.enable("__main__")