Files
evil_MoE/src/projected_grpo/extract_vhack.py
T
2026-05-23 11:26:39 +08:00

83 lines
2.6 KiB
Python

"""Extract v_hack from contrastive pairs of hidden states.
Per Wu-Tang (2026, arXiv 2604.01476) §3.1:
d = (1/N) * sum_i (h_i^+ - h_i^-)
where h^+ are last-token hidden states from hack-flavored prompts and h^- from
clean ones, taken at intermediate-to-late layers (60-75% of model depth).
Validation: held-out separation accuracy > 90%.
"""
from __future__ import annotations
from dataclasses import dataclass
import torch
from jaxtyping import Float
from loguru import logger
from torch import Tensor
@dataclass
class VHackResult:
v_hack: Float[Tensor, "d"] # unit-normed direction
val_accuracy: float # held-out hack-vs-clean separation accuracy
layer_idx: int
n_train: int
n_val: int
@torch.no_grad()
def collect_last_token_hidden(
model,
tokenizer,
prompts: list[str],
layer_idx: int,
device: str = "cuda",
) -> Float[Tensor, "n d"]:
"""Forward each prompt, return last-token hidden state at layer_idx."""
hs = []
for p in prompts:
ids = tokenizer(p, return_tensors="pt").to(device)
out = model(**ids, output_hidden_states=True)
# out.hidden_states is tuple of (n_layers+1,) tensors of shape (1, seq, d)
h = out.hidden_states[layer_idx][0, -1, :].float().cpu() # "d" — fp32 for stable v_hack
hs.append(h)
return torch.stack(hs, dim=0)
def extract_vhack(
h_hack_train: Float[Tensor, "n_train d"],
h_clean_train: Float[Tensor, "n_train d"],
h_hack_val: Float[Tensor, "n_val d"],
h_clean_val: Float[Tensor, "n_val d"],
layer_idx: int,
) -> VHackResult:
"""Mean-difference direction with held-out validation."""
v = (h_hack_train.mean(dim=0) - h_clean_train.mean(dim=0))
v = v / (v.norm() + 1e-12)
# Validate: projection score on hack should exceed clean.
s_hack = h_hack_val @ v
s_clean = h_clean_val @ v
# paired accuracy: each (hack, clean) pair, hack should score higher
correct = (s_hack > s_clean).float().mean().item()
logger.info(
f"v_hack extracted layer={layer_idx} n_train={len(h_hack_train)} "
f"n_val={len(h_hack_val)} val_acc={correct:.3f} "
f"SHOULD>0.9 on a trained model: v_hack should separate hack from clean. "
f"On tiny-random/untrained models val_acc~0.5 (no semantic structure yet), "
f"which is fine for smoke -- the projection mechanism is what we test there."
)
return VHackResult(
v_hack=v,
val_accuracy=correct,
layer_idx=layer_idx,
n_train=len(h_hack_train),
n_val=len(h_hack_val),
)