antipasto: add EVA-style data-driven group_init

Weight-SVD init (PiSSA-style) kept as fallback; when calibration_data is
provided, group_init() collects pre-hook activations, SVDs the pooled inputs
per layer, and re-decomposes W_orig through the top-r input-PCA directions.
Vhr_final = Vh_A @ Vhr_new keeps rows orthonormal while preserving the
input-aligned span.

Co-Authored-By: Claude Sonnet 4.6 <noreply@anthropic.com>
This commit is contained in:
wassname
2026-05-01 20:55:56 +08:00
parent b698331cfa
commit f91c7b23f2
2 changed files with 81 additions and 3 deletions
+80 -2
View File
@@ -22,7 +22,7 @@ Refs:
"""
import math
from dataclasses import dataclass
from typing import Literal
from typing import Iterable, Literal
import torch
from einops import einsum, rearrange
@@ -32,6 +32,9 @@ from torch import nn, Tensor as T
from ..variant import register, ParamSpec
from ..config import AdapterConfig, register_config
CalibrationBatch = dict | tuple | list | T
CalibrationData = Iterable[CalibrationBatch]
@register_config
@dataclass
@@ -110,8 +113,83 @@ class AntiPaSTO:
layer.lora_Vh.copy_(Vhr.to(layer.lora_Vh.dtype))
W_res = (W - (Ur * Sr) @ Vhr).to(layer.weight.dtype)
layer.weight.data.copy_(W_res)
# group_init() refines this to input-aligned directions if calibration_data is given.
# FIXME antipasto needs an init from data too
@staticmethod
def group_init(model: nn.Module, targets, cfg, calibration_data: CalibrationData | None) -> None:
"""EVA-style data-driven refinement: replace weight-SVD basis with input-PCA basis.
Collects pre-hook activations, runs SVD on the pooled inputs per layer, then
re-decomposes W_orig through those input-aligned directions so the low-rank
subspace captures the actual input distribution rather than W's spectral structure.
If calibration_data is None the weight-SVD init from init() is kept unchanged.
"""
if calibration_data is None:
return
layers = {name: layer for name, layer, _ in targets}
captured: dict[str, list[T]] = {n: [] for n in layers}
def make_hook(name):
def _h(module, args, kwargs):
x = args[0].detach()
captured[name].append(rearrange(x, "... d -> (...) d").to(torch.float32).cpu())
return _h
handles = [
layers[n].register_forward_pre_hook(make_hook(n), with_kwargs=True)
for n in layers
]
try:
was_training = model.training
model.eval()
with torch.no_grad():
for batch in calibration_data:
if isinstance(batch, dict):
model(**batch)
elif isinstance(batch, (list, tuple)):
model(*batch)
else:
model(batch)
if was_training:
model.train()
finally:
for h in handles:
h.remove()
r = cfg.r
for name, layer in layers.items():
X = torch.cat(captured[name], dim=0) # (N, d_in)
if X.shape[0] < r:
raise RuntimeError(
f"AntiPaSTO at {name}: only {X.shape[0]} calibration tokens, need >= r={r}"
)
# Top-r right singular vectors of input distribution (same as EVA lora_A init)
_, _, Vh_data = torch.linalg.svd(X, full_matrices=False)
Vhr_new = Vh_data[:r] # (r, d_in)
# Recover W_orig: init() already wrote W_res into layer.weight
W_res = layer.weight.data.float()
U_old = layer.lora_U.float() # (d_out, r)
S_old = layer.lora_S.float() # (r,)
Vh_old = layer.lora_Vh.float() # (r, d_in)
W_orig = W_res + (U_old * S_old.unsqueeze(0)) @ Vh_old
# Project W_orig onto the input subspace, then SVD for proper U/S
A = W_orig @ Vhr_new.T # (d_out, r)
U_A, S_A, Vh_A = torch.linalg.svd(A, full_matrices=False)
# Rotate Vhr_new by Vh_A so rows remain orthonormal and span is preserved
Vhr_final = Vh_A @ Vhr_new # (r, d_in)
W_res_new = (W_orig - (U_A * S_A.unsqueeze(0)) @ Vhr_final).to(layer.weight.dtype)
with torch.no_grad():
layer.lora_U.copy_(U_A.to(layer.lora_U))
layer.lora_S.copy_(S_A.to(layer.lora_S))
layer.lora_Vh.copy_(Vhr_final.to(layer.lora_Vh))
layer.weight.data.copy_(W_res_new)
@staticmethod
def forward(
+1 -1
View File
@@ -101,7 +101,7 @@ class EVA:
)
# full_matrices=False -> Vh shape (min(N,d_in), d_in); take top-r rows
_, _, Vh = torch.linalg.svd(X, full_matrices=False)
A = Vh[: cfg.r, :].to(layer.lora_A.dtype).to(layer.lora_A.device)
A = Vh[: cfg.r, :].to(layer.lora_A)
with torch.no_grad():
layer.lora_A.copy_(A)