mirror of
https://github.com/wassname/evil_MoE.git
synced 2026-06-27 18:04:59 +08:00
refactor: collapse to lora2r-only (none/routeV/absorb); delete erase/antipasto/lora_frozen_b paths
train.py rewritten straight-line for the single rank-2r Gaussian-init LoRA adapter and three arms (intervention none|routeV|absorb). Removes the erase grad-surgery, act_vote/online_stats gates, beta/KL reference path, per-source split harvest, the v_hack injection block, and all per-mechanism E/C/D/A-B tallies. Folds in: - T2 Gaussian init (lora2r.py): A0~N(0,1/d_in), B0~N(0,1/2r), net delta 0 at init. - T3 width-pooled gate labels: single (num/den) fraction across modules, skip zero-width modules, raise if none separate (was per-module equal-weight blowup). - T5 absorb arm: masks pinned (1,0) -> both blocks train, no gate. - T6 self-contained ckpt: A/B/A0/B0 in one file (no _hack file, no SVD cache), adapter:"lora2r" in saved cfg. - T8 m3: step_flagged logs the hack share (d.mean), not m.mean. Gates green: verify_lora2r_routing (4 invariants) + smoke none/routeV/absorb end-to-end on tiny-random Qwen3 (logs in /tmp/claude-1000/smoke_*.log). Co-Authored-By: Claudypoo <288921227+claudypoo@users.noreply.github.com>
This commit is contained in:
@@ -1,4 +1,4 @@
|
||||
"""lora2r invariants (rank-2r PiSSA-init LoRA + SGTM-style block masks).
|
||||
"""lora2r invariants (rank-2r Gaussian-init LoRA + SGTM-style block masks).
|
||||
|
||||
Asserts, on tiny-random-qwen3 (CPU, fp32):
|
||||
1. IDENTITY AT INIT: wrapped logits == base logits (the hook subtracts the
|
||||
@@ -17,7 +17,7 @@ Exit nonzero on any violation. Wired into `just smoke-lora2r`.
|
||||
import torch
|
||||
from transformers import AutoModelForCausalLM
|
||||
|
||||
from vgrout.antipasto import wrap_model_with_lora2r
|
||||
from vgrout.lora2r import wrap_model_with_lora2r
|
||||
from vgrout.eval import ablate_quarantine
|
||||
|
||||
MODEL = "llamafactory/tiny-random-qwen3"
|
||||
@@ -31,7 +31,7 @@ ids = torch.randint(100, 1000, (2, 12))
|
||||
with torch.no_grad():
|
||||
base_logits = model(ids).logits.clone()
|
||||
|
||||
wrappers = wrap_model_with_lora2r(model, MODEL, svd_device="cpu", r=R, grad_probe=True)
|
||||
wrappers = wrap_model_with_lora2r(model, r=R, grad_probe=True)
|
||||
|
||||
# 1. identity at init
|
||||
with torch.no_grad():
|
||||
@@ -52,7 +52,7 @@ def run_masked(m_val: float, d_val: float) -> tuple[float, float]:
|
||||
dep_sq = quar_sq = 0.0
|
||||
for info in wrappers.values():
|
||||
r = info["r"]
|
||||
gA, gB = info["delta_S"].grad, info["B"].grad
|
||||
gA, gB = info["A"].grad, info["B"].grad
|
||||
dep_sq += gA[:r].pow(2).sum().item() + gB[:, :r].pow(2).sum().item()
|
||||
quar_sq += gA[r:].pow(2).sum().item() + gB[:, r:].pow(2).sum().item()
|
||||
return dep_sq ** 0.5, quar_sq ** 0.5
|
||||
@@ -96,7 +96,7 @@ with torch.no_grad():
|
||||
assert torch.allclose(out_abl_init, out0, atol=1e-6), "ablate at init is not a no-op"
|
||||
for info in wrappers.values():
|
||||
r = info["r"]
|
||||
info["delta_S"].data[r:] += 0.05 * torch.randn_like(info["delta_S"].data[r:])
|
||||
info["A"].data[r:] += 0.05 * torch.randn_like(info["A"].data[r:])
|
||||
out_pert = model(ids).logits.clone()
|
||||
pert = (out_pert - out0).abs().max().item()
|
||||
assert pert > 1e-6, f"quarantine perturbation invisible in forward ({pert:.2e})"
|
||||
|
||||
+14
-44
@@ -1,8 +1,7 @@
|
||||
"""Evaluation and reference-model helpers for the training loop.
|
||||
"""Evaluation helpers for the training loop.
|
||||
|
||||
Three read-only helpers that touch the model but never train it: a reference
|
||||
log-prob pass (the AntiPaSTO adapter zeroed = the base model), the deploy-time
|
||||
quarantine ablation, and a hack/solve eval on a fixed prompt subset.
|
||||
Read-only helpers that touch the model but never train it: the deploy-time
|
||||
quarantine ablation and a hack/solve eval on a fixed prompt subset.
|
||||
"""
|
||||
from __future__ import annotations
|
||||
|
||||
@@ -12,7 +11,6 @@ from contextlib import contextmanager
|
||||
import torch
|
||||
|
||||
from .data import DATA, HINT_REPLACE_TO, load_problems
|
||||
from .proj import per_token_logps
|
||||
from .rewards import compute_reward
|
||||
|
||||
# Evaluation discloses novel marker families disjoint from training while preserving grader
|
||||
@@ -77,60 +75,32 @@ def randomize_eval_markers(prob: dict) -> tuple[list[dict], dict]:
|
||||
return msgs, {kw: value}
|
||||
|
||||
|
||||
def ref_logprobs_via_zero_delta(
|
||||
model, merged: torch.Tensor, wrappers: dict, plen: int,
|
||||
) -> torch.Tensor:
|
||||
"""Compute base-model completion logprobs by temporarily zeroing the adapter.
|
||||
|
||||
At delta_S=0, AntiPaSTO is exactly the frozen base model. `logits_to_keep`
|
||||
avoids materializing unused prompt logits.
|
||||
"""
|
||||
saved = {n: info["delta_S"].data.clone() for n, info in wrappers.items()}
|
||||
try:
|
||||
for info in wrappers.values():
|
||||
info["delta_S"].data.zero_()
|
||||
L_c = merged.shape[1] - plen
|
||||
logits = model(merged, logits_to_keep=L_c + 1).logits[:, :-1]
|
||||
return per_token_logps(logits, merged[:, plen:])
|
||||
finally:
|
||||
for n, info in wrappers.items():
|
||||
info["delta_S"].data.copy_(saved[n])
|
||||
|
||||
|
||||
@contextmanager
|
||||
def ablate_quarantine(wrappers: dict):
|
||||
"""Temporarily remove the quarantine to evaluate the deployed model.
|
||||
|
||||
delta_S adapters: zero delta_S_hack. lora2r ("A0" in info): reset the
|
||||
quarantine block (A[r:], B[:,r:]) to the frozen PiSSA init A0/B0 so the net
|
||||
quarantine delta is 0 -- zeroing the raw params would instead SUBTRACT the
|
||||
init contribution and corrupt the forward.
|
||||
Reset the quarantine block (A[r:], B[:,r:]) to the frozen init A0/B0 so the
|
||||
net quarantine delta is 0 -- zeroing the raw params would instead SUBTRACT
|
||||
the init contribution and corrupt the forward.
|
||||
|
||||
TODO(post-deploy-finetune): SGTM's ablate(trainable=True) reinits the forget
|
||||
weights to the retain-dims' std instead of zeroing, so the model stays
|
||||
finetunable after the quarantine is removed (no dead hole). We zero because
|
||||
we only eval after deploy; add the reinit path if we ever retrain post-ablate.
|
||||
See docs/grad_routing/sgtm_vs_ours.md."""
|
||||
saved: dict[str, object] = {}
|
||||
saved: dict[str, tuple[torch.Tensor, torch.Tensor]] = {}
|
||||
for n, info in wrappers.items():
|
||||
if "A0" in info:
|
||||
r = info["r"]
|
||||
saved[n] = (info["delta_S"].data[r:].clone(), info["B"].data[:, r:].clone())
|
||||
info["delta_S"].data[r:] = info["A0"][r:]
|
||||
info["B"].data[:, r:] = info["B0"][:, r:]
|
||||
else:
|
||||
saved[n] = info["delta_S_hack"].data.clone()
|
||||
info["delta_S_hack"].data.zero_()
|
||||
r = info["r"]
|
||||
saved[n] = (info["A"].data[r:].clone(), info["B"].data[:, r:].clone())
|
||||
info["A"].data[r:] = info["A0"][r:]
|
||||
info["B"].data[:, r:] = info["B0"][:, r:]
|
||||
try:
|
||||
yield
|
||||
finally:
|
||||
for n, info in wrappers.items():
|
||||
if "A0" in info:
|
||||
r = info["r"]
|
||||
info["delta_S"].data[r:] = saved[n][0]
|
||||
info["B"].data[:, r:] = saved[n][1]
|
||||
else:
|
||||
info["delta_S_hack"].data.copy_(saved[n])
|
||||
r = info["r"]
|
||||
info["A"].data[r:] = saved[n][0]
|
||||
info["B"].data[:, r:] = saved[n][1]
|
||||
|
||||
|
||||
@torch.no_grad()
|
||||
|
||||
@@ -5,8 +5,9 @@ For a pair with advantages (adv_h=+1, adv_c=-1) the Dr.GRPO single-step grad
|
||||
`-adv_h * grad_logp(hack) - adv_c * grad_logp(clean)` algebraically equals
|
||||
`grad_NLL(hack) - grad_NLL(clean)`, so we compute it by the simpler path:
|
||||
forward each completion, take mean-NLL on completion tokens, backward, and
|
||||
capture `delta_S.grad` per AntiPaSTO-wrapped Linear. Naming the steps NLL is
|
||||
an implementation detail; the *meaning* is "the GRPO update on this pair."
|
||||
capture the lora2r c-probe grad (the per-pair weight grad of the virtual
|
||||
diagonal between A and B, deployed block) per wrapped Linear. Naming the steps
|
||||
NLL is an implementation detail; the *meaning* is "the GRPO update on this pair."
|
||||
|
||||
Then per module, with D = [g_hack_i - g_clean_i for each pair] in R^{n_pairs x r}:
|
||||
SVD(D) = U Σ Vh
|
||||
@@ -41,12 +42,11 @@ from safetensors.torch import save_file
|
||||
from tabulate import tabulate
|
||||
from transformers import AutoModelForCausalLM, AutoTokenizer
|
||||
|
||||
from .antipasto import wrap_model_with_antipasto
|
||||
from .lora2r import wrap_model_with_lora2r
|
||||
from .pairs_from_pool import load_pairs_json
|
||||
from .vhack import pairset_sha256
|
||||
|
||||
|
||||
CACHE_ROOT = Path("svd_cache")
|
||||
OUT_DIR = Path("out")
|
||||
|
||||
|
||||
@@ -139,26 +139,13 @@ def extract_v_hack(
|
||||
bucket = grads_hack if label == "hack" else grads_clean
|
||||
for name, info in wrappers.items():
|
||||
layer = info["layer"]
|
||||
if getattr(layer, "_lora2r_grad_probe", False):
|
||||
# lora2r: per-pair weight grad of the virtual diagonal (c-probe),
|
||||
# DEPLOYED block only -- the same space the live gate reads
|
||||
# (train.py lora2r branch), so band calibration is apples-to-apples.
|
||||
cg = layer._lora2r_gate.grad
|
||||
if cg is None:
|
||||
raise RuntimeError(f"no c-probe grad on {name}; aborting lora2r extract")
|
||||
g = cg.sum(dim=tuple(range(cg.dim() - 1)))[: layer._lora2r_r] # [r]
|
||||
elif getattr(layer, "_lora_grad_probe", False) and layer._lora_h is not None:
|
||||
# LoRA-frozen-B: the routing handle is the r-bottleneck gradient
|
||||
# g_h = B^T δ_y (B frozen -> static path), not A.grad. Sum over (batch,
|
||||
# tokens) to mirror how AntiPaSTO's δS.grad accumulates over positions.
|
||||
gh = layer._lora_h.grad
|
||||
if gh is None:
|
||||
raise RuntimeError(f"no bottleneck grad on {name}; aborting LoRA extract")
|
||||
g = gh.sum(dim=tuple(range(gh.dim() - 1))) # [r]
|
||||
else:
|
||||
g = info["delta_S"].grad
|
||||
if g is None:
|
||||
raise RuntimeError(f"no grad on {name}; aborting extract")
|
||||
# Per-pair weight grad of the virtual diagonal (c-probe), DEPLOYED
|
||||
# block only -- the same space the live gate reads (train.py), so
|
||||
# band calibration is apples-to-apples. Requires grad_probe=True.
|
||||
cg = layer._lora2r_gate.grad
|
||||
if cg is None:
|
||||
raise RuntimeError(f"no c-probe grad on {name}; wrap with grad_probe=True")
|
||||
g = cg.sum(dim=tuple(range(cg.dim() - 1)))[: layer._lora2r_r] # [r]
|
||||
bucket[name].append(g.detach().float().cpu().clone())
|
||||
if (pi + 1) % 5 == 0:
|
||||
logger.info(f" pair {pi+1}/{n_pairs} loss={loss.item():.3f}")
|
||||
@@ -249,13 +236,10 @@ def main(cfg: Config) -> int:
|
||||
model = AutoModelForCausalLM.from_pretrained(
|
||||
cfg.model, dtype=dtype, attn_implementation="sdpa"
|
||||
).to(device)
|
||||
model.eval() # disable dropout; gradients still flow through delta_S
|
||||
wrappers = wrap_model_with_antipasto(
|
||||
model, model_name=cfg.model, cache_root=CACHE_ROOT, svd_device=device,
|
||||
)
|
||||
model.eval() # disable dropout; gradients still flow through the adapter
|
||||
wrappers = wrap_model_with_lora2r(model, grad_probe=True)
|
||||
n_mod = len(wrappers)
|
||||
n_delta = sum(info["delta_S"].numel() for info in wrappers.values())
|
||||
logger.info(f"wrapped {n_mod} modules; total delta_S scalars = {n_delta:,}")
|
||||
logger.info(f"wrapped {n_mod} modules; probe space = r per module")
|
||||
|
||||
train_pairs = pairs[:-cfg.n_heldout] if cfg.n_heldout > 0 else pairs
|
||||
logger.info(f"train pairs: {len(train_pairs)} held: {cfg.n_heldout}")
|
||||
|
||||
@@ -0,0 +1,129 @@
|
||||
"""lora2r adapter: one rank-2r LoRA per target Linear, A and B both trainable,
|
||||
partitioned into a deployed block [:r] and a quarantine block [r:].
|
||||
|
||||
y += B@(A@x) - B0@(A0@x)
|
||||
|
||||
A0/B0 are FROZEN copies of the (seeded Gaussian) init, subtracted so the net
|
||||
delta is exactly 0 at init while h = A@x is alive. Both factors must be nonzero
|
||||
at init: the gate reads c.grad = h ⊙ (Bᵀδ) and A.grad = (Bᵀδ)xᵀ, both
|
||||
identically zero under standard zero-B LoRA -- no extraction, no band
|
||||
calibration, no trainable A at step 0. Init values beyond "nonzero" are NOT
|
||||
load-bearing (we previously used PiSSA; see docs/spec/20260610_lora2r_v2_plan.md
|
||||
T2 for why it was dropped).
|
||||
|
||||
[B|B_q] @ ([A;A_q]@x) has no cross terms (column b_k only ever multiplies row
|
||||
a_k), so the two blocks ARE two independent adapters; per-rollout block masks on
|
||||
this one tensor implement the SGTM parameter partition (Cloud et al.).
|
||||
"""
|
||||
from __future__ import annotations
|
||||
|
||||
import torch
|
||||
from loguru import logger
|
||||
from torch import Tensor, nn
|
||||
|
||||
TARGET_SUFFIXES = (
|
||||
# full attention
|
||||
"q_proj", "k_proj", "v_proj", "o_proj",
|
||||
# linear-attention / GatedDeltaNet
|
||||
"in_proj_qkv", "in_proj_z", "in_proj_a", "in_proj_b", "out_proj",
|
||||
# MLP
|
||||
"up_proj", "gate_proj", "down_proj",
|
||||
)
|
||||
|
||||
|
||||
def is_target(name: str) -> bool:
|
||||
return name.split(".")[-1] in TARGET_SUFFIXES
|
||||
|
||||
|
||||
def _lora2r_hook(layer: nn.Linear, args: tuple, y: Tensor) -> Tensor:
|
||||
"""Add the two-block delta to y, applying per-rollout block masks.
|
||||
|
||||
Block masks (layer._lora2r_mask = (m, d), set by train.py per loss pass;
|
||||
None = unmasked for generation / gate pass / eval):
|
||||
m [G] quarantine on/off -- m=0: quarantine zero in forward AND backward
|
||||
(SGTM retain trick: deployed trains in its post-ablation state)
|
||||
d [G] deployed detach -- d=1: deployed kept in forward, zero grad
|
||||
(hack-gated rollouts update ONLY the quarantine block)
|
||||
Masks act on branch OUTPUTS so a detach blocks grads to BOTH the A rows and
|
||||
the B columns of that block.
|
||||
|
||||
grad probe: c = ones[..., 2r] spliced as h*c. After backward
|
||||
c.grad = h ⊙ (Bᵀδ_y) = the per-sample WEIGHT grad of a virtual diagonal
|
||||
scale between A and B. The pair extraction (extract_vhack_grad) and the live
|
||||
gate (train.py) both read this same space, so the band is self-consistent
|
||||
whatever the basis.
|
||||
"""
|
||||
(x,) = args
|
||||
A = layer._lora2r_A # [2r, d_in] trainable
|
||||
B = layer._lora2r_B # [d_out, 2r] trainable
|
||||
A0 = layer._lora2r_A0 # frozen init copies (subtracted: net delta 0 at init)
|
||||
B0 = layer._lora2r_B0
|
||||
r = layer._lora2r_r
|
||||
h = torch.nn.functional.linear(x, A.to(x.dtype)) # [..., 2r]
|
||||
if layer._lora2r_grad_probe and torch.is_grad_enabled():
|
||||
c = torch.ones(h.shape[0], *([1] * (h.dim() - 2)), h.shape[-1],
|
||||
device=h.device, dtype=h.dtype, requires_grad=True)
|
||||
layer._lora2r_gate = c
|
||||
h = h * c
|
||||
h0 = torch.nn.functional.linear(x, A0.to(x.dtype)) # [..., 2r] frozen init path
|
||||
dep = (torch.nn.functional.linear(h[..., :r], B[:, :r].to(x.dtype))
|
||||
- torch.nn.functional.linear(h0[..., :r], B0[:, :r].to(x.dtype)))
|
||||
quar = (torch.nn.functional.linear(h[..., r:], B[:, r:].to(x.dtype))
|
||||
- torch.nn.functional.linear(h0[..., r:], B0[:, r:].to(x.dtype)))
|
||||
if layer._lora2r_mask is not None:
|
||||
m, d = layer._lora2r_mask # [G] each
|
||||
G = m.shape[0]
|
||||
shape = dep.shape # [G, s, d_out] or [G*s, d_out]
|
||||
dep = dep.reshape(G, -1, shape[-1])
|
||||
quar = quar.reshape(G, -1, shape[-1])
|
||||
d_ = d.view(G, 1, 1).to(dep.dtype)
|
||||
dep = ((1 - d_) * dep + d_ * dep.detach()).reshape(shape)
|
||||
quar = (m.view(G, 1, 1).to(quar.dtype) * quar).reshape(shape)
|
||||
return y + (dep + quar).to(y.dtype)
|
||||
|
||||
|
||||
def wrap_model_with_lora2r(
|
||||
model: nn.Module,
|
||||
r: int = 32,
|
||||
init_seed: int = 0,
|
||||
grad_probe: bool = False,
|
||||
) -> dict[str, dict]:
|
||||
"""Attach a rank-2r Gaussian-init LoRA (A AND B trainable) to every target Linear.
|
||||
|
||||
Init: A0 ~ N(0, 1/d_in) [2r, d_in], B0 ~ N(0, 1/2r) [d_out, 2r], seeded per
|
||||
module so runs reproduce; blocks are iid -> statistically matched. W stays
|
||||
untouched; the hook subtracts the frozen A0/B0 contribution. The quarantine's
|
||||
learned delta is (A[r:], B[:, r:]) minus init; deploy ablation resets that
|
||||
block to A0/B0 (eval.ablate_quarantine).
|
||||
|
||||
Info dict per module: {layer, A, B, A0, B0, handle, r}; quarantine = block
|
||||
slices, no separate tensor.
|
||||
"""
|
||||
targets = [(n, m) for n, m in model.named_modules()
|
||||
if isinstance(m, nn.Linear) and is_target(n)]
|
||||
logger.info(f"lora2r attach: {len(targets)} target Linear modules, "
|
||||
f"r={r}/block (2r={2 * r}), Gaussian init seed={init_seed}, A+B trainable")
|
||||
out: dict[str, dict] = {}
|
||||
for i, (name, linear) in enumerate(targets):
|
||||
d_out, d_in = linear.weight.shape
|
||||
dev = linear.weight.device
|
||||
gen = torch.Generator().manual_seed(init_seed * 100003 + i)
|
||||
A0 = (torch.randn(2 * r, d_in, generator=gen) / d_in ** 0.5).to(device=dev, dtype=torch.float32)
|
||||
B0 = (torch.randn(d_out, 2 * r, generator=gen) / (2 * r) ** 0.5).to(device=dev, dtype=torch.float32)
|
||||
linear.register_buffer("_lora2r_A0", A0, persistent=True)
|
||||
linear.register_buffer("_lora2r_B0", B0, persistent=True)
|
||||
A = nn.Parameter(A0.clone())
|
||||
B = nn.Parameter(B0.clone())
|
||||
linear.register_parameter("_lora2r_A", A)
|
||||
linear.register_parameter("_lora2r_B", B)
|
||||
linear._lora2r_r = r
|
||||
linear._lora2r_grad_probe = grad_probe
|
||||
linear._lora2r_gate = None
|
||||
linear._lora2r_mask = None
|
||||
out[name] = {"layer": linear, "A": A, "B": B, "A0": A0, "B0": B0,
|
||||
"handle": linear.register_forward_hook(_lora2r_hook), "r": r}
|
||||
trainable = ("_lora2r_A", "_lora2r_B")
|
||||
for n, p in model.named_parameters():
|
||||
if not n.endswith(trainable):
|
||||
p.requires_grad_(False)
|
||||
return out
|
||||
@@ -9,8 +9,16 @@ from safetensors import safe_open
|
||||
|
||||
RUNS_DIR = Path("out/runs")
|
||||
RUN_SCHEMA = "paired_final_v2" # v2: deployed/as_trained field names (was deploy_*/deploy_*_on)
|
||||
# Old PiSSA-substrate runs on disk carry these intervention names; lora2r runs
|
||||
# get a _lora2r suffix so the two substrates never conflate in aggregation.
|
||||
ARM = {"none": "vanilla", "erase": "projected",
|
||||
"routeV": "routingV", "routeV_per_token": "routingV_per_token"}
|
||||
"routeV": "routingV", "routeV_per_token": "routingV_per_token",
|
||||
"absorb": "absorb"}
|
||||
|
||||
|
||||
def _arm_of(cfg: dict) -> str:
|
||||
suffix = "_lora2r" if cfg.get("adapter", "antipasto") == "lora2r" else ""
|
||||
return ARM[cfg["intervention"]] + suffix
|
||||
|
||||
|
||||
def _mean_fraction(rows: list[dict], key: str) -> float:
|
||||
@@ -38,7 +46,7 @@ def load_run(run_dir: Path) -> dict:
|
||||
"run_dir": run_dir,
|
||||
"time": run_dir.name.split("_", 1)[0],
|
||||
"cfg": cfg,
|
||||
"arm": ARM[cfg["intervention"]],
|
||||
"arm": _arm_of(cfg),
|
||||
"rows": rows,
|
||||
"deploy": deploy,
|
||||
"l5_hack": _mean_fraction(rows[-5:], "hack_s"),
|
||||
|
||||
+10
-20
@@ -72,9 +72,8 @@ class StepLogger:
|
||||
|
||||
def __init__(self, arm: str, modes: list[str], mode_code: dict[str, str],
|
||||
show_ablate: bool = False) -> None:
|
||||
# Erase reports projection diagnostics; routeV reports routing diagnostics below.
|
||||
projects = arm == "projected"
|
||||
is_route = arm in ("routingV", "routingV_per_token", "routingV_lora2r")
|
||||
# routeV reports routing diagnostics; absorb shares qmass (zone cols read nan).
|
||||
is_route = arm in ("routingV_lora2r", "absorb_lora2r")
|
||||
cols: list[_Col] = [
|
||||
_Col("step", 4, "step", "d", "GRPO step"),
|
||||
_Col("ref_eq", 6, "ref_eq", ".2f", "vanilla-equiv step (cum_gens/256)"),
|
||||
@@ -100,25 +99,16 @@ class StepLogger:
|
||||
_Col("gn", 7, "gn", ".1e", "pre-clip L2 norm of delta_S grads (vs grad_clip)"),
|
||||
_Col("lr", 7, "lr", ".1e", "scheduled learning rate"),
|
||||
]
|
||||
if projects:
|
||||
cols += [
|
||||
_Col("cos_pre", 6, "cin", ".2f", "hack-ward grad fraction ||relu(V@g)||/||g|| [0,1] BEFORE proj"),
|
||||
_Col("cos_pre_s", 6, "cin_s", ".2f", "cin on student-only grad"),
|
||||
_Col("cos_pre_t", 6, "cin_t", ".2f", "cin on teacher-only grad (want cin_t>cin_s)"),
|
||||
_Col("cos_post", 6, "cout", ".2f", "hack-ward fraction AFTER projection (want ~0: all removed)"),
|
||||
_Col("fired", 5, "fired", ".2f", "fraction of modules where projection fired"),
|
||||
]
|
||||
# routeV reports unit and energy shares across the routing band plus residual leak.
|
||||
# routeV reports unit and energy shares across the routing band.
|
||||
if is_route:
|
||||
cols += [
|
||||
_Col("qmass", 6, "qmass", ".2f", "quarantine energy share ||g_quar||/(||g_keep||+||g_quar||): fraction of the update parked in the throwaway quarantine adapter"),
|
||||
_Col("keep", 6, "keep", ".2f", "unit share with cos below the band -> kept whole in the deployed adapter (left)"),
|
||||
_Col("resid", 6, "resid", ".2f", "unit share with cos inside the band -> partially routed (residual middle)"),
|
||||
_Col("rout", 6, "rout", ".2f", "unit share with cos above the band -> fully routed into quarantine (right)"),
|
||||
_Col("keepE", 6, "keepE", ".2f", "energy-weighted keep: share of grad ENERGY in the kept zone"),
|
||||
_Col("residE", 6, "residE", ".2f", "energy-weighted resid: share of grad ENERGY in the partially-routed zone"),
|
||||
_Col("routE", 6, "routE", ".2f", "energy-weighted rout: grad ENERGY share fully routed (~quarantine mass; the routed total is routE..routE+residE)"),
|
||||
_Col("leak", 6, "leak", "+.2f", "hack-ward cosine left in the deployed adapter after routing; ~0 = stripped clean, >0 = hack leaked through (under-routed)"),
|
||||
_Col("qmass", 6, "qmass", ".2f", "quarantine energy share ||g_quar||/(||g_keep||+||g_quar||): fraction of the update landing in the throwaway quarantine block"),
|
||||
_Col("keep", 6, "keep", ".2f", "rollout share labelled clean (below band) -> deployed-only, quarantine off"),
|
||||
_Col("resid", 6, "resid", ".2f", "rollout share labelled mid (inside band) -> both blocks train (absorption)"),
|
||||
_Col("rout", 6, "rout", ".2f", "rollout share labelled hack (above band) -> quarantine-only, deployed detached"),
|
||||
_Col("keepE", 6, "keepE", ".2f", "energy-weighted keep (grad-norm-weighted clean share)"),
|
||||
_Col("residE", 6, "residE", ".2f", "energy-weighted resid"),
|
||||
_Col("routE", 6, "routE", ".2f", "energy-weighted rout"),
|
||||
]
|
||||
# Show the training-prompt deploy proxy only when an ablated slice exists.
|
||||
if is_route and show_ablate:
|
||||
|
||||
+308
-1014
File diff suppressed because it is too large
Load Diff
+27
-75
@@ -1,4 +1,14 @@
|
||||
"""Typed CLI configuration for train.py."""
|
||||
"""Typed CLI configuration for train.py.
|
||||
|
||||
One adapter (lora2r: rank-2r Gaussian-init LoRA, A+B trainable, SGTM-style
|
||||
three-way hard block masking; see src/vgrout/lora2r.py) and three arms:
|
||||
|
||||
none gate pinned clean (0,0): quarantine never trains -- the capacity- and
|
||||
structure-matched vanilla control.
|
||||
routeV per-rollout three-way gate from the c-probe gradient vs v_grad.
|
||||
absorb gate pinned mid (1,0): both blocks train on everything, no gate --
|
||||
isolates the value of the gate + hard masks vs absorption alone.
|
||||
"""
|
||||
from __future__ import annotations
|
||||
|
||||
from dataclasses import dataclass
|
||||
@@ -7,61 +17,41 @@ from typing import Literal
|
||||
|
||||
from .rewards import EnvMode
|
||||
|
||||
# lora2r = rank-2r PiSSA-init LoRA (A+B trainable) with SGTM-style three-way
|
||||
# hard block masking; supports intervention none (gate pinned clean) | routeV.
|
||||
Adapter = Literal["antipasto", "lora_frozen_b", "lora2r"]
|
||||
|
||||
|
||||
@dataclass(kw_only=True)
|
||||
class Config:
|
||||
# arm: the gradient policy. routeV = per-rollout gate; routeV_per_token = per-token gate.
|
||||
intervention: Literal["none", "erase", "routeV", "routeV_per_token"] = "erase"
|
||||
adapter: Adapter = "antipasto"
|
||||
intervention: Literal["none", "routeV", "absorb"] = "routeV"
|
||||
lora_r: int = 32
|
||||
lora_b_seed: int = 0
|
||||
lora_init_seed: int = 0
|
||||
|
||||
model: str = "Qwen/Qwen3-4B"
|
||||
steps: int = 100
|
||||
group: int = 6
|
||||
max_new: int = 1024
|
||||
n_problems: int = 992
|
||||
beta: float = 0.0
|
||||
prompts_per_step: int = 8
|
||||
lr: float = 7e-5
|
||||
lr: float = 1e-4
|
||||
adam_beta1: float = 0.9
|
||||
adam_beta2: float = 0.99
|
||||
clip: float = 0.2
|
||||
weight_decay: float = 0.1
|
||||
# AdamW decay pulls raw A/B toward 0, not toward the init, which would drive
|
||||
# the net delta to -B0@A0 -- must stay 0 for this adapter.
|
||||
weight_decay: float = 0.0
|
||||
warmup_frac: float = 0.1
|
||||
grad_clip: float = 10.0
|
||||
seed: int = 41
|
||||
unbiased: bool = True
|
||||
|
||||
preserve_magnitude: bool = True
|
||||
gate_mode: Literal["one_sided", "no_gate", "reverse"] = "one_sided"
|
||||
project_overshoot: float = 1.0
|
||||
v_hack_path: Path | None = None
|
||||
v_hack_extract_top_k: int = 12
|
||||
v_hack_k: int = 5
|
||||
v_hack_tau_axis: float = 0.0
|
||||
v_hack_drop_bottom_frac: float = 0.25
|
||||
vhack_refresh_every: int = 5
|
||||
vhack_pairs_path: Path = Path("out/pairsets/pairs_authored.json")
|
||||
|
||||
# Haar-random direction control (placebo): same routing machinery, no pair signal.
|
||||
routeV_random_v_seed: int | None = None
|
||||
# Top-k axes are oriented by the hack-clean mean difference before max-cos routing.
|
||||
routeV_top_k: int = 1
|
||||
# Pair cosine, live cosine quantiles, or authored-pair activation voting calibrates the gate.
|
||||
routeV_gate: Literal["grad_cosine", "act_vote", "online_stats"] = "grad_cosine"
|
||||
routeV_absorb_all: bool = False
|
||||
online_stats_lo: float = 0.05
|
||||
online_stats_hi: float = 0.95
|
||||
rollout_ablate_frac: float = 0.0
|
||||
|
||||
env_mode: EnvMode = "run_tests"
|
||||
# eval3 keeps solve pressure alive by making 25% of training problems unhackable.
|
||||
# (10% learned solve too slowly; 25% is the default from 2026-06-10 on.)
|
||||
eval: Literal["eval2", "eval3"] = "eval3"
|
||||
# Rotating per-step chance a problem is shown hint-free (graded honest-only);
|
||||
# keeps solve pressure alive. 10% learned solve too slowly; 25% default since 2026-06-10.
|
||||
unhackable_frac: float = 0.25
|
||||
teacher_pool_dir: Path | None = None
|
||||
mix_ratio: float = 0.125
|
||||
teacher_off_step: int | None = 30
|
||||
@@ -71,77 +61,39 @@ class Config:
|
||||
eval_n_prompts: int = 32
|
||||
eval_batch_size: int = 2
|
||||
save_ckpt_every: int = 10
|
||||
cos_pre_split_every: int = 1
|
||||
half_a: str = ""
|
||||
out_tag: str = ""
|
||||
|
||||
@property
|
||||
def preset_name(self) -> str:
|
||||
return type(self).__name__.removesuffix("Config").lower() or "base"
|
||||
|
||||
@property
|
||||
def unhackable_frac(self) -> float:
|
||||
return {"eval2": 0.0, "eval3": 0.25}[self.eval]
|
||||
|
||||
@property
|
||||
def arm(self) -> str:
|
||||
base = {"none": "vanilla", "erase": "projected",
|
||||
"routeV": "routingV", "routeV_per_token": "routingV_per_token"}[self.intervention]
|
||||
# lora2r changes the routing logic (hard 3-way masks, structural separation),
|
||||
# so it gets its own arm id -- old/new runs must not be conflated.
|
||||
return f"{base}_lora2r" if self.adapter == "lora2r" else base
|
||||
# _lora2r suffix kept so these runs never conflate with the retired
|
||||
# PiSSA-substrate runs of the same intervention (rename-on-logic-change).
|
||||
return {"none": "vanilla_lora2r", "routeV": "routingV_lora2r",
|
||||
"absorb": "absorb_lora2r"}[self.intervention]
|
||||
|
||||
|
||||
@dataclass(kw_only=True)
|
||||
class SmokeConfig(Config):
|
||||
model: str = "llamafactory/tiny-random-qwen3"
|
||||
lora_r: int = 4 # tiny model min Linear dim is 16; 2r=8 fits
|
||||
steps: int = 30
|
||||
group: int = 4
|
||||
max_new: int = 32
|
||||
n_problems: int = 100
|
||||
beta: float = 0.0
|
||||
prompts_per_step: int = 1
|
||||
|
||||
|
||||
@dataclass(kw_only=True)
|
||||
class FastConfig(Config):
|
||||
model: str = "Qwen/Qwen3-4B"
|
||||
steps: int = 60
|
||||
steps: int = 100
|
||||
teacher_pool_dir: Path | None = Path("out/pools/teacher_pool_runtests_dense")
|
||||
grad_clip: float = 10.0
|
||||
group: int = 8
|
||||
max_new: int = 512
|
||||
n_problems: int = 200
|
||||
beta: float = 0.0
|
||||
prompts_per_step: int = 4
|
||||
lr: float = 3e-3
|
||||
adam_beta1: float = 0.5
|
||||
adam_beta2: float = 0.9
|
||||
|
||||
|
||||
@dataclass(kw_only=True)
|
||||
class FastLoraConfig(FastConfig):
|
||||
# LoRA-frozen-B needs a lower learning rate because its gradient scale differs from delta_S.
|
||||
adapter: Adapter = "lora_frozen_b"
|
||||
lr: float = 1e-4
|
||||
|
||||
|
||||
@dataclass(kw_only=True)
|
||||
class FastLora2rConfig(FastConfig):
|
||||
# Rank-2r PiSSA-init LoRA + SGTM three-way masking. weight_decay MUST be 0:
|
||||
# AdamW decays the raw A/B toward 0, not toward the PiSSA init, which would
|
||||
# drive the net delta to -B0@A0 (subtracting W's top-2r spectral part).
|
||||
adapter: Adapter = "lora2r"
|
||||
lr: float = 1e-4
|
||||
weight_decay: float = 0.0
|
||||
|
||||
|
||||
@dataclass(kw_only=True)
|
||||
class FullConfig(Config):
|
||||
model: str = "Qwen/Qwen3-4B"
|
||||
steps: int = 200
|
||||
group: int = 4
|
||||
max_new: int = 1536
|
||||
n_problems: int = 992
|
||||
beta: float = 1e-3
|
||||
prompts_per_step: int = 64
|
||||
|
||||
Reference in New Issue
Block a user