jaxtyping: shape contracts for v_hack save/load/apply/project paths

The four touchpoints where v_hack flows through the codebase now carry
shape annotations checked at runtime under BEARTYPE=1:

- proj._project_one_module(g: [r], V: [k, r]) -> (g_proj: [r], ...).
  New typed helper, called from project_delta_S_grad's per-module loop.
  Catches transposed V or wrong-rank g at the function boundary instead
  of producing silently wrong cosines.
- proj.mean_cin_from_grads(grad_dict, v_hack) typed to dicts of [r] and [k, r].
- proj.project_delta_S_grad(v_hack: dict[str, Float[Tensor, "k r"]], ...).
- train.load_v_hack(...) -> dict[str, Float[Tensor, "k r"]].
- extract_vhack_grad.extract_v_hack now returns (v_hack, v_sv, raw_grads,
  rows) with v_hack and v_sv as separate typed dicts. The previous mixed
  return dict (some keys [k, r], some [k] under "_sv/" prefix) made the
  shape contract un-typeable.

The combined `_sv/{name}` prefix scheme stays at the safetensors file
boundary only -- both save sites combine V + S into one payload, and
load_v_hack splits them back apart. In memory, V and S are always
separate.

Module docstring in proj.py now states the shape conventions (r, k, V, g, c).
This commit is contained in:
wassname
2026-05-27 23:20:38 +00:00
parent 3fb8202138
commit 577f075611
3 changed files with 100 additions and 56 deletions
+22 -9
View File
@@ -31,6 +31,7 @@ from pathlib import Path
import torch
import tyro
from jaxtyping import Float
from loguru import logger
from safetensors.torch import save_file
from tabulate import tabulate
@@ -89,7 +90,12 @@ def extract_v_hack(
tau_axis: float,
n_heldout: int,
device,
) -> tuple[dict[str, torch.Tensor], dict[str, torch.Tensor], list[dict]]:
) -> tuple[
dict[str, Float[torch.Tensor, "k r"]],
dict[str, Float[torch.Tensor, "k"]],
dict[str, Float[torch.Tensor, "n_pairs r"]],
list[dict],
]:
"""Run pair-grads + per-module SVD on D = g_hack - g_clean, return v_hack.
Pure function — caller owns model loading, wrapping, and saving. train.py
@@ -100,6 +106,9 @@ def extract_v_hack(
v_hack: dict[name -> Tensor[k, r]] (cpu fp32), top-k right singular
vectors of D per module, oriented so mean(D @ v_i) > 0. If
tau_axis > 0, rows where S_i/S_0 < tau_axis are zeroed.
v_sv: dict[name -> Tensor[k]] (cpu fp32), singular values matching v_hack.
Saved alongside V under `_sv/{name}` keys so load-time noise-floor
filtering works without re-extracting.
raw_grads: dict["hack/name"|"clean/name" -> Tensor[n_pairs, r]] for
offline analysis (verify_vhack_heldout reads this).
diag_rows: per-module diagnostic dicts (sv_top frac, ||D||, etc.).
@@ -138,6 +147,7 @@ def extract_v_hack(
# SVD(D) gives orthonormal right singular vectors = principal axes of variation
# of the hack-clean axis. Top-k generalizes mean-diff (which is the k=1 case).
v_hack: dict[str, torch.Tensor] = {}
v_sv: dict[str, torch.Tensor] = {}
rows = []
n_zero = 0
k = min(top_k, n_pairs)
@@ -172,9 +182,10 @@ def extract_v_hack(
v_hack[name] = torch.zeros((k, D.shape[1]), dtype=V.dtype).contiguous()
else:
v_hack[name] = V.contiguous()
# Record singular values so future weighted projection (subtract scaled
# by S_i/S_0) is possible without re-extracting. Loader filters _sv/ keys.
v_hack[f"_sv/{name}"] = S_d[:k].clone().contiguous()
# Record singular values so the load-time noise-floor filter has the
# extraction-time S_i per axis without re-extracting. Saved under
# `_sv/{name}` keys in the safetensors file (combined at save site).
v_sv[name] = S_d[:k].clone().contiguous()
sv_top = S_d[:k]
sv_total = S_d.sum().clamp_min(1e-12)
rows.append({
@@ -191,7 +202,7 @@ def extract_v_hack(
f"v_hack: modules={n_modules} k_max={k} zero-||D||={n_zero} "
f"axes_kept_avg={n_axes_kept_total/max(1,n_modules):.1f} (tau_axis={tau_axis})"
)
return v_hack, raw_grads, rows
return v_hack, v_sv, raw_grads, rows
def main(cfg: Config) -> int:
@@ -217,19 +228,21 @@ def main(cfg: Config) -> int:
train_pairs = PAIRS[:-cfg.n_heldout]
logger.info(f"train pairs: {len(train_pairs)} held: {cfg.n_heldout}")
v_hack, raw_grads, rows = extract_v_hack(
v_hack, v_sv, raw_grads, rows = extract_v_hack(
model, tokenizer, wrappers, PAIRS,
top_k=cfg.top_k, tau_axis=cfg.tau_axis,
n_heldout=cfg.n_heldout, device=device,
)
# Skip _sv/ keys when counting V tensors (singular values are saved alongside V).
n_zero = sum(1 for n, v in v_hack.items() if not n.startswith("_sv/") and v.norm() < 1e-12)
n_zero = sum(1 for v in v_hack.values() if v.norm() < 1e-12)
k = min(cfg.top_k, len(train_pairs))
OUT_DIR.mkdir(exist_ok=True)
save_file(raw_grads, str(cfg.train_grads_path),
metadata={"model": cfg.model, "dtype": cfg.dtype})
save_file(v_hack, str(cfg.out_path),
# v_hack file layout: bare `{name}` keys hold V[k, r]; `_sv/{name}` keys
# hold S[k]. Loader at train.py:load_v_hack splits them back apart.
save_payload = {**v_hack, **{f"_sv/{n}": s for n, s in v_sv.items()}}
save_file(save_payload, str(cfg.out_path),
metadata={"model": cfg.model, "dtype": cfg.dtype, "top_k": str(k),
"tau_axis": str(cfg.tau_axis), "schema": "v2_with_sv"})
+53 -27
View File
@@ -1,7 +1,17 @@
"""Gradient projection + delta_S grad utilities. Imported by smoke and train."""
"""Gradient projection + delta_S grad utilities. Imported by smoke and train.
Shape conventions in the v_hack / delta_S plumbing (jaxtyping-annotated):
- `r` = per-module SVD rank (delta_S dimension; varies per Linear)
- `k` = number of v_hack directions kept per module (after top-k slice and
global noise-floor filter at load time)
- `V` is `[k, r]`, rows orthonormal in R^r and oriented hack-ward
- `g = delta_S.grad` is `[r]`
- `c = V @ g` is `[k]`
"""
from __future__ import annotations
import torch
from jaxtyping import Float
def per_token_logps(logits: torch.Tensor, ids: torch.Tensor) -> torch.Tensor:
@@ -22,8 +32,8 @@ def per_token_logps(logits: torch.Tensor, ids: torch.Tensor) -> torch.Tensor:
@torch.no_grad()
def mean_cin_from_grads(
grad_dict: dict[str, torch.Tensor],
v_hack: dict[str, torch.Tensor],
grad_dict: dict[str, Float[torch.Tensor, "r"]],
v_hack: dict[str, Float[torch.Tensor, "k r"]],
) -> float:
"""Mean over modules of ||V g|| / ||g||, given a dict of per-module grads.
@@ -44,10 +54,46 @@ def mean_cin_from_grads(
return float(sum(cs) / len(cs)) if cs else float("nan")
def _project_one_module(
g: Float[torch.Tensor, "r"],
V: Float[torch.Tensor, "k r"],
gate_mode: str,
preserve_magnitude: bool,
) -> tuple[Float[torch.Tensor, "r"], float, float, bool]:
"""Per-module top-k removal. Returns (g_proj, cos_in, cos_out, fired).
Inner helper so the shape contract (g:[r], V:[k,r]) is jaxtyping-checked
when BEARTYPE=1 — catches transposed V or wrong-rank g at the boundary
instead of producing silently wrong cosines.
"""
gn = g.norm()
if gn < 1e-12:
return g, 0.0, 0.0, False
c = V @ g # [k]
cin = (c.norm() / gn).item()
if gate_mode == "no_gate":
c_use = c
fired = True
elif gate_mode == "one_sided":
mask = (c > 0).to(c.dtype)
c_use = c * mask
fired = bool((c_use != 0).any())
else:
raise ValueError(f"unknown gate_mode={gate_mode!r}")
if not fired:
return g, cin, cin, False
g_proj = g - c_use @ V # [r]
gp_n = g_proj.norm()
if preserve_magnitude and gp_n > 1e-12:
g_proj = g_proj * (gn / gp_n)
cout = ((V @ g_proj).norm() / g_proj.norm().clamp_min(1e-12)).item()
return g_proj, cin, cout, True
@torch.no_grad()
def project_delta_S_grad(
wrappers: dict,
v_hack: dict[str, torch.Tensor],
v_hack: dict[str, Float[torch.Tensor, "k r"]],
preserve_magnitude: bool,
measure_only: bool = False,
gate_mode: str = "one_sided",
@@ -84,33 +130,13 @@ def project_delta_S_grad(
if name not in v_hack: # module dropped by global noise-floor filter
continue
V = v_hack[name].to(g.device, dtype=g.dtype) # [k, r]
gn = g.norm()
if gn < 1e-12:
cos_in_list.append(0.0); cos_out_list.append(0.0); continue
c = V @ g # [k]
cin = c.norm() / gn
cos_in_list.append(cin.item())
if gate_mode == "no_gate":
c_use = c
fired = True
elif gate_mode == "one_sided":
mask = (c > 0).to(c.dtype)
c_use = c * mask
fired = bool((c_use != 0).any())
else:
raise ValueError(f"unknown gate_mode={gate_mode!r}")
g_proj, cin, cout, fired = _project_one_module(g, V, gate_mode, preserve_magnitude)
cos_in_list.append(cin)
cos_out_list.append(cout)
if fired:
g_proj = g - c_use @ V # [r]
gp_n = g_proj.norm()
if preserve_magnitude and gp_n > 1e-12:
g_proj = g_proj * (gn / gp_n)
cout = (V @ g_proj).norm() / g_proj.norm().clamp_min(1e-12)
cos_out_list.append(cout.item())
if not measure_only:
info["delta_S"].grad = g_proj
n_fired += 1
else:
cos_out_list.append(cin.item())
cin_t = torch.tensor(cos_in_list); cout_t = torch.tensor(cos_out_list)
return {
"mean_cos_in": cin_t.mean().item(),
+25 -20
View File
@@ -70,6 +70,7 @@ os.environ.setdefault("PYTORCH_CUDA_ALLOC_CONF", "expandable_segments:True")
import torch
import torch.nn.functional as F
import tyro
from jaxtyping import Float
from loguru import logger
from safetensors import safe_open
from safetensors.torch import save_file
@@ -235,7 +236,7 @@ def load_problems(n: int) -> list[dict]:
def load_v_hack(
path: Path, model_name: str, wrappers: dict,
k_use: int | None = None, drop_bottom_frac: float = 0.0,
) -> dict[str, torch.Tensor]:
) -> dict[str, Float[torch.Tensor, "k r"]]:
"""Load v_hack (top-k directions) for this wrapped model.
File schema (v2): bare `{name}` keys hold V[k_max, r]; `_sv/{name}` keys hold
@@ -417,13 +418,16 @@ def main(cfg: Config) -> int:
from .pairs import PAIRS as VHACK_PAIRS
logger.info(f"v_hack cache miss at {v_hack_path}; extracting (~5min)...")
model.eval() # match standalone extract: deterministic backward, no dropout
v_hack_cpu_dict, raw_grads, _diag = extract_v_hack(
v_hack_extracted, v_sv_extracted, _raw_grads, _diag = extract_v_hack(
model, tok, wrappers, VHACK_PAIRS,
top_k=cfg.v_hack_extract_top_k, tau_axis=cfg.v_hack_tau_axis,
n_heldout=2, device=device,
)
OUT_DIR.mkdir(exist_ok=True)
save_file(v_hack_cpu_dict, str(v_hack_path),
# Combine V and S under one safetensors file with `_sv/{name}` prefix
# for the singular values. load_v_hack splits them back apart.
save_payload = {**v_hack_extracted, **{f"_sv/{n}": s for n, s in v_sv_extracted.items()}}
save_file(save_payload, str(v_hack_path),
metadata={"model": model_name, "dtype": "bf16",
"top_k": str(min(cfg.v_hack_extract_top_k, len(VHACK_PAIRS) - 2)),
"tau_axis": str(cfg.v_hack_tau_axis), "schema": "v2_with_sv"})
@@ -598,23 +602,24 @@ def main(cfg: Config) -> int:
if cfg.arm == "vanilla"
else "cout=subspace energy fraction in grad after projection"
)
caption = (
"table columns: "
"step=GRPO step; "
"ref_eq=vanilla-equivalent step (cum_gens / 256); "
"rew=mean combined reward; rew_s=student mean reward; "
"sprd=reward spread>0 (T/F; F means zero-variance bail fired and step was skipped); "
"N=rollouts; "
"gt_s/gt_t=ground-truth passes (student/teacher); "
"hack_s/hack_t=hack-flagged rollouts (student/teacher); "
"lp_s/lp_t=mean per-token student/teacher gen_logp under current student (diagnostic, no IS correction); "
"loss=mean GRPO loss; "
"cin=v_hack subspace energy fraction in grad before projection; "
"cin_s/cin_t=cin on student-only/teacher-only gradient; "
f"{cout_def}; "
"fired=fraction of modules where projection fired; "
"gen/fb/t_rew=generation/forward+backward/reward-grading wall-time (s); sec=total step wall-time (s)."
)
caption = """
table columns:
- step= GRPO step;
- ref_eq= vanilla-equivalent step (cum_gens / 256);
- rew= mean combined reward; rew_s=student mean reward;
- sprd= reward spread>0 (T/F; F means zero-variance bail fired and step was skipped);
- N= rollouts;
- gt_s/gt_t= ground-truth passes (student/teacher);
- hack_s/hack_t=hack-flagged rollouts (student/teacher);
- lp_s/lp_t= mean per-token student/teacher gen_logp under current student (diagnostic, no IS correction);
- loss= mean GRPO loss;
- cin= v_hack subspace energy fraction in grad before projection;
- cin_s/cin_t= cin on student-only/teacher-only gradient;
- "{cout_def};
- fired=fraction of modules where projection fired;
- gen/fb/t_rew=generation/forward+backward/reward-grading wall-time (s); sec=total step wall-time (s)
"""
logger.info(caption + "\n\n" + _fmt_header())
OUT_DIR.mkdir(exist_ok=True)