diff --git a/src/projected_grpo/extract_vhack_grad.py b/src/projected_grpo/extract_vhack_grad.py index eb4bbe4..69d67a7 100644 --- a/src/projected_grpo/extract_vhack_grad.py +++ b/src/projected_grpo/extract_vhack_grad.py @@ -31,6 +31,7 @@ from pathlib import Path import torch import tyro +from jaxtyping import Float from loguru import logger from safetensors.torch import save_file from tabulate import tabulate @@ -89,7 +90,12 @@ def extract_v_hack( tau_axis: float, n_heldout: int, device, -) -> tuple[dict[str, torch.Tensor], dict[str, torch.Tensor], list[dict]]: +) -> tuple[ + dict[str, Float[torch.Tensor, "k r"]], + dict[str, Float[torch.Tensor, "k"]], + dict[str, Float[torch.Tensor, "n_pairs r"]], + list[dict], +]: """Run pair-grads + per-module SVD on D = g_hack - g_clean, return v_hack. Pure function — caller owns model loading, wrapping, and saving. train.py @@ -100,6 +106,9 @@ def extract_v_hack( v_hack: dict[name -> Tensor[k, r]] (cpu fp32), top-k right singular vectors of D per module, oriented so mean(D @ v_i) > 0. If tau_axis > 0, rows where S_i/S_0 < tau_axis are zeroed. + v_sv: dict[name -> Tensor[k]] (cpu fp32), singular values matching v_hack. + Saved alongside V under `_sv/{name}` keys so load-time noise-floor + filtering works without re-extracting. raw_grads: dict["hack/name"|"clean/name" -> Tensor[n_pairs, r]] for offline analysis (verify_vhack_heldout reads this). diag_rows: per-module diagnostic dicts (sv_top frac, ||D||, etc.). @@ -138,6 +147,7 @@ def extract_v_hack( # SVD(D) gives orthonormal right singular vectors = principal axes of variation # of the hack-clean axis. Top-k generalizes mean-diff (which is the k=1 case). v_hack: dict[str, torch.Tensor] = {} + v_sv: dict[str, torch.Tensor] = {} rows = [] n_zero = 0 k = min(top_k, n_pairs) @@ -172,9 +182,10 @@ def extract_v_hack( v_hack[name] = torch.zeros((k, D.shape[1]), dtype=V.dtype).contiguous() else: v_hack[name] = V.contiguous() - # Record singular values so future weighted projection (subtract scaled - # by S_i/S_0) is possible without re-extracting. Loader filters _sv/ keys. - v_hack[f"_sv/{name}"] = S_d[:k].clone().contiguous() + # Record singular values so the load-time noise-floor filter has the + # extraction-time S_i per axis without re-extracting. Saved under + # `_sv/{name}` keys in the safetensors file (combined at save site). + v_sv[name] = S_d[:k].clone().contiguous() sv_top = S_d[:k] sv_total = S_d.sum().clamp_min(1e-12) rows.append({ @@ -191,7 +202,7 @@ def extract_v_hack( f"v_hack: modules={n_modules} k_max={k} zero-||D||={n_zero} " f"axes_kept_avg={n_axes_kept_total/max(1,n_modules):.1f} (tau_axis={tau_axis})" ) - return v_hack, raw_grads, rows + return v_hack, v_sv, raw_grads, rows def main(cfg: Config) -> int: @@ -217,19 +228,21 @@ def main(cfg: Config) -> int: train_pairs = PAIRS[:-cfg.n_heldout] logger.info(f"train pairs: {len(train_pairs)} held: {cfg.n_heldout}") - v_hack, raw_grads, rows = extract_v_hack( + v_hack, v_sv, raw_grads, rows = extract_v_hack( model, tokenizer, wrappers, PAIRS, top_k=cfg.top_k, tau_axis=cfg.tau_axis, n_heldout=cfg.n_heldout, device=device, ) - # Skip _sv/ keys when counting V tensors (singular values are saved alongside V). - n_zero = sum(1 for n, v in v_hack.items() if not n.startswith("_sv/") and v.norm() < 1e-12) + n_zero = sum(1 for v in v_hack.values() if v.norm() < 1e-12) k = min(cfg.top_k, len(train_pairs)) OUT_DIR.mkdir(exist_ok=True) save_file(raw_grads, str(cfg.train_grads_path), metadata={"model": cfg.model, "dtype": cfg.dtype}) - save_file(v_hack, str(cfg.out_path), + # v_hack file layout: bare `{name}` keys hold V[k, r]; `_sv/{name}` keys + # hold S[k]. Loader at train.py:load_v_hack splits them back apart. + save_payload = {**v_hack, **{f"_sv/{n}": s for n, s in v_sv.items()}} + save_file(save_payload, str(cfg.out_path), metadata={"model": cfg.model, "dtype": cfg.dtype, "top_k": str(k), "tau_axis": str(cfg.tau_axis), "schema": "v2_with_sv"}) diff --git a/src/projected_grpo/proj.py b/src/projected_grpo/proj.py index 72a3e93..628d76b 100644 --- a/src/projected_grpo/proj.py +++ b/src/projected_grpo/proj.py @@ -1,7 +1,17 @@ -"""Gradient projection + delta_S grad utilities. Imported by smoke and train.""" +"""Gradient projection + delta_S grad utilities. Imported by smoke and train. + +Shape conventions in the v_hack / delta_S plumbing (jaxtyping-annotated): +- `r` = per-module SVD rank (delta_S dimension; varies per Linear) +- `k` = number of v_hack directions kept per module (after top-k slice and + global noise-floor filter at load time) +- `V` is `[k, r]`, rows orthonormal in R^r and oriented hack-ward +- `g = delta_S.grad` is `[r]` +- `c = V @ g` is `[k]` +""" from __future__ import annotations import torch +from jaxtyping import Float def per_token_logps(logits: torch.Tensor, ids: torch.Tensor) -> torch.Tensor: @@ -22,8 +32,8 @@ def per_token_logps(logits: torch.Tensor, ids: torch.Tensor) -> torch.Tensor: @torch.no_grad() def mean_cin_from_grads( - grad_dict: dict[str, torch.Tensor], - v_hack: dict[str, torch.Tensor], + grad_dict: dict[str, Float[torch.Tensor, "r"]], + v_hack: dict[str, Float[torch.Tensor, "k r"]], ) -> float: """Mean over modules of ||V g|| / ||g||, given a dict of per-module grads. @@ -44,10 +54,46 @@ def mean_cin_from_grads( return float(sum(cs) / len(cs)) if cs else float("nan") +def _project_one_module( + g: Float[torch.Tensor, "r"], + V: Float[torch.Tensor, "k r"], + gate_mode: str, + preserve_magnitude: bool, +) -> tuple[Float[torch.Tensor, "r"], float, float, bool]: + """Per-module top-k removal. Returns (g_proj, cos_in, cos_out, fired). + + Inner helper so the shape contract (g:[r], V:[k,r]) is jaxtyping-checked + when BEARTYPE=1 — catches transposed V or wrong-rank g at the boundary + instead of producing silently wrong cosines. + """ + gn = g.norm() + if gn < 1e-12: + return g, 0.0, 0.0, False + c = V @ g # [k] + cin = (c.norm() / gn).item() + if gate_mode == "no_gate": + c_use = c + fired = True + elif gate_mode == "one_sided": + mask = (c > 0).to(c.dtype) + c_use = c * mask + fired = bool((c_use != 0).any()) + else: + raise ValueError(f"unknown gate_mode={gate_mode!r}") + if not fired: + return g, cin, cin, False + g_proj = g - c_use @ V # [r] + gp_n = g_proj.norm() + if preserve_magnitude and gp_n > 1e-12: + g_proj = g_proj * (gn / gp_n) + cout = ((V @ g_proj).norm() / g_proj.norm().clamp_min(1e-12)).item() + return g_proj, cin, cout, True + + @torch.no_grad() def project_delta_S_grad( wrappers: dict, - v_hack: dict[str, torch.Tensor], + v_hack: dict[str, Float[torch.Tensor, "k r"]], preserve_magnitude: bool, measure_only: bool = False, gate_mode: str = "one_sided", @@ -84,33 +130,13 @@ def project_delta_S_grad( if name not in v_hack: # module dropped by global noise-floor filter continue V = v_hack[name].to(g.device, dtype=g.dtype) # [k, r] - gn = g.norm() - if gn < 1e-12: - cos_in_list.append(0.0); cos_out_list.append(0.0); continue - c = V @ g # [k] - cin = c.norm() / gn - cos_in_list.append(cin.item()) - if gate_mode == "no_gate": - c_use = c - fired = True - elif gate_mode == "one_sided": - mask = (c > 0).to(c.dtype) - c_use = c * mask - fired = bool((c_use != 0).any()) - else: - raise ValueError(f"unknown gate_mode={gate_mode!r}") + g_proj, cin, cout, fired = _project_one_module(g, V, gate_mode, preserve_magnitude) + cos_in_list.append(cin) + cos_out_list.append(cout) if fired: - g_proj = g - c_use @ V # [r] - gp_n = g_proj.norm() - if preserve_magnitude and gp_n > 1e-12: - g_proj = g_proj * (gn / gp_n) - cout = (V @ g_proj).norm() / g_proj.norm().clamp_min(1e-12) - cos_out_list.append(cout.item()) if not measure_only: info["delta_S"].grad = g_proj n_fired += 1 - else: - cos_out_list.append(cin.item()) cin_t = torch.tensor(cos_in_list); cout_t = torch.tensor(cos_out_list) return { "mean_cos_in": cin_t.mean().item(), diff --git a/src/projected_grpo/train.py b/src/projected_grpo/train.py index fc2d22a..433a9d1 100644 --- a/src/projected_grpo/train.py +++ b/src/projected_grpo/train.py @@ -70,6 +70,7 @@ os.environ.setdefault("PYTORCH_CUDA_ALLOC_CONF", "expandable_segments:True") import torch import torch.nn.functional as F import tyro +from jaxtyping import Float from loguru import logger from safetensors import safe_open from safetensors.torch import save_file @@ -235,7 +236,7 @@ def load_problems(n: int) -> list[dict]: def load_v_hack( path: Path, model_name: str, wrappers: dict, k_use: int | None = None, drop_bottom_frac: float = 0.0, -) -> dict[str, torch.Tensor]: +) -> dict[str, Float[torch.Tensor, "k r"]]: """Load v_hack (top-k directions) for this wrapped model. File schema (v2): bare `{name}` keys hold V[k_max, r]; `_sv/{name}` keys hold @@ -417,13 +418,16 @@ def main(cfg: Config) -> int: from .pairs import PAIRS as VHACK_PAIRS logger.info(f"v_hack cache miss at {v_hack_path}; extracting (~5min)...") model.eval() # match standalone extract: deterministic backward, no dropout - v_hack_cpu_dict, raw_grads, _diag = extract_v_hack( + v_hack_extracted, v_sv_extracted, _raw_grads, _diag = extract_v_hack( model, tok, wrappers, VHACK_PAIRS, top_k=cfg.v_hack_extract_top_k, tau_axis=cfg.v_hack_tau_axis, n_heldout=2, device=device, ) OUT_DIR.mkdir(exist_ok=True) - save_file(v_hack_cpu_dict, str(v_hack_path), + # Combine V and S under one safetensors file with `_sv/{name}` prefix + # for the singular values. load_v_hack splits them back apart. + save_payload = {**v_hack_extracted, **{f"_sv/{n}": s for n, s in v_sv_extracted.items()}} + save_file(save_payload, str(v_hack_path), metadata={"model": model_name, "dtype": "bf16", "top_k": str(min(cfg.v_hack_extract_top_k, len(VHACK_PAIRS) - 2)), "tau_axis": str(cfg.v_hack_tau_axis), "schema": "v2_with_sv"}) @@ -598,23 +602,24 @@ def main(cfg: Config) -> int: if cfg.arm == "vanilla" else "cout=subspace energy fraction in grad after projection" ) - caption = ( - "table columns: " - "step=GRPO step; " - "ref_eq=vanilla-equivalent step (cum_gens / 256); " - "rew=mean combined reward; rew_s=student mean reward; " - "sprd=reward spread>0 (T/F; F means zero-variance bail fired and step was skipped); " - "N=rollouts; " - "gt_s/gt_t=ground-truth passes (student/teacher); " - "hack_s/hack_t=hack-flagged rollouts (student/teacher); " - "lp_s/lp_t=mean per-token student/teacher gen_logp under current student (diagnostic, no IS correction); " - "loss=mean GRPO loss; " - "cin=v_hack subspace energy fraction in grad before projection; " - "cin_s/cin_t=cin on student-only/teacher-only gradient; " - f"{cout_def}; " - "fired=fraction of modules where projection fired; " - "gen/fb/t_rew=generation/forward+backward/reward-grading wall-time (s); sec=total step wall-time (s)." - ) + caption = """ +table columns: + - step= GRPO step; + - ref_eq= vanilla-equivalent step (cum_gens / 256); + - rew= mean combined reward; rew_s=student mean reward; + - sprd= reward spread>0 (T/F; F means zero-variance bail fired and step was skipped); + - N= rollouts; + - gt_s/gt_t= ground-truth passes (student/teacher); + - hack_s/hack_t=hack-flagged rollouts (student/teacher); + - lp_s/lp_t= mean per-token student/teacher gen_logp under current student (diagnostic, no IS correction); + - loss= mean GRPO loss; + - cin= v_hack subspace energy fraction in grad before projection; + - cin_s/cin_t= cin on student-only/teacher-only gradient; + - "{cout_def}; + - fired=fraction of modules where projection fired; + - gen/fb/t_rew=generation/forward+backward/reward-grading wall-time (s); sec=total step wall-time (s) + +""" logger.info(caption + "\n\n" + _fmt_header()) OUT_DIR.mkdir(exist_ok=True)