diff --git a/scripts/verify_v_act.py b/scripts/verify_v_act.py new file mode 100644 index 0000000..17773a2 --- /dev/null +++ b/scripts/verify_v_act.py @@ -0,0 +1,89 @@ +"""Verify extract_v_act reproduces the cached diagnostic pair features. + +The journal-(d)/(e) AUROC evidence for the act gate was measured by +scripts/diag_pinning.py (ActTap) on cached features. The training extractor +(src/vgrout/extract_vhack_act.py) must produce the SAME pooled acts on the same +checkpoint + pairs, else the routeA gate is not the thing the diagnostics +validated. GPU required (bf16 + flash_attention_2 to match the cache); not part +of `just smoke` -- run once after any extractor change: + + uv run python scripts/verify_v_act.py +""" +from __future__ import annotations + +import sys +from dataclasses import dataclass +from pathlib import Path + +import torch +import tyro +from loguru import logger +from safetensors.torch import load_file + +from vgrout.extract_vhack_act import extract_v_act +from vgrout.lora2r import wrap_model_with_lora2r +from vgrout.pairs import load_pairs + + +@dataclass +class Cfg: + feats: Path = Path("out/diag/pinning_feats.pt") + run_dir: Path = Path("out/runs/20260611T003538_fast_vanilla_lora2r_seed43_l2r_vanilla_s43_v3") + ckpt: str = "first_hack" + pairs: Path = Path("data/pairs/hack_pairs.md#all-in-one") + rel_tol: float = 1e-2 # bf16 forward, fp32 pooling; expect ~0 + + +def main(cfg: Cfg) -> int: + import json + import struct + from transformers import AutoModelForCausalLM, AutoTokenizer + + fe = torch.load(cfg.feats, weights_only=False) + ckpt_path = cfg.run_dir / f"{cfg.ckpt}.safetensors" + with open(ckpt_path, "rb") as f: + meta = json.loads(f.read(struct.unpack("5}: rel diff={rel:.2e} max abs diff={mx:.2e} shape={tuple(got.shape)}") + ok &= rel.item() < cfg.rel_tol + # v_act from the cached feats (diag's _v_from over all pairs) must match too. + D = (fe["pair_feats"][("act", "hack")] - fe["pair_feats"][("act", "clean")]).float().mean(0) + v_ref = D / D.norm(dim=-1, keepdim=True).clamp_min(1e-12) + cos = (v * v_ref).sum(-1) + print(f" v_act vs cache-derived v: min per-module cos={cos.min():.6f} " + f"(SHOULD ~1.0); unit rows max|1-||v|||={(v.norm(dim=-1) - 1).abs().max():.1e}") + ok &= cos.min().item() > 0.999 + print(("🟢 PASS" if ok else "🔴 FAIL") + " verify_v_act") + return 0 if ok else 1 + + +if __name__ == "__main__": + sys.exit(main(tyro.cli(Cfg))) diff --git a/src/vgrout/extract_vhack_act.py b/src/vgrout/extract_vhack_act.py new file mode 100644 index 0000000..4effea4 --- /dev/null +++ b/src/vgrout/extract_vhack_act.py @@ -0,0 +1,109 @@ +"""Activation-side per-module v_act extraction (the routeA direction source). + +For each authored (hack, clean) pair we forward prompt+completion once per side, +no backward, and capture the deployed bottleneck activation h = A[:r] @ x per +wrapped module, mean-pooled over completion tokens. The per-module direction is +the unit-normalized mean paired difference: + + v_act[m] = unitnorm( mean_pairs( h_hack[m] - h_clean[m] ) ) [r] + +Diagnostic basis: RESEARCH_JOURNAL 2026-06-11 (d) -- on the A>0 contrast this +score replicates across three emergence windows (AUROC 0.87/0.75/0.75) while the +gradient score decays to chance; see docs/spec/20260611_act_gate_spec.md. +`tstat=True` divides the mean by its standard error over pairs (clamped |t|<=3) +before normalizing; null at the current 8 pairs, kept for larger pair sets. + +ActCapture is the one hook used everywhere acts are read: pair extraction here, +the live gate in train.py, and scripts/verify_v_act.py. It pools inside the hook +(completion-token mean) so nothing of size [B, L, d_in] is retained. +""" +from __future__ import annotations + +import torch +import torch.nn.functional as F +from jaxtyping import Float +from torch import Tensor + + +class ActCapture: + """Forward hooks stashing the pooled deployed bottleneck h = A[:r] @ x per module. + + Call set_pool(plen, comp_mask) before each forward: positions < plen are prompt, + comp_mask [B, L-plen] marks real (non-pad) completion tokens. pooled() then + returns the completion-token mean act [B, M, r] in module order `names`.""" + + def __init__(self, wrappers: dict, names: list[str]): + self.wrappers, self.names = wrappers, names + self.h: dict[str, Tensor] = {} + self.handles = [] + self.plen: int | None = None + self.mask: Tensor | None = None # [B, L_c, 1] float + self.denom: Tensor | None = None # [B, 1] + + def __enter__(self): + for nm in self.names: + layer = self.wrappers[nm]["layer"] + + def hook(layer, args, out, nm=nm): + (x,) = args # [B, L, d_in] + with torch.no_grad(): + hc = F.linear(x[:, self.plen:].detach(), + layer._lora2r_A[: layer._lora2r_r].to(x.dtype)) + self.h[nm] = (hc.float() * self.mask).sum(1) / self.denom # [B, r] + + self.handles.append(layer.register_forward_hook(hook)) + return self + + def __exit__(self, *exc): + for h in self.handles: + h.remove() + + def set_pool(self, plen: int, comp_mask: Float[Tensor, "B L_c"]) -> None: + denom = comp_mask.sum(1) + assert (denom > 0).all(), f"rollout with 0 completion tokens: {denom.tolist()}" + self.plen, self.mask, self.denom = plen, comp_mask.unsqueeze(-1), denom.unsqueeze(-1) + + def pooled(self) -> Float[Tensor, "B M r"]: + return torch.stack([self.h[nm] for nm in self.names], dim=1) + + +@torch.no_grad() +def extract_v_act( + model, tok, wrappers: dict, pairs: list, device, tstat: bool = False, +) -> tuple[Float[Tensor, "M r"], dict[str, Float[Tensor, "P M r"]]]: + """Forward each pair side once -> (v_act unit rows, raw per-pair pooled acts). + + Module order is sorted(wrappers) (= the diag scripts' convention). Tokenization + matches the cached diagnostic features: full = tok(prompt+completion), pool + positions >= len(tok(prompt)). Caller owns model.eval() and quarantine state.""" + names = sorted(wrappers) + acts: dict[str, list[Tensor]] = {"hack": [], "clean": []} + with ActCapture(wrappers, names) as cap: + for pair in pairs: + n_prompt = tok(pair.prompt, return_tensors="pt").input_ids.shape[1] + for side, completion in (("hack", pair.hack), ("clean", pair.clean)): + full_ids = tok(pair.prompt + completion, + return_tensors="pt").input_ids.to(device) + L_c = full_ids.shape[1] - n_prompt + assert L_c > 0, f"pair {pair.problem_id} {side}: no completion tokens" + cap.set_pool(n_prompt, torch.ones(1, L_c, device=device)) + # logits_to_keep=1: the decoder Linears (whose inputs we hook) run on + # every position regardless; skipping the full-seq lm_head is free. + model(full_ids, logits_to_keep=1) + acts[side].append(cap.pooled()[0].cpu()) + H = torch.stack(acts["hack"]) # [P, M, r] + C = torch.stack(acts["clean"]) + D = H - C + d = D.mean(0) + if tstat: + se = D.std(0) / D.shape[0] ** 0.5 + d = (d / se.clamp_min(1e-12)).clamp(-3.0, 3.0) + v = d / d.norm(dim=-1, keepdim=True).clamp_min(1e-12) # unit per module + return v, {"hack": H, "clean": C} + + +def haar_unit_rows(shape: tuple[int, int], seed: int) -> Float[Tensor, "M r"]: + """Random unit rows matching v_act's shape -- the placebo directionality control.""" + g = torch.Generator().manual_seed(seed) + d = torch.randn(shape, generator=g) + return d / d.norm(dim=-1, keepdim=True).clamp_min(1e-12)