diff --git a/scripts/verify_science_invariants.py b/scripts/verify_science_invariants.py index e22ff27..cd12b91 100644 --- a/scripts/verify_science_invariants.py +++ b/scripts/verify_science_invariants.py @@ -6,14 +6,12 @@ import json import tempfile from pathlib import Path -import torch from loguru import logger -from safetensors.torch import save_file from tabulate import tabulate from vgrout.data import DATA, RH_HINT_REPLACE_FROM, load_problems from vgrout.eval import load_eval_splits -from vgrout.vhack import load_v_hack, pairset_sha256 +from vgrout.pairs import load_pairs, pairset_sha256 def _must_raise(fn) -> bool: @@ -29,20 +27,45 @@ def main() -> int: with tempfile.TemporaryDirectory() as td: tmp = Path(td) - pairs_path = tmp / "pairs.json" - pairs_path.write_text('[{"prompt":"p","hack":"h","clean":"c"}]\n') - vhack_path = tmp / "vhack.safetensors" - dtype = "bf16" if torch.cuda.is_available() else "fp32" - save_file( - {"module": torch.tensor([[1.0, 0.0, 0.0]]), "_sv/module": torch.tensor([1.0])}, - str(vhack_path), - metadata={"model": "test", "dtype": dtype, "pairs_sha256": pairset_sha256(pairs_path)}, + pairs_path = tmp / "pairs.md" + pairs_path.write_text( + "## tiny\n\n### p\n\n#### Prompt\n`````text\np\n`````\n\n" + "#### Hack\n`````text\nh\n`````\n\n#### Clean\n`````text\nc\n`````\n\n" + "## unrelated\n\n### q\n\n#### Prompt\n`````text\nq\n`````\n\n" + "#### Hack\n`````text\nx\n`````\n\n#### Clean\n`````text\ny\n`````\n" ) - wrappers = {"module": {"delta_S": torch.zeros(3)}} - exact_load = bool(load_v_hack(vhack_path, "test", wrappers, pairs_path)) - pairs_path.write_text(pairs_path.read_text() + " ") - changed_rejected = _must_raise(lambda: load_v_hack(vhack_path, "test", wrappers, pairs_path)) - rows.append({"invariant": "v_hack pair bytes", "success": exact_load and changed_rejected}) + # Pairsets are content-addressed by the SELECTED section's bytes (pairset_sha256): + # an edit elsewhere in the file must not change the hash; an edit inside the + # selected section must. This is what gates a stale extracted direction. + pairs_ref = Path(f"{pairs_path}#tiny") + selected_hash = pairset_sha256(pairs_ref) + pairs_path.write_text(pairs_path.read_text().replace("\nx\n", "\nother changed\n")) + unrelated_ignored = pairset_sha256(pairs_ref) == selected_hash + pairs_path.write_text(pairs_path.read_text().replace("\nh\n", "\nchanged\n")) + selected_changed = pairset_sha256(pairs_ref) != selected_hash + missing_rejected = _must_raise(lambda: load_pairs(Path(f"{pairs_path}#missing"))) + rows.append({ + "invariant": "selected Markdown pair bytes", + "success": bool(selected_hash) and unrelated_ignored and selected_changed and missing_rejected, + }) + + malformed = tmp / "malformed.md" + malformed.write_text( + "## x\n\n### duplicate\n\n#### Prompt\n`````text\np\n`````\n\n" + "#### Prompt\n`````text\np2\n`````\n\n#### Hack\n`````text\nh\n`````\n\n" + "#### Clean\n`````text\nc\n`````\n" + ) + rows.append({ + "invariant": "malformed Markdown fails", + "success": _must_raise(lambda: load_pairs(Path(f"{malformed}#x"))), + }) + + real_pairsets_ok = ( + len(load_pairs(Path("docs/personas/hack_pairs.md#mechanism-authored"))) == 11 + and len(load_pairs(Path("docs/personas/pair_diagnostics.md#null-vampire"))) == 12 + and len(load_pairs(Path("out/pairsets/prog_wide_clean.json"))) == 8 + ) + rows.append({"invariant": "authored/control/generated pairsets load", "success": real_pairsets_ok}) source = json.loads(DATA.read_text().splitlines()[0]) missing = json.loads(json.dumps(source)) diff --git a/src/vgrout/extract_vhack_grad.py b/src/vgrout/extract_vhack_grad.py index cf15404..a7a6529 100644 --- a/src/vgrout/extract_vhack_grad.py +++ b/src/vgrout/extract_vhack_grad.py @@ -43,8 +43,7 @@ from tabulate import tabulate from transformers import AutoModelForCausalLM, AutoTokenizer from .lora2r import wrap_model_with_lora2r -from .pairs_from_pool import load_pairs_json -from .vhack import pairset_sha256 +from .pairs import load_pairs, pairset_sha256 OUT_DIR = Path("out") @@ -65,8 +64,7 @@ class Config: # magnitude on r=2560 modules, so this rarely changes effect size; it does # make k-ablations honest (axes 4-5 might be pure noise on N=12 pairs). tau_axis: float = 0.0 - # Path to a JSON file with list[HackPair-as-dict]. Required; see - # out/pairsets/pairs_authored.json or prog_wide.json. + # Pairset reference: generated JSON or one `path.md#section`. pairs_from_pool: Path | None = None @@ -224,8 +222,8 @@ def main(cfg: Config) -> int: device = torch.device("cuda" if torch.cuda.is_available() else "cpu") dtype = resolve_dtype(cfg.dtype) if cfg.pairs_from_pool is None: - raise ValueError("--pairs-from-pool is required; use out/pairsets/pairs_authored.json or prog_wide.json") - pairs = load_pairs_json(cfg.pairs_from_pool) + raise ValueError("--pairs-from-pool is required; use docs/personas/hack_pairs.md#mechanism-authored") + pairs = load_pairs(cfg.pairs_from_pool) logger.info(f"pairs source: {cfg.pairs_from_pool} -> {len(pairs)} pairs") logger.info( f"device={device} model={cfg.model} dtype={cfg.dtype} " diff --git a/src/vgrout/figs.py b/src/vgrout/figs.py index 048efe4..3aa9b6e 100644 --- a/src/vgrout/figs.py +++ b/src/vgrout/figs.py @@ -17,7 +17,7 @@ from pathlib import Path FIGS_DIR = Path("docs/figs") # Reader-facing arm names. Code/log tags carry our internal vocabulary -# (routeV = the current routing arm; "knob" = the delta_S adapter); plots must +# (routeV = the current routing arm); plots must # not. Map every internal tag to the word a paper reader sees. Anything missing # falls through to its raw tag, so a new arm shows up loud rather than silently # mislabelled. diff --git a/src/vgrout/proj.py b/src/vgrout/proj.py index 0dde3cb..17d6ec5 100644 --- a/src/vgrout/proj.py +++ b/src/vgrout/proj.py @@ -1,17 +1,12 @@ -"""Gradient projection + delta_S grad utilities. Imported by smoke and train. +"""Per-token log-probs for the GRPO PPO ratio. Imported by train.py. -Shape conventions in the v_hack / delta_S plumbing (jaxtyping-annotated): -- `r` = per-module SVD rank (delta_S dimension; varies per Linear) -- `k` = number of v_hack directions kept per module (after top-k slice and - global noise-floor filter at load time) -- `V` is `[k, r]`, rows orthonormal in R^r and oriented hack-ward -- `g = delta_S.grad` is `[r]` -- `c = V @ g` is `[k]` +(The old delta_S gradient-projection/erase machinery lived here too; it was +removed with the PiSSA->lora2r migration -- routing is now block-mask based in +train.py, nothing projects a delta_S grad anymore.) """ from __future__ import annotations import torch -from jaxtyping import Float def per_token_logps(logits: torch.Tensor, ids: torch.Tensor) -> torch.Tensor: @@ -28,164 +23,3 @@ def per_token_logps(logits: torch.Tensor, ids: torch.Tensor) -> torch.Tensor: return -torch.nn.functional.cross_entropy( logits.reshape(-1, V), ids.reshape(-1), reduction="none" ).float().view(B, L) - - -def _hackward_cos(c: Float[torch.Tensor, "k"], gn: torch.Tensor) -> float: - """Fraction of the gradient's magnitude that points hack-ward into v_hack: - ||relu(c)|| / ||g||, where c = V @ g and V rows are orthonormal and oriented - hack-ward (c_i > 0 means "grad pushes hack-ward on axis i"). In [0, 1]. - - relu BEFORE aggregating is the point: the one_sided projection removes only - relu(c) (the hack-ward axes), and with V orthonormal ||removed|| = ||relu(c)||, - so this reads directly as "fraction of the grad the projection strips" (a signed - sum would let +/- axes cancel and read ~0 even while routing a large hack-ward - magnitude). - - After a one_sided erase, V @ g_proj = min(c, 0) (positive axes zeroed), so - relu of it is 0 -> cos_post == 0 exactly. That clean SHOULD (cos_post -> 0) is - the diagnostic; we drop the sign because a one_sided method never acts on the - safe-ward (negative) part anyway. - """ - return (torch.relu(c).norm() / gn).item() - - -@torch.no_grad() -def mean_cos_pre_from_grads( - grad_dict: dict[str, Float[torch.Tensor, "r"]], - v_hack: dict[str, Float[torch.Tensor, "k r"]], -) -> float: - """Mean over modules of ||relu(V @ g)|| / ||g|| (hack-ward fraction, in [0,1]). - - Used to compute per-source cos_pre (cos_pre_s for student-only grad, - cos_pre_t for teacher-only grad) without mutating model.grad or calling - the full projection pipeline. - """ - cs = [] - for name, g in grad_dict.items(): - if g is None or name not in v_hack: - continue - V = v_hack[name].to(g.device, dtype=g.dtype) - gn = g.norm() - if gn < 1e-12: - continue - cs.append(_hackward_cos(V @ g, gn)) - return float(sum(cs) / len(cs)) if cs else float("nan") - - -def _project_one_module( - g: Float[torch.Tensor, "r"], - V: Float[torch.Tensor, "k r"], - gate_mode: str, - preserve_magnitude: bool, - overshoot: float = 1.0, -) -> tuple[Float[torch.Tensor, "r"], Float[torch.Tensor, "r"], float, float, bool]: - """Per-module top-k removal. Returns (g_proj, removed, cos_pre, cos_post, fired). - - `removed` = overshoot*c_use@V, the vector subtracted from g (computed - before any preserve_magnitude rescale, so removed ∈ span(V) always). - Erasure drops it; routing parks it in delta_S_hack. Note g_proj + removed - == g ONLY when preserve_magnitude is False and overshoot is 1.0; with the - defaults g_proj is rescaled, so the sum is not the original g (routing does - not rely on that sum -- see project_delta_S_grad). - - cos_pre / cos_post are the hack-ward FRACTION ||relu(V @ g)|| / ||g|| (in - [0,1]; see _hackward_cos). cos_pre = how much of the grad points hack-ward - into v_hack; cos_post = the residual after projection. Under one_sided (and - no_gate, and overshoot>=1) projection cos_post -> 0 exactly: every hack-ward - axis was removed, so relu of the residual coefficients is 0. - - `overshoot` scales the removed coefficient: g_proj = g - overshoot*c_use@V. - overshoot=1.0 just removes the hack-ward component; overshoot=1.1 removes - 110% (a 10% reversal: V@g_proj = -0.1c on fired axes), a milder version of - gate_mode="reverse" (which is overshoot=2.0 on the full, ungated c). - """ - gn = g.norm() - if gn < 1e-12: - z = torch.zeros_like(g) - return g, z, 0.0, 0.0, False - c = V @ g # [k] - cos_pre = _hackward_cos(c, gn) - if gate_mode == "no_gate": - c_use = c - fired = True - elif gate_mode == "one_sided": - mask = (c > 0).to(c.dtype) - c_use = c * mask - fired = bool((c_use != 0).any()) - elif gate_mode == "reverse": - # Subtract 2*c@V: V@g_proj = V@g - 2*(V V^T) c = c - 2c = -c. - # Flips the sign of the gradient component in span(V); pushes - # actively away from hack rather than just removing. - c_use = 2 * c - fired = True - else: - raise ValueError(f"unknown gate_mode={gate_mode!r}") - if not fired: - return g, torch.zeros_like(g), cos_pre, cos_pre, False - removed = overshoot * c_use @ V # [r] - g_proj = g - removed - gp_n = g_proj.norm() - if preserve_magnitude and gp_n > 1e-12: - g_proj = g_proj * (gn / gp_n) - cos_post = _hackward_cos(V @ g_proj, g_proj.norm().clamp_min(1e-12)) - return g_proj, removed, cos_pre, cos_post, True - - -@torch.no_grad() -def project_delta_S_grad( - wrappers: dict, - v_hack: dict[str, Float[torch.Tensor, "k r"]], - preserve_magnitude: bool, - measure_only: bool = False, - gate_mode: str = "one_sided", - overshoot: float = 1.0, -) -> dict[str, float]: - """Per-module top-k removal of hack-aligned grad components. - - For each wrapped module: - g = delta_S.grad # [r] - V = v_hack[name] # [k, r], rows orthonormal, oriented hack-ward - c = V @ g # [k] per-direction coefficients - - gate_mode="one_sided" (default): - mask = (c > 0) # only zap when grad is going hack-ward on that axis - g' = g - (c * mask) @ V # subtract only positive-coefficient components - - gate_mode="no_gate": - g' = g - c @ V # full V·V^T removal, sign-agnostic; - # drives ||V g'|| -> 0 exactly. No trust in v_hack - # orientation: any motion in span(V) is suspect. - - `preserve_magnitude`: rescale g' to ||g|| after projection. - `measure_only`: same math, but g is not mutated (the `none` intervention). - Diagnostics returned (per call, averaged over modules): - mean_cos_pre = mean over modules of ||relu(V @ g)||/||g|| (hack-ward fraction, [0,1]) - mean_cos_post = same after projection (-> 0 when hack-ward axes were removed) - frac_fired = fraction of modules where at least one direction fired (c_i > 0) - """ - cos_pre_list, cos_post_list, n_fired = [], [], 0 - for name, info in wrappers.items(): - g = info["delta_S"].grad - if g is None: - continue - if name not in v_hack: # module dropped by global noise-floor filter - continue - V = v_hack[name].to(g.device, dtype=g.dtype) # [k, r] - g_proj, _, cos_pre, cos_post, fired = _project_one_module( - g, V, gate_mode, preserve_magnitude, overshoot) - cos_pre_list.append(cos_pre) - cos_post_list.append(cos_post) - if fired and not measure_only: - info["delta_S"].grad = g_proj - if fired: - n_fired += 1 - pre_t = torch.tensor(cos_pre_list); post_t = torch.tensor(cos_post_list) - return { - "mean_cos_pre": pre_t.mean().item(), - "min_cos_pre": pre_t.min().item() if pre_t.numel() else float("nan"), - "max_cos_pre": pre_t.max().item() if pre_t.numel() else float("nan"), - "mean_cos_post": post_t.mean().item(), - "min_cos_post": post_t.min().item() if post_t.numel() else float("nan"), - "max_cos_post": post_t.max().item() if post_t.numel() else float("nan"), - "frac_fired": n_fired / len(cos_pre_list) if cos_pre_list else 0.0, - } diff --git a/src/vgrout/tablelog.py b/src/vgrout/tablelog.py index d29dede..ff9c165 100644 --- a/src/vgrout/tablelog.py +++ b/src/vgrout/tablelog.py @@ -96,7 +96,7 @@ class StepLogger: _Col("lp_s", 6, "lp_s↓", "+.2f", "mean student gen_logp (diagnostic)"), _Col("lp_t", 6, "lp_t↑", "+.2f", "mean teacher gen_logp; off-policy gap = lp_s-lp_t"), _Col("loss", 7, "loss", "+.2f", "mean GRPO loss"), - _Col("gn", 7, "gn", ".1e", "pre-clip L2 norm of delta_S grads (vs grad_clip)"), + _Col("gn", 7, "gn", ".1e", "pre-clip L2 norm of A/B grads (vs grad_clip)"), _Col("lr", 7, "lr", ".1e", "scheduled learning rate"), ] # routeV reports unit and energy shares across the routing band.