mirror of
https://github.com/wassname/lora-lite.git
synced 2026-06-27 15:15:55 +08:00
antipasto: replace EVA-style group_init with Wanda-style dimension selection
Score each singular dimension by S[i] * mean|X @ Vh[i]| (weight magnitude times activation magnitude), then pick top-r by joint score instead of top-r by S alone. Keeps the weight-SVD basis; only reorders which r dimensions are retained based on real input activations. Co-Authored-By: Claude Sonnet 4.6 <noreply@anthropic.com>
This commit is contained in:
@@ -117,13 +117,13 @@ class AntiPaSTO:
|
||||
|
||||
@staticmethod
|
||||
def group_init(model: nn.Module, targets, cfg, calibration_data: CalibrationData | None) -> None:
|
||||
"""EVA-style data-driven refinement: replace weight-SVD basis with input-PCA basis.
|
||||
"""Wanda-style data-driven dimension selection within the weight SVD.
|
||||
|
||||
Collects pre-hook activations, runs SVD on the pooled inputs per layer, then
|
||||
re-decomposes W_orig through those input-aligned directions so the low-rank
|
||||
subspace captures the actual input distribution rather than W's spectral structure.
|
||||
init() picks the top-r singular dimensions by S alone (PiSSA-style).
|
||||
group_init() re-selects based on S[i] * mean|X @ Vh[i]|: dimensions
|
||||
that are both large in W AND active given real inputs.
|
||||
|
||||
If calibration_data is None the weight-SVD init from init() is kept unchanged.
|
||||
If calibration_data is None the weight-SVD init from init() is kept.
|
||||
"""
|
||||
if calibration_data is None:
|
||||
return
|
||||
@@ -166,29 +166,28 @@ class AntiPaSTO:
|
||||
f"AntiPaSTO at {name}: only {X.shape[0]} calibration tokens, need >= r={r}"
|
||||
)
|
||||
|
||||
# Top-r right singular vectors of input distribution (same as EVA lora_A init)
|
||||
_, _, Vh_data = torch.linalg.svd(X, full_matrices=False)
|
||||
Vhr_new = Vh_data[:r] # (r, d_in)
|
||||
|
||||
# Recover W_orig: init() already wrote W_res into layer.weight
|
||||
# Recover W_orig: init() wrote W_res into layer.weight and stored top-r components
|
||||
W_res = layer.weight.data.float()
|
||||
U_old = layer.lora_U.float() # (d_out, r)
|
||||
S_old = layer.lora_S.float() # (r,)
|
||||
Vh_old = layer.lora_Vh.float() # (r, d_in)
|
||||
W_orig = W_res + (U_old * S_old.unsqueeze(0)) @ Vh_old
|
||||
|
||||
# Project W_orig onto the input subspace, then SVD for proper U/S
|
||||
A = W_orig @ Vhr_new.T # (d_out, r)
|
||||
U_A, S_A, Vh_A = torch.linalg.svd(A, full_matrices=False)
|
||||
# Full SVD to score all dimensions
|
||||
U_full, S_full, Vh_full = torch.linalg.svd(W_orig, full_matrices=False)
|
||||
# score[i] = S[i] * mean|X @ Vh[i]| (Wanda: weight magnitude × activation magnitude)
|
||||
act_mag = (X @ Vh_full.T).abs().mean(dim=0) # (k,)
|
||||
scores = S_full * act_mag
|
||||
idx = scores.argsort(descending=True)[:r] # top-r by joint importance
|
||||
idx = idx.sort().values # stable ordering
|
||||
|
||||
# Rotate Vhr_new by Vh_A so rows remain orthonormal and span is preserved
|
||||
Vhr_final = Vh_A @ Vhr_new # (r, d_in)
|
||||
W_res_new = (W_orig - (U_A * S_A.unsqueeze(0)) @ Vhr_final).to(layer.weight.dtype)
|
||||
Ur, Sr, Vhr = U_full[:, idx], S_full[idx], Vh_full[idx]
|
||||
W_res_new = (W_orig - (Ur * Sr.unsqueeze(0)) @ Vhr).to(layer.weight.dtype)
|
||||
|
||||
with torch.no_grad():
|
||||
layer.lora_U.copy_(U_A.to(layer.lora_U))
|
||||
layer.lora_S.copy_(S_A.to(layer.lora_S))
|
||||
layer.lora_Vh.copy_(Vhr_final.to(layer.lora_Vh))
|
||||
layer.lora_U.copy_(Ur.to(layer.lora_U))
|
||||
layer.lora_S.copy_(Sr.to(layer.lora_S))
|
||||
layer.lora_Vh.copy_(Vhr.to(layer.lora_Vh))
|
||||
layer.weight.data.copy_(W_res_new)
|
||||
|
||||
@staticmethod
|
||||
|
||||
Reference in New Issue
Block a user