feat(routeA): forward-only v_act extraction + verify gate vs cached diag features

extract_v_act: pooled completion-token bottleneck act per module, v = unit-norm
mean pair diff (tstat flag default off, null at n=8 pairs). ActCapture is the
single hook shared by extraction, the live gate, and verification.
verify_v_act (pueue #24): rel diff 7.3e-4 hack / 7.7e-4 clean vs
out/diag/pinning_feats.pt on the v3 first_hack ckpt; min per-module cos 0.99997.

Co-Authored-By: Claudypoo <288921227+claudypoo@users.noreply.github.com>
This commit is contained in:
wassname
2026-06-11 12:13:32 +00:00
parent d51028a618
commit 5a340e5c3e
2 changed files with 198 additions and 0 deletions
+89
View File
@@ -0,0 +1,89 @@
"""Verify extract_v_act reproduces the cached diagnostic pair features.
The journal-(d)/(e) AUROC evidence for the act gate was measured by
scripts/diag_pinning.py (ActTap) on cached features. The training extractor
(src/vgrout/extract_vhack_act.py) must produce the SAME pooled acts on the same
checkpoint + pairs, else the routeA gate is not the thing the diagnostics
validated. GPU required (bf16 + flash_attention_2 to match the cache); not part
of `just smoke` -- run once after any extractor change:
uv run python scripts/verify_v_act.py
"""
from __future__ import annotations
import sys
from dataclasses import dataclass
from pathlib import Path
import torch
import tyro
from loguru import logger
from safetensors.torch import load_file
from vgrout.extract_vhack_act import extract_v_act
from vgrout.lora2r import wrap_model_with_lora2r
from vgrout.pairs import load_pairs
@dataclass
class Cfg:
feats: Path = Path("out/diag/pinning_feats.pt")
run_dir: Path = Path("out/runs/20260611T003538_fast_vanilla_lora2r_seed43_l2r_vanilla_s43_v3")
ckpt: str = "first_hack"
pairs: Path = Path("data/pairs/hack_pairs.md#all-in-one")
rel_tol: float = 1e-2 # bf16 forward, fp32 pooling; expect ~0
def main(cfg: Cfg) -> int:
import json
import struct
from transformers import AutoModelForCausalLM, AutoTokenizer
fe = torch.load(cfg.feats, weights_only=False)
ckpt_path = cfg.run_dir / f"{cfg.ckpt}.safetensors"
with open(ckpt_path, "rb") as f:
meta = json.loads(f.read(struct.unpack("<Q", f.read(8))[0])).get("__metadata__", {})
run_cfg = json.loads(meta.get("cfg", "{}"))
model_name, r, init_seed = run_cfg["model"], run_cfg["lora_r"], run_cfg["lora_init_seed"]
device = torch.device("cuda")
tok = AutoTokenizer.from_pretrained(model_name)
model = AutoModelForCausalLM.from_pretrained(
model_name, dtype=torch.bfloat16, attn_implementation="flash_attention_2").to(device)
model.config.use_cache = False
wrappers = wrap_model_with_lora2r(model, r=r, init_seed=init_seed)
assert sorted(wrappers) == fe["names"], "module set/order drifted vs the cache"
sd = load_file(str(ckpt_path))
for nm, info in wrappers.items():
info["A"].data.copy_(sd[f"A/{nm}"].to(device, torch.float32))
info["B"].data.copy_(sd[f"B/{nm}"].to(device, torch.float32))
model.eval()
pairs = load_pairs(cfg.pairs)
assert [p.problem_id for p in pairs] == fe["pair_ids"], "pairset drifted vs the cache"
logger.info(f"extracting acts for {len(pairs)} pairs on {ckpt_path.name} ({model_name})")
v, acts = extract_v_act(model, tok, wrappers, pairs, device)
print("SHOULD: rel diff ~0 for both sides (same ckpt/pairs/dtype/kernels) ELSE the "
"training extractor is not what diag_pinning measured.")
ok = True
for side in ("hack", "clean"):
ref = fe["pair_feats"][("act", side)].float()
got = acts[side]
rel = (got - ref).norm() / ref.norm()
mx = (got - ref).abs().max()
print(f" {side:>5}: rel diff={rel:.2e} max abs diff={mx:.2e} shape={tuple(got.shape)}")
ok &= rel.item() < cfg.rel_tol
# v_act from the cached feats (diag's _v_from over all pairs) must match too.
D = (fe["pair_feats"][("act", "hack")] - fe["pair_feats"][("act", "clean")]).float().mean(0)
v_ref = D / D.norm(dim=-1, keepdim=True).clamp_min(1e-12)
cos = (v * v_ref).sum(-1)
print(f" v_act vs cache-derived v: min per-module cos={cos.min():.6f} "
f"(SHOULD ~1.0); unit rows max|1-||v|||={(v.norm(dim=-1) - 1).abs().max():.1e}")
ok &= cos.min().item() > 0.999
print(("🟢 PASS" if ok else "🔴 FAIL") + " verify_v_act")
return 0 if ok else 1
if __name__ == "__main__":
sys.exit(main(tyro.cli(Cfg)))
+109
View File
@@ -0,0 +1,109 @@
"""Activation-side per-module v_act extraction (the routeA direction source).
For each authored (hack, clean) pair we forward prompt+completion once per side,
no backward, and capture the deployed bottleneck activation h = A[:r] @ x per
wrapped module, mean-pooled over completion tokens. The per-module direction is
the unit-normalized mean paired difference:
v_act[m] = unitnorm( mean_pairs( h_hack[m] - h_clean[m] ) ) [r]
Diagnostic basis: RESEARCH_JOURNAL 2026-06-11 (d) -- on the A>0 contrast this
score replicates across three emergence windows (AUROC 0.87/0.75/0.75) while the
gradient score decays to chance; see docs/spec/20260611_act_gate_spec.md.
`tstat=True` divides the mean by its standard error over pairs (clamped |t|<=3)
before normalizing; null at the current 8 pairs, kept for larger pair sets.
ActCapture is the one hook used everywhere acts are read: pair extraction here,
the live gate in train.py, and scripts/verify_v_act.py. It pools inside the hook
(completion-token mean) so nothing of size [B, L, d_in] is retained.
"""
from __future__ import annotations
import torch
import torch.nn.functional as F
from jaxtyping import Float
from torch import Tensor
class ActCapture:
"""Forward hooks stashing the pooled deployed bottleneck h = A[:r] @ x per module.
Call set_pool(plen, comp_mask) before each forward: positions < plen are prompt,
comp_mask [B, L-plen] marks real (non-pad) completion tokens. pooled() then
returns the completion-token mean act [B, M, r] in module order `names`."""
def __init__(self, wrappers: dict, names: list[str]):
self.wrappers, self.names = wrappers, names
self.h: dict[str, Tensor] = {}
self.handles = []
self.plen: int | None = None
self.mask: Tensor | None = None # [B, L_c, 1] float
self.denom: Tensor | None = None # [B, 1]
def __enter__(self):
for nm in self.names:
layer = self.wrappers[nm]["layer"]
def hook(layer, args, out, nm=nm):
(x,) = args # [B, L, d_in]
with torch.no_grad():
hc = F.linear(x[:, self.plen:].detach(),
layer._lora2r_A[: layer._lora2r_r].to(x.dtype))
self.h[nm] = (hc.float() * self.mask).sum(1) / self.denom # [B, r]
self.handles.append(layer.register_forward_hook(hook))
return self
def __exit__(self, *exc):
for h in self.handles:
h.remove()
def set_pool(self, plen: int, comp_mask: Float[Tensor, "B L_c"]) -> None:
denom = comp_mask.sum(1)
assert (denom > 0).all(), f"rollout with 0 completion tokens: {denom.tolist()}"
self.plen, self.mask, self.denom = plen, comp_mask.unsqueeze(-1), denom.unsqueeze(-1)
def pooled(self) -> Float[Tensor, "B M r"]:
return torch.stack([self.h[nm] for nm in self.names], dim=1)
@torch.no_grad()
def extract_v_act(
model, tok, wrappers: dict, pairs: list, device, tstat: bool = False,
) -> tuple[Float[Tensor, "M r"], dict[str, Float[Tensor, "P M r"]]]:
"""Forward each pair side once -> (v_act unit rows, raw per-pair pooled acts).
Module order is sorted(wrappers) (= the diag scripts' convention). Tokenization
matches the cached diagnostic features: full = tok(prompt+completion), pool
positions >= len(tok(prompt)). Caller owns model.eval() and quarantine state."""
names = sorted(wrappers)
acts: dict[str, list[Tensor]] = {"hack": [], "clean": []}
with ActCapture(wrappers, names) as cap:
for pair in pairs:
n_prompt = tok(pair.prompt, return_tensors="pt").input_ids.shape[1]
for side, completion in (("hack", pair.hack), ("clean", pair.clean)):
full_ids = tok(pair.prompt + completion,
return_tensors="pt").input_ids.to(device)
L_c = full_ids.shape[1] - n_prompt
assert L_c > 0, f"pair {pair.problem_id} {side}: no completion tokens"
cap.set_pool(n_prompt, torch.ones(1, L_c, device=device))
# logits_to_keep=1: the decoder Linears (whose inputs we hook) run on
# every position regardless; skipping the full-seq lm_head is free.
model(full_ids, logits_to_keep=1)
acts[side].append(cap.pooled()[0].cpu())
H = torch.stack(acts["hack"]) # [P, M, r]
C = torch.stack(acts["clean"])
D = H - C
d = D.mean(0)
if tstat:
se = D.std(0) / D.shape[0] ** 0.5
d = (d / se.clamp_min(1e-12)).clamp(-3.0, 3.0)
v = d / d.norm(dim=-1, keepdim=True).clamp_min(1e-12) # unit per module
return v, {"hack": H, "clean": C}
def haar_unit_rows(shape: tuple[int, int], seed: int) -> Float[Tensor, "M r"]:
"""Random unit rows matching v_act's shape -- the placebo directionality control."""
g = torch.Generator().manual_seed(seed)
d = torch.randn(shape, generator=g)
return d / d.norm(dim=-1, keepdim=True).clamp_min(1e-12)