antipasto: replace EVA-style group_init with Wanda-style dimension selection

Score each singular dimension by S[i] * mean|X @ Vh[i]| (weight magnitude
times activation magnitude), then pick top-r by joint score instead of top-r
by S alone. Keeps the weight-SVD basis; only reorders which r dimensions are
retained based on real input activations.

Co-Authored-By: Claude Sonnet 4.6 <noreply@anthropic.com>
This commit is contained in:
wassname
2026-05-01 21:24:52 +08:00
parent f91c7b23f2
commit 19888fbb82
+18 -19
View File
@@ -117,13 +117,13 @@ class AntiPaSTO:
@staticmethod
def group_init(model: nn.Module, targets, cfg, calibration_data: CalibrationData | None) -> None:
"""EVA-style data-driven refinement: replace weight-SVD basis with input-PCA basis.
"""Wanda-style data-driven dimension selection within the weight SVD.
Collects pre-hook activations, runs SVD on the pooled inputs per layer, then
re-decomposes W_orig through those input-aligned directions so the low-rank
subspace captures the actual input distribution rather than W's spectral structure.
init() picks the top-r singular dimensions by S alone (PiSSA-style).
group_init() re-selects based on S[i] * mean|X @ Vh[i]|: dimensions
that are both large in W AND active given real inputs.
If calibration_data is None the weight-SVD init from init() is kept unchanged.
If calibration_data is None the weight-SVD init from init() is kept.
"""
if calibration_data is None:
return
@@ -166,29 +166,28 @@ class AntiPaSTO:
f"AntiPaSTO at {name}: only {X.shape[0]} calibration tokens, need >= r={r}"
)
# Top-r right singular vectors of input distribution (same as EVA lora_A init)
_, _, Vh_data = torch.linalg.svd(X, full_matrices=False)
Vhr_new = Vh_data[:r] # (r, d_in)
# Recover W_orig: init() already wrote W_res into layer.weight
# Recover W_orig: init() wrote W_res into layer.weight and stored top-r components
W_res = layer.weight.data.float()
U_old = layer.lora_U.float() # (d_out, r)
S_old = layer.lora_S.float() # (r,)
Vh_old = layer.lora_Vh.float() # (r, d_in)
W_orig = W_res + (U_old * S_old.unsqueeze(0)) @ Vh_old
# Project W_orig onto the input subspace, then SVD for proper U/S
A = W_orig @ Vhr_new.T # (d_out, r)
U_A, S_A, Vh_A = torch.linalg.svd(A, full_matrices=False)
# Full SVD to score all dimensions
U_full, S_full, Vh_full = torch.linalg.svd(W_orig, full_matrices=False)
# score[i] = S[i] * mean|X @ Vh[i]| (Wanda: weight magnitude × activation magnitude)
act_mag = (X @ Vh_full.T).abs().mean(dim=0) # (k,)
scores = S_full * act_mag
idx = scores.argsort(descending=True)[:r] # top-r by joint importance
idx = idx.sort().values # stable ordering
# Rotate Vhr_new by Vh_A so rows remain orthonormal and span is preserved
Vhr_final = Vh_A @ Vhr_new # (r, d_in)
W_res_new = (W_orig - (U_A * S_A.unsqueeze(0)) @ Vhr_final).to(layer.weight.dtype)
Ur, Sr, Vhr = U_full[:, idx], S_full[idx], Vh_full[idx]
W_res_new = (W_orig - (Ur * Sr.unsqueeze(0)) @ Vhr).to(layer.weight.dtype)
with torch.no_grad():
layer.lora_U.copy_(U_A.to(layer.lora_U))
layer.lora_S.copy_(S_A.to(layer.lora_S))
layer.lora_Vh.copy_(Vhr_final.to(layer.lora_Vh))
layer.lora_U.copy_(Ur.to(layer.lora_U))
layer.lora_S.copy_(Sr.to(layer.lora_S))
layer.lora_Vh.copy_(Vhr.to(layer.lora_Vh))
layer.weight.data.copy_(W_res_new)
@staticmethod