cleanup: delete dead delta_S machinery (PiSSA->lora2r leftovers)

Off the live lora2r path; removed with vhack.py (commit 4120d75):
- proj.py: drop project_delta_S_grad/_project_one_module/mean_cos_pre_from_grads/
  _hackward_cos (no live importer; train.py uses only per_token_logps).
- verify_science_invariants: test pairset_sha256's content gate directly (drops the
  load_v_hack vehicle + fake delta_S wrapper fixture).
- extract_vhack_grad: import pairset_sha256 from .pairs (was re-exported via vhack).
- tablelog/figs: stale 'delta_S grads'/'knob' comments -> A/B grads.
Smoke + verify_science_invariants green; no delta_S left in live code.

Co-Authored-By: Claudypoo <288921227+claudypoo@users.noreply.github.com>
This commit is contained in:
wassname
2026-06-10 11:45:54 +00:00
parent 4120d75ea4
commit 7e11c024c4
5 changed files with 49 additions and 194 deletions
+39 -16
View File
@@ -6,14 +6,12 @@ import json
import tempfile import tempfile
from pathlib import Path from pathlib import Path
import torch
from loguru import logger from loguru import logger
from safetensors.torch import save_file
from tabulate import tabulate from tabulate import tabulate
from vgrout.data import DATA, RH_HINT_REPLACE_FROM, load_problems from vgrout.data import DATA, RH_HINT_REPLACE_FROM, load_problems
from vgrout.eval import load_eval_splits from vgrout.eval import load_eval_splits
from vgrout.vhack import load_v_hack, pairset_sha256 from vgrout.pairs import load_pairs, pairset_sha256
def _must_raise(fn) -> bool: def _must_raise(fn) -> bool:
@@ -29,20 +27,45 @@ def main() -> int:
with tempfile.TemporaryDirectory() as td: with tempfile.TemporaryDirectory() as td:
tmp = Path(td) tmp = Path(td)
pairs_path = tmp / "pairs.json" pairs_path = tmp / "pairs.md"
pairs_path.write_text('[{"prompt":"p","hack":"h","clean":"c"}]\n') pairs_path.write_text(
vhack_path = tmp / "vhack.safetensors" "## tiny\n\n### p\n\n#### Prompt\n`````text\np\n`````\n\n"
dtype = "bf16" if torch.cuda.is_available() else "fp32" "#### Hack\n`````text\nh\n`````\n\n#### Clean\n`````text\nc\n`````\n\n"
save_file( "## unrelated\n\n### q\n\n#### Prompt\n`````text\nq\n`````\n\n"
{"module": torch.tensor([[1.0, 0.0, 0.0]]), "_sv/module": torch.tensor([1.0])}, "#### Hack\n`````text\nx\n`````\n\n#### Clean\n`````text\ny\n`````\n"
str(vhack_path),
metadata={"model": "test", "dtype": dtype, "pairs_sha256": pairset_sha256(pairs_path)},
) )
wrappers = {"module": {"delta_S": torch.zeros(3)}} # Pairsets are content-addressed by the SELECTED section's bytes (pairset_sha256):
exact_load = bool(load_v_hack(vhack_path, "test", wrappers, pairs_path)) # an edit elsewhere in the file must not change the hash; an edit inside the
pairs_path.write_text(pairs_path.read_text() + " ") # selected section must. This is what gates a stale extracted direction.
changed_rejected = _must_raise(lambda: load_v_hack(vhack_path, "test", wrappers, pairs_path)) pairs_ref = Path(f"{pairs_path}#tiny")
rows.append({"invariant": "v_hack pair bytes", "success": exact_load and changed_rejected}) selected_hash = pairset_sha256(pairs_ref)
pairs_path.write_text(pairs_path.read_text().replace("\nx\n", "\nother changed\n"))
unrelated_ignored = pairset_sha256(pairs_ref) == selected_hash
pairs_path.write_text(pairs_path.read_text().replace("\nh\n", "\nchanged\n"))
selected_changed = pairset_sha256(pairs_ref) != selected_hash
missing_rejected = _must_raise(lambda: load_pairs(Path(f"{pairs_path}#missing")))
rows.append({
"invariant": "selected Markdown pair bytes",
"success": bool(selected_hash) and unrelated_ignored and selected_changed and missing_rejected,
})
malformed = tmp / "malformed.md"
malformed.write_text(
"## x\n\n### duplicate\n\n#### Prompt\n`````text\np\n`````\n\n"
"#### Prompt\n`````text\np2\n`````\n\n#### Hack\n`````text\nh\n`````\n\n"
"#### Clean\n`````text\nc\n`````\n"
)
rows.append({
"invariant": "malformed Markdown fails",
"success": _must_raise(lambda: load_pairs(Path(f"{malformed}#x"))),
})
real_pairsets_ok = (
len(load_pairs(Path("docs/personas/hack_pairs.md#mechanism-authored"))) == 11
and len(load_pairs(Path("docs/personas/pair_diagnostics.md#null-vampire"))) == 12
and len(load_pairs(Path("out/pairsets/prog_wide_clean.json"))) == 8
)
rows.append({"invariant": "authored/control/generated pairsets load", "success": real_pairsets_ok})
source = json.loads(DATA.read_text().splitlines()[0]) source = json.loads(DATA.read_text().splitlines()[0])
missing = json.loads(json.dumps(source)) missing = json.loads(json.dumps(source))
+4 -6
View File
@@ -43,8 +43,7 @@ from tabulate import tabulate
from transformers import AutoModelForCausalLM, AutoTokenizer from transformers import AutoModelForCausalLM, AutoTokenizer
from .lora2r import wrap_model_with_lora2r from .lora2r import wrap_model_with_lora2r
from .pairs_from_pool import load_pairs_json from .pairs import load_pairs, pairset_sha256
from .vhack import pairset_sha256
OUT_DIR = Path("out") OUT_DIR = Path("out")
@@ -65,8 +64,7 @@ class Config:
# magnitude on r=2560 modules, so this rarely changes effect size; it does # magnitude on r=2560 modules, so this rarely changes effect size; it does
# make k-ablations honest (axes 4-5 might be pure noise on N=12 pairs). # make k-ablations honest (axes 4-5 might be pure noise on N=12 pairs).
tau_axis: float = 0.0 tau_axis: float = 0.0
# Path to a JSON file with list[HackPair-as-dict]. Required; see # Pairset reference: generated JSON or one `path.md#section`.
# out/pairsets/pairs_authored.json or prog_wide.json.
pairs_from_pool: Path | None = None pairs_from_pool: Path | None = None
@@ -224,8 +222,8 @@ def main(cfg: Config) -> int:
device = torch.device("cuda" if torch.cuda.is_available() else "cpu") device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
dtype = resolve_dtype(cfg.dtype) dtype = resolve_dtype(cfg.dtype)
if cfg.pairs_from_pool is None: if cfg.pairs_from_pool is None:
raise ValueError("--pairs-from-pool is required; use out/pairsets/pairs_authored.json or prog_wide.json") raise ValueError("--pairs-from-pool is required; use docs/personas/hack_pairs.md#mechanism-authored")
pairs = load_pairs_json(cfg.pairs_from_pool) pairs = load_pairs(cfg.pairs_from_pool)
logger.info(f"pairs source: {cfg.pairs_from_pool} -> {len(pairs)} pairs") logger.info(f"pairs source: {cfg.pairs_from_pool} -> {len(pairs)} pairs")
logger.info( logger.info(
f"device={device} model={cfg.model} dtype={cfg.dtype} " f"device={device} model={cfg.model} dtype={cfg.dtype} "
+1 -1
View File
@@ -17,7 +17,7 @@ from pathlib import Path
FIGS_DIR = Path("docs/figs") FIGS_DIR = Path("docs/figs")
# Reader-facing arm names. Code/log tags carry our internal vocabulary # Reader-facing arm names. Code/log tags carry our internal vocabulary
# (routeV = the current routing arm; "knob" = the delta_S adapter); plots must # (routeV = the current routing arm); plots must
# not. Map every internal tag to the word a paper reader sees. Anything missing # not. Map every internal tag to the word a paper reader sees. Anything missing
# falls through to its raw tag, so a new arm shows up loud rather than silently # falls through to its raw tag, so a new arm shows up loud rather than silently
# mislabelled. # mislabelled.
+4 -170
View File
@@ -1,17 +1,12 @@
"""Gradient projection + delta_S grad utilities. Imported by smoke and train. """Per-token log-probs for the GRPO PPO ratio. Imported by train.py.
Shape conventions in the v_hack / delta_S plumbing (jaxtyping-annotated): (The old delta_S gradient-projection/erase machinery lived here too; it was
- `r` = per-module SVD rank (delta_S dimension; varies per Linear) removed with the PiSSA->lora2r migration -- routing is now block-mask based in
- `k` = number of v_hack directions kept per module (after top-k slice and train.py, nothing projects a delta_S grad anymore.)
global noise-floor filter at load time)
- `V` is `[k, r]`, rows orthonormal in R^r and oriented hack-ward
- `g = delta_S.grad` is `[r]`
- `c = V @ g` is `[k]`
""" """
from __future__ import annotations from __future__ import annotations
import torch import torch
from jaxtyping import Float
def per_token_logps(logits: torch.Tensor, ids: torch.Tensor) -> torch.Tensor: def per_token_logps(logits: torch.Tensor, ids: torch.Tensor) -> torch.Tensor:
@@ -28,164 +23,3 @@ def per_token_logps(logits: torch.Tensor, ids: torch.Tensor) -> torch.Tensor:
return -torch.nn.functional.cross_entropy( return -torch.nn.functional.cross_entropy(
logits.reshape(-1, V), ids.reshape(-1), reduction="none" logits.reshape(-1, V), ids.reshape(-1), reduction="none"
).float().view(B, L) ).float().view(B, L)
def _hackward_cos(c: Float[torch.Tensor, "k"], gn: torch.Tensor) -> float:
"""Fraction of the gradient's magnitude that points hack-ward into v_hack:
||relu(c)|| / ||g||, where c = V @ g and V rows are orthonormal and oriented
hack-ward (c_i > 0 means "grad pushes hack-ward on axis i"). In [0, 1].
relu BEFORE aggregating is the point: the one_sided projection removes only
relu(c) (the hack-ward axes), and with V orthonormal ||removed|| = ||relu(c)||,
so this reads directly as "fraction of the grad the projection strips" (a signed
sum would let +/- axes cancel and read ~0 even while routing a large hack-ward
magnitude).
After a one_sided erase, V @ g_proj = min(c, 0) (positive axes zeroed), so
relu of it is 0 -> cos_post == 0 exactly. That clean SHOULD (cos_post -> 0) is
the diagnostic; we drop the sign because a one_sided method never acts on the
safe-ward (negative) part anyway.
"""
return (torch.relu(c).norm() / gn).item()
@torch.no_grad()
def mean_cos_pre_from_grads(
grad_dict: dict[str, Float[torch.Tensor, "r"]],
v_hack: dict[str, Float[torch.Tensor, "k r"]],
) -> float:
"""Mean over modules of ||relu(V @ g)|| / ||g|| (hack-ward fraction, in [0,1]).
Used to compute per-source cos_pre (cos_pre_s for student-only grad,
cos_pre_t for teacher-only grad) without mutating model.grad or calling
the full projection pipeline.
"""
cs = []
for name, g in grad_dict.items():
if g is None or name not in v_hack:
continue
V = v_hack[name].to(g.device, dtype=g.dtype)
gn = g.norm()
if gn < 1e-12:
continue
cs.append(_hackward_cos(V @ g, gn))
return float(sum(cs) / len(cs)) if cs else float("nan")
def _project_one_module(
g: Float[torch.Tensor, "r"],
V: Float[torch.Tensor, "k r"],
gate_mode: str,
preserve_magnitude: bool,
overshoot: float = 1.0,
) -> tuple[Float[torch.Tensor, "r"], Float[torch.Tensor, "r"], float, float, bool]:
"""Per-module top-k removal. Returns (g_proj, removed, cos_pre, cos_post, fired).
`removed` = overshoot*c_use@V, the vector subtracted from g (computed
before any preserve_magnitude rescale, so removed ∈ span(V) always).
Erasure drops it; routing parks it in delta_S_hack. Note g_proj + removed
== g ONLY when preserve_magnitude is False and overshoot is 1.0; with the
defaults g_proj is rescaled, so the sum is not the original g (routing does
not rely on that sum -- see project_delta_S_grad).
cos_pre / cos_post are the hack-ward FRACTION ||relu(V @ g)|| / ||g|| (in
[0,1]; see _hackward_cos). cos_pre = how much of the grad points hack-ward
into v_hack; cos_post = the residual after projection. Under one_sided (and
no_gate, and overshoot>=1) projection cos_post -> 0 exactly: every hack-ward
axis was removed, so relu of the residual coefficients is 0.
`overshoot` scales the removed coefficient: g_proj = g - overshoot*c_use@V.
overshoot=1.0 just removes the hack-ward component; overshoot=1.1 removes
110% (a 10% reversal: V@g_proj = -0.1c on fired axes), a milder version of
gate_mode="reverse" (which is overshoot=2.0 on the full, ungated c).
"""
gn = g.norm()
if gn < 1e-12:
z = torch.zeros_like(g)
return g, z, 0.0, 0.0, False
c = V @ g # [k]
cos_pre = _hackward_cos(c, gn)
if gate_mode == "no_gate":
c_use = c
fired = True
elif gate_mode == "one_sided":
mask = (c > 0).to(c.dtype)
c_use = c * mask
fired = bool((c_use != 0).any())
elif gate_mode == "reverse":
# Subtract 2*c@V: V@g_proj = V@g - 2*(V V^T) c = c - 2c = -c.
# Flips the sign of the gradient component in span(V); pushes
# actively away from hack rather than just removing.
c_use = 2 * c
fired = True
else:
raise ValueError(f"unknown gate_mode={gate_mode!r}")
if not fired:
return g, torch.zeros_like(g), cos_pre, cos_pre, False
removed = overshoot * c_use @ V # [r]
g_proj = g - removed
gp_n = g_proj.norm()
if preserve_magnitude and gp_n > 1e-12:
g_proj = g_proj * (gn / gp_n)
cos_post = _hackward_cos(V @ g_proj, g_proj.norm().clamp_min(1e-12))
return g_proj, removed, cos_pre, cos_post, True
@torch.no_grad()
def project_delta_S_grad(
wrappers: dict,
v_hack: dict[str, Float[torch.Tensor, "k r"]],
preserve_magnitude: bool,
measure_only: bool = False,
gate_mode: str = "one_sided",
overshoot: float = 1.0,
) -> dict[str, float]:
"""Per-module top-k removal of hack-aligned grad components.
For each wrapped module:
g = delta_S.grad # [r]
V = v_hack[name] # [k, r], rows orthonormal, oriented hack-ward
c = V @ g # [k] per-direction coefficients
gate_mode="one_sided" (default):
mask = (c > 0) # only zap when grad is going hack-ward on that axis
g' = g - (c * mask) @ V # subtract only positive-coefficient components
gate_mode="no_gate":
g' = g - c @ V # full V·V^T removal, sign-agnostic;
# drives ||V g'|| -> 0 exactly. No trust in v_hack
# orientation: any motion in span(V) is suspect.
`preserve_magnitude`: rescale g' to ||g|| after projection.
`measure_only`: same math, but g is not mutated (the `none` intervention).
Diagnostics returned (per call, averaged over modules):
mean_cos_pre = mean over modules of ||relu(V @ g)||/||g|| (hack-ward fraction, [0,1])
mean_cos_post = same after projection (-> 0 when hack-ward axes were removed)
frac_fired = fraction of modules where at least one direction fired (c_i > 0)
"""
cos_pre_list, cos_post_list, n_fired = [], [], 0
for name, info in wrappers.items():
g = info["delta_S"].grad
if g is None:
continue
if name not in v_hack: # module dropped by global noise-floor filter
continue
V = v_hack[name].to(g.device, dtype=g.dtype) # [k, r]
g_proj, _, cos_pre, cos_post, fired = _project_one_module(
g, V, gate_mode, preserve_magnitude, overshoot)
cos_pre_list.append(cos_pre)
cos_post_list.append(cos_post)
if fired and not measure_only:
info["delta_S"].grad = g_proj
if fired:
n_fired += 1
pre_t = torch.tensor(cos_pre_list); post_t = torch.tensor(cos_post_list)
return {
"mean_cos_pre": pre_t.mean().item(),
"min_cos_pre": pre_t.min().item() if pre_t.numel() else float("nan"),
"max_cos_pre": pre_t.max().item() if pre_t.numel() else float("nan"),
"mean_cos_post": post_t.mean().item(),
"min_cos_post": post_t.min().item() if post_t.numel() else float("nan"),
"max_cos_post": post_t.max().item() if post_t.numel() else float("nan"),
"frac_fired": n_fired / len(cos_pre_list) if cos_pre_list else 0.0,
}
+1 -1
View File
@@ -96,7 +96,7 @@ class StepLogger:
_Col("lp_s", 6, "lp_s↓", "+.2f", "mean student gen_logp (diagnostic)"), _Col("lp_s", 6, "lp_s↓", "+.2f", "mean student gen_logp (diagnostic)"),
_Col("lp_t", 6, "lp_t↑", "+.2f", "mean teacher gen_logp; off-policy gap = lp_s-lp_t"), _Col("lp_t", 6, "lp_t↑", "+.2f", "mean teacher gen_logp; off-policy gap = lp_s-lp_t"),
_Col("loss", 7, "loss", "+.2f", "mean GRPO loss"), _Col("loss", 7, "loss", "+.2f", "mean GRPO loss"),
_Col("gn", 7, "gn", ".1e", "pre-clip L2 norm of delta_S grads (vs grad_clip)"), _Col("gn", 7, "gn", ".1e", "pre-clip L2 norm of A/B grads (vs grad_clip)"),
_Col("lr", 7, "lr", ".1e", "scheduled learning rate"), _Col("lr", 7, "lr", ".1e", "scheduled learning rate"),
] ]
# routeV reports unit and energy shares across the routing band. # routeV reports unit and energy shares across the routing band.