mirror of
https://github.com/wassname/lora-lite.git
synced 2026-06-27 14:45:22 +08:00
antipasto_ablate: warm-start lora_c from S-space output variance
group_init now seeds each lora_c to the top-k principal axes of the S-space output coords h=diag(S)Vh x (highest-energy output dirs => largest loss-grad on the ablation strength), so lora_c starts in a high-gradient region not random. Cheap r x r second moment when not orienting; reuses Sigma xx^T when cov_orient. Benchmark always calibrates ablate now. This is the data-variance direction, not a contrastive behavior dir (SFT has no pos/neg split) -- noted in the docstring. UAT: |cos(lora_c, top output-PC)| = 1.0000 vs ~0.35 chance; smoke green. Co-Authored-By: Claudypoo <288921227+claudypoo@users.noreply.github.com>
This commit is contained in:
@@ -604,9 +604,9 @@ def run(args: BenchmarkConfig) -> dict[str, Any]:
|
||||
# downstream task (IPM mode, per CorDA). eva needs only a few batches for its init;
|
||||
# corda/asvd/cov-orient estimate an input second moment, so we hand them many more
|
||||
# batches (PEFT calibrates on a few hundred sequences) for a well-conditioned basis.
|
||||
needs_calib = args.variant in ("eva", "antipasto_corda", "antipasto_asvd") or (
|
||||
args.variant == "antipasto_ablate" and args.antipasto_cov_orient
|
||||
)
|
||||
# antipasto_ablate always calibrates now: group_init warm-starts lora_c from the
|
||||
# S-space output variance (cov_orient adds the heavier CorDA re-orient on top).
|
||||
needs_calib = args.variant in ("eva", "antipasto_corda", "antipasto_asvd", "antipasto_ablate")
|
||||
init_meter = group_init_meter() # wall-time + peak CPU RAM of group_init
|
||||
if needs_calib:
|
||||
n_batches = min(4, len(batches)) if args.variant == "eva" else min(64, len(batches))
|
||||
|
||||
@@ -85,29 +85,43 @@ class AntiPaSTOAblate:
|
||||
layer.lora_Vh.copy_(Vhr.to(layer.lora_Vh.dtype))
|
||||
W_res = (W - (Ur * Sr) @ Vhr).to(layer.weight.dtype)
|
||||
layer.weight.data.copy_(W_res)
|
||||
# FIXME: lora_c is random-init. A group_init warm-start from the S-space
|
||||
# contrastive direction dS (cf. sspace.py extract) would converge faster and
|
||||
# land on the behavior direction; not implemented -- random trains, just slower.
|
||||
# lora_c starts random here; group_init warm-starts it from the S-space output
|
||||
# variance when calibration_data is given (see group_init), else it trains from noise.
|
||||
|
||||
@staticmethod
|
||||
def group_init(model: nn.Module, targets, cfg, calibration_data: CalibrationData | None) -> None:
|
||||
"""If cov_orient, re-orient each target's SVD by input covariance C=E[x x^T]
|
||||
(CorDA) so the data-relevant output directions land in the top-r and the
|
||||
behavior direction is fully ablatable at low r. No-op otherwise (keeps the
|
||||
plain-SVD basis from init()). C is accumulated on CPU; for down_proj's large
|
||||
d_in this is heavy -- exclude it or use plain ablation there."""
|
||||
if not getattr(cfg, "cov_orient", False) or calibration_data is None:
|
||||
return
|
||||
"""Warm-start each lora_c from calibration activations (and, if cov_orient,
|
||||
re-orient the frozen SVD by input covariance C=E[x xᵀ] first, CorDA-style).
|
||||
|
||||
lora_c is seeded to the top-k principal axes of the S-space OUTPUT coords
|
||||
h = diag(S) Vh x over the calibration set: the highest-energy output directions,
|
||||
where the loss-gradient on the ablation strength is largest, so lora_c starts in a
|
||||
high-gradient region instead of a near-orthogonal random one. NOTE this is the data
|
||||
VARIANCE direction, not a contrastive behavior direction -- this benchmark is SFT
|
||||
with no pos/neg split. For steering with contrastive pairs, seed lora_c from
|
||||
mean(h|pos) - mean(h|neg) instead (cf. steering-lite sspace extract).
|
||||
|
||||
Σ xxᵀ (d_in², heavy for down_proj) is only accumulated to orient; the warm-start
|
||||
alone (cov_orient=False) needs just the cheap r×r second moment Σ hhᵀ."""
|
||||
if calibration_data is None:
|
||||
return
|
||||
orient = bool(getattr(cfg, "cov_orient", False))
|
||||
layers = {name: layer for name, layer, _ in targets}
|
||||
cov: dict[str, T] = {}
|
||||
gram: dict[str, T] = {} # Σ xxᵀ (d_in²), only when orienting
|
||||
mom: dict[str, T] = {} # Σ hhᵀ (r²), when not orienting (basis is fixed at init)
|
||||
cnt: dict[str, int] = {n: 0 for n in layers}
|
||||
|
||||
def make_hook(name):
|
||||
layer = layers[name]
|
||||
def _h(module, args, kwargs):
|
||||
x = rearrange(args[0].detach(), "... d -> (...) d").to(torch.float32).cpu()
|
||||
g = x.T @ x
|
||||
cov[name] = g if name not in cov else cov[name] + g
|
||||
if orient:
|
||||
g = x.T @ x
|
||||
gram[name] = g if name not in gram else gram[name] + g
|
||||
else:
|
||||
h = (x @ layer.lora_Vh.float().cpu().T) * layer.lora_S.float().cpu()
|
||||
m = h.T @ h
|
||||
mom[name] = m if name not in mom else mom[name] + m
|
||||
cnt[name] += x.shape[0]
|
||||
return _h
|
||||
|
||||
@@ -129,32 +143,41 @@ class AntiPaSTOAblate:
|
||||
for h in handles:
|
||||
h.remove()
|
||||
|
||||
r, eps = cfg.r, float(cfg.cov_eps)
|
||||
r, k, eps = cfg.r, cfg.k, float(cfg.cov_eps)
|
||||
for name, layer in layers.items():
|
||||
if cnt[name] < r:
|
||||
raise RuntimeError(f"AntiPaSTOAblate at {name}: {cnt[name]} tokens, need >= r={r}")
|
||||
W_res = layer.weight.data.float().cpu()
|
||||
U_old, S_old, Vh_old = (layer.lora_U.float().cpu(),
|
||||
layer.lora_S.float().cpu(),
|
||||
layer.lora_Vh.float().cpu())
|
||||
W_orig = W_res + (U_old * S_old) @ Vh_old
|
||||
if orient:
|
||||
W_res = layer.weight.data.float().cpu()
|
||||
U_old, S_old, Vh_old = (layer.lora_U.float().cpu(),
|
||||
layer.lora_S.float().cpu(),
|
||||
layer.lora_Vh.float().cpu())
|
||||
W_orig = W_res + (U_old * S_old) @ Vh_old
|
||||
|
||||
C = cov[name] / cnt[name]
|
||||
lam, Q = torch.linalg.eigh(C)
|
||||
lam = lam.clamp_min(0) + eps
|
||||
Chalf = (Q * lam.sqrt()) @ Q.T
|
||||
Cinvhalf = (Q * lam.rsqrt()) @ Q.T
|
||||
Ut, St, Vht = torch.linalg.svd(W_orig @ Chalf, full_matrices=False)
|
||||
Ur = Ut[:, :r] # orthonormal output basis (ablation acts here)
|
||||
Sr = St[:r]
|
||||
Pr = Vht[:r] @ Cinvhalf # oblique input projector (input-side only)
|
||||
W_res_new = W_orig - (Ur * Sr) @ Pr
|
||||
C = gram[name] / cnt[name]
|
||||
lam, Q = torch.linalg.eigh(C)
|
||||
lam = lam.clamp_min(0) + eps
|
||||
Chalf = (Q * lam.sqrt()) @ Q.T
|
||||
Cinvhalf = (Q * lam.rsqrt()) @ Q.T
|
||||
Ut, St, Vht = torch.linalg.svd(W_orig @ Chalf, full_matrices=False)
|
||||
Ur = Ut[:, :r] # orthonormal output basis (ablation acts here)
|
||||
Sr = St[:r]
|
||||
Pr = Vht[:r] @ Cinvhalf # oblique input projector (input-side only)
|
||||
W_res_new = W_orig - (Ur * Sr) @ Pr
|
||||
with torch.no_grad():
|
||||
layer.lora_U.copy_(Ur.to(layer.lora_U))
|
||||
layer.lora_S.copy_(Sr.to(layer.lora_S))
|
||||
layer.lora_Vh.copy_(Pr.to(layer.lora_Vh)) # store P in the Vh slot
|
||||
layer.weight.data.copy_(W_res_new.to(layer.weight))
|
||||
# output S-space second moment in the (now oriented) basis: diag(S) P Σxxᵀ Pᵀ diag(S)
|
||||
SP = Sr[:, None] * Pr
|
||||
M = SP @ gram[name] @ SP.T
|
||||
else:
|
||||
M = mom[name] # (r, r) Σ hhᵀ in the init basis
|
||||
|
||||
c0 = torch.linalg.eigh(M).eigenvectors[:, -k:] # top-k principal dirs (orthonormal)
|
||||
with torch.no_grad():
|
||||
layer.lora_U.copy_(Ur.to(layer.lora_U))
|
||||
layer.lora_S.copy_(Sr.to(layer.lora_S))
|
||||
layer.lora_Vh.copy_(Pr.to(layer.lora_Vh)) # store P in the Vh slot
|
||||
layer.weight.data.copy_(W_res_new.to(layer.weight))
|
||||
layer.lora_c.copy_(c0.to(layer.lora_c))
|
||||
|
||||
@staticmethod
|
||||
def _orthonormal(c: T) -> T:
|
||||
|
||||
Reference in New Issue
Block a user