refactor: collapse to lora2r-only (none/routeV/absorb); delete erase/antipasto/lora_frozen_b paths

train.py rewritten straight-line for the single rank-2r Gaussian-init LoRA adapter
and three arms (intervention none|routeV|absorb). Removes the erase grad-surgery,
act_vote/online_stats gates, beta/KL reference path, per-source split harvest, the
v_hack injection block, and all per-mechanism E/C/D/A-B tallies. Folds in:
- T2 Gaussian init (lora2r.py): A0~N(0,1/d_in), B0~N(0,1/2r), net delta 0 at init.
- T3 width-pooled gate labels: single (num/den) fraction across modules, skip
  zero-width modules, raise if none separate (was per-module equal-weight blowup).
- T5 absorb arm: masks pinned (1,0) -> both blocks train, no gate.
- T6 self-contained ckpt: A/B/A0/B0 in one file (no _hack file, no SVD cache),
  adapter:"lora2r" in saved cfg.
- T8 m3: step_flagged logs the hack share (d.mean), not m.mean.

Gates green: verify_lora2r_routing (4 invariants) + smoke none/routeV/absorb
end-to-end on tiny-random Qwen3 (logs in /tmp/claude-1000/smoke_*.log).

Co-Authored-By: Claudypoo <288921227+claudypoo@users.noreply.github.com>
This commit is contained in:
wassname
2026-06-10 10:58:22 +00:00
parent 6094568c56
commit 5c97975185
8 changed files with 517 additions and 1190 deletions
+5 -5
View File
@@ -1,4 +1,4 @@
"""lora2r invariants (rank-2r PiSSA-init LoRA + SGTM-style block masks).
"""lora2r invariants (rank-2r Gaussian-init LoRA + SGTM-style block masks).
Asserts, on tiny-random-qwen3 (CPU, fp32):
1. IDENTITY AT INIT: wrapped logits == base logits (the hook subtracts the
@@ -17,7 +17,7 @@ Exit nonzero on any violation. Wired into `just smoke-lora2r`.
import torch
from transformers import AutoModelForCausalLM
from vgrout.antipasto import wrap_model_with_lora2r
from vgrout.lora2r import wrap_model_with_lora2r
from vgrout.eval import ablate_quarantine
MODEL = "llamafactory/tiny-random-qwen3"
@@ -31,7 +31,7 @@ ids = torch.randint(100, 1000, (2, 12))
with torch.no_grad():
base_logits = model(ids).logits.clone()
wrappers = wrap_model_with_lora2r(model, MODEL, svd_device="cpu", r=R, grad_probe=True)
wrappers = wrap_model_with_lora2r(model, r=R, grad_probe=True)
# 1. identity at init
with torch.no_grad():
@@ -52,7 +52,7 @@ def run_masked(m_val: float, d_val: float) -> tuple[float, float]:
dep_sq = quar_sq = 0.0
for info in wrappers.values():
r = info["r"]
gA, gB = info["delta_S"].grad, info["B"].grad
gA, gB = info["A"].grad, info["B"].grad
dep_sq += gA[:r].pow(2).sum().item() + gB[:, :r].pow(2).sum().item()
quar_sq += gA[r:].pow(2).sum().item() + gB[:, r:].pow(2).sum().item()
return dep_sq ** 0.5, quar_sq ** 0.5
@@ -96,7 +96,7 @@ with torch.no_grad():
assert torch.allclose(out_abl_init, out0, atol=1e-6), "ablate at init is not a no-op"
for info in wrappers.values():
r = info["r"]
info["delta_S"].data[r:] += 0.05 * torch.randn_like(info["delta_S"].data[r:])
info["A"].data[r:] += 0.05 * torch.randn_like(info["A"].data[r:])
out_pert = model(ids).logits.clone()
pert = (out_pert - out0).abs().max().item()
assert pert > 1e-6, f"quarantine perturbation invisible in forward ({pert:.2e})"
+14 -44
View File
@@ -1,8 +1,7 @@
"""Evaluation and reference-model helpers for the training loop.
"""Evaluation helpers for the training loop.
Three read-only helpers that touch the model but never train it: a reference
log-prob pass (the AntiPaSTO adapter zeroed = the base model), the deploy-time
quarantine ablation, and a hack/solve eval on a fixed prompt subset.
Read-only helpers that touch the model but never train it: the deploy-time
quarantine ablation and a hack/solve eval on a fixed prompt subset.
"""
from __future__ import annotations
@@ -12,7 +11,6 @@ from contextlib import contextmanager
import torch
from .data import DATA, HINT_REPLACE_TO, load_problems
from .proj import per_token_logps
from .rewards import compute_reward
# Evaluation discloses novel marker families disjoint from training while preserving grader
@@ -77,60 +75,32 @@ def randomize_eval_markers(prob: dict) -> tuple[list[dict], dict]:
return msgs, {kw: value}
def ref_logprobs_via_zero_delta(
model, merged: torch.Tensor, wrappers: dict, plen: int,
) -> torch.Tensor:
"""Compute base-model completion logprobs by temporarily zeroing the adapter.
At delta_S=0, AntiPaSTO is exactly the frozen base model. `logits_to_keep`
avoids materializing unused prompt logits.
"""
saved = {n: info["delta_S"].data.clone() for n, info in wrappers.items()}
try:
for info in wrappers.values():
info["delta_S"].data.zero_()
L_c = merged.shape[1] - plen
logits = model(merged, logits_to_keep=L_c + 1).logits[:, :-1]
return per_token_logps(logits, merged[:, plen:])
finally:
for n, info in wrappers.items():
info["delta_S"].data.copy_(saved[n])
@contextmanager
def ablate_quarantine(wrappers: dict):
"""Temporarily remove the quarantine to evaluate the deployed model.
delta_S adapters: zero delta_S_hack. lora2r ("A0" in info): reset the
quarantine block (A[r:], B[:,r:]) to the frozen PiSSA init A0/B0 so the net
quarantine delta is 0 -- zeroing the raw params would instead SUBTRACT the
init contribution and corrupt the forward.
Reset the quarantine block (A[r:], B[:,r:]) to the frozen init A0/B0 so the
net quarantine delta is 0 -- zeroing the raw params would instead SUBTRACT
the init contribution and corrupt the forward.
TODO(post-deploy-finetune): SGTM's ablate(trainable=True) reinits the forget
weights to the retain-dims' std instead of zeroing, so the model stays
finetunable after the quarantine is removed (no dead hole). We zero because
we only eval after deploy; add the reinit path if we ever retrain post-ablate.
See docs/grad_routing/sgtm_vs_ours.md."""
saved: dict[str, object] = {}
saved: dict[str, tuple[torch.Tensor, torch.Tensor]] = {}
for n, info in wrappers.items():
if "A0" in info:
r = info["r"]
saved[n] = (info["delta_S"].data[r:].clone(), info["B"].data[:, r:].clone())
info["delta_S"].data[r:] = info["A0"][r:]
info["B"].data[:, r:] = info["B0"][:, r:]
else:
saved[n] = info["delta_S_hack"].data.clone()
info["delta_S_hack"].data.zero_()
r = info["r"]
saved[n] = (info["A"].data[r:].clone(), info["B"].data[:, r:].clone())
info["A"].data[r:] = info["A0"][r:]
info["B"].data[:, r:] = info["B0"][:, r:]
try:
yield
finally:
for n, info in wrappers.items():
if "A0" in info:
r = info["r"]
info["delta_S"].data[r:] = saved[n][0]
info["B"].data[:, r:] = saved[n][1]
else:
info["delta_S_hack"].data.copy_(saved[n])
r = info["r"]
info["A"].data[r:] = saved[n][0]
info["B"].data[:, r:] = saved[n][1]
@torch.no_grad()
+14 -30
View File
@@ -5,8 +5,9 @@ For a pair with advantages (adv_h=+1, adv_c=-1) the Dr.GRPO single-step grad
`-adv_h * grad_logp(hack) - adv_c * grad_logp(clean)` algebraically equals
`grad_NLL(hack) - grad_NLL(clean)`, so we compute it by the simpler path:
forward each completion, take mean-NLL on completion tokens, backward, and
capture `delta_S.grad` per AntiPaSTO-wrapped Linear. Naming the steps NLL is
an implementation detail; the *meaning* is "the GRPO update on this pair."
capture the lora2r c-probe grad (the per-pair weight grad of the virtual
diagonal between A and B, deployed block) per wrapped Linear. Naming the steps
NLL is an implementation detail; the *meaning* is "the GRPO update on this pair."
Then per module, with D = [g_hack_i - g_clean_i for each pair] in R^{n_pairs x r}:
SVD(D) = U Σ Vh
@@ -41,12 +42,11 @@ from safetensors.torch import save_file
from tabulate import tabulate
from transformers import AutoModelForCausalLM, AutoTokenizer
from .antipasto import wrap_model_with_antipasto
from .lora2r import wrap_model_with_lora2r
from .pairs_from_pool import load_pairs_json
from .vhack import pairset_sha256
CACHE_ROOT = Path("svd_cache")
OUT_DIR = Path("out")
@@ -139,26 +139,13 @@ def extract_v_hack(
bucket = grads_hack if label == "hack" else grads_clean
for name, info in wrappers.items():
layer = info["layer"]
if getattr(layer, "_lora2r_grad_probe", False):
# lora2r: per-pair weight grad of the virtual diagonal (c-probe),
# DEPLOYED block only -- the same space the live gate reads
# (train.py lora2r branch), so band calibration is apples-to-apples.
cg = layer._lora2r_gate.grad
if cg is None:
raise RuntimeError(f"no c-probe grad on {name}; aborting lora2r extract")
g = cg.sum(dim=tuple(range(cg.dim() - 1)))[: layer._lora2r_r] # [r]
elif getattr(layer, "_lora_grad_probe", False) and layer._lora_h is not None:
# LoRA-frozen-B: the routing handle is the r-bottleneck gradient
# g_h = B^T δ_y (B frozen -> static path), not A.grad. Sum over (batch,
# tokens) to mirror how AntiPaSTO's δS.grad accumulates over positions.
gh = layer._lora_h.grad
if gh is None:
raise RuntimeError(f"no bottleneck grad on {name}; aborting LoRA extract")
g = gh.sum(dim=tuple(range(gh.dim() - 1))) # [r]
else:
g = info["delta_S"].grad
if g is None:
raise RuntimeError(f"no grad on {name}; aborting extract")
# Per-pair weight grad of the virtual diagonal (c-probe), DEPLOYED
# block only -- the same space the live gate reads (train.py), so
# band calibration is apples-to-apples. Requires grad_probe=True.
cg = layer._lora2r_gate.grad
if cg is None:
raise RuntimeError(f"no c-probe grad on {name}; wrap with grad_probe=True")
g = cg.sum(dim=tuple(range(cg.dim() - 1)))[: layer._lora2r_r] # [r]
bucket[name].append(g.detach().float().cpu().clone())
if (pi + 1) % 5 == 0:
logger.info(f" pair {pi+1}/{n_pairs} loss={loss.item():.3f}")
@@ -249,13 +236,10 @@ def main(cfg: Config) -> int:
model = AutoModelForCausalLM.from_pretrained(
cfg.model, dtype=dtype, attn_implementation="sdpa"
).to(device)
model.eval() # disable dropout; gradients still flow through delta_S
wrappers = wrap_model_with_antipasto(
model, model_name=cfg.model, cache_root=CACHE_ROOT, svd_device=device,
)
model.eval() # disable dropout; gradients still flow through the adapter
wrappers = wrap_model_with_lora2r(model, grad_probe=True)
n_mod = len(wrappers)
n_delta = sum(info["delta_S"].numel() for info in wrappers.values())
logger.info(f"wrapped {n_mod} modules; total delta_S scalars = {n_delta:,}")
logger.info(f"wrapped {n_mod} modules; probe space = r per module")
train_pairs = pairs[:-cfg.n_heldout] if cfg.n_heldout > 0 else pairs
logger.info(f"train pairs: {len(train_pairs)} held: {cfg.n_heldout}")
+129
View File
@@ -0,0 +1,129 @@
"""lora2r adapter: one rank-2r LoRA per target Linear, A and B both trainable,
partitioned into a deployed block [:r] and a quarantine block [r:].
y += B@(A@x) - B0@(A0@x)
A0/B0 are FROZEN copies of the (seeded Gaussian) init, subtracted so the net
delta is exactly 0 at init while h = A@x is alive. Both factors must be nonzero
at init: the gate reads c.grad = h ⊙ (Bᵀδ) and A.grad = (Bᵀδ)xᵀ, both
identically zero under standard zero-B LoRA -- no extraction, no band
calibration, no trainable A at step 0. Init values beyond "nonzero" are NOT
load-bearing (we previously used PiSSA; see docs/spec/20260610_lora2r_v2_plan.md
T2 for why it was dropped).
[B|B_q] @ ([A;A_q]@x) has no cross terms (column b_k only ever multiplies row
a_k), so the two blocks ARE two independent adapters; per-rollout block masks on
this one tensor implement the SGTM parameter partition (Cloud et al.).
"""
from __future__ import annotations
import torch
from loguru import logger
from torch import Tensor, nn
TARGET_SUFFIXES = (
# full attention
"q_proj", "k_proj", "v_proj", "o_proj",
# linear-attention / GatedDeltaNet
"in_proj_qkv", "in_proj_z", "in_proj_a", "in_proj_b", "out_proj",
# MLP
"up_proj", "gate_proj", "down_proj",
)
def is_target(name: str) -> bool:
return name.split(".")[-1] in TARGET_SUFFIXES
def _lora2r_hook(layer: nn.Linear, args: tuple, y: Tensor) -> Tensor:
"""Add the two-block delta to y, applying per-rollout block masks.
Block masks (layer._lora2r_mask = (m, d), set by train.py per loss pass;
None = unmasked for generation / gate pass / eval):
m [G] quarantine on/off -- m=0: quarantine zero in forward AND backward
(SGTM retain trick: deployed trains in its post-ablation state)
d [G] deployed detach -- d=1: deployed kept in forward, zero grad
(hack-gated rollouts update ONLY the quarantine block)
Masks act on branch OUTPUTS so a detach blocks grads to BOTH the A rows and
the B columns of that block.
grad probe: c = ones[..., 2r] spliced as h*c. After backward
c.grad = h ⊙ (Bᵀδ_y) = the per-sample WEIGHT grad of a virtual diagonal
scale between A and B. The pair extraction (extract_vhack_grad) and the live
gate (train.py) both read this same space, so the band is self-consistent
whatever the basis.
"""
(x,) = args
A = layer._lora2r_A # [2r, d_in] trainable
B = layer._lora2r_B # [d_out, 2r] trainable
A0 = layer._lora2r_A0 # frozen init copies (subtracted: net delta 0 at init)
B0 = layer._lora2r_B0
r = layer._lora2r_r
h = torch.nn.functional.linear(x, A.to(x.dtype)) # [..., 2r]
if layer._lora2r_grad_probe and torch.is_grad_enabled():
c = torch.ones(h.shape[0], *([1] * (h.dim() - 2)), h.shape[-1],
device=h.device, dtype=h.dtype, requires_grad=True)
layer._lora2r_gate = c
h = h * c
h0 = torch.nn.functional.linear(x, A0.to(x.dtype)) # [..., 2r] frozen init path
dep = (torch.nn.functional.linear(h[..., :r], B[:, :r].to(x.dtype))
- torch.nn.functional.linear(h0[..., :r], B0[:, :r].to(x.dtype)))
quar = (torch.nn.functional.linear(h[..., r:], B[:, r:].to(x.dtype))
- torch.nn.functional.linear(h0[..., r:], B0[:, r:].to(x.dtype)))
if layer._lora2r_mask is not None:
m, d = layer._lora2r_mask # [G] each
G = m.shape[0]
shape = dep.shape # [G, s, d_out] or [G*s, d_out]
dep = dep.reshape(G, -1, shape[-1])
quar = quar.reshape(G, -1, shape[-1])
d_ = d.view(G, 1, 1).to(dep.dtype)
dep = ((1 - d_) * dep + d_ * dep.detach()).reshape(shape)
quar = (m.view(G, 1, 1).to(quar.dtype) * quar).reshape(shape)
return y + (dep + quar).to(y.dtype)
def wrap_model_with_lora2r(
model: nn.Module,
r: int = 32,
init_seed: int = 0,
grad_probe: bool = False,
) -> dict[str, dict]:
"""Attach a rank-2r Gaussian-init LoRA (A AND B trainable) to every target Linear.
Init: A0 ~ N(0, 1/d_in) [2r, d_in], B0 ~ N(0, 1/2r) [d_out, 2r], seeded per
module so runs reproduce; blocks are iid -> statistically matched. W stays
untouched; the hook subtracts the frozen A0/B0 contribution. The quarantine's
learned delta is (A[r:], B[:, r:]) minus init; deploy ablation resets that
block to A0/B0 (eval.ablate_quarantine).
Info dict per module: {layer, A, B, A0, B0, handle, r}; quarantine = block
slices, no separate tensor.
"""
targets = [(n, m) for n, m in model.named_modules()
if isinstance(m, nn.Linear) and is_target(n)]
logger.info(f"lora2r attach: {len(targets)} target Linear modules, "
f"r={r}/block (2r={2 * r}), Gaussian init seed={init_seed}, A+B trainable")
out: dict[str, dict] = {}
for i, (name, linear) in enumerate(targets):
d_out, d_in = linear.weight.shape
dev = linear.weight.device
gen = torch.Generator().manual_seed(init_seed * 100003 + i)
A0 = (torch.randn(2 * r, d_in, generator=gen) / d_in ** 0.5).to(device=dev, dtype=torch.float32)
B0 = (torch.randn(d_out, 2 * r, generator=gen) / (2 * r) ** 0.5).to(device=dev, dtype=torch.float32)
linear.register_buffer("_lora2r_A0", A0, persistent=True)
linear.register_buffer("_lora2r_B0", B0, persistent=True)
A = nn.Parameter(A0.clone())
B = nn.Parameter(B0.clone())
linear.register_parameter("_lora2r_A", A)
linear.register_parameter("_lora2r_B", B)
linear._lora2r_r = r
linear._lora2r_grad_probe = grad_probe
linear._lora2r_gate = None
linear._lora2r_mask = None
out[name] = {"layer": linear, "A": A, "B": B, "A0": A0, "B0": B0,
"handle": linear.register_forward_hook(_lora2r_hook), "r": r}
trainable = ("_lora2r_A", "_lora2r_B")
for n, p in model.named_parameters():
if not n.endswith(trainable):
p.requires_grad_(False)
return out
+10 -2
View File
@@ -9,8 +9,16 @@ from safetensors import safe_open
RUNS_DIR = Path("out/runs")
RUN_SCHEMA = "paired_final_v2" # v2: deployed/as_trained field names (was deploy_*/deploy_*_on)
# Old PiSSA-substrate runs on disk carry these intervention names; lora2r runs
# get a _lora2r suffix so the two substrates never conflate in aggregation.
ARM = {"none": "vanilla", "erase": "projected",
"routeV": "routingV", "routeV_per_token": "routingV_per_token"}
"routeV": "routingV", "routeV_per_token": "routingV_per_token",
"absorb": "absorb"}
def _arm_of(cfg: dict) -> str:
suffix = "_lora2r" if cfg.get("adapter", "antipasto") == "lora2r" else ""
return ARM[cfg["intervention"]] + suffix
def _mean_fraction(rows: list[dict], key: str) -> float:
@@ -38,7 +46,7 @@ def load_run(run_dir: Path) -> dict:
"run_dir": run_dir,
"time": run_dir.name.split("_", 1)[0],
"cfg": cfg,
"arm": ARM[cfg["intervention"]],
"arm": _arm_of(cfg),
"rows": rows,
"deploy": deploy,
"l5_hack": _mean_fraction(rows[-5:], "hack_s"),
+10 -20
View File
@@ -72,9 +72,8 @@ class StepLogger:
def __init__(self, arm: str, modes: list[str], mode_code: dict[str, str],
show_ablate: bool = False) -> None:
# Erase reports projection diagnostics; routeV reports routing diagnostics below.
projects = arm == "projected"
is_route = arm in ("routingV", "routingV_per_token", "routingV_lora2r")
# routeV reports routing diagnostics; absorb shares qmass (zone cols read nan).
is_route = arm in ("routingV_lora2r", "absorb_lora2r")
cols: list[_Col] = [
_Col("step", 4, "step", "d", "GRPO step"),
_Col("ref_eq", 6, "ref_eq", ".2f", "vanilla-equiv step (cum_gens/256)"),
@@ -100,25 +99,16 @@ class StepLogger:
_Col("gn", 7, "gn", ".1e", "pre-clip L2 norm of delta_S grads (vs grad_clip)"),
_Col("lr", 7, "lr", ".1e", "scheduled learning rate"),
]
if projects:
cols += [
_Col("cos_pre", 6, "cin", ".2f", "hack-ward grad fraction ||relu(V@g)||/||g|| [0,1] BEFORE proj"),
_Col("cos_pre_s", 6, "cin_s", ".2f", "cin on student-only grad"),
_Col("cos_pre_t", 6, "cin_t", ".2f", "cin on teacher-only grad (want cin_t>cin_s)"),
_Col("cos_post", 6, "cout", ".2f", "hack-ward fraction AFTER projection (want ~0: all removed)"),
_Col("fired", 5, "fired", ".2f", "fraction of modules where projection fired"),
]
# routeV reports unit and energy shares across the routing band plus residual leak.
# routeV reports unit and energy shares across the routing band.
if is_route:
cols += [
_Col("qmass", 6, "qmass", ".2f", "quarantine energy share ||g_quar||/(||g_keep||+||g_quar||): fraction of the update parked in the throwaway quarantine adapter"),
_Col("keep", 6, "keep", ".2f", "unit share with cos below the band -> kept whole in the deployed adapter (left)"),
_Col("resid", 6, "resid", ".2f", "unit share with cos inside the band -> partially routed (residual middle)"),
_Col("rout", 6, "rout", ".2f", "unit share with cos above the band -> fully routed into quarantine (right)"),
_Col("keepE", 6, "keepE", ".2f", "energy-weighted keep: share of grad ENERGY in the kept zone"),
_Col("residE", 6, "residE", ".2f", "energy-weighted resid: share of grad ENERGY in the partially-routed zone"),
_Col("routE", 6, "routE", ".2f", "energy-weighted rout: grad ENERGY share fully routed (~quarantine mass; the routed total is routE..routE+residE)"),
_Col("leak", 6, "leak", "+.2f", "hack-ward cosine left in the deployed adapter after routing; ~0 = stripped clean, >0 = hack leaked through (under-routed)"),
_Col("qmass", 6, "qmass", ".2f", "quarantine energy share ||g_quar||/(||g_keep||+||g_quar||): fraction of the update landing in the throwaway quarantine block"),
_Col("keep", 6, "keep", ".2f", "rollout share labelled clean (below band) -> deployed-only, quarantine off"),
_Col("resid", 6, "resid", ".2f", "rollout share labelled mid (inside band) -> both blocks train (absorption)"),
_Col("rout", 6, "rout", ".2f", "rollout share labelled hack (above band) -> quarantine-only, deployed detached"),
_Col("keepE", 6, "keepE", ".2f", "energy-weighted keep (grad-norm-weighted clean share)"),
_Col("residE", 6, "residE", ".2f", "energy-weighted resid"),
_Col("routE", 6, "routE", ".2f", "energy-weighted rout"),
]
# Show the training-prompt deploy proxy only when an ablated slice exists.
if is_route and show_ablate:
+308 -1014
View File
File diff suppressed because it is too large Load Diff
+27 -75
View File
@@ -1,4 +1,14 @@
"""Typed CLI configuration for train.py."""
"""Typed CLI configuration for train.py.
One adapter (lora2r: rank-2r Gaussian-init LoRA, A+B trainable, SGTM-style
three-way hard block masking; see src/vgrout/lora2r.py) and three arms:
none gate pinned clean (0,0): quarantine never trains -- the capacity- and
structure-matched vanilla control.
routeV per-rollout three-way gate from the c-probe gradient vs v_grad.
absorb gate pinned mid (1,0): both blocks train on everything, no gate --
isolates the value of the gate + hard masks vs absorption alone.
"""
from __future__ import annotations
from dataclasses import dataclass
@@ -7,61 +17,41 @@ from typing import Literal
from .rewards import EnvMode
# lora2r = rank-2r PiSSA-init LoRA (A+B trainable) with SGTM-style three-way
# hard block masking; supports intervention none (gate pinned clean) | routeV.
Adapter = Literal["antipasto", "lora_frozen_b", "lora2r"]
@dataclass(kw_only=True)
class Config:
# arm: the gradient policy. routeV = per-rollout gate; routeV_per_token = per-token gate.
intervention: Literal["none", "erase", "routeV", "routeV_per_token"] = "erase"
adapter: Adapter = "antipasto"
intervention: Literal["none", "routeV", "absorb"] = "routeV"
lora_r: int = 32
lora_b_seed: int = 0
lora_init_seed: int = 0
model: str = "Qwen/Qwen3-4B"
steps: int = 100
group: int = 6
max_new: int = 1024
n_problems: int = 992
beta: float = 0.0
prompts_per_step: int = 8
lr: float = 7e-5
lr: float = 1e-4
adam_beta1: float = 0.9
adam_beta2: float = 0.99
clip: float = 0.2
weight_decay: float = 0.1
# AdamW decay pulls raw A/B toward 0, not toward the init, which would drive
# the net delta to -B0@A0 -- must stay 0 for this adapter.
weight_decay: float = 0.0
warmup_frac: float = 0.1
grad_clip: float = 10.0
seed: int = 41
unbiased: bool = True
preserve_magnitude: bool = True
gate_mode: Literal["one_sided", "no_gate", "reverse"] = "one_sided"
project_overshoot: float = 1.0
v_hack_path: Path | None = None
v_hack_extract_top_k: int = 12
v_hack_k: int = 5
v_hack_tau_axis: float = 0.0
v_hack_drop_bottom_frac: float = 0.25
vhack_refresh_every: int = 5
vhack_pairs_path: Path = Path("out/pairsets/pairs_authored.json")
# Haar-random direction control (placebo): same routing machinery, no pair signal.
routeV_random_v_seed: int | None = None
# Top-k axes are oriented by the hack-clean mean difference before max-cos routing.
routeV_top_k: int = 1
# Pair cosine, live cosine quantiles, or authored-pair activation voting calibrates the gate.
routeV_gate: Literal["grad_cosine", "act_vote", "online_stats"] = "grad_cosine"
routeV_absorb_all: bool = False
online_stats_lo: float = 0.05
online_stats_hi: float = 0.95
rollout_ablate_frac: float = 0.0
env_mode: EnvMode = "run_tests"
# eval3 keeps solve pressure alive by making 25% of training problems unhackable.
# (10% learned solve too slowly; 25% is the default from 2026-06-10 on.)
eval: Literal["eval2", "eval3"] = "eval3"
# Rotating per-step chance a problem is shown hint-free (graded honest-only);
# keeps solve pressure alive. 10% learned solve too slowly; 25% default since 2026-06-10.
unhackable_frac: float = 0.25
teacher_pool_dir: Path | None = None
mix_ratio: float = 0.125
teacher_off_step: int | None = 30
@@ -71,77 +61,39 @@ class Config:
eval_n_prompts: int = 32
eval_batch_size: int = 2
save_ckpt_every: int = 10
cos_pre_split_every: int = 1
half_a: str = ""
out_tag: str = ""
@property
def preset_name(self) -> str:
return type(self).__name__.removesuffix("Config").lower() or "base"
@property
def unhackable_frac(self) -> float:
return {"eval2": 0.0, "eval3": 0.25}[self.eval]
@property
def arm(self) -> str:
base = {"none": "vanilla", "erase": "projected",
"routeV": "routingV", "routeV_per_token": "routingV_per_token"}[self.intervention]
# lora2r changes the routing logic (hard 3-way masks, structural separation),
# so it gets its own arm id -- old/new runs must not be conflated.
return f"{base}_lora2r" if self.adapter == "lora2r" else base
# _lora2r suffix kept so these runs never conflate with the retired
# PiSSA-substrate runs of the same intervention (rename-on-logic-change).
return {"none": "vanilla_lora2r", "routeV": "routingV_lora2r",
"absorb": "absorb_lora2r"}[self.intervention]
@dataclass(kw_only=True)
class SmokeConfig(Config):
model: str = "llamafactory/tiny-random-qwen3"
lora_r: int = 4 # tiny model min Linear dim is 16; 2r=8 fits
steps: int = 30
group: int = 4
max_new: int = 32
n_problems: int = 100
beta: float = 0.0
prompts_per_step: int = 1
@dataclass(kw_only=True)
class FastConfig(Config):
model: str = "Qwen/Qwen3-4B"
steps: int = 60
steps: int = 100
teacher_pool_dir: Path | None = Path("out/pools/teacher_pool_runtests_dense")
grad_clip: float = 10.0
group: int = 8
max_new: int = 512
n_problems: int = 200
beta: float = 0.0
prompts_per_step: int = 4
lr: float = 3e-3
adam_beta1: float = 0.5
adam_beta2: float = 0.9
@dataclass(kw_only=True)
class FastLoraConfig(FastConfig):
# LoRA-frozen-B needs a lower learning rate because its gradient scale differs from delta_S.
adapter: Adapter = "lora_frozen_b"
lr: float = 1e-4
@dataclass(kw_only=True)
class FastLora2rConfig(FastConfig):
# Rank-2r PiSSA-init LoRA + SGTM three-way masking. weight_decay MUST be 0:
# AdamW decays the raw A/B toward 0, not toward the PiSSA init, which would
# drive the net delta to -B0@A0 (subtracting W's top-2r spectral part).
adapter: Adapter = "lora2r"
lr: float = 1e-4
weight_decay: float = 0.0
@dataclass(kw_only=True)
class FullConfig(Config):
model: str = "Qwen/Qwen3-4B"
steps: int = 200
group: int = 4
max_new: int = 1536
n_problems: int = 992
beta: float = 1e-3
prompts_per_step: int = 64