diff --git a/src/lora_lite/variants/antipasto.py b/src/lora_lite/variants/antipasto.py index 41814c5..654a23d 100644 --- a/src/lora_lite/variants/antipasto.py +++ b/src/lora_lite/variants/antipasto.py @@ -117,13 +117,13 @@ class AntiPaSTO: @staticmethod def group_init(model: nn.Module, targets, cfg, calibration_data: CalibrationData | None) -> None: - """EVA-style data-driven refinement: replace weight-SVD basis with input-PCA basis. + """Wanda-style data-driven dimension selection within the weight SVD. - Collects pre-hook activations, runs SVD on the pooled inputs per layer, then - re-decomposes W_orig through those input-aligned directions so the low-rank - subspace captures the actual input distribution rather than W's spectral structure. + init() picks the top-r singular dimensions by S alone (PiSSA-style). + group_init() re-selects based on S[i] * mean|X @ Vh[i]|: dimensions + that are both large in W AND active given real inputs. - If calibration_data is None the weight-SVD init from init() is kept unchanged. + If calibration_data is None the weight-SVD init from init() is kept. """ if calibration_data is None: return @@ -166,29 +166,28 @@ class AntiPaSTO: f"AntiPaSTO at {name}: only {X.shape[0]} calibration tokens, need >= r={r}" ) - # Top-r right singular vectors of input distribution (same as EVA lora_A init) - _, _, Vh_data = torch.linalg.svd(X, full_matrices=False) - Vhr_new = Vh_data[:r] # (r, d_in) - - # Recover W_orig: init() already wrote W_res into layer.weight + # Recover W_orig: init() wrote W_res into layer.weight and stored top-r components W_res = layer.weight.data.float() U_old = layer.lora_U.float() # (d_out, r) S_old = layer.lora_S.float() # (r,) Vh_old = layer.lora_Vh.float() # (r, d_in) W_orig = W_res + (U_old * S_old.unsqueeze(0)) @ Vh_old - # Project W_orig onto the input subspace, then SVD for proper U/S - A = W_orig @ Vhr_new.T # (d_out, r) - U_A, S_A, Vh_A = torch.linalg.svd(A, full_matrices=False) + # Full SVD to score all dimensions + U_full, S_full, Vh_full = torch.linalg.svd(W_orig, full_matrices=False) + # score[i] = S[i] * mean|X @ Vh[i]| (Wanda: weight magnitude × activation magnitude) + act_mag = (X @ Vh_full.T).abs().mean(dim=0) # (k,) + scores = S_full * act_mag + idx = scores.argsort(descending=True)[:r] # top-r by joint importance + idx = idx.sort().values # stable ordering - # Rotate Vhr_new by Vh_A so rows remain orthonormal and span is preserved - Vhr_final = Vh_A @ Vhr_new # (r, d_in) - W_res_new = (W_orig - (U_A * S_A.unsqueeze(0)) @ Vhr_final).to(layer.weight.dtype) + Ur, Sr, Vhr = U_full[:, idx], S_full[idx], Vh_full[idx] + W_res_new = (W_orig - (Ur * Sr.unsqueeze(0)) @ Vhr).to(layer.weight.dtype) with torch.no_grad(): - layer.lora_U.copy_(U_A.to(layer.lora_U)) - layer.lora_S.copy_(S_A.to(layer.lora_S)) - layer.lora_Vh.copy_(Vhr_final.to(layer.lora_Vh)) + layer.lora_U.copy_(Ur.to(layer.lora_U)) + layer.lora_S.copy_(Sr.to(layer.lora_S)) + layer.lora_Vh.copy_(Vhr.to(layer.lora_Vh)) layer.weight.data.copy_(W_res_new) @staticmethod