mirror of
https://github.com/wassname/evil_MoE.git
synced 2026-06-27 19:15:20 +08:00
cleanup: delete dead delta_S machinery (PiSSA->lora2r leftovers)
Off the live lora2r path; removed with vhack.py (commit 4120d75):
- proj.py: drop project_delta_S_grad/_project_one_module/mean_cos_pre_from_grads/
_hackward_cos (no live importer; train.py uses only per_token_logps).
- verify_science_invariants: test pairset_sha256's content gate directly (drops the
load_v_hack vehicle + fake delta_S wrapper fixture).
- extract_vhack_grad: import pairset_sha256 from .pairs (was re-exported via vhack).
- tablelog/figs: stale 'delta_S grads'/'knob' comments -> A/B grads.
Smoke + verify_science_invariants green; no delta_S left in live code.
Co-Authored-By: Claudypoo <288921227+claudypoo@users.noreply.github.com>
This commit is contained in:
@@ -6,14 +6,12 @@ import json
|
||||
import tempfile
|
||||
from pathlib import Path
|
||||
|
||||
import torch
|
||||
from loguru import logger
|
||||
from safetensors.torch import save_file
|
||||
from tabulate import tabulate
|
||||
|
||||
from vgrout.data import DATA, RH_HINT_REPLACE_FROM, load_problems
|
||||
from vgrout.eval import load_eval_splits
|
||||
from vgrout.vhack import load_v_hack, pairset_sha256
|
||||
from vgrout.pairs import load_pairs, pairset_sha256
|
||||
|
||||
|
||||
def _must_raise(fn) -> bool:
|
||||
@@ -29,20 +27,45 @@ def main() -> int:
|
||||
with tempfile.TemporaryDirectory() as td:
|
||||
tmp = Path(td)
|
||||
|
||||
pairs_path = tmp / "pairs.json"
|
||||
pairs_path.write_text('[{"prompt":"p","hack":"h","clean":"c"}]\n')
|
||||
vhack_path = tmp / "vhack.safetensors"
|
||||
dtype = "bf16" if torch.cuda.is_available() else "fp32"
|
||||
save_file(
|
||||
{"module": torch.tensor([[1.0, 0.0, 0.0]]), "_sv/module": torch.tensor([1.0])},
|
||||
str(vhack_path),
|
||||
metadata={"model": "test", "dtype": dtype, "pairs_sha256": pairset_sha256(pairs_path)},
|
||||
pairs_path = tmp / "pairs.md"
|
||||
pairs_path.write_text(
|
||||
"## tiny\n\n### p\n\n#### Prompt\n`````text\np\n`````\n\n"
|
||||
"#### Hack\n`````text\nh\n`````\n\n#### Clean\n`````text\nc\n`````\n\n"
|
||||
"## unrelated\n\n### q\n\n#### Prompt\n`````text\nq\n`````\n\n"
|
||||
"#### Hack\n`````text\nx\n`````\n\n#### Clean\n`````text\ny\n`````\n"
|
||||
)
|
||||
wrappers = {"module": {"delta_S": torch.zeros(3)}}
|
||||
exact_load = bool(load_v_hack(vhack_path, "test", wrappers, pairs_path))
|
||||
pairs_path.write_text(pairs_path.read_text() + " ")
|
||||
changed_rejected = _must_raise(lambda: load_v_hack(vhack_path, "test", wrappers, pairs_path))
|
||||
rows.append({"invariant": "v_hack pair bytes", "success": exact_load and changed_rejected})
|
||||
# Pairsets are content-addressed by the SELECTED section's bytes (pairset_sha256):
|
||||
# an edit elsewhere in the file must not change the hash; an edit inside the
|
||||
# selected section must. This is what gates a stale extracted direction.
|
||||
pairs_ref = Path(f"{pairs_path}#tiny")
|
||||
selected_hash = pairset_sha256(pairs_ref)
|
||||
pairs_path.write_text(pairs_path.read_text().replace("\nx\n", "\nother changed\n"))
|
||||
unrelated_ignored = pairset_sha256(pairs_ref) == selected_hash
|
||||
pairs_path.write_text(pairs_path.read_text().replace("\nh\n", "\nchanged\n"))
|
||||
selected_changed = pairset_sha256(pairs_ref) != selected_hash
|
||||
missing_rejected = _must_raise(lambda: load_pairs(Path(f"{pairs_path}#missing")))
|
||||
rows.append({
|
||||
"invariant": "selected Markdown pair bytes",
|
||||
"success": bool(selected_hash) and unrelated_ignored and selected_changed and missing_rejected,
|
||||
})
|
||||
|
||||
malformed = tmp / "malformed.md"
|
||||
malformed.write_text(
|
||||
"## x\n\n### duplicate\n\n#### Prompt\n`````text\np\n`````\n\n"
|
||||
"#### Prompt\n`````text\np2\n`````\n\n#### Hack\n`````text\nh\n`````\n\n"
|
||||
"#### Clean\n`````text\nc\n`````\n"
|
||||
)
|
||||
rows.append({
|
||||
"invariant": "malformed Markdown fails",
|
||||
"success": _must_raise(lambda: load_pairs(Path(f"{malformed}#x"))),
|
||||
})
|
||||
|
||||
real_pairsets_ok = (
|
||||
len(load_pairs(Path("docs/personas/hack_pairs.md#mechanism-authored"))) == 11
|
||||
and len(load_pairs(Path("docs/personas/pair_diagnostics.md#null-vampire"))) == 12
|
||||
and len(load_pairs(Path("out/pairsets/prog_wide_clean.json"))) == 8
|
||||
)
|
||||
rows.append({"invariant": "authored/control/generated pairsets load", "success": real_pairsets_ok})
|
||||
|
||||
source = json.loads(DATA.read_text().splitlines()[0])
|
||||
missing = json.loads(json.dumps(source))
|
||||
|
||||
@@ -43,8 +43,7 @@ from tabulate import tabulate
|
||||
from transformers import AutoModelForCausalLM, AutoTokenizer
|
||||
|
||||
from .lora2r import wrap_model_with_lora2r
|
||||
from .pairs_from_pool import load_pairs_json
|
||||
from .vhack import pairset_sha256
|
||||
from .pairs import load_pairs, pairset_sha256
|
||||
|
||||
|
||||
OUT_DIR = Path("out")
|
||||
@@ -65,8 +64,7 @@ class Config:
|
||||
# magnitude on r=2560 modules, so this rarely changes effect size; it does
|
||||
# make k-ablations honest (axes 4-5 might be pure noise on N=12 pairs).
|
||||
tau_axis: float = 0.0
|
||||
# Path to a JSON file with list[HackPair-as-dict]. Required; see
|
||||
# out/pairsets/pairs_authored.json or prog_wide.json.
|
||||
# Pairset reference: generated JSON or one `path.md#section`.
|
||||
pairs_from_pool: Path | None = None
|
||||
|
||||
|
||||
@@ -224,8 +222,8 @@ def main(cfg: Config) -> int:
|
||||
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
|
||||
dtype = resolve_dtype(cfg.dtype)
|
||||
if cfg.pairs_from_pool is None:
|
||||
raise ValueError("--pairs-from-pool is required; use out/pairsets/pairs_authored.json or prog_wide.json")
|
||||
pairs = load_pairs_json(cfg.pairs_from_pool)
|
||||
raise ValueError("--pairs-from-pool is required; use docs/personas/hack_pairs.md#mechanism-authored")
|
||||
pairs = load_pairs(cfg.pairs_from_pool)
|
||||
logger.info(f"pairs source: {cfg.pairs_from_pool} -> {len(pairs)} pairs")
|
||||
logger.info(
|
||||
f"device={device} model={cfg.model} dtype={cfg.dtype} "
|
||||
|
||||
+1
-1
@@ -17,7 +17,7 @@ from pathlib import Path
|
||||
FIGS_DIR = Path("docs/figs")
|
||||
|
||||
# Reader-facing arm names. Code/log tags carry our internal vocabulary
|
||||
# (routeV = the current routing arm; "knob" = the delta_S adapter); plots must
|
||||
# (routeV = the current routing arm); plots must
|
||||
# not. Map every internal tag to the word a paper reader sees. Anything missing
|
||||
# falls through to its raw tag, so a new arm shows up loud rather than silently
|
||||
# mislabelled.
|
||||
|
||||
+4
-170
@@ -1,17 +1,12 @@
|
||||
"""Gradient projection + delta_S grad utilities. Imported by smoke and train.
|
||||
"""Per-token log-probs for the GRPO PPO ratio. Imported by train.py.
|
||||
|
||||
Shape conventions in the v_hack / delta_S plumbing (jaxtyping-annotated):
|
||||
- `r` = per-module SVD rank (delta_S dimension; varies per Linear)
|
||||
- `k` = number of v_hack directions kept per module (after top-k slice and
|
||||
global noise-floor filter at load time)
|
||||
- `V` is `[k, r]`, rows orthonormal in R^r and oriented hack-ward
|
||||
- `g = delta_S.grad` is `[r]`
|
||||
- `c = V @ g` is `[k]`
|
||||
(The old delta_S gradient-projection/erase machinery lived here too; it was
|
||||
removed with the PiSSA->lora2r migration -- routing is now block-mask based in
|
||||
train.py, nothing projects a delta_S grad anymore.)
|
||||
"""
|
||||
from __future__ import annotations
|
||||
|
||||
import torch
|
||||
from jaxtyping import Float
|
||||
|
||||
|
||||
def per_token_logps(logits: torch.Tensor, ids: torch.Tensor) -> torch.Tensor:
|
||||
@@ -28,164 +23,3 @@ def per_token_logps(logits: torch.Tensor, ids: torch.Tensor) -> torch.Tensor:
|
||||
return -torch.nn.functional.cross_entropy(
|
||||
logits.reshape(-1, V), ids.reshape(-1), reduction="none"
|
||||
).float().view(B, L)
|
||||
|
||||
|
||||
def _hackward_cos(c: Float[torch.Tensor, "k"], gn: torch.Tensor) -> float:
|
||||
"""Fraction of the gradient's magnitude that points hack-ward into v_hack:
|
||||
||relu(c)|| / ||g||, where c = V @ g and V rows are orthonormal and oriented
|
||||
hack-ward (c_i > 0 means "grad pushes hack-ward on axis i"). In [0, 1].
|
||||
|
||||
relu BEFORE aggregating is the point: the one_sided projection removes only
|
||||
relu(c) (the hack-ward axes), and with V orthonormal ||removed|| = ||relu(c)||,
|
||||
so this reads directly as "fraction of the grad the projection strips" (a signed
|
||||
sum would let +/- axes cancel and read ~0 even while routing a large hack-ward
|
||||
magnitude).
|
||||
|
||||
After a one_sided erase, V @ g_proj = min(c, 0) (positive axes zeroed), so
|
||||
relu of it is 0 -> cos_post == 0 exactly. That clean SHOULD (cos_post -> 0) is
|
||||
the diagnostic; we drop the sign because a one_sided method never acts on the
|
||||
safe-ward (negative) part anyway.
|
||||
"""
|
||||
return (torch.relu(c).norm() / gn).item()
|
||||
|
||||
|
||||
@torch.no_grad()
|
||||
def mean_cos_pre_from_grads(
|
||||
grad_dict: dict[str, Float[torch.Tensor, "r"]],
|
||||
v_hack: dict[str, Float[torch.Tensor, "k r"]],
|
||||
) -> float:
|
||||
"""Mean over modules of ||relu(V @ g)|| / ||g|| (hack-ward fraction, in [0,1]).
|
||||
|
||||
Used to compute per-source cos_pre (cos_pre_s for student-only grad,
|
||||
cos_pre_t for teacher-only grad) without mutating model.grad or calling
|
||||
the full projection pipeline.
|
||||
"""
|
||||
cs = []
|
||||
for name, g in grad_dict.items():
|
||||
if g is None or name not in v_hack:
|
||||
continue
|
||||
V = v_hack[name].to(g.device, dtype=g.dtype)
|
||||
gn = g.norm()
|
||||
if gn < 1e-12:
|
||||
continue
|
||||
cs.append(_hackward_cos(V @ g, gn))
|
||||
return float(sum(cs) / len(cs)) if cs else float("nan")
|
||||
|
||||
|
||||
def _project_one_module(
|
||||
g: Float[torch.Tensor, "r"],
|
||||
V: Float[torch.Tensor, "k r"],
|
||||
gate_mode: str,
|
||||
preserve_magnitude: bool,
|
||||
overshoot: float = 1.0,
|
||||
) -> tuple[Float[torch.Tensor, "r"], Float[torch.Tensor, "r"], float, float, bool]:
|
||||
"""Per-module top-k removal. Returns (g_proj, removed, cos_pre, cos_post, fired).
|
||||
|
||||
`removed` = overshoot*c_use@V, the vector subtracted from g (computed
|
||||
before any preserve_magnitude rescale, so removed ∈ span(V) always).
|
||||
Erasure drops it; routing parks it in delta_S_hack. Note g_proj + removed
|
||||
== g ONLY when preserve_magnitude is False and overshoot is 1.0; with the
|
||||
defaults g_proj is rescaled, so the sum is not the original g (routing does
|
||||
not rely on that sum -- see project_delta_S_grad).
|
||||
|
||||
cos_pre / cos_post are the hack-ward FRACTION ||relu(V @ g)|| / ||g|| (in
|
||||
[0,1]; see _hackward_cos). cos_pre = how much of the grad points hack-ward
|
||||
into v_hack; cos_post = the residual after projection. Under one_sided (and
|
||||
no_gate, and overshoot>=1) projection cos_post -> 0 exactly: every hack-ward
|
||||
axis was removed, so relu of the residual coefficients is 0.
|
||||
|
||||
`overshoot` scales the removed coefficient: g_proj = g - overshoot*c_use@V.
|
||||
overshoot=1.0 just removes the hack-ward component; overshoot=1.1 removes
|
||||
110% (a 10% reversal: V@g_proj = -0.1c on fired axes), a milder version of
|
||||
gate_mode="reverse" (which is overshoot=2.0 on the full, ungated c).
|
||||
"""
|
||||
gn = g.norm()
|
||||
if gn < 1e-12:
|
||||
z = torch.zeros_like(g)
|
||||
return g, z, 0.0, 0.0, False
|
||||
c = V @ g # [k]
|
||||
cos_pre = _hackward_cos(c, gn)
|
||||
if gate_mode == "no_gate":
|
||||
c_use = c
|
||||
fired = True
|
||||
elif gate_mode == "one_sided":
|
||||
mask = (c > 0).to(c.dtype)
|
||||
c_use = c * mask
|
||||
fired = bool((c_use != 0).any())
|
||||
elif gate_mode == "reverse":
|
||||
# Subtract 2*c@V: V@g_proj = V@g - 2*(V V^T) c = c - 2c = -c.
|
||||
# Flips the sign of the gradient component in span(V); pushes
|
||||
# actively away from hack rather than just removing.
|
||||
c_use = 2 * c
|
||||
fired = True
|
||||
else:
|
||||
raise ValueError(f"unknown gate_mode={gate_mode!r}")
|
||||
if not fired:
|
||||
return g, torch.zeros_like(g), cos_pre, cos_pre, False
|
||||
removed = overshoot * c_use @ V # [r]
|
||||
g_proj = g - removed
|
||||
gp_n = g_proj.norm()
|
||||
if preserve_magnitude and gp_n > 1e-12:
|
||||
g_proj = g_proj * (gn / gp_n)
|
||||
cos_post = _hackward_cos(V @ g_proj, g_proj.norm().clamp_min(1e-12))
|
||||
return g_proj, removed, cos_pre, cos_post, True
|
||||
|
||||
|
||||
@torch.no_grad()
|
||||
def project_delta_S_grad(
|
||||
wrappers: dict,
|
||||
v_hack: dict[str, Float[torch.Tensor, "k r"]],
|
||||
preserve_magnitude: bool,
|
||||
measure_only: bool = False,
|
||||
gate_mode: str = "one_sided",
|
||||
overshoot: float = 1.0,
|
||||
) -> dict[str, float]:
|
||||
"""Per-module top-k removal of hack-aligned grad components.
|
||||
|
||||
For each wrapped module:
|
||||
g = delta_S.grad # [r]
|
||||
V = v_hack[name] # [k, r], rows orthonormal, oriented hack-ward
|
||||
c = V @ g # [k] per-direction coefficients
|
||||
|
||||
gate_mode="one_sided" (default):
|
||||
mask = (c > 0) # only zap when grad is going hack-ward on that axis
|
||||
g' = g - (c * mask) @ V # subtract only positive-coefficient components
|
||||
|
||||
gate_mode="no_gate":
|
||||
g' = g - c @ V # full V·V^T removal, sign-agnostic;
|
||||
# drives ||V g'|| -> 0 exactly. No trust in v_hack
|
||||
# orientation: any motion in span(V) is suspect.
|
||||
|
||||
`preserve_magnitude`: rescale g' to ||g|| after projection.
|
||||
`measure_only`: same math, but g is not mutated (the `none` intervention).
|
||||
Diagnostics returned (per call, averaged over modules):
|
||||
mean_cos_pre = mean over modules of ||relu(V @ g)||/||g|| (hack-ward fraction, [0,1])
|
||||
mean_cos_post = same after projection (-> 0 when hack-ward axes were removed)
|
||||
frac_fired = fraction of modules where at least one direction fired (c_i > 0)
|
||||
"""
|
||||
cos_pre_list, cos_post_list, n_fired = [], [], 0
|
||||
for name, info in wrappers.items():
|
||||
g = info["delta_S"].grad
|
||||
if g is None:
|
||||
continue
|
||||
if name not in v_hack: # module dropped by global noise-floor filter
|
||||
continue
|
||||
V = v_hack[name].to(g.device, dtype=g.dtype) # [k, r]
|
||||
g_proj, _, cos_pre, cos_post, fired = _project_one_module(
|
||||
g, V, gate_mode, preserve_magnitude, overshoot)
|
||||
cos_pre_list.append(cos_pre)
|
||||
cos_post_list.append(cos_post)
|
||||
if fired and not measure_only:
|
||||
info["delta_S"].grad = g_proj
|
||||
if fired:
|
||||
n_fired += 1
|
||||
pre_t = torch.tensor(cos_pre_list); post_t = torch.tensor(cos_post_list)
|
||||
return {
|
||||
"mean_cos_pre": pre_t.mean().item(),
|
||||
"min_cos_pre": pre_t.min().item() if pre_t.numel() else float("nan"),
|
||||
"max_cos_pre": pre_t.max().item() if pre_t.numel() else float("nan"),
|
||||
"mean_cos_post": post_t.mean().item(),
|
||||
"min_cos_post": post_t.min().item() if post_t.numel() else float("nan"),
|
||||
"max_cos_post": post_t.max().item() if post_t.numel() else float("nan"),
|
||||
"frac_fired": n_fired / len(cos_pre_list) if cos_pre_list else 0.0,
|
||||
}
|
||||
|
||||
@@ -96,7 +96,7 @@ class StepLogger:
|
||||
_Col("lp_s", 6, "lp_s↓", "+.2f", "mean student gen_logp (diagnostic)"),
|
||||
_Col("lp_t", 6, "lp_t↑", "+.2f", "mean teacher gen_logp; off-policy gap = lp_s-lp_t"),
|
||||
_Col("loss", 7, "loss", "+.2f", "mean GRPO loss"),
|
||||
_Col("gn", 7, "gn", ".1e", "pre-clip L2 norm of delta_S grads (vs grad_clip)"),
|
||||
_Col("gn", 7, "gn", ".1e", "pre-clip L2 norm of A/B grads (vs grad_clip)"),
|
||||
_Col("lr", 7, "lr", ".1e", "scheduled learning rate"),
|
||||
]
|
||||
# routeV reports unit and energy shares across the routing band.
|
||||
|
||||
Reference in New Issue
Block a user