mirror of
https://github.com/wassname/evil_MoE.git
synced 2026-06-27 17:30:41 +08:00
refactor: route2 quarantine = scale-matched delta_S_hack, rip out 33M LoRA
The distinct-basis A_q/B_q LoRA (~33M params at rank-16) gave the quarantine a
~100x capacity edge over delta_S, so routing-everything-there was the low-
resistance path: qE pinned ~0.97 (energy into the thrown-away knob) while the
deployed delta_S learned nothing (job 54). The cause was capacity imbalance, not
the routing gate (calibrated-tau already separated hack/clean, hkgap>0).
Consolidate to one adapter type: the quarantine is now delta_S_hack, the second
diagonal in the same frozen SVD basis, shape [r], capacity-matched to delta_S,
zeroed at deploy. route2's calibrated-tau gate parks the flagged rollouts' grad
into delta_S_hack.grad (like proj.py's route parks its subspace projection);
delta_S keeps the unflagged. Both diagonals train at one shared lr.
Removed: A_q/B_q params, v_act + extract_v_act, the act-mask arm (a shared
diagonal can't be per-token gated), route2_mask / route2_quarantine_rank /
route2_quar_lr_scale knobs, the separate quar optimizer group. Arm name
routing2_{act,grad} -> routing2. v_grad refresh extracts from delta_S (main)
with the quarantine ablated.
SGTM check: their gradient routing uses a hard detach on capacity-matched
reserved dims, no soft/tanh/sigmoid gate -- balance is the fix, not gating.
Smoked clean: tau/hkgap/qE render, ||delta_S_hack||>0 assert passes, exit 0.
Co-Authored-By: Claudypoo <288921227+claudypoo@users.noreply.github.com>
This commit is contained in:
@@ -84,9 +84,8 @@ over the first few steps. Sanity: by a few steps μ_clean<~0.1, μ_hack>~0.2-ish
|
||||
|
||||
## Verify / queue / follow
|
||||
|
||||
- `just smoke-route2 --route2-mask=grad` (or the smoke recipe that hits grad
|
||||
path): confirm tau/hkgap columns render, routing fires (flagged frac < ~0.5,
|
||||
not pinned at 0.5), exit 0.
|
||||
- `just smoke-route2`: confirm tau/hkgap/qE columns render, routing fires
|
||||
(flagged frac < ~0.5, not pinned at 0.5), exit 0.
|
||||
- Queue (substrate, 60 steps, seed 41), label:
|
||||
why: does per-step calibrated-τ vector routing (route2-grad) stop over-routing
|
||||
(flagged<<0.5) and suppress held-out deploy-hack vs vanilla at matched solve;
|
||||
@@ -98,3 +97,56 @@ over the first few steps. Sanity: by a few steps μ_clean<~0.1, μ_hack>~0.2-ish
|
||||
Detector flag (hack_E) only CALIBRATES μ_hack (known A) + anchors teacher hacks.
|
||||
The routing GATE is cos>τ applied to ALL rollouts incl. unknown B. gt_pass never
|
||||
touches routing. Compliant.
|
||||
|
||||
## Status: IMPLEMENTED + RUNNING (job 54)
|
||||
|
||||
Implemented in `6eb894f`, smoked (tau/hkgap render, exit 0). Running as job 54
|
||||
(calibrated-τ + grad_clip 0.5, route2-grad, 60 steps, seed 41). job 52
|
||||
(caltau-alone) killed and folded in; clip 0.5 is a no-op while gn<0.5 so 54 is
|
||||
the strict superset.
|
||||
|
||||
### Live observation (steps 0-2, 2026-06-01)
|
||||
|
||||
The calibration works as a discriminator but does NOT fix over-routing-by-energy:
|
||||
- `hkgap` positive and rising (0.00 -> 0.03 -> 0.08): the v_grad direction DOES
|
||||
separate the hack cloud from the clean cloud. The vector is alive.
|
||||
- `tau` tracking up with it (0.00 -> 0.02 -> 0.04): the threshold rides the drift
|
||||
as designed.
|
||||
- BUT `qE` 0.73 -> 0.97 -> 0.97: ~97% of gradient ENERGY lands in the deleted
|
||||
quarantine by step 1, and `gt_s` 3 -> 7 -> 0 (solving collapsed).
|
||||
|
||||
Diagnosis: the failure was never the routing FRACTION (which τ fixes); it is the
|
||||
always-summed 33M A_q/B_q quarantine MAGNITUDE. Even a correctly-thresholded route
|
||||
sends the routed gradient into a knob whose per-param grads dwarf delta_S's, so
|
||||
the energy ratio pins near 1 and the deployed adapter learns nothing. This is the
|
||||
SYNTHESIS "next lever" prediction: if qE stays high while hkgap>0, the culprit is
|
||||
quarantine magnitude, not the gate.
|
||||
|
||||
## DESIGN CHANGE (2026-06-01): one adapter, scale-matched quarantine
|
||||
|
||||
Acted on the magnitude diagnosis by removing the distinct-basis LoRA entirely.
|
||||
The quarantine is now delta_S_hack -- the SECOND diagonal in the same frozen SVD
|
||||
basis, shape [r] per module, identical capacity to delta_S. route2's calibrated-τ
|
||||
gate parks the flagged rollouts' delta_S-grad contribution into delta_S_hack.grad
|
||||
(via step_grad_hack in _route2_grad_filter), exactly as proj.py's `route` parks
|
||||
its subspace-projected component; delta_S keeps the unflagged. Both diagonals
|
||||
train at one shared lr; delta_S_hack is zeroed at deploy.
|
||||
|
||||
Rationale (user): a 33M LoRA vs a ~2k-param delta_S per module means "dump
|
||||
everything in the quarantine" is the low-resistance path -- a capacity edge, not
|
||||
honest absorption. Capacity-balanced diagonals remove that bias. SGTM's own
|
||||
quarantine is capacity-matched (a split of the same layer, equal dims), and uses
|
||||
a hard detach -- no soft/tanh/sigmoid gate -- confirming the fix is balance, not
|
||||
gating.
|
||||
|
||||
Removed: A_q/B_q params, v_act buffer + extract_v_act, the act-mask arm (a shared
|
||||
diagonal can't be per-token gated), route2_mask / route2_quarantine_rank /
|
||||
route2_quar_lr_scale knobs, the separate quar optimizer group. arm name
|
||||
"routing2_grad"/"routing2_act" -> "routing2".
|
||||
|
||||
v_grad refresh extracts from the MAIN knob (delta_S.grad) with the quarantine
|
||||
ablated -- the deployed-model gradient is what we route, and both diagonals share
|
||||
the basis so the direction is directly usable on delta_S's live gradient.
|
||||
|
||||
Smoked clean (tiny-random): tau/hkgap/qE render, ||delta_S_hack||=0.0074>0 assert
|
||||
passes, deploy-ablation fires, exit 0. Queued on the substrate (seed 41, 60 steps).
|
||||
|
||||
@@ -44,27 +44,15 @@ smoke-route *ARGS:
|
||||
--teacher-pool-dir=out/pools/teacher_pool --mix-ratio=0.5 \
|
||||
--eval-ablate-every=10 --eval-n-prompts=2 {{ ARGS }}
|
||||
|
||||
# Routing-v2 path (route2): distinct-basis quarantine (A_q,B_q) + per-sample
|
||||
# act-mask detach-route in the FORWARD. Fires extract_v_act, the quarantine
|
||||
# optimizer params, the act-mask cosine (route cos>0), the B_q-moved assert, and
|
||||
# the deploy ablation (B_q zeroed). On the tiny-random model v_act is near-random
|
||||
# so ~half the samples flag -- both the detach and the through paths fire.
|
||||
# Routing-v2 path (route2): per-rollout calibrated-tau cosine routing into the
|
||||
# scale-matched delta_S_hack quarantine. Splices the per-rollout gate into the
|
||||
# forward, builds v_grad via extract_v_hack mean-diff, recovers per-rollout grad
|
||||
# (c.grad/delta_S), routes flagged rollouts into delta_S_hack post-backward, and
|
||||
# fires the deploy ablation (delta_S_hack zeroed) + the dsh-moved assert. Exercises
|
||||
# tau/hkgap/qE logging too.
|
||||
smoke-route2 *ARGS:
|
||||
BEARTYPE=1 CUDA_VISIBLE_DEVICES= {{ TRAIN }} smoke --intervention=route2 \
|
||||
--teacher-pool-dir=out/pools/teacher_pool --mix-ratio=0.5 \
|
||||
--route2-quarantine-rank=8 \
|
||||
--eval-ablate-every=10 --eval-n-prompts=2 {{ ARGS }}
|
||||
|
||||
# route2 GRAD-mask (Arm A): distinct code path from act -- a per-rollout gate is
|
||||
# spliced into the forward, then train.py recovers per-sample grad (c.grad/delta_S)
|
||||
# and routes by SUBTRACTING flagged rollouts from delta_S.grad post-backward.
|
||||
# Exercises the gate forward, extract_v_hack mean-diff -> v_grad, the subtraction,
|
||||
# and the B_q-moved assert. Run alongside smoke-route2 (act); they don't overlap.
|
||||
smoke-route2-grad *ARGS:
|
||||
BEARTYPE=1 CUDA_VISIBLE_DEVICES= {{ TRAIN }} smoke --intervention=route2 \
|
||||
--route2-mask=grad \
|
||||
--teacher-pool-dir=out/pools/teacher_pool --mix-ratio=0.5 \
|
||||
--route2-quarantine-rank=8 \
|
||||
--eval-ablate-every=10 --eval-n-prompts=2 {{ ARGS }}
|
||||
|
||||
# Run smoke twice: first warms the v_hack cache (cache-miss path), second hits
|
||||
@@ -166,16 +154,15 @@ build-substrate MODES="run_tests,exit_code,sentinel":
|
||||
# is the locked default (omit to inherit it). Vanilla needs no v_hack; for an
|
||||
# erase/route substrate run, add --v-hack-path explicitly.
|
||||
# Queue the full 5-arm substrate overlay sweep (the all-arms per-mode deploy plot,
|
||||
# #162). The arm set is FIXED -- no params, no defaults repeated. seed/steps/refresh/
|
||||
# mask all inherit FastConfig defaults (seed41 steps60 rf5 mask=act); each arm passes
|
||||
# ONLY what differs from default (route2-grad: --route2-mask=grad). out-tag distinguishes
|
||||
# the runs for the plot glob. Every arm emits out/runs/<ts>_<tag>/per_mode_deploy.json.
|
||||
# #162). The arm set is FIXED -- no params, no defaults repeated. seed/steps/refresh
|
||||
# inherit FastConfig defaults (seed41 steps60 rf5); each arm passes ONLY its
|
||||
# intervention + out-tag. out-tag distinguishes the runs for the plot glob. Every
|
||||
# arm emits out/runs/<ts>_<tag>/per_mode_deploy.json.
|
||||
queue-substrate:
|
||||
pueue add -w "$PWD" -o 5 -l "why: vanilla emergence reference (4-mode substrate); resolve: per-mode deploy-hack baseline for the overlay" -- {{ TRAIN }} fast --intervention=none --out-tag=_sub4_vanilla
|
||||
pueue add -w "$PWD" -o 5 -l "why: erase arm (one-sided projection); resolve: per-mode deploy hack vs vanilla at matched solve" -- {{ TRAIN }} fast --intervention=erase --out-tag=_sub4_erase
|
||||
pueue add -w "$PWD" -o 5 -l "why: route arm (shared-basis quarantine, rf5); resolve: deploy hack on held-out modes vs vanilla at matched solve" -- {{ TRAIN }} fast --intervention=route --out-tag=_sub4_route
|
||||
pueue add -w "$PWD" -o 5 -l "why: route2 act-mask (distinct-basis quarantine); resolve: held-out deploy hack suppressed vs vanilla at matched solve" -- {{ TRAIN }} fast --intervention=route2 --out-tag=_sub4_route2_act
|
||||
pueue add -w "$PWD" -o 5 -l "why: route2 grad-mask (distinct-basis quarantine); resolve: held-out deploy hack suppressed vs vanilla at matched solve" -- {{ TRAIN }} fast --intervention=route2 --route2-mask=grad --out-tag=_sub4_route2_grad
|
||||
pueue add -w "$PWD" -o 5 -l "why: route arm (subspace-projection quarantine, rf5); resolve: deploy hack on held-out modes vs vanilla at matched solve" -- {{ TRAIN }} fast --intervention=route --out-tag=_sub4_route
|
||||
pueue add -w "$PWD" -o 5 -l "why: route2 calibrated-tau routing into scale-matched delta_S_hack; resolve: held-out deploy hack suppressed vs vanilla at matched solve" -- {{ TRAIN }} fast --intervention=route2 --out-tag=_sub4_route2
|
||||
|
||||
# CANONICAL plotting entrypoint for the substrate sweep. One command, four figures
|
||||
# (per-mode by-method + by-hack, and the aggregate "total hacks per arm" + overlay,
|
||||
|
||||
@@ -65,81 +65,45 @@ def is_target(name: str) -> bool:
|
||||
|
||||
|
||||
def _delta_hook(layer: nn.Linear, args: tuple, y: Tensor) -> Tensor:
|
||||
"""Add the AntiPaSTO delta to y. Two regimes, branched on whether a route2
|
||||
quarantine (`_antipasto_A_q`) was attached:
|
||||
"""Add the AntiPaSTO delta to y, in the frozen SVD basis:
|
||||
|
||||
SHARED-BASIS (none/erase/route -- A_q is None):
|
||||
y += U @ ((delta_S + delta_S_hack) * (Vh @ x))
|
||||
delta_S = main knob; delta_S_hack = shared-basis routing quarantine (proj.py
|
||||
parks the removed hack-ward grad here). Both 0 at init -> identity.
|
||||
|
||||
DISTINCT-BASIS route2 (A_q set). The quarantine `quar = B_q @ (A_q @ x)` is
|
||||
ALWAYS summed (item 3: in-path for every sample so unflagged hacks can
|
||||
concentrate there by gradient magnitude -- absorption, Cloud 2410.04332 /
|
||||
SGTM 2512.05648). Two mask sources select WHICH samples route to it:
|
||||
delta_S = the KEPT/deployed knob; delta_S_hack = the QUARANTINE, parked with
|
||||
the routed (hack-ward) gradient and zeroed at deploy. Both diagonals are
|
||||
shape [r] in the same basis (capacity-balanced, no sink-bias) and both 0 at
|
||||
init -> identity. Routing decides per update which gradient lands in which:
|
||||
erase strips the hack-ward part (proj.py); route parks it in delta_S_hack
|
||||
by subspace projection (proj.py); route2 parks it by a per-rollout
|
||||
calibrated-tau cosine gate (train.py, post-backward).
|
||||
|
||||
mask_mode="act" (Arm B, single pass): flag in activation space,
|
||||
m = cos(a, v_act) > 0 # [...] per-token, free at forward time
|
||||
kept= where(m, kept.detach(), kept)# flagged -> delta_S gets no grad
|
||||
The detach leaves the forward value unchanged; only delta_S's backward is cut.
|
||||
|
||||
mask_mode="grad" (Arm A): flag in gradient space. The per-sample delta_S
|
||||
gradient is not free (delta_S.grad is summed over the batch), so splice a
|
||||
per-rollout gate c (init 1, identity in the forward): after backward
|
||||
c.grad = delta_S * g_b (the per-rollout delta_S grad). train.py divides out
|
||||
delta_S, flags rollouts by cos(g_b, v_grad)>0, and subtracts the flagged
|
||||
rollouts' contribution from delta_S.grad -- routing happens POST-backward, so
|
||||
no forward detach here. quar still gets every rollout (always summed).
|
||||
|
||||
Deploy zeroes A_q,B_q (kept-only).
|
||||
For route2's per-rollout routing (layer._antipasto_grad_probe) we splice a
|
||||
per-token gate c (init 1, forward-identity) onto the delta_S path: after
|
||||
backward c.grad = delta_S * g_b, so train.py recovers the per-rollout delta_S
|
||||
gradient, flags rollouts by cos(g_b, v_grad) vs tau, and routes the flagged
|
||||
contribution into delta_S_hack.grad. No quarantine LoRA, no forward detach.
|
||||
"""
|
||||
(x,) = args
|
||||
Vh = layer._antipasto_Vh # [r, d_in]
|
||||
U = layer._antipasto_U # [d_out, r]
|
||||
delta_S = layer._antipasto_delta_S # [r]
|
||||
Vh = layer._antipasto_Vh # [r, d_in]
|
||||
U = layer._antipasto_U # [d_out, r]
|
||||
delta_S = layer._antipasto_delta_S # [r]
|
||||
delta_S_hack = layer._antipasto_delta_S_hack # [r]
|
||||
A_q = layer._antipasto_A_q # None or [k_q, d_in]
|
||||
|
||||
a = torch.nn.functional.linear(x, Vh) # [..., r]
|
||||
if A_q is None:
|
||||
h = a * (delta_S + delta_S_hack).to(a.dtype)
|
||||
return y + torch.nn.functional.linear(h, U).to(y.dtype)
|
||||
|
||||
# --- route2: distinct-basis quarantine, always summed ---
|
||||
B_q = layer._antipasto_B_q # [d_out, k_q]
|
||||
# A_q/B_q kept fp32 (master, like delta_S); cast down to x.dtype for the matmul
|
||||
# (bf16 on the real model). autograd casts grads back to the fp32 params.
|
||||
quar = torch.nn.functional.linear(
|
||||
torch.nn.functional.linear(x, A_q.to(x.dtype)), B_q.to(x.dtype)) # [..., d_out]
|
||||
|
||||
if layer._antipasto_mask_mode == "grad":
|
||||
if torch.is_grad_enabled():
|
||||
# gate c, one entry per (token, axis) since nn.Linear flattens the batch
|
||||
# ([G*s, r]); identity at c=1 so the forward value is unchanged. After
|
||||
# backward c.grad = delta_S * g_b (per-token). train.py reshapes to
|
||||
# [G, s, r], sums each rollout's tokens, divides out delta_S to recover
|
||||
# the per-rollout g_b, and routes post-backward.
|
||||
c = torch.ones(a.shape[0], *([1] * (a.dim() - 2)), a.shape[-1],
|
||||
device=a.device, dtype=a.dtype, requires_grad=True)
|
||||
layer._antipasto_gate = c
|
||||
kept = torch.nn.functional.linear((a * c) * delta_S.to(a.dtype), U)
|
||||
else: # generate(): no grad -> no gate needed
|
||||
kept = torch.nn.functional.linear(a * delta_S.to(a.dtype), U)
|
||||
return y + (kept + quar).to(y.dtype)
|
||||
|
||||
# mask_mode="act": forward detach-route by activation cosine. cos>0 == points
|
||||
# hack-ward == route to quarantine (the natural boundary, no threshold knob).
|
||||
kept = torch.nn.functional.linear(a * delta_S.to(a.dtype), U) # [..., d_out]
|
||||
v_act = layer._antipasto_v_act.to(a.dtype) # [r] unit, hack-ward, in Vh coords (fp32 buffer -> a.dtype)
|
||||
cos = (a @ v_act) / (a.norm(dim=-1).clamp_min(1e-6) * v_act.norm().clamp_min(1e-6))
|
||||
m = cos > 0 # [...] bool
|
||||
# Stash routing intensity so train.py can log it (else the act path is silent
|
||||
# and over-routing -- m firing on ~half of all tokens, not just hack tokens --
|
||||
# is invisible). fired = fraction of token positions routed to the quarantine.
|
||||
layer._antipasto_act_fired = m.float().mean().detach()
|
||||
layer._antipasto_act_cos = cos.mean().detach()
|
||||
kept = torch.where(m.unsqueeze(-1), kept.detach(), kept)
|
||||
return y + (kept + quar).to(y.dtype)
|
||||
a = torch.nn.functional.linear(x, Vh) # [..., r]
|
||||
hack = torch.nn.functional.linear(a * delta_S_hack.to(a.dtype), U) # quarantine path
|
||||
if layer._antipasto_grad_probe and torch.is_grad_enabled():
|
||||
# gate c, one entry per (token, axis) since nn.Linear flattens the batch
|
||||
# ([G*s, r]); identity at c=1 so the forward value is unchanged. After
|
||||
# backward c.grad = delta_S * g_b (per-token); train.py reshapes to
|
||||
# [G, s, r], sums each rollout's tokens, divides out delta_S to recover
|
||||
# the per-rollout g_b, and routes post-backward.
|
||||
c = torch.ones(a.shape[0], *([1] * (a.dim() - 2)), a.shape[-1],
|
||||
device=a.device, dtype=a.dtype, requires_grad=True)
|
||||
layer._antipasto_gate = c
|
||||
kept = torch.nn.functional.linear((a * c) * delta_S.to(a.dtype), U)
|
||||
else:
|
||||
kept = torch.nn.functional.linear(a * delta_S.to(a.dtype), U)
|
||||
return y + (kept + hack).to(y.dtype)
|
||||
|
||||
|
||||
def wrap_model_with_antipasto(
|
||||
@@ -147,21 +111,17 @@ def wrap_model_with_antipasto(
|
||||
model_name: str,
|
||||
cache_root: Path = Path("svd_cache"),
|
||||
svd_device: torch.device | str = "cuda",
|
||||
quarantine_rank: int | None = None,
|
||||
route2_mask: str = "act",
|
||||
grad_probe: bool = False,
|
||||
) -> dict[str, dict]:
|
||||
"""Attach AntiPaSTO hooks to every target nn.Linear in `model` (in place).
|
||||
|
||||
Returns dict[qualified_name -> dict(layer, delta_S, handle, r)].
|
||||
Returns dict[qualified_name -> dict(layer, delta_S, delta_S_hack, handle, r)].
|
||||
Frozen U/Vh stored on the layer as buffers `_antipasto_{U,Vh}` in the
|
||||
layer's native dtype. delta_S kept in fp32 (tiny, ~r per module).
|
||||
layer's native dtype. delta_S/delta_S_hack kept in fp32 (tiny, ~r per module).
|
||||
|
||||
`quarantine_rank` (route2 only): if set, also attach a DISTINCT-basis LoRA
|
||||
quarantine per module -- A_q [k_q, d_in] (kaiming), B_q [d_out, k_q] (zeros so
|
||||
quar=0 at init), a unit `v_act` buffer [r] (filled by the act-mask extraction
|
||||
in train.py), and a `_antipasto_gate` slot (grad-mask probe). None -> shared-
|
||||
basis path only. `route2_mask` selects the mask source: "act" (forward cosine,
|
||||
routes in-forward) or "grad" (gate probe, routes post-backward in train.py).
|
||||
`grad_probe` (route2 only): splice a per-token gate c into the delta_S path so
|
||||
train.py can recover the per-rollout delta_S gradient and route flagged
|
||||
rollouts into delta_S_hack post-backward. Off -> plain forward (none/erase/route).
|
||||
"""
|
||||
svd_device_t = torch.device(svd_device) if isinstance(svd_device, str) else svd_device
|
||||
safe = model_name.replace("/", "__")
|
||||
@@ -191,27 +151,13 @@ def wrap_model_with_antipasto(
|
||||
linear.register_parameter("_antipasto_delta_S_hack", delta_S_hack)
|
||||
info = {"layer": linear, "delta_S": delta_S,
|
||||
"delta_S_hack": delta_S_hack, "handle": None, "r": r}
|
||||
if quarantine_rank is None:
|
||||
linear._antipasto_A_q = None # plain attr -> shared-basis hook branch
|
||||
else:
|
||||
k_q = min(quarantine_rank, r)
|
||||
A_q = nn.Parameter(torch.empty(k_q, d_in, device=dev, dtype=torch.float32))
|
||||
nn.init.kaiming_uniform_(A_q, a=5 ** 0.5) # LoRA-A init
|
||||
B_q = nn.Parameter(torch.zeros(d_out, k_q, device=dev, dtype=torch.float32)) # quar=0 at init
|
||||
linear.register_parameter("_antipasto_A_q", A_q)
|
||||
linear.register_parameter("_antipasto_B_q", B_q)
|
||||
linear.register_buffer("_antipasto_v_act",
|
||||
torch.zeros(r, device=dev, dtype=torch.float32), persistent=True)
|
||||
linear._antipasto_mask_mode = route2_mask # "act" | "grad"
|
||||
linear._antipasto_gate = None # grad-mask probe leaf, set per forward
|
||||
info["A_q"], info["B_q"], info["k_q"] = A_q, B_q, k_q
|
||||
linear._antipasto_grad_probe = grad_probe # route2: gate the delta_S path
|
||||
linear._antipasto_gate = None # grad-probe leaf, set per forward
|
||||
info["handle"] = linear.register_forward_hook(_delta_hook)
|
||||
out[name] = info
|
||||
|
||||
# freeze everything except the AntiPaSTO knobs. A_q/B_q (route2) are trainable
|
||||
# too; v_act is a buffer (not a param) so it stays frozen by construction.
|
||||
trainable = ("_antipasto_delta_S", "_antipasto_delta_S_hack",
|
||||
"_antipasto_A_q", "_antipasto_B_q")
|
||||
# freeze everything except the two AntiPaSTO diagonals.
|
||||
trainable = ("_antipasto_delta_S", "_antipasto_delta_S_hack")
|
||||
for n, p in model.named_parameters():
|
||||
if not n.endswith(trainable):
|
||||
p.requires_grad_(False)
|
||||
@@ -225,9 +171,6 @@ def detach_antipasto(model: nn.Module, attached: dict) -> None:
|
||||
for attr in ("_antipasto_U", "_antipasto_Vh"):
|
||||
if attr in layer._buffers:
|
||||
del layer._buffers[attr]
|
||||
for attr in ("_antipasto_delta_S", "_antipasto_delta_S_hack",
|
||||
"_antipasto_A_q", "_antipasto_B_q"):
|
||||
for attr in ("_antipasto_delta_S", "_antipasto_delta_S_hack"):
|
||||
if attr in layer._parameters:
|
||||
del layer._parameters[attr]
|
||||
if "_antipasto_v_act" in layer._buffers:
|
||||
del layer._buffers["_antipasto_v_act"]
|
||||
|
||||
@@ -218,76 +218,6 @@ def extract_v_hack(
|
||||
return v_hack, v_sv, raw_grads, rows
|
||||
|
||||
|
||||
@torch.no_grad()
|
||||
def extract_v_act(
|
||||
model,
|
||||
tokenizer,
|
||||
wrappers: dict,
|
||||
pairs: list,
|
||||
n_heldout: int,
|
||||
device,
|
||||
) -> dict[str, Float[torch.Tensor, "r"]]:
|
||||
"""Activation-space persona direction per module for route2 Arm B.
|
||||
|
||||
Forward-only analogue of extract_v_hack: capture the SVD-coord activation
|
||||
a = Vh @ x at each wrapped Linear, averaged over completion tokens, and return
|
||||
the unit hack-minus-clean direction per module (in R^r, the same space the
|
||||
route2 forward computes cos(a, v_act) in). No backward -> cheaper than v_hack.
|
||||
|
||||
Oriented hack-ward by construction (mean_hack - mean_clean). The k=1 mean-diff;
|
||||
we deliberately do NOT SVD here -- the mask only needs one direction, and the
|
||||
sign/orientation is unambiguous from the paired mean difference.
|
||||
"""
|
||||
train_pairs = pairs[:-n_heldout] if n_heldout > 0 else pairs
|
||||
sums: dict[str, dict] = {
|
||||
name: {"hack": torch.zeros(info["r"]), "clean": torch.zeros(info["r"]),
|
||||
"n_hack": 0, "n_clean": 0}
|
||||
for name, info in wrappers.items()
|
||||
}
|
||||
captured: dict[str, torch.Tensor] = {}
|
||||
handles = []
|
||||
for name, info in wrappers.items():
|
||||
def mk(nm):
|
||||
def pre(_mod, args):
|
||||
captured[nm] = args[0].detach() # [1, L, d_in] input to the Linear
|
||||
return pre
|
||||
handles.append(info["layer"].register_forward_pre_hook(mk(name)))
|
||||
|
||||
try:
|
||||
for pi, pair in enumerate(train_pairs):
|
||||
for label, completion in (("hack", pair.hack), ("clean", pair.clean)):
|
||||
prompt_ids = tokenizer(pair.prompt, return_tensors="pt").input_ids.to(device)
|
||||
full_ids = tokenizer(pair.prompt + completion, return_tensors="pt").input_ids.to(device)
|
||||
n_prompt = prompt_ids.shape[1]
|
||||
captured.clear()
|
||||
model(full_ids)
|
||||
for name, info in wrappers.items():
|
||||
x = captured[name].float() # [1, L, d_in]
|
||||
Vh = info["layer"]._antipasto_Vh.float() # [r, d_in]
|
||||
a = torch.nn.functional.linear(x, Vh) # [1, L, r]
|
||||
a_comp = a[:, n_prompt:, :].mean(dim=(0, 1)).cpu() # [r] mean over completion toks
|
||||
sums[name][label] += a_comp
|
||||
sums[name]["n_" + label] += 1
|
||||
if (pi + 1) % 5 == 0:
|
||||
logger.info(f" v_act pair {pi+1}/{len(train_pairs)}")
|
||||
finally:
|
||||
for h in handles:
|
||||
h.remove()
|
||||
|
||||
v_act: dict[str, torch.Tensor] = {}
|
||||
n_zero = 0
|
||||
for name, s in sums.items():
|
||||
diff = s["hack"] / max(1, s["n_hack"]) - s["clean"] / max(1, s["n_clean"])
|
||||
nrm = diff.norm()
|
||||
if nrm < 1e-12:
|
||||
n_zero += 1
|
||||
v_act[name] = torch.zeros_like(diff)
|
||||
else:
|
||||
v_act[name] = (diff / nrm).contiguous()
|
||||
logger.info(f"v_act: modules={len(v_act)} zero-||diff||={n_zero} (activation mean-diff, unit, hack-ward)")
|
||||
return v_act
|
||||
|
||||
|
||||
def main(cfg: Config) -> int:
|
||||
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
|
||||
dtype = resolve_dtype(cfg.dtype)
|
||||
|
||||
+89
-188
@@ -140,26 +140,15 @@ class Config:
|
||||
# none = vanilla GRPO (project_delta_S_grad runs measure_only; grad untouched)
|
||||
# erase = today's projection: subtract the hack-ward component from delta_S
|
||||
# route = park the hack-ward component in the delta_S_hack quarantine knob
|
||||
# (Gradient Routing, Cloud 2410.04332); ablate it at eval.
|
||||
# route2 = DISTINCT-basis quarantine (A_q,B_q LoRA) + per-sample act-mask
|
||||
# detach-route in the FORWARD (no proj.py grad surgery). Absorption
|
||||
# design (SGTM 2512.05648); see docs/spec/20260531_routing_v2_*.md.
|
||||
# by SUBSPACE PROJECTION (Gradient Routing, Cloud 2410.04332); ablate
|
||||
# it at eval.
|
||||
# route2 = park the hack-ward component in the SAME scale-matched delta_S_hack
|
||||
# quarantine, but selected by a PER-ROLLOUT calibrated-tau cosine gate
|
||||
# (cos(g_b,v_grad) > tau) instead of subspace projection. See
|
||||
# docs/spec/20260601_calibrated_tau_route2grad.md.
|
||||
# Replaces the old `arm` flag (vanilla/projected); `arm` survives as a derived
|
||||
# display name (see property below) so log/run-id formatting is unchanged.
|
||||
intervention: Literal["none", "erase", "route", "route2"] = "erase"
|
||||
# route2-only: quarantine LoRA rank (per module, capped at r). Any sample whose
|
||||
# mask cosine > 0 (points hack-ward) routes into the quarantine; no threshold
|
||||
# knob -- the load-time noise floor already filters.
|
||||
route2_quarantine_rank: int = 16
|
||||
# route2 mask source. "act" (Arm B): forward-time cos(a, v_act), routes via
|
||||
# detach, single pass. "grad" (Arm A): per-rollout cos(g_b, v_grad) from a gate
|
||||
# probe, routes by subtracting flagged rollouts from delta_S.grad post-backward.
|
||||
route2_mask: Literal["act", "grad"] = "act"
|
||||
# route2-only: the quarantine A_q/B_q (33M fresh kaiming params) is ~60x larger
|
||||
# than delta_S (0.5M) and at the shared delta_S lr it diverged -- gn 0.3->7.5 at
|
||||
# step 8, generations -> token salad, lp_t -11 (run 43). Give it its own lower lr.
|
||||
# Scale of main lr; 1.0 = old (diverging) behaviour, 0.1 = the fix.
|
||||
route2_quar_lr_scale: float = 0.1
|
||||
# Scale-dependent knobs — every preset must set these to a real value;
|
||||
# subclasses below override the defaults.
|
||||
model: str = "Qwen/Qwen3-4B"
|
||||
@@ -260,11 +249,9 @@ class Config:
|
||||
@property
|
||||
def arm(self) -> str:
|
||||
"""Display name for run-id / BLUF / logs (results.py + plot_dynamics
|
||||
classify off this). One-to-one with intervention; not a CLI flag.
|
||||
route2 splits by mask source so the 5-arm plot can tell act from grad."""
|
||||
if self.intervention == "route2":
|
||||
return f"routing2_{self.route2_mask}"
|
||||
return {"none": "vanilla", "erase": "projected", "route": "routing"}[self.intervention]
|
||||
classify off this). One-to-one with intervention; not a CLI flag."""
|
||||
return {"none": "vanilla", "erase": "projected",
|
||||
"route": "routing", "route2": "routing2"}[self.intervention]
|
||||
|
||||
|
||||
@dataclass(kw_only=True)
|
||||
@@ -544,10 +531,9 @@ def ref_logprobs_via_zero_delta(
|
||||
|
||||
@contextmanager
|
||||
def ablate_quarantine(wrappers: dict):
|
||||
"""Zero the routing quarantine for the duration -- the eval-time ablation of
|
||||
the routed hack capability. Save -> zero -> (eval) -> restore. The route arm's
|
||||
deployment model IS this ablated state. Covers BOTH quarantines: delta_S_hack
|
||||
(shared-basis route) and B_q (distinct-basis route2; zeroing B_q -> quar=0).
|
||||
"""Zero the routing quarantine (delta_S_hack) for the duration -- the
|
||||
eval-time ablation of the routed hack capability. Save -> zero -> (eval) ->
|
||||
restore. The route/route2 arms' deployment model IS this ablated state.
|
||||
|
||||
TODO(post-deploy-finetune): SGTM's ablate(trainable=True) reinits the forget
|
||||
weights to the retain-dims' std instead of zeroing, so the model stays
|
||||
@@ -555,18 +541,13 @@ def ablate_quarantine(wrappers: dict):
|
||||
we only eval after deploy; add the reinit path if we ever retrain post-ablate.
|
||||
See docs/grad_routing/sgtm_vs_ours.md."""
|
||||
saved = {n: info["delta_S_hack"].data.clone() for n, info in wrappers.items()}
|
||||
saved_Bq = {n: info["B_q"].data.clone() for n, info in wrappers.items() if "B_q" in info}
|
||||
for info in wrappers.values():
|
||||
info["delta_S_hack"].data.zero_()
|
||||
if "B_q" in info:
|
||||
info["B_q"].data.zero_()
|
||||
try:
|
||||
yield
|
||||
finally:
|
||||
for n, info in wrappers.items():
|
||||
info["delta_S_hack"].data.copy_(saved[n])
|
||||
if "B_q" in info:
|
||||
info["B_q"].data.copy_(saved_Bq[n])
|
||||
|
||||
|
||||
@torch.no_grad()
|
||||
@@ -699,26 +680,16 @@ class StepLogger:
|
||||
_Col("cos_post", 6, "cout", ".2f", "hack-ward fraction AFTER projection (want ~0: all removed)"),
|
||||
_Col("fired", 5, "fired", ".2f", "fraction of modules where projection fired"),
|
||||
]
|
||||
# route2 act-mask: no v_hack grad projection, but the forward routes by
|
||||
# cos(activation, v_act)>0. Surface that routing intensity (reuses the row's
|
||||
# cos_pre/fired keys, populated from the stashed act stats in train.py) so the
|
||||
# act path is no longer silent -- watch `fired` for over-routing (>>0.5 means
|
||||
# the sign test fires on generic tokens, starving delta_S onto the quarantine).
|
||||
if arm == "routing2_act":
|
||||
cols += [
|
||||
_Col("cos_pre", 7, "act_cos", "+.2f", "mean cos(activation, v_act): forward routing alignment"),
|
||||
_Col("fired", 6, "act_fire", ".2f", "fraction of token positions routed to quarantine (cos>0)"),
|
||||
]
|
||||
# route2 grad-mask: the routing gate is cos(g_b,v_grad) > tau, where tau is
|
||||
# the per-step EMA midpoint of the hack vs clean cos clouds. Surface tau and
|
||||
# route2: the routing gate is cos(g_b,v_grad) > tau, where tau is the
|
||||
# per-step EMA midpoint of the hack vs clean cos clouds. Surface tau and
|
||||
# the hack-clean gap so we can see the threshold ride the drift and whether
|
||||
# the direction still separates (hkgap>0) -- replaces the silent cos>0 gate.
|
||||
if arm == "routing2_grad":
|
||||
if arm == "routing2":
|
||||
cols += [
|
||||
_Col("tau", 6, "tau", "+.2f", "per-step calibrated route threshold (midpoint of hack vs clean cos clouds)"),
|
||||
_Col("hkgap", 6, "hkgap", "+.2f", "ema_hack_cos - ema_clean_cos; >0 = v_grad still separates hack from clean (else direction dead)"),
|
||||
]
|
||||
if arm in ("routing", "routing2_act", "routing2_grad"):
|
||||
if arm in ("routing", "routing2"):
|
||||
cols += [
|
||||
_Col("q_egy", 6, "qE", ".2f", "grad energy into quarantine ||g_quar||/(||g_keep||+||g_quar||); ~0.5+ rising = learning dumped into the thrown-away knob"),
|
||||
_Col("hack_deploy", 7, "hk_dep", "+.2f", "DEPLOY-eval hack (quarantine deleted = deployed model)"),
|
||||
@@ -785,26 +756,18 @@ def main(cfg: Config) -> int:
|
||||
model.config.use_cache = False
|
||||
|
||||
is_route2 = cfg.intervention == "route2"
|
||||
is_route2_grad = is_route2 and cfg.route2_mask == "grad"
|
||||
is_route2_act = is_route2 and cfg.route2_mask == "act"
|
||||
wrappers = wrap_model_with_antipasto(
|
||||
model, model_name, CACHE_ROOT, device,
|
||||
quarantine_rank=cfg.route2_quarantine_rank if is_route2 else None,
|
||||
route2_mask=cfg.route2_mask,
|
||||
grad_probe=is_route2, # route2 needs the per-rollout delta_S gate probe
|
||||
)
|
||||
# Both knobs are trainable params. delta_S_hack only ever gets a grad under
|
||||
# intervention=route (the routing split in proj.py); under none/erase its
|
||||
# grad stays None so AdamW skips it and it stays exactly 0 (forward adds 0,
|
||||
# so none/erase reproduce the pre-quarantine behaviour bit-for-bit).
|
||||
# Both diagonals are trainable params, same shape r (capacity-balanced).
|
||||
# delta_S_hack only ever gets a grad under route (proj.py subspace split) or
|
||||
# route2 (per-rollout tau routing); under none/erase its grad stays None so
|
||||
# AdamW skips it and it stays exactly 0 (forward adds 0 -> identity).
|
||||
delta_params = [info["delta_S"] for info in wrappers.values()]
|
||||
delta_hack_params = [info["delta_S_hack"] for info in wrappers.values()]
|
||||
# route2: the distinct-basis quarantine LoRA (A_q,B_q), trainable, deleted at
|
||||
# deploy. act-mask routes in the forward (detach); grad-mask routes post-backward.
|
||||
quar_params = ([info["A_q"] for info in wrappers.values()]
|
||||
+ [info["B_q"] for info in wrappers.values()]) if is_route2 else []
|
||||
logger.info(f"trainable delta_S: {sum(p.numel() for p in delta_params):,} "
|
||||
f"(+{sum(p.numel() for p in delta_hack_params):,} delta_S_hack, route-only"
|
||||
f"{f'; +{sum(p.numel() for p in quar_params):,} A_q/B_q route2' if is_route2 else ''})")
|
||||
f"(+{sum(p.numel() for p in delta_hack_params):,} delta_S_hack quarantine)")
|
||||
|
||||
# v_hack: the hack-direction subspace the erase/route arms project against.
|
||||
# VANILLA (intervention=none) is a pure GRPO baseline and ignores v_hack
|
||||
@@ -819,39 +782,29 @@ def main(cfg: Config) -> int:
|
||||
v_hack = None # route2 routes via the mask, not erase/route grad surgery
|
||||
if is_route2:
|
||||
# The persona pairs are the only "detector" (weak, self-supervised). They
|
||||
# produce the mask direction; no oracle, no gt_pass. Same pairs for both
|
||||
# masks so act vs grad differ only in the space the direction lives in.
|
||||
# produce the routing direction; no oracle, no gt_pass.
|
||||
if cfg.vhack_pairs_path is not None:
|
||||
from .pairs_from_pool import load_pairs_json
|
||||
MASK_PAIRS = load_pairs_json(cfg.vhack_pairs_path)
|
||||
logger.info(f"route2 mask pairs: pool-derived ({cfg.vhack_pairs_path}) -> {len(MASK_PAIRS)} pairs")
|
||||
logger.info(f"route2 pairs: pool-derived ({cfg.vhack_pairs_path}) -> {len(MASK_PAIRS)} pairs")
|
||||
else:
|
||||
from .pairs import PAIRS as MASK_PAIRS
|
||||
logger.info(f"route2 mask pairs: hand-crafted PAIRS -> {len(MASK_PAIRS)} pairs")
|
||||
logger.info(f"route2 pairs: hand-crafted PAIRS -> {len(MASK_PAIRS)} pairs")
|
||||
model.eval()
|
||||
if cfg.route2_mask == "act":
|
||||
# Arm B: activation-space mean-diff -> _antipasto_v_act buffer (the
|
||||
# forward cos(a,v_act) reads it). Forward-only, cheap.
|
||||
from .extract_vhack_grad import extract_v_act
|
||||
v_act = extract_v_act(model, tok, wrappers, MASK_PAIRS, n_heldout=2, device=device)
|
||||
for name, info in wrappers.items():
|
||||
info["layer"]._antipasto_v_act.data.copy_(v_act[name].to(device))
|
||||
logger.info(f"route2 act: loaded v_act into {len(v_act)} modules")
|
||||
else:
|
||||
# Arm A: gradient-space mean-diff. extract_v_hack gives per-pair GRPO
|
||||
# gradients on delta_S; v_grad = unit(mean(g_hack - g_clean)) per
|
||||
# module, oriented hack-ward (training reinforces hacks with the
|
||||
# same sign, so cos(g_b, v_grad)>0 flags a reinforced-hack rollout).
|
||||
from .extract_vhack_grad import extract_v_hack
|
||||
_, _, raw_grads, _ = extract_v_hack(
|
||||
model, tok, wrappers, MASK_PAIRS,
|
||||
top_k=1, tau_axis=0.0, n_heldout=2, device=device,
|
||||
)
|
||||
v_grad = {}
|
||||
for name in wrappers:
|
||||
d = (raw_grads[f"hack/{name}"] - raw_grads[f"clean/{name}"]).mean(0)
|
||||
v_grad[name] = (d / d.norm().clamp_min(1e-12)).to(device)
|
||||
logger.info(f"route2 grad: built v_grad (gradient mean-diff) for {len(v_grad)} modules")
|
||||
# gradient-space mean-diff. extract_v_hack gives per-pair GRPO gradients
|
||||
# on delta_S; v_grad = unit(mean(g_hack - g_clean)) per module, oriented
|
||||
# hack-ward (training reinforces hacks with the same sign, so a rollout
|
||||
# with cos(g_b, v_grad) above the calibrated tau is a reinforced hack).
|
||||
from .extract_vhack_grad import extract_v_hack
|
||||
_, _, raw_grads, _ = extract_v_hack(
|
||||
model, tok, wrappers, MASK_PAIRS,
|
||||
top_k=1, tau_axis=0.0, n_heldout=2, device=device,
|
||||
)
|
||||
v_grad = {}
|
||||
for name in wrappers:
|
||||
d = (raw_grads[f"hack/{name}"] - raw_grads[f"clean/{name}"]).mean(0)
|
||||
v_grad[name] = (d / d.norm().clamp_min(1e-12)).to(device)
|
||||
logger.info(f"route2 grad: built v_grad (gradient mean-diff) for {len(v_grad)} modules")
|
||||
model.train()
|
||||
else:
|
||||
# v_hack path resolution, most-specific first. The pairset (personas) is
|
||||
@@ -957,18 +910,13 @@ def main(cfg: Config) -> int:
|
||||
f"G_s={G_s} student + G_t={G_t} teacher per prompt (mix_ratio={cfg.mix_ratio})."
|
||||
)
|
||||
|
||||
# Quarantine (A_q/B_q) gets its own lower lr: it is ~60x bigger than delta_S and
|
||||
# freshly kaiming-init, so the shared lr diverged it (run 43). Separate param group
|
||||
# so the scheduler scales both proportionally (the group's lr rides on `lr` via the
|
||||
# ratio captured here -- LinearLR/CosineAnnealingLR multiply each group's base lr).
|
||||
quar_lr = lr * cfg.route2_quar_lr_scale
|
||||
# One group: delta_S (kept) + delta_S_hack (quarantine) share the lr -- same
|
||||
# shape, same basis, so no per-group lr juggling (the old A_q/B_q LoRA needed
|
||||
# its own lower lr because it was ~60x bigger; gone now).
|
||||
opt = torch.optim.AdamW(
|
||||
[{"params": delta_params + delta_hack_params, "lr": lr},
|
||||
{"params": quar_params, "lr": quar_lr}],
|
||||
delta_params + delta_hack_params,
|
||||
lr=lr, weight_decay=cfg.weight_decay, betas=(adam_beta1, adam_beta2),
|
||||
)
|
||||
if quar_params:
|
||||
logger.info(f"route2 quarantine lr = {quar_lr:.1e} ({cfg.route2_quar_lr_scale}x main lr {lr:.1e})")
|
||||
# Linear warmup over `warmup_frac * steps`, then cosine decay to 0 over the rest.
|
||||
# Fraction-based so short presets (fast: 20 steps) don't spend half the run
|
||||
# under warmup. Canonical full-preset: 0.1 * 100 = 10 (matches ariahw config.py:141).
|
||||
@@ -1183,27 +1131,15 @@ def main(cfg: Config) -> int:
|
||||
# what the projection + optimizer step ultimately sees.
|
||||
step_grad_s: dict[str, torch.Tensor] = {}
|
||||
step_grad_t: dict[str, torch.Tensor] = {}
|
||||
# route2 quarantine grads must survive the per-pass model.zero_grad (which
|
||||
# exists to isolate delta_S's per-source grad). A_q/B_q need neither source
|
||||
# split nor projection, just plain accumulation, so we stash and re-inject
|
||||
# them exactly as delta_S is. Keyed "<module>.<A_q|B_q>".
|
||||
step_grad_quar: dict[str, torch.Tensor] = {}
|
||||
# route2: the flagged rollouts' delta_S-grad contribution, accumulated per
|
||||
# module across prompts, parked into delta_S_hack.grad at injection (the
|
||||
# quarantine, deleted at deploy). Keyed by module name. Mirrors how proj.py
|
||||
# parks route's removed component into delta_S_hack.
|
||||
step_grad_hack: dict[str, torch.Tensor] = {}
|
||||
|
||||
def _stash_quar_grads():
|
||||
if not is_route2:
|
||||
return
|
||||
for name, info in wrappers.items():
|
||||
for sub in ("A_q", "B_q"):
|
||||
p = info[sub]
|
||||
if p.grad is None:
|
||||
continue
|
||||
key = f"{name}.{sub}"
|
||||
step_grad_quar[key] = (step_grad_quar[key] + p.grad.detach().clone()
|
||||
if key in step_grad_quar else p.grad.detach().clone())
|
||||
|
||||
# route2 grad-mask: recover the per-rollout delta_S grad from the gate
|
||||
# route2: recover the per-rollout delta_S grad from the gate
|
||||
# (c.grad = delta_S * g_b), flag rollouts whose grad points hack-ward
|
||||
# (cos(g_b, v_grad) > 0), and subtract their contribution from delta_S.grad.
|
||||
# (cos(g_b, v_grad) > tau), and route their contribution into delta_S_hack.
|
||||
# Only axes where delta_S has moved (|delta_S| > GATE_EPS) carry a reliable
|
||||
# per-rollout split; near-zero axes keep the full grad, so routing on a fresh
|
||||
# axis lags ~1 step until delta_S grows there (the A1 stale-mask trade-off).
|
||||
@@ -1254,6 +1190,10 @@ def main(cfg: Config) -> int:
|
||||
step_flagged.append(flagged.mean().item())
|
||||
sub = torch.where(reliable, (cg * flagged.unsqueeze(1)).sum(0) / dS_safe,
|
||||
torch.zeros_like(g)) # flagged rollouts' contribution
|
||||
# Park the flagged contribution in delta_S_hack (deleted at deploy);
|
||||
# delta_S keeps only the unflagged. Capacity-balanced: both shape [r].
|
||||
step_grad_hack[name] = (step_grad_hack[name] + sub.detach().clone()
|
||||
if name in step_grad_hack else sub.detach().clone())
|
||||
return g - sub
|
||||
|
||||
# Split backward into student/teacher only every cos_pre_split_every steps.
|
||||
@@ -1525,7 +1465,6 @@ def main(cfg: Config) -> int:
|
||||
step_grad_s[name] = (step_grad_s[name] + gs.detach().clone()
|
||||
if name in step_grad_s
|
||||
else gs.detach().clone())
|
||||
_stash_quar_grads()
|
||||
model.zero_grad(set_to_none=True)
|
||||
# Pass 2: teacher.
|
||||
loss_t.backward()
|
||||
@@ -1536,7 +1475,6 @@ def main(cfg: Config) -> int:
|
||||
step_grad_t[name] = (step_grad_t[name] + gt.detach().clone()
|
||||
if name in step_grad_t
|
||||
else gt.detach().clone())
|
||||
_stash_quar_grads()
|
||||
model.zero_grad(set_to_none=True)
|
||||
agg_loss += (loss_s + loss_t).item()
|
||||
else:
|
||||
@@ -1550,13 +1488,13 @@ def main(cfg: Config) -> int:
|
||||
ptl_norm = (per_tok_loss * mask).sum(1) / mask.sum(1).clamp_min(1)
|
||||
loss = ptl_norm.sum() / (group * prompts_per_step)
|
||||
loss.backward()
|
||||
# route2-grad: per-prompt anchor masks for the tau calibration.
|
||||
# route2: per-prompt anchor masks for the tau calibration.
|
||||
# Hack cloud = teacher rows (known-A hacks) + detector-flagged
|
||||
# (hack_E) student rows. Clean cloud = not-flagged student rows
|
||||
# (contaminated with unknown B by design -> conservative tau; B
|
||||
# still routes via cos>tau). is_student = [True]*G_s + [False]*G_t,
|
||||
# so hack_E_flags (len G_s) aligns with the leading student rows.
|
||||
if is_route2_grad:
|
||||
if is_route2:
|
||||
_n_merged = merged.shape[0]
|
||||
_ha = torch.zeros(_n_merged, dtype=torch.bool, device=per_tok_loss.device)
|
||||
_ca = torch.zeros(_n_merged, dtype=torch.bool, device=per_tok_loss.device)
|
||||
@@ -1569,14 +1507,13 @@ def main(cfg: Config) -> int:
|
||||
g = info["delta_S"].grad
|
||||
if g is None:
|
||||
continue
|
||||
# grad-mask routes here: strip flagged rollouts from delta_S.grad
|
||||
# (quarantine still learns them via its always-on forward path).
|
||||
if is_route2_grad:
|
||||
# route2 routes here: strip flagged rollouts from delta_S.grad
|
||||
# and park them in delta_S_hack (via step_grad_hack in the filter).
|
||||
if is_route2:
|
||||
g = _route2_grad_filter(info, merged.shape[0], _ha, _ca)
|
||||
step_grad_s[name] = (step_grad_s[name] + g.detach().clone()
|
||||
if name in step_grad_s
|
||||
else g.detach().clone())
|
||||
_stash_quar_grads()
|
||||
model.zero_grad(set_to_none=True)
|
||||
agg_loss += loss.item()
|
||||
t_fb += time.perf_counter() - _tfb
|
||||
@@ -1595,11 +1532,11 @@ def main(cfg: Config) -> int:
|
||||
info["delta_S"].grad = gs
|
||||
else:
|
||||
info["delta_S"].grad = gs + gt
|
||||
# route2: re-inject the stashed quarantine grads (student+teacher summed
|
||||
# across passes) so clip + opt.step move A_q/B_q.
|
||||
for key, g in step_grad_quar.items():
|
||||
name, sub = key.rsplit(".", 1)
|
||||
wrappers[name][sub].grad = g
|
||||
# route2: park the flagged rollouts' contribution into delta_S_hack.grad
|
||||
# (the autograd grad from delta_S_hack's own forward path was wiped by the
|
||||
# per-prompt zero_grad; we impose the routed grad here, like proj.py's route).
|
||||
for name, g in step_grad_hack.items():
|
||||
wrappers[name]["delta_S_hack"].grad = g
|
||||
|
||||
# Per-source cin: project student-only and teacher-only grads into v_hack
|
||||
# subspace. Discriminator: cos_pre_t > cos_pre_s on a clean base means v_hack
|
||||
@@ -1612,22 +1549,10 @@ def main(cfg: Config) -> int:
|
||||
diag = {"mean_cos_pre": float("nan"), "mean_cos_post": float("nan"),
|
||||
"frac_fired": float("nan"), "mean_cos_pre_s": float("nan"),
|
||||
"mean_cos_pre_t": float("nan")}
|
||||
# route2 act-mask: the forward stashed per-layer fired-fraction + mean cos
|
||||
# (cos(a,v_act)). Surface them in cin (mean cos) and fired (routed fraction)
|
||||
# so over-routing is visible -- a frozen sign-test direction fires on ~half
|
||||
# of all tokens, starving delta_S and dumping learning onto the quarantine.
|
||||
if is_route2_act:
|
||||
fired = [info["layer"]._antipasto_act_fired for info in wrappers.values()
|
||||
if hasattr(info["layer"], "_antipasto_act_fired")]
|
||||
coss = [info["layer"]._antipasto_act_cos for info in wrappers.values()
|
||||
if hasattr(info["layer"], "_antipasto_act_cos")]
|
||||
if fired:
|
||||
diag["frac_fired"] = float(torch.stack(fired).mean())
|
||||
diag["mean_cos_pre"] = float(torch.stack(coss).mean())
|
||||
# route2 grad-mask: report the mean per-module per-rollout flag rate so
|
||||
# we can watch the mask actually fire (and rise as hacks emerge).
|
||||
if is_route2_grad and step_flagged:
|
||||
logger.debug(f"route2-grad flagged frac (mean over modules*prompts): "
|
||||
# route2: report the mean per-module per-rollout flag rate so we can
|
||||
# watch the mask actually fire (and rise as hacks emerge).
|
||||
if is_route2 and step_flagged:
|
||||
logger.debug(f"route2 flagged frac (mean over modules*prompts): "
|
||||
f"{sum(step_flagged)/len(step_flagged):+.3f}")
|
||||
else:
|
||||
if split_this_step:
|
||||
@@ -1671,7 +1596,7 @@ def main(cfg: Config) -> int:
|
||||
# ignored -> identical norm to before (R4). For route it bounds the
|
||||
# combined update (main + quarantine).
|
||||
# Split the grad energy: how much is going to delta_S (the KEPT/deployed
|
||||
# knob) vs the quarantine (delta_S_hack for route, A_q/B_q for route2 --
|
||||
# knob) vs the quarantine (delta_S_hack, deleted at deploy --
|
||||
# the THROWN-AWAY knob). qE = quar / (keep + quar) in [0,1]. Rising qE
|
||||
# means routing is dumping the learning into the quarantine, so the
|
||||
# deployed model learns nothing -- the invisible failure in job 46
|
||||
@@ -1681,9 +1606,9 @@ def main(cfg: Config) -> int:
|
||||
gs = [p.grad for p in params if p.grad is not None]
|
||||
return float(torch.norm(torch.stack([g.norm() for g in gs]))) if gs else 0.0
|
||||
gn_keep = _grad_l2(delta_params)
|
||||
gn_quar = _grad_l2(delta_hack_params + quar_params)
|
||||
gn_quar = _grad_l2(delta_hack_params)
|
||||
q_egy = gn_quar / (gn_keep + gn_quar) if (gn_keep + gn_quar) > 0 else 0.0
|
||||
gn = float(torch.nn.utils.clip_grad_norm_(delta_params + delta_hack_params + quar_params, cfg.grad_clip))
|
||||
gn = float(torch.nn.utils.clip_grad_norm_(delta_params + delta_hack_params, cfg.grad_clip))
|
||||
opt.step()
|
||||
sched.step()
|
||||
|
||||
@@ -1694,13 +1619,12 @@ def main(cfg: Config) -> int:
|
||||
refr = "-" # set to "mod/axes" below if a refresh fires; rendered in the per-step row
|
||||
do_refresh = cfg.vhack_refresh_every > 0 and (step + 1) % cfg.vhack_refresh_every == 0
|
||||
if do_refresh and is_route2:
|
||||
# route2 mask refresh: re-extract v_act / v_grad against the CURRENT
|
||||
# model so the mask tracks where hacks separate now, not at step 0.
|
||||
# Without this the frozen mask goes stale -- cin_t decays to cin_s
|
||||
# within ~6 steps (2026-05-31 journal, frozen-real-V route). Same
|
||||
# MASK_PAIRS (the weak detector, no oracle); quarantine ablated so the
|
||||
# hack signal flows back through the observable path, matching the
|
||||
# B_q=0 state the build-time extraction saw.
|
||||
# route2 v_grad refresh: re-extract against the CURRENT model so the
|
||||
# routing direction tracks where hacks separate now, not at step 0.
|
||||
# Without this the frozen direction goes stale -- cin_t decays to cin_s
|
||||
# within ~6 steps (2026-05-31 journal). Same MASK_PAIRS (the weak
|
||||
# detector, no oracle); quarantine ablated so the hack signal flows back
|
||||
# through the observable path, matching the state the build-time extract saw.
|
||||
_was_training = model.training
|
||||
model.eval()
|
||||
opt.zero_grad(set_to_none=True)
|
||||
@@ -1708,39 +1632,21 @@ def main(cfg: Config) -> int:
|
||||
logger.disable("__main__")
|
||||
try:
|
||||
with ablate_quarantine(wrappers):
|
||||
if cfg.route2_mask == "act":
|
||||
from .extract_vhack_grad import extract_v_act
|
||||
_v = extract_v_act(model, tok, wrappers, MASK_PAIRS, n_heldout=2, device=device)
|
||||
# Mean |cos(old, new)| over modules = how much the mask direction
|
||||
# moved. Near 1.0 => stable hack subspace; low => v_act is chasing
|
||||
# a drifting target (the staleness this refresh is meant to fix).
|
||||
_ov = []
|
||||
for name, info in wrappers.items():
|
||||
old = info["layer"]._antipasto_v_act
|
||||
new = _v[name].to(device, dtype=old.dtype)
|
||||
_ov.append((old @ new).abs() / (old.norm().clamp_min(1e-9) * new.norm().clamp_min(1e-9)))
|
||||
old.data.copy_(new)
|
||||
_act_overlap = float(torch.stack(_ov).mean())
|
||||
else:
|
||||
from .extract_vhack_grad import extract_v_hack
|
||||
_, _, raw_grads, _ = extract_v_hack(
|
||||
model, tok, wrappers, MASK_PAIRS,
|
||||
top_k=1, tau_axis=0.0, n_heldout=2, device=device,
|
||||
)
|
||||
for name in wrappers: # update in place so _route2_grad_filter's closure sees it
|
||||
d = (raw_grads[f"hack/{name}"] - raw_grads[f"clean/{name}"]).mean(0)
|
||||
v_grad[name] = (d / d.norm().clamp_min(1e-12)).to(device)
|
||||
from .extract_vhack_grad import extract_v_hack
|
||||
_, _, raw_grads, _ = extract_v_hack(
|
||||
model, tok, wrappers, MASK_PAIRS,
|
||||
top_k=1, tau_axis=0.0, n_heldout=2, device=device,
|
||||
)
|
||||
for name in wrappers: # update in place so _route2_grad_filter's closure sees it
|
||||
d = (raw_grads[f"hack/{name}"] - raw_grads[f"clean/{name}"]).mean(0)
|
||||
v_grad[name] = (d / d.norm().clamp_min(1e-12)).to(device)
|
||||
finally:
|
||||
logger.enable("projected_grpo.extract_vhack_grad")
|
||||
logger.enable("__main__")
|
||||
opt.zero_grad(set_to_none=True) # extract leaves .grad populated
|
||||
if _was_training:
|
||||
model.train()
|
||||
# No verbose per-refresh log line (it was noise). The compact `refr`
|
||||
# column marks the refresh and, for act, carries the basis-overlap
|
||||
# (mean |cos| old-vs-new v_act): ~1 = stable subspace, low = chasing a
|
||||
# drifting target. grad-mask has no cheap overlap -> bare marker.
|
||||
refr = f"{_act_overlap:.2f}" if cfg.route2_mask == "act" else "rfr"
|
||||
refr = "rfr" # compact marker; v_grad refresh has no cheap overlap gauge
|
||||
if v_hack is not None and do_refresh:
|
||||
from .extract_vhack_grad import extract_v_hack
|
||||
if cfg.vhack_pairs_path is not None:
|
||||
@@ -2051,13 +1957,9 @@ def main(cfg: Config) -> int:
|
||||
dsh_norm = float(sum(info["delta_S_hack"].data.float().pow(2).sum().item()
|
||||
for info in wrappers.values()) ** 0.5)
|
||||
logger.info(f"||delta_S_hack|| = {dsh_norm:.4f} "
|
||||
f"(SHOULD: >0 for route, ==0 for none/erase; ELSE routing broke)")
|
||||
if cfg.intervention == "route":
|
||||
assert dsh_norm > 0.0, "route: delta_S_hack never moved -> degenerated to erasure"
|
||||
if is_route2:
|
||||
bq_norm = sum(info["B_q"].data.norm().item() for info in wrappers.values())
|
||||
logger.info(f"||B_q|| sum = {bq_norm:.4f} (SHOULD: >0; ELSE quarantine never seeded)")
|
||||
assert bq_norm > 0.0, "route2: B_q never moved -> quarantine never seeded (mask never fired?)"
|
||||
f"(SHOULD: >0 for route/route2, ==0 for none/erase; ELSE routing broke)")
|
||||
if cfg.intervention in ("route", "route2"):
|
||||
assert dsh_norm > 0.0, f"{cfg.intervention}: delta_S_hack never moved -> nothing routed into quarantine"
|
||||
|
||||
# Last training generation -- a fast eyeball for coherence before the eval
|
||||
# numbers. SHOULD: real code/prose for the problem. If it is token salad the
|
||||
@@ -2112,7 +2014,6 @@ def main(cfg: Config) -> int:
|
||||
# deploy==train. This is the canonical source for the all-arms per-mode plot.
|
||||
deploy_record = {
|
||||
"arm": cfg.arm, "intervention": cfg.intervention,
|
||||
"route2_mask": cfg.route2_mask if is_route2 else None,
|
||||
"refresh_every": cfg.vhack_refresh_every, "seed": cfg.seed,
|
||||
"steps": n_steps, "model": model_name, "out_tag": cfg.out_tag,
|
||||
"log": str(verbose_log), "eval_n": ev_deploy["n"],
|
||||
|
||||
Reference in New Issue
Block a user