mirror of
https://github.com/wassname/lora-lite.git
synced 2026-06-27 16:15:50 +08:00
antipasto_ablate: warm-start lora_c from S-space output variance
group_init now seeds each lora_c to the top-k principal axes of the S-space output coords h=diag(S)Vh x (highest-energy output dirs => largest loss-grad on the ablation strength), so lora_c starts in a high-gradient region not random. Cheap r x r second moment when not orienting; reuses Sigma xx^T when cov_orient. Benchmark always calibrates ablate now. This is the data-variance direction, not a contrastive behavior dir (SFT has no pos/neg split) -- noted in the docstring. UAT: |cos(lora_c, top output-PC)| = 1.0000 vs ~0.35 chance; smoke green. Co-Authored-By: Claudypoo <288921227+claudypoo@users.noreply.github.com>
This commit is contained in:
@@ -604,9 +604,9 @@ def run(args: BenchmarkConfig) -> dict[str, Any]:
|
|||||||
# downstream task (IPM mode, per CorDA). eva needs only a few batches for its init;
|
# downstream task (IPM mode, per CorDA). eva needs only a few batches for its init;
|
||||||
# corda/asvd/cov-orient estimate an input second moment, so we hand them many more
|
# corda/asvd/cov-orient estimate an input second moment, so we hand them many more
|
||||||
# batches (PEFT calibrates on a few hundred sequences) for a well-conditioned basis.
|
# batches (PEFT calibrates on a few hundred sequences) for a well-conditioned basis.
|
||||||
needs_calib = args.variant in ("eva", "antipasto_corda", "antipasto_asvd") or (
|
# antipasto_ablate always calibrates now: group_init warm-starts lora_c from the
|
||||||
args.variant == "antipasto_ablate" and args.antipasto_cov_orient
|
# S-space output variance (cov_orient adds the heavier CorDA re-orient on top).
|
||||||
)
|
needs_calib = args.variant in ("eva", "antipasto_corda", "antipasto_asvd", "antipasto_ablate")
|
||||||
init_meter = group_init_meter() # wall-time + peak CPU RAM of group_init
|
init_meter = group_init_meter() # wall-time + peak CPU RAM of group_init
|
||||||
if needs_calib:
|
if needs_calib:
|
||||||
n_batches = min(4, len(batches)) if args.variant == "eva" else min(64, len(batches))
|
n_batches = min(4, len(batches)) if args.variant == "eva" else min(64, len(batches))
|
||||||
|
|||||||
@@ -85,29 +85,43 @@ class AntiPaSTOAblate:
|
|||||||
layer.lora_Vh.copy_(Vhr.to(layer.lora_Vh.dtype))
|
layer.lora_Vh.copy_(Vhr.to(layer.lora_Vh.dtype))
|
||||||
W_res = (W - (Ur * Sr) @ Vhr).to(layer.weight.dtype)
|
W_res = (W - (Ur * Sr) @ Vhr).to(layer.weight.dtype)
|
||||||
layer.weight.data.copy_(W_res)
|
layer.weight.data.copy_(W_res)
|
||||||
# FIXME: lora_c is random-init. A group_init warm-start from the S-space
|
# lora_c starts random here; group_init warm-starts it from the S-space output
|
||||||
# contrastive direction dS (cf. sspace.py extract) would converge faster and
|
# variance when calibration_data is given (see group_init), else it trains from noise.
|
||||||
# land on the behavior direction; not implemented -- random trains, just slower.
|
|
||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
def group_init(model: nn.Module, targets, cfg, calibration_data: CalibrationData | None) -> None:
|
def group_init(model: nn.Module, targets, cfg, calibration_data: CalibrationData | None) -> None:
|
||||||
"""If cov_orient, re-orient each target's SVD by input covariance C=E[x x^T]
|
"""Warm-start each lora_c from calibration activations (and, if cov_orient,
|
||||||
(CorDA) so the data-relevant output directions land in the top-r and the
|
re-orient the frozen SVD by input covariance C=E[x xᵀ] first, CorDA-style).
|
||||||
behavior direction is fully ablatable at low r. No-op otherwise (keeps the
|
|
||||||
plain-SVD basis from init()). C is accumulated on CPU; for down_proj's large
|
|
||||||
d_in this is heavy -- exclude it or use plain ablation there."""
|
|
||||||
if not getattr(cfg, "cov_orient", False) or calibration_data is None:
|
|
||||||
return
|
|
||||||
|
|
||||||
|
lora_c is seeded to the top-k principal axes of the S-space OUTPUT coords
|
||||||
|
h = diag(S) Vh x over the calibration set: the highest-energy output directions,
|
||||||
|
where the loss-gradient on the ablation strength is largest, so lora_c starts in a
|
||||||
|
high-gradient region instead of a near-orthogonal random one. NOTE this is the data
|
||||||
|
VARIANCE direction, not a contrastive behavior direction -- this benchmark is SFT
|
||||||
|
with no pos/neg split. For steering with contrastive pairs, seed lora_c from
|
||||||
|
mean(h|pos) - mean(h|neg) instead (cf. steering-lite sspace extract).
|
||||||
|
|
||||||
|
Σ xxᵀ (d_in², heavy for down_proj) is only accumulated to orient; the warm-start
|
||||||
|
alone (cov_orient=False) needs just the cheap r×r second moment Σ hhᵀ."""
|
||||||
|
if calibration_data is None:
|
||||||
|
return
|
||||||
|
orient = bool(getattr(cfg, "cov_orient", False))
|
||||||
layers = {name: layer for name, layer, _ in targets}
|
layers = {name: layer for name, layer, _ in targets}
|
||||||
cov: dict[str, T] = {}
|
gram: dict[str, T] = {} # Σ xxᵀ (d_in²), only when orienting
|
||||||
|
mom: dict[str, T] = {} # Σ hhᵀ (r²), when not orienting (basis is fixed at init)
|
||||||
cnt: dict[str, int] = {n: 0 for n in layers}
|
cnt: dict[str, int] = {n: 0 for n in layers}
|
||||||
|
|
||||||
def make_hook(name):
|
def make_hook(name):
|
||||||
|
layer = layers[name]
|
||||||
def _h(module, args, kwargs):
|
def _h(module, args, kwargs):
|
||||||
x = rearrange(args[0].detach(), "... d -> (...) d").to(torch.float32).cpu()
|
x = rearrange(args[0].detach(), "... d -> (...) d").to(torch.float32).cpu()
|
||||||
g = x.T @ x
|
if orient:
|
||||||
cov[name] = g if name not in cov else cov[name] + g
|
g = x.T @ x
|
||||||
|
gram[name] = g if name not in gram else gram[name] + g
|
||||||
|
else:
|
||||||
|
h = (x @ layer.lora_Vh.float().cpu().T) * layer.lora_S.float().cpu()
|
||||||
|
m = h.T @ h
|
||||||
|
mom[name] = m if name not in mom else mom[name] + m
|
||||||
cnt[name] += x.shape[0]
|
cnt[name] += x.shape[0]
|
||||||
return _h
|
return _h
|
||||||
|
|
||||||
@@ -129,32 +143,41 @@ class AntiPaSTOAblate:
|
|||||||
for h in handles:
|
for h in handles:
|
||||||
h.remove()
|
h.remove()
|
||||||
|
|
||||||
r, eps = cfg.r, float(cfg.cov_eps)
|
r, k, eps = cfg.r, cfg.k, float(cfg.cov_eps)
|
||||||
for name, layer in layers.items():
|
for name, layer in layers.items():
|
||||||
if cnt[name] < r:
|
if cnt[name] < r:
|
||||||
raise RuntimeError(f"AntiPaSTOAblate at {name}: {cnt[name]} tokens, need >= r={r}")
|
raise RuntimeError(f"AntiPaSTOAblate at {name}: {cnt[name]} tokens, need >= r={r}")
|
||||||
W_res = layer.weight.data.float().cpu()
|
if orient:
|
||||||
U_old, S_old, Vh_old = (layer.lora_U.float().cpu(),
|
W_res = layer.weight.data.float().cpu()
|
||||||
layer.lora_S.float().cpu(),
|
U_old, S_old, Vh_old = (layer.lora_U.float().cpu(),
|
||||||
layer.lora_Vh.float().cpu())
|
layer.lora_S.float().cpu(),
|
||||||
W_orig = W_res + (U_old * S_old) @ Vh_old
|
layer.lora_Vh.float().cpu())
|
||||||
|
W_orig = W_res + (U_old * S_old) @ Vh_old
|
||||||
|
|
||||||
C = cov[name] / cnt[name]
|
C = gram[name] / cnt[name]
|
||||||
lam, Q = torch.linalg.eigh(C)
|
lam, Q = torch.linalg.eigh(C)
|
||||||
lam = lam.clamp_min(0) + eps
|
lam = lam.clamp_min(0) + eps
|
||||||
Chalf = (Q * lam.sqrt()) @ Q.T
|
Chalf = (Q * lam.sqrt()) @ Q.T
|
||||||
Cinvhalf = (Q * lam.rsqrt()) @ Q.T
|
Cinvhalf = (Q * lam.rsqrt()) @ Q.T
|
||||||
Ut, St, Vht = torch.linalg.svd(W_orig @ Chalf, full_matrices=False)
|
Ut, St, Vht = torch.linalg.svd(W_orig @ Chalf, full_matrices=False)
|
||||||
Ur = Ut[:, :r] # orthonormal output basis (ablation acts here)
|
Ur = Ut[:, :r] # orthonormal output basis (ablation acts here)
|
||||||
Sr = St[:r]
|
Sr = St[:r]
|
||||||
Pr = Vht[:r] @ Cinvhalf # oblique input projector (input-side only)
|
Pr = Vht[:r] @ Cinvhalf # oblique input projector (input-side only)
|
||||||
W_res_new = W_orig - (Ur * Sr) @ Pr
|
W_res_new = W_orig - (Ur * Sr) @ Pr
|
||||||
|
with torch.no_grad():
|
||||||
|
layer.lora_U.copy_(Ur.to(layer.lora_U))
|
||||||
|
layer.lora_S.copy_(Sr.to(layer.lora_S))
|
||||||
|
layer.lora_Vh.copy_(Pr.to(layer.lora_Vh)) # store P in the Vh slot
|
||||||
|
layer.weight.data.copy_(W_res_new.to(layer.weight))
|
||||||
|
# output S-space second moment in the (now oriented) basis: diag(S) P Σxxᵀ Pᵀ diag(S)
|
||||||
|
SP = Sr[:, None] * Pr
|
||||||
|
M = SP @ gram[name] @ SP.T
|
||||||
|
else:
|
||||||
|
M = mom[name] # (r, r) Σ hhᵀ in the init basis
|
||||||
|
|
||||||
|
c0 = torch.linalg.eigh(M).eigenvectors[:, -k:] # top-k principal dirs (orthonormal)
|
||||||
with torch.no_grad():
|
with torch.no_grad():
|
||||||
layer.lora_U.copy_(Ur.to(layer.lora_U))
|
layer.lora_c.copy_(c0.to(layer.lora_c))
|
||||||
layer.lora_S.copy_(Sr.to(layer.lora_S))
|
|
||||||
layer.lora_Vh.copy_(Pr.to(layer.lora_Vh)) # store P in the Vh slot
|
|
||||||
layer.weight.data.copy_(W_res_new.to(layer.weight))
|
|
||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
def _orthonormal(c: T) -> T:
|
def _orthonormal(c: T) -> T:
|
||||||
|
|||||||
Reference in New Issue
Block a user