diff --git a/src/lora_lite/variants/antipasto.py b/src/lora_lite/variants/antipasto.py index 3425cbf..41814c5 100644 --- a/src/lora_lite/variants/antipasto.py +++ b/src/lora_lite/variants/antipasto.py @@ -22,7 +22,7 @@ Refs: """ import math from dataclasses import dataclass -from typing import Literal +from typing import Iterable, Literal import torch from einops import einsum, rearrange @@ -32,6 +32,9 @@ from torch import nn, Tensor as T from ..variant import register, ParamSpec from ..config import AdapterConfig, register_config +CalibrationBatch = dict | tuple | list | T +CalibrationData = Iterable[CalibrationBatch] + @register_config @dataclass @@ -110,8 +113,83 @@ class AntiPaSTO: layer.lora_Vh.copy_(Vhr.to(layer.lora_Vh.dtype)) W_res = (W - (Ur * Sr) @ Vhr).to(layer.weight.dtype) layer.weight.data.copy_(W_res) + # group_init() refines this to input-aligned directions if calibration_data is given. - # FIXME antipasto needs an init from data too + @staticmethod + def group_init(model: nn.Module, targets, cfg, calibration_data: CalibrationData | None) -> None: + """EVA-style data-driven refinement: replace weight-SVD basis with input-PCA basis. + + Collects pre-hook activations, runs SVD on the pooled inputs per layer, then + re-decomposes W_orig through those input-aligned directions so the low-rank + subspace captures the actual input distribution rather than W's spectral structure. + + If calibration_data is None the weight-SVD init from init() is kept unchanged. + """ + if calibration_data is None: + return + + layers = {name: layer for name, layer, _ in targets} + captured: dict[str, list[T]] = {n: [] for n in layers} + + def make_hook(name): + def _h(module, args, kwargs): + x = args[0].detach() + captured[name].append(rearrange(x, "... d -> (...) d").to(torch.float32).cpu()) + return _h + + handles = [ + layers[n].register_forward_pre_hook(make_hook(n), with_kwargs=True) + for n in layers + ] + try: + was_training = model.training + model.eval() + with torch.no_grad(): + for batch in calibration_data: + if isinstance(batch, dict): + model(**batch) + elif isinstance(batch, (list, tuple)): + model(*batch) + else: + model(batch) + if was_training: + model.train() + finally: + for h in handles: + h.remove() + + r = cfg.r + for name, layer in layers.items(): + X = torch.cat(captured[name], dim=0) # (N, d_in) + if X.shape[0] < r: + raise RuntimeError( + f"AntiPaSTO at {name}: only {X.shape[0]} calibration tokens, need >= r={r}" + ) + + # Top-r right singular vectors of input distribution (same as EVA lora_A init) + _, _, Vh_data = torch.linalg.svd(X, full_matrices=False) + Vhr_new = Vh_data[:r] # (r, d_in) + + # Recover W_orig: init() already wrote W_res into layer.weight + W_res = layer.weight.data.float() + U_old = layer.lora_U.float() # (d_out, r) + S_old = layer.lora_S.float() # (r,) + Vh_old = layer.lora_Vh.float() # (r, d_in) + W_orig = W_res + (U_old * S_old.unsqueeze(0)) @ Vh_old + + # Project W_orig onto the input subspace, then SVD for proper U/S + A = W_orig @ Vhr_new.T # (d_out, r) + U_A, S_A, Vh_A = torch.linalg.svd(A, full_matrices=False) + + # Rotate Vhr_new by Vh_A so rows remain orthonormal and span is preserved + Vhr_final = Vh_A @ Vhr_new # (r, d_in) + W_res_new = (W_orig - (U_A * S_A.unsqueeze(0)) @ Vhr_final).to(layer.weight.dtype) + + with torch.no_grad(): + layer.lora_U.copy_(U_A.to(layer.lora_U)) + layer.lora_S.copy_(S_A.to(layer.lora_S)) + layer.lora_Vh.copy_(Vhr_final.to(layer.lora_Vh)) + layer.weight.data.copy_(W_res_new) @staticmethod def forward( diff --git a/src/lora_lite/variants/eva.py b/src/lora_lite/variants/eva.py index d5c17bc..0b8bdbf 100644 --- a/src/lora_lite/variants/eva.py +++ b/src/lora_lite/variants/eva.py @@ -101,7 +101,7 @@ class EVA: ) # full_matrices=False -> Vh shape (min(N,d_in), d_in); take top-r rows _, _, Vh = torch.linalg.svd(X, full_matrices=False) - A = Vh[: cfg.r, :].to(layer.lora_A.dtype).to(layer.lora_A.device) + A = Vh[: cfg.r, :].to(layer.lora_A) with torch.no_grad(): layer.lora_A.copy_(A)