mirror of
https://github.com/wassname/evil_MoE.git
synced 2026-06-27 18:04:59 +08:00
jaxtyping: shape contracts for v_hack save/load/apply/project paths
The four touchpoints where v_hack flows through the codebase now carry
shape annotations checked at runtime under BEARTYPE=1:
- proj._project_one_module(g: [r], V: [k, r]) -> (g_proj: [r], ...).
New typed helper, called from project_delta_S_grad's per-module loop.
Catches transposed V or wrong-rank g at the function boundary instead
of producing silently wrong cosines.
- proj.mean_cin_from_grads(grad_dict, v_hack) typed to dicts of [r] and [k, r].
- proj.project_delta_S_grad(v_hack: dict[str, Float[Tensor, "k r"]], ...).
- train.load_v_hack(...) -> dict[str, Float[Tensor, "k r"]].
- extract_vhack_grad.extract_v_hack now returns (v_hack, v_sv, raw_grads,
rows) with v_hack and v_sv as separate typed dicts. The previous mixed
return dict (some keys [k, r], some [k] under "_sv/" prefix) made the
shape contract un-typeable.
The combined `_sv/{name}` prefix scheme stays at the safetensors file
boundary only -- both save sites combine V + S into one payload, and
load_v_hack splits them back apart. In memory, V and S are always
separate.
Module docstring in proj.py now states the shape conventions (r, k, V, g, c).
This commit is contained in:
@@ -31,6 +31,7 @@ from pathlib import Path
|
||||
|
||||
import torch
|
||||
import tyro
|
||||
from jaxtyping import Float
|
||||
from loguru import logger
|
||||
from safetensors.torch import save_file
|
||||
from tabulate import tabulate
|
||||
@@ -89,7 +90,12 @@ def extract_v_hack(
|
||||
tau_axis: float,
|
||||
n_heldout: int,
|
||||
device,
|
||||
) -> tuple[dict[str, torch.Tensor], dict[str, torch.Tensor], list[dict]]:
|
||||
) -> tuple[
|
||||
dict[str, Float[torch.Tensor, "k r"]],
|
||||
dict[str, Float[torch.Tensor, "k"]],
|
||||
dict[str, Float[torch.Tensor, "n_pairs r"]],
|
||||
list[dict],
|
||||
]:
|
||||
"""Run pair-grads + per-module SVD on D = g_hack - g_clean, return v_hack.
|
||||
|
||||
Pure function — caller owns model loading, wrapping, and saving. train.py
|
||||
@@ -100,6 +106,9 @@ def extract_v_hack(
|
||||
v_hack: dict[name -> Tensor[k, r]] (cpu fp32), top-k right singular
|
||||
vectors of D per module, oriented so mean(D @ v_i) > 0. If
|
||||
tau_axis > 0, rows where S_i/S_0 < tau_axis are zeroed.
|
||||
v_sv: dict[name -> Tensor[k]] (cpu fp32), singular values matching v_hack.
|
||||
Saved alongside V under `_sv/{name}` keys so load-time noise-floor
|
||||
filtering works without re-extracting.
|
||||
raw_grads: dict["hack/name"|"clean/name" -> Tensor[n_pairs, r]] for
|
||||
offline analysis (verify_vhack_heldout reads this).
|
||||
diag_rows: per-module diagnostic dicts (sv_top frac, ||D||, etc.).
|
||||
@@ -138,6 +147,7 @@ def extract_v_hack(
|
||||
# SVD(D) gives orthonormal right singular vectors = principal axes of variation
|
||||
# of the hack-clean axis. Top-k generalizes mean-diff (which is the k=1 case).
|
||||
v_hack: dict[str, torch.Tensor] = {}
|
||||
v_sv: dict[str, torch.Tensor] = {}
|
||||
rows = []
|
||||
n_zero = 0
|
||||
k = min(top_k, n_pairs)
|
||||
@@ -172,9 +182,10 @@ def extract_v_hack(
|
||||
v_hack[name] = torch.zeros((k, D.shape[1]), dtype=V.dtype).contiguous()
|
||||
else:
|
||||
v_hack[name] = V.contiguous()
|
||||
# Record singular values so future weighted projection (subtract scaled
|
||||
# by S_i/S_0) is possible without re-extracting. Loader filters _sv/ keys.
|
||||
v_hack[f"_sv/{name}"] = S_d[:k].clone().contiguous()
|
||||
# Record singular values so the load-time noise-floor filter has the
|
||||
# extraction-time S_i per axis without re-extracting. Saved under
|
||||
# `_sv/{name}` keys in the safetensors file (combined at save site).
|
||||
v_sv[name] = S_d[:k].clone().contiguous()
|
||||
sv_top = S_d[:k]
|
||||
sv_total = S_d.sum().clamp_min(1e-12)
|
||||
rows.append({
|
||||
@@ -191,7 +202,7 @@ def extract_v_hack(
|
||||
f"v_hack: modules={n_modules} k_max={k} zero-||D||={n_zero} "
|
||||
f"axes_kept_avg={n_axes_kept_total/max(1,n_modules):.1f} (tau_axis={tau_axis})"
|
||||
)
|
||||
return v_hack, raw_grads, rows
|
||||
return v_hack, v_sv, raw_grads, rows
|
||||
|
||||
|
||||
def main(cfg: Config) -> int:
|
||||
@@ -217,19 +228,21 @@ def main(cfg: Config) -> int:
|
||||
train_pairs = PAIRS[:-cfg.n_heldout]
|
||||
logger.info(f"train pairs: {len(train_pairs)} held: {cfg.n_heldout}")
|
||||
|
||||
v_hack, raw_grads, rows = extract_v_hack(
|
||||
v_hack, v_sv, raw_grads, rows = extract_v_hack(
|
||||
model, tokenizer, wrappers, PAIRS,
|
||||
top_k=cfg.top_k, tau_axis=cfg.tau_axis,
|
||||
n_heldout=cfg.n_heldout, device=device,
|
||||
)
|
||||
# Skip _sv/ keys when counting V tensors (singular values are saved alongside V).
|
||||
n_zero = sum(1 for n, v in v_hack.items() if not n.startswith("_sv/") and v.norm() < 1e-12)
|
||||
n_zero = sum(1 for v in v_hack.values() if v.norm() < 1e-12)
|
||||
k = min(cfg.top_k, len(train_pairs))
|
||||
|
||||
OUT_DIR.mkdir(exist_ok=True)
|
||||
save_file(raw_grads, str(cfg.train_grads_path),
|
||||
metadata={"model": cfg.model, "dtype": cfg.dtype})
|
||||
save_file(v_hack, str(cfg.out_path),
|
||||
# v_hack file layout: bare `{name}` keys hold V[k, r]; `_sv/{name}` keys
|
||||
# hold S[k]. Loader at train.py:load_v_hack splits them back apart.
|
||||
save_payload = {**v_hack, **{f"_sv/{n}": s for n, s in v_sv.items()}}
|
||||
save_file(save_payload, str(cfg.out_path),
|
||||
metadata={"model": cfg.model, "dtype": cfg.dtype, "top_k": str(k),
|
||||
"tau_axis": str(cfg.tau_axis), "schema": "v2_with_sv"})
|
||||
|
||||
|
||||
+53
-27
@@ -1,7 +1,17 @@
|
||||
"""Gradient projection + delta_S grad utilities. Imported by smoke and train."""
|
||||
"""Gradient projection + delta_S grad utilities. Imported by smoke and train.
|
||||
|
||||
Shape conventions in the v_hack / delta_S plumbing (jaxtyping-annotated):
|
||||
- `r` = per-module SVD rank (delta_S dimension; varies per Linear)
|
||||
- `k` = number of v_hack directions kept per module (after top-k slice and
|
||||
global noise-floor filter at load time)
|
||||
- `V` is `[k, r]`, rows orthonormal in R^r and oriented hack-ward
|
||||
- `g = delta_S.grad` is `[r]`
|
||||
- `c = V @ g` is `[k]`
|
||||
"""
|
||||
from __future__ import annotations
|
||||
|
||||
import torch
|
||||
from jaxtyping import Float
|
||||
|
||||
|
||||
def per_token_logps(logits: torch.Tensor, ids: torch.Tensor) -> torch.Tensor:
|
||||
@@ -22,8 +32,8 @@ def per_token_logps(logits: torch.Tensor, ids: torch.Tensor) -> torch.Tensor:
|
||||
|
||||
@torch.no_grad()
|
||||
def mean_cin_from_grads(
|
||||
grad_dict: dict[str, torch.Tensor],
|
||||
v_hack: dict[str, torch.Tensor],
|
||||
grad_dict: dict[str, Float[torch.Tensor, "r"]],
|
||||
v_hack: dict[str, Float[torch.Tensor, "k r"]],
|
||||
) -> float:
|
||||
"""Mean over modules of ||V g|| / ||g||, given a dict of per-module grads.
|
||||
|
||||
@@ -44,10 +54,46 @@ def mean_cin_from_grads(
|
||||
return float(sum(cs) / len(cs)) if cs else float("nan")
|
||||
|
||||
|
||||
def _project_one_module(
|
||||
g: Float[torch.Tensor, "r"],
|
||||
V: Float[torch.Tensor, "k r"],
|
||||
gate_mode: str,
|
||||
preserve_magnitude: bool,
|
||||
) -> tuple[Float[torch.Tensor, "r"], float, float, bool]:
|
||||
"""Per-module top-k removal. Returns (g_proj, cos_in, cos_out, fired).
|
||||
|
||||
Inner helper so the shape contract (g:[r], V:[k,r]) is jaxtyping-checked
|
||||
when BEARTYPE=1 — catches transposed V or wrong-rank g at the boundary
|
||||
instead of producing silently wrong cosines.
|
||||
"""
|
||||
gn = g.norm()
|
||||
if gn < 1e-12:
|
||||
return g, 0.0, 0.0, False
|
||||
c = V @ g # [k]
|
||||
cin = (c.norm() / gn).item()
|
||||
if gate_mode == "no_gate":
|
||||
c_use = c
|
||||
fired = True
|
||||
elif gate_mode == "one_sided":
|
||||
mask = (c > 0).to(c.dtype)
|
||||
c_use = c * mask
|
||||
fired = bool((c_use != 0).any())
|
||||
else:
|
||||
raise ValueError(f"unknown gate_mode={gate_mode!r}")
|
||||
if not fired:
|
||||
return g, cin, cin, False
|
||||
g_proj = g - c_use @ V # [r]
|
||||
gp_n = g_proj.norm()
|
||||
if preserve_magnitude and gp_n > 1e-12:
|
||||
g_proj = g_proj * (gn / gp_n)
|
||||
cout = ((V @ g_proj).norm() / g_proj.norm().clamp_min(1e-12)).item()
|
||||
return g_proj, cin, cout, True
|
||||
|
||||
|
||||
@torch.no_grad()
|
||||
def project_delta_S_grad(
|
||||
wrappers: dict,
|
||||
v_hack: dict[str, torch.Tensor],
|
||||
v_hack: dict[str, Float[torch.Tensor, "k r"]],
|
||||
preserve_magnitude: bool,
|
||||
measure_only: bool = False,
|
||||
gate_mode: str = "one_sided",
|
||||
@@ -84,33 +130,13 @@ def project_delta_S_grad(
|
||||
if name not in v_hack: # module dropped by global noise-floor filter
|
||||
continue
|
||||
V = v_hack[name].to(g.device, dtype=g.dtype) # [k, r]
|
||||
gn = g.norm()
|
||||
if gn < 1e-12:
|
||||
cos_in_list.append(0.0); cos_out_list.append(0.0); continue
|
||||
c = V @ g # [k]
|
||||
cin = c.norm() / gn
|
||||
cos_in_list.append(cin.item())
|
||||
if gate_mode == "no_gate":
|
||||
c_use = c
|
||||
fired = True
|
||||
elif gate_mode == "one_sided":
|
||||
mask = (c > 0).to(c.dtype)
|
||||
c_use = c * mask
|
||||
fired = bool((c_use != 0).any())
|
||||
else:
|
||||
raise ValueError(f"unknown gate_mode={gate_mode!r}")
|
||||
g_proj, cin, cout, fired = _project_one_module(g, V, gate_mode, preserve_magnitude)
|
||||
cos_in_list.append(cin)
|
||||
cos_out_list.append(cout)
|
||||
if fired:
|
||||
g_proj = g - c_use @ V # [r]
|
||||
gp_n = g_proj.norm()
|
||||
if preserve_magnitude and gp_n > 1e-12:
|
||||
g_proj = g_proj * (gn / gp_n)
|
||||
cout = (V @ g_proj).norm() / g_proj.norm().clamp_min(1e-12)
|
||||
cos_out_list.append(cout.item())
|
||||
if not measure_only:
|
||||
info["delta_S"].grad = g_proj
|
||||
n_fired += 1
|
||||
else:
|
||||
cos_out_list.append(cin.item())
|
||||
cin_t = torch.tensor(cos_in_list); cout_t = torch.tensor(cos_out_list)
|
||||
return {
|
||||
"mean_cos_in": cin_t.mean().item(),
|
||||
|
||||
+25
-20
@@ -70,6 +70,7 @@ os.environ.setdefault("PYTORCH_CUDA_ALLOC_CONF", "expandable_segments:True")
|
||||
import torch
|
||||
import torch.nn.functional as F
|
||||
import tyro
|
||||
from jaxtyping import Float
|
||||
from loguru import logger
|
||||
from safetensors import safe_open
|
||||
from safetensors.torch import save_file
|
||||
@@ -235,7 +236,7 @@ def load_problems(n: int) -> list[dict]:
|
||||
def load_v_hack(
|
||||
path: Path, model_name: str, wrappers: dict,
|
||||
k_use: int | None = None, drop_bottom_frac: float = 0.0,
|
||||
) -> dict[str, torch.Tensor]:
|
||||
) -> dict[str, Float[torch.Tensor, "k r"]]:
|
||||
"""Load v_hack (top-k directions) for this wrapped model.
|
||||
|
||||
File schema (v2): bare `{name}` keys hold V[k_max, r]; `_sv/{name}` keys hold
|
||||
@@ -417,13 +418,16 @@ def main(cfg: Config) -> int:
|
||||
from .pairs import PAIRS as VHACK_PAIRS
|
||||
logger.info(f"v_hack cache miss at {v_hack_path}; extracting (~5min)...")
|
||||
model.eval() # match standalone extract: deterministic backward, no dropout
|
||||
v_hack_cpu_dict, raw_grads, _diag = extract_v_hack(
|
||||
v_hack_extracted, v_sv_extracted, _raw_grads, _diag = extract_v_hack(
|
||||
model, tok, wrappers, VHACK_PAIRS,
|
||||
top_k=cfg.v_hack_extract_top_k, tau_axis=cfg.v_hack_tau_axis,
|
||||
n_heldout=2, device=device,
|
||||
)
|
||||
OUT_DIR.mkdir(exist_ok=True)
|
||||
save_file(v_hack_cpu_dict, str(v_hack_path),
|
||||
# Combine V and S under one safetensors file with `_sv/{name}` prefix
|
||||
# for the singular values. load_v_hack splits them back apart.
|
||||
save_payload = {**v_hack_extracted, **{f"_sv/{n}": s for n, s in v_sv_extracted.items()}}
|
||||
save_file(save_payload, str(v_hack_path),
|
||||
metadata={"model": model_name, "dtype": "bf16",
|
||||
"top_k": str(min(cfg.v_hack_extract_top_k, len(VHACK_PAIRS) - 2)),
|
||||
"tau_axis": str(cfg.v_hack_tau_axis), "schema": "v2_with_sv"})
|
||||
@@ -598,23 +602,24 @@ def main(cfg: Config) -> int:
|
||||
if cfg.arm == "vanilla"
|
||||
else "cout=subspace energy fraction in grad after projection"
|
||||
)
|
||||
caption = (
|
||||
"table columns: "
|
||||
"step=GRPO step; "
|
||||
"ref_eq=vanilla-equivalent step (cum_gens / 256); "
|
||||
"rew=mean combined reward; rew_s=student mean reward; "
|
||||
"sprd=reward spread>0 (T/F; F means zero-variance bail fired and step was skipped); "
|
||||
"N=rollouts; "
|
||||
"gt_s/gt_t=ground-truth passes (student/teacher); "
|
||||
"hack_s/hack_t=hack-flagged rollouts (student/teacher); "
|
||||
"lp_s/lp_t=mean per-token student/teacher gen_logp under current student (diagnostic, no IS correction); "
|
||||
"loss=mean GRPO loss; "
|
||||
"cin=v_hack subspace energy fraction in grad before projection; "
|
||||
"cin_s/cin_t=cin on student-only/teacher-only gradient; "
|
||||
f"{cout_def}; "
|
||||
"fired=fraction of modules where projection fired; "
|
||||
"gen/fb/t_rew=generation/forward+backward/reward-grading wall-time (s); sec=total step wall-time (s)."
|
||||
)
|
||||
caption = """
|
||||
table columns:
|
||||
- step= GRPO step;
|
||||
- ref_eq= vanilla-equivalent step (cum_gens / 256);
|
||||
- rew= mean combined reward; rew_s=student mean reward;
|
||||
- sprd= reward spread>0 (T/F; F means zero-variance bail fired and step was skipped);
|
||||
- N= rollouts;
|
||||
- gt_s/gt_t= ground-truth passes (student/teacher);
|
||||
- hack_s/hack_t=hack-flagged rollouts (student/teacher);
|
||||
- lp_s/lp_t= mean per-token student/teacher gen_logp under current student (diagnostic, no IS correction);
|
||||
- loss= mean GRPO loss;
|
||||
- cin= v_hack subspace energy fraction in grad before projection;
|
||||
- cin_s/cin_t= cin on student-only/teacher-only gradient;
|
||||
- "{cout_def};
|
||||
- fired=fraction of modules where projection fired;
|
||||
- gen/fb/t_rew=generation/forward+backward/reward-grading wall-time (s); sec=total step wall-time (s)
|
||||
|
||||
"""
|
||||
logger.info(caption + "\n\n" + _fmt_header())
|
||||
|
||||
OUT_DIR.mkdir(exist_ok=True)
|
||||
|
||||
Reference in New Issue
Block a user