mirror of
https://github.com/wassname/evil_MoE.git
synced 2026-06-27 19:15:20 +08:00
55937a86fb
git mv src/projected_grpo -> src/vgrout and find-replace the module name in
all imports (.py), `-m projected_grpo.*` invocations (justfile), and the
[project] name (pyproject; setuptools auto-discovers via where=["src"]).
Left RESEARCH_JOURNAL.md untouched: its commands/paths are dated lab notes
tied to past commits, so rewriting them would falsify provenance. Repo dir,
git remote, and absolute paths unchanged.
Verified: `import vgrout` and `python -m vgrout.train --help` load the full
graph; verify_rewards.py + verify_gate_anchor.py (both import vgrout) pass.
Full `just smoke` is blocked upstream by missing gitignored data artifacts
(out/pools/{substrate,teacher_pool}, out/vhack/*smoke*), unrelated to the rename.
153 lines
5.5 KiB
Python
153 lines
5.5 KiB
Python
"""Held-out v_hack validation (spec.md §B validation).
|
|
|
|
For each held-out pair, compute per-module gradient diff (g_hack - g_clean)
|
|
in delta_S basis, then cos-align with the trained v_hack[name].
|
|
|
|
Report:
|
|
- per-suffix median/mean cos_align
|
|
- fraction of modules with cos_align > 0 (SHOULD > 0.5)
|
|
- mean cos_align across modules (target > 0.2)
|
|
|
|
Run: uv run python -m vgrout.verify_vhack_heldout
|
|
"""
|
|
from __future__ import annotations
|
|
|
|
import sys
|
|
from collections import defaultdict
|
|
from dataclasses import dataclass
|
|
from pathlib import Path
|
|
|
|
import json
|
|
|
|
import torch
|
|
import tyro
|
|
from loguru import logger
|
|
from safetensors.torch import save_file
|
|
from tabulate import tabulate
|
|
from transformers import AutoModelForCausalLM, AutoTokenizer
|
|
|
|
from vgrout.antipasto import wrap_model_with_antipasto
|
|
from vgrout.extract_vhack_grad import completion_nll, resolve_dtype
|
|
from vgrout.pairs import PAIRS
|
|
from vgrout.extract_vhack_grad import load_v_hack
|
|
|
|
|
|
CACHE_ROOT = Path("svd_cache")
|
|
OUT_DIR = Path("out")
|
|
|
|
|
|
@dataclass
|
|
class Config:
|
|
model: str = "out/baked/qwen3_4b_rh25"
|
|
dtype: str = "bf16" # must match extract_vhack_grad.py and train.py
|
|
v_hack_path: Path = OUT_DIR / "vhack" / "v_hack_rh25.safetensors"
|
|
out_path: Path = OUT_DIR / "vhack_heldout_cos_rh25.safetensors"
|
|
n_heldout: int = 2
|
|
|
|
|
|
def main(cfg: Config) -> int:
|
|
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
|
|
dtype = resolve_dtype(cfg.dtype)
|
|
logger.info(f"device={device} model={cfg.model} dtype={cfg.dtype}")
|
|
|
|
held = PAIRS[-cfg.n_heldout:]
|
|
logger.info(f"held-out pairs: {len(held)}")
|
|
|
|
tokenizer = AutoTokenizer.from_pretrained(cfg.model)
|
|
model = AutoModelForCausalLM.from_pretrained(
|
|
cfg.model, dtype=dtype, attn_implementation="sdpa"
|
|
).to(device)
|
|
model.eval()
|
|
wrappers = wrap_model_with_antipasto(
|
|
model, model_name=cfg.model, cache_root=CACHE_ROOT, svd_device=device,
|
|
)
|
|
v_hack = load_v_hack(cfg.v_hack_path, cfg.model, wrappers)
|
|
logger.info(f"loaded v_hack: {len(v_hack)} modules")
|
|
|
|
grads_hack: dict[str, list[torch.Tensor]] = defaultdict(list)
|
|
grads_clean: dict[str, list[torch.Tensor]] = defaultdict(list)
|
|
for pi, pair in enumerate(held):
|
|
for label, completion in (("hack", pair.hack), ("clean", pair.clean)):
|
|
model.zero_grad(set_to_none=True)
|
|
loss = completion_nll(model, tokenizer, pair.prompt, completion, device)
|
|
loss.backward()
|
|
bucket = grads_hack if label == "hack" else grads_clean
|
|
for name, info in wrappers.items():
|
|
bucket[name].append(info["delta_S"].grad.detach().float().cpu().clone())
|
|
logger.info(f" held pair {pi+1}/{len(held)} loss={loss.item():.3f}")
|
|
|
|
# per-module cos_align
|
|
cos_by_suffix: dict[str, list[float]] = defaultdict(list)
|
|
all_cos = []
|
|
rows_all = []
|
|
for name, V in v_hack.items():
|
|
# V is [k, r], orthonormal rows. Held-out diff direction should land
|
|
# in the subspace, so report subspace energy fraction ||V·diff/||diff|| || ∈ [0,1].
|
|
gh = torch.stack(grads_hack[name]).mean(0)
|
|
gc = torch.stack(grads_clean[name]).mean(0)
|
|
diff = gh - gc
|
|
nrm = diff.norm()
|
|
if nrm < 1e-12:
|
|
cos = 0.0
|
|
else:
|
|
cos = (V @ (diff / nrm)).norm().item()
|
|
suf = name.split(".")[-1]
|
|
cos_by_suffix[suf].append(cos)
|
|
all_cos.append(cos)
|
|
rows_all.append((name, cos))
|
|
|
|
agg_rows = []
|
|
for suf, vals in sorted(cos_by_suffix.items()):
|
|
t = torch.tensor(vals)
|
|
agg_rows.append({
|
|
"suffix": suf,
|
|
"n": len(vals),
|
|
"mean_energy": f"{t.mean():.3f}",
|
|
"median_energy": f"{t.median():.3f}",
|
|
"min": f"{t.min():.3f}",
|
|
"max": f"{t.max():.3f}",
|
|
})
|
|
|
|
t_all = torch.tensor(all_cos)
|
|
mean_energy = t_all.mean().item()
|
|
median_energy = t_all.median().item()
|
|
cue = "🟢" if median_energy > 0.30 else ("🟡" if median_energy > 0.10 else "🔴")
|
|
|
|
print(f"\nSHOULD: median_energy > 0.30 (held-out diff lands in trained subspace). "
|
|
f"Prior synthetic-pair run got ~0.01 -- that was the smoking gun.\n")
|
|
print(tabulate(agg_rows, headers="keys", tablefmt="tsv", floatfmt=".3f"))
|
|
print()
|
|
print(f"out: {cfg.out_path}")
|
|
print(f"argv: verify_vhack_heldout --model={cfg.model} --v-hack-path={cfg.v_hack_path}")
|
|
print(f"main metric: median_energy={median_energy:.3f} [modules={len(all_cos)}]")
|
|
print(f"{cue} modules={len(all_cos)} mean={mean_energy:.3f} median={median_energy:.3f}")
|
|
|
|
frac_pos = (t_all > 0).float().mean().item()
|
|
mean_cos = mean_energy
|
|
median_cos = median_energy
|
|
|
|
# save for downstream plotting / sanity. Cos values as a single tensor;
|
|
# module names in the metadata header (JSON-encoded preserves order).
|
|
names = [n for n, _ in rows_all]
|
|
cos_t = torch.tensor([c for _, c in rows_all], dtype=torch.float32)
|
|
save_file(
|
|
{"cos": cos_t},
|
|
str(cfg.out_path),
|
|
metadata={"model": cfg.model, "dtype": cfg.dtype, "names": json.dumps(names)},
|
|
)
|
|
|
|
gate_pass = frac_pos > 0.50
|
|
target_pass = mean_cos > 0.20
|
|
if not gate_pass:
|
|
logger.error(f"GATE FAIL: frac>0 = {frac_pos:.3f} <= 0.50")
|
|
return 1
|
|
if not target_pass:
|
|
logger.warning(f"TARGET MISS: mean_cos = {mean_cos:+.3f} <= 0.20 (gate passes but signal weak)")
|
|
else:
|
|
logger.info(f"TARGET PASS: mean_cos = {mean_cos:+.3f} > 0.20")
|
|
return 0
|
|
|
|
|
|
if __name__ == "__main__":
|
|
sys.exit(main(tyro.cli(Config)))
|