mirror of
https://github.com/wassname/evil_MoE.git
synced 2026-06-27 18:59:35 +08:00
feat(routeA): forward-only v_act extraction + verify gate vs cached diag features
extract_v_act: pooled completion-token bottleneck act per module, v = unit-norm mean pair diff (tstat flag default off, null at n=8 pairs). ActCapture is the single hook shared by extraction, the live gate, and verification. verify_v_act (pueue #24): rel diff 7.3e-4 hack / 7.7e-4 clean vs out/diag/pinning_feats.pt on the v3 first_hack ckpt; min per-module cos 0.99997. Co-Authored-By: Claudypoo <288921227+claudypoo@users.noreply.github.com>
This commit is contained in:
@@ -0,0 +1,89 @@
|
||||
"""Verify extract_v_act reproduces the cached diagnostic pair features.
|
||||
|
||||
The journal-(d)/(e) AUROC evidence for the act gate was measured by
|
||||
scripts/diag_pinning.py (ActTap) on cached features. The training extractor
|
||||
(src/vgrout/extract_vhack_act.py) must produce the SAME pooled acts on the same
|
||||
checkpoint + pairs, else the routeA gate is not the thing the diagnostics
|
||||
validated. GPU required (bf16 + flash_attention_2 to match the cache); not part
|
||||
of `just smoke` -- run once after any extractor change:
|
||||
|
||||
uv run python scripts/verify_v_act.py
|
||||
"""
|
||||
from __future__ import annotations
|
||||
|
||||
import sys
|
||||
from dataclasses import dataclass
|
||||
from pathlib import Path
|
||||
|
||||
import torch
|
||||
import tyro
|
||||
from loguru import logger
|
||||
from safetensors.torch import load_file
|
||||
|
||||
from vgrout.extract_vhack_act import extract_v_act
|
||||
from vgrout.lora2r import wrap_model_with_lora2r
|
||||
from vgrout.pairs import load_pairs
|
||||
|
||||
|
||||
@dataclass
|
||||
class Cfg:
|
||||
feats: Path = Path("out/diag/pinning_feats.pt")
|
||||
run_dir: Path = Path("out/runs/20260611T003538_fast_vanilla_lora2r_seed43_l2r_vanilla_s43_v3")
|
||||
ckpt: str = "first_hack"
|
||||
pairs: Path = Path("data/pairs/hack_pairs.md#all-in-one")
|
||||
rel_tol: float = 1e-2 # bf16 forward, fp32 pooling; expect ~0
|
||||
|
||||
|
||||
def main(cfg: Cfg) -> int:
|
||||
import json
|
||||
import struct
|
||||
from transformers import AutoModelForCausalLM, AutoTokenizer
|
||||
|
||||
fe = torch.load(cfg.feats, weights_only=False)
|
||||
ckpt_path = cfg.run_dir / f"{cfg.ckpt}.safetensors"
|
||||
with open(ckpt_path, "rb") as f:
|
||||
meta = json.loads(f.read(struct.unpack("<Q", f.read(8))[0])).get("__metadata__", {})
|
||||
run_cfg = json.loads(meta.get("cfg", "{}"))
|
||||
model_name, r, init_seed = run_cfg["model"], run_cfg["lora_r"], run_cfg["lora_init_seed"]
|
||||
|
||||
device = torch.device("cuda")
|
||||
tok = AutoTokenizer.from_pretrained(model_name)
|
||||
model = AutoModelForCausalLM.from_pretrained(
|
||||
model_name, dtype=torch.bfloat16, attn_implementation="flash_attention_2").to(device)
|
||||
model.config.use_cache = False
|
||||
wrappers = wrap_model_with_lora2r(model, r=r, init_seed=init_seed)
|
||||
assert sorted(wrappers) == fe["names"], "module set/order drifted vs the cache"
|
||||
sd = load_file(str(ckpt_path))
|
||||
for nm, info in wrappers.items():
|
||||
info["A"].data.copy_(sd[f"A/{nm}"].to(device, torch.float32))
|
||||
info["B"].data.copy_(sd[f"B/{nm}"].to(device, torch.float32))
|
||||
model.eval()
|
||||
|
||||
pairs = load_pairs(cfg.pairs)
|
||||
assert [p.problem_id for p in pairs] == fe["pair_ids"], "pairset drifted vs the cache"
|
||||
logger.info(f"extracting acts for {len(pairs)} pairs on {ckpt_path.name} ({model_name})")
|
||||
v, acts = extract_v_act(model, tok, wrappers, pairs, device)
|
||||
|
||||
print("SHOULD: rel diff ~0 for both sides (same ckpt/pairs/dtype/kernels) ELSE the "
|
||||
"training extractor is not what diag_pinning measured.")
|
||||
ok = True
|
||||
for side in ("hack", "clean"):
|
||||
ref = fe["pair_feats"][("act", side)].float()
|
||||
got = acts[side]
|
||||
rel = (got - ref).norm() / ref.norm()
|
||||
mx = (got - ref).abs().max()
|
||||
print(f" {side:>5}: rel diff={rel:.2e} max abs diff={mx:.2e} shape={tuple(got.shape)}")
|
||||
ok &= rel.item() < cfg.rel_tol
|
||||
# v_act from the cached feats (diag's _v_from over all pairs) must match too.
|
||||
D = (fe["pair_feats"][("act", "hack")] - fe["pair_feats"][("act", "clean")]).float().mean(0)
|
||||
v_ref = D / D.norm(dim=-1, keepdim=True).clamp_min(1e-12)
|
||||
cos = (v * v_ref).sum(-1)
|
||||
print(f" v_act vs cache-derived v: min per-module cos={cos.min():.6f} "
|
||||
f"(SHOULD ~1.0); unit rows max|1-||v|||={(v.norm(dim=-1) - 1).abs().max():.1e}")
|
||||
ok &= cos.min().item() > 0.999
|
||||
print(("🟢 PASS" if ok else "🔴 FAIL") + " verify_v_act")
|
||||
return 0 if ok else 1
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
sys.exit(main(tyro.cli(Cfg)))
|
||||
@@ -0,0 +1,109 @@
|
||||
"""Activation-side per-module v_act extraction (the routeA direction source).
|
||||
|
||||
For each authored (hack, clean) pair we forward prompt+completion once per side,
|
||||
no backward, and capture the deployed bottleneck activation h = A[:r] @ x per
|
||||
wrapped module, mean-pooled over completion tokens. The per-module direction is
|
||||
the unit-normalized mean paired difference:
|
||||
|
||||
v_act[m] = unitnorm( mean_pairs( h_hack[m] - h_clean[m] ) ) [r]
|
||||
|
||||
Diagnostic basis: RESEARCH_JOURNAL 2026-06-11 (d) -- on the A>0 contrast this
|
||||
score replicates across three emergence windows (AUROC 0.87/0.75/0.75) while the
|
||||
gradient score decays to chance; see docs/spec/20260611_act_gate_spec.md.
|
||||
`tstat=True` divides the mean by its standard error over pairs (clamped |t|<=3)
|
||||
before normalizing; null at the current 8 pairs, kept for larger pair sets.
|
||||
|
||||
ActCapture is the one hook used everywhere acts are read: pair extraction here,
|
||||
the live gate in train.py, and scripts/verify_v_act.py. It pools inside the hook
|
||||
(completion-token mean) so nothing of size [B, L, d_in] is retained.
|
||||
"""
|
||||
from __future__ import annotations
|
||||
|
||||
import torch
|
||||
import torch.nn.functional as F
|
||||
from jaxtyping import Float
|
||||
from torch import Tensor
|
||||
|
||||
|
||||
class ActCapture:
|
||||
"""Forward hooks stashing the pooled deployed bottleneck h = A[:r] @ x per module.
|
||||
|
||||
Call set_pool(plen, comp_mask) before each forward: positions < plen are prompt,
|
||||
comp_mask [B, L-plen] marks real (non-pad) completion tokens. pooled() then
|
||||
returns the completion-token mean act [B, M, r] in module order `names`."""
|
||||
|
||||
def __init__(self, wrappers: dict, names: list[str]):
|
||||
self.wrappers, self.names = wrappers, names
|
||||
self.h: dict[str, Tensor] = {}
|
||||
self.handles = []
|
||||
self.plen: int | None = None
|
||||
self.mask: Tensor | None = None # [B, L_c, 1] float
|
||||
self.denom: Tensor | None = None # [B, 1]
|
||||
|
||||
def __enter__(self):
|
||||
for nm in self.names:
|
||||
layer = self.wrappers[nm]["layer"]
|
||||
|
||||
def hook(layer, args, out, nm=nm):
|
||||
(x,) = args # [B, L, d_in]
|
||||
with torch.no_grad():
|
||||
hc = F.linear(x[:, self.plen:].detach(),
|
||||
layer._lora2r_A[: layer._lora2r_r].to(x.dtype))
|
||||
self.h[nm] = (hc.float() * self.mask).sum(1) / self.denom # [B, r]
|
||||
|
||||
self.handles.append(layer.register_forward_hook(hook))
|
||||
return self
|
||||
|
||||
def __exit__(self, *exc):
|
||||
for h in self.handles:
|
||||
h.remove()
|
||||
|
||||
def set_pool(self, plen: int, comp_mask: Float[Tensor, "B L_c"]) -> None:
|
||||
denom = comp_mask.sum(1)
|
||||
assert (denom > 0).all(), f"rollout with 0 completion tokens: {denom.tolist()}"
|
||||
self.plen, self.mask, self.denom = plen, comp_mask.unsqueeze(-1), denom.unsqueeze(-1)
|
||||
|
||||
def pooled(self) -> Float[Tensor, "B M r"]:
|
||||
return torch.stack([self.h[nm] for nm in self.names], dim=1)
|
||||
|
||||
|
||||
@torch.no_grad()
|
||||
def extract_v_act(
|
||||
model, tok, wrappers: dict, pairs: list, device, tstat: bool = False,
|
||||
) -> tuple[Float[Tensor, "M r"], dict[str, Float[Tensor, "P M r"]]]:
|
||||
"""Forward each pair side once -> (v_act unit rows, raw per-pair pooled acts).
|
||||
|
||||
Module order is sorted(wrappers) (= the diag scripts' convention). Tokenization
|
||||
matches the cached diagnostic features: full = tok(prompt+completion), pool
|
||||
positions >= len(tok(prompt)). Caller owns model.eval() and quarantine state."""
|
||||
names = sorted(wrappers)
|
||||
acts: dict[str, list[Tensor]] = {"hack": [], "clean": []}
|
||||
with ActCapture(wrappers, names) as cap:
|
||||
for pair in pairs:
|
||||
n_prompt = tok(pair.prompt, return_tensors="pt").input_ids.shape[1]
|
||||
for side, completion in (("hack", pair.hack), ("clean", pair.clean)):
|
||||
full_ids = tok(pair.prompt + completion,
|
||||
return_tensors="pt").input_ids.to(device)
|
||||
L_c = full_ids.shape[1] - n_prompt
|
||||
assert L_c > 0, f"pair {pair.problem_id} {side}: no completion tokens"
|
||||
cap.set_pool(n_prompt, torch.ones(1, L_c, device=device))
|
||||
# logits_to_keep=1: the decoder Linears (whose inputs we hook) run on
|
||||
# every position regardless; skipping the full-seq lm_head is free.
|
||||
model(full_ids, logits_to_keep=1)
|
||||
acts[side].append(cap.pooled()[0].cpu())
|
||||
H = torch.stack(acts["hack"]) # [P, M, r]
|
||||
C = torch.stack(acts["clean"])
|
||||
D = H - C
|
||||
d = D.mean(0)
|
||||
if tstat:
|
||||
se = D.std(0) / D.shape[0] ** 0.5
|
||||
d = (d / se.clamp_min(1e-12)).clamp(-3.0, 3.0)
|
||||
v = d / d.norm(dim=-1, keepdim=True).clamp_min(1e-12) # unit per module
|
||||
return v, {"hack": H, "clean": C}
|
||||
|
||||
|
||||
def haar_unit_rows(shape: tuple[int, int], seed: int) -> Float[Tensor, "M r"]:
|
||||
"""Random unit rows matching v_act's shape -- the placebo directionality control."""
|
||||
g = torch.Generator().manual_seed(seed)
|
||||
d = torch.randn(shape, generator=g)
|
||||
return d / d.norm(dim=-1, keepdim=True).clamp_min(1e-12)
|
||||
Reference in New Issue
Block a user