mirror of
https://github.com/wassname/lora-lite.git
synced 2026-06-27 15:15:55 +08:00
Revert ablate lora_c warm-start: variance-PC seed didn't help on SFT
Job 94 result (Qwen3.5-0.8B, GSM8K, 2500 steps, single seed): warm-start (top-k S-space output-variance PC): test 55.6 / valid 64.0, init 33.2s random-init (prior default): test 56.0 / valid 68.0, init 2.2s Equal-or-worse accuracy (within single-seed noise) for +31s of calibration init. The optimal ablation direction is loss-defined, not variance-defined, so seeding lora_c from the data-variance PC buys nothing here. Reverts fe562c2; ablate is back to the cheap random-init default. cov_orient (CorDA re-orient) path kept. The FIXME's actual proposal -- a *contrastive* dS seed -- stays open but needs pos/neg pairs this SFT benchmark lacks (only relevant for labelled steering). Co-Authored-By: Claudypoo <288921227+claudypoo@users.noreply.github.com>
This commit is contained in:
@@ -604,9 +604,9 @@ def run(args: BenchmarkConfig) -> dict[str, Any]:
|
||||
# downstream task (IPM mode, per CorDA). eva needs only a few batches for its init;
|
||||
# corda/asvd/cov-orient estimate an input second moment, so we hand them many more
|
||||
# batches (PEFT calibrates on a few hundred sequences) for a well-conditioned basis.
|
||||
# antipasto_ablate always calibrates now: group_init warm-starts lora_c from the
|
||||
# S-space output variance (cov_orient adds the heavier CorDA re-orient on top).
|
||||
needs_calib = args.variant in ("eva", "antipasto_corda", "antipasto_asvd", "antipasto_ablate")
|
||||
needs_calib = args.variant in ("eva", "antipasto_corda", "antipasto_asvd") or (
|
||||
args.variant == "antipasto_ablate" and args.antipasto_cov_orient
|
||||
)
|
||||
init_meter = group_init_meter() # wall-time + peak CPU RAM of group_init
|
||||
if needs_calib:
|
||||
n_batches = min(4, len(batches)) if args.variant == "eva" else min(64, len(batches))
|
||||
|
||||
@@ -85,43 +85,33 @@ class AntiPaSTOAblate:
|
||||
layer.lora_Vh.copy_(Vhr.to(layer.lora_Vh.dtype))
|
||||
W_res = (W - (Ur * Sr) @ Vhr).to(layer.weight.dtype)
|
||||
layer.weight.data.copy_(W_res)
|
||||
# lora_c starts random here; group_init warm-starts it from the S-space output
|
||||
# variance when calibration_data is given (see group_init), else it trains from noise.
|
||||
# lora_c is random-init. Tried (job 94) seeding it from the top-k S-space
|
||||
# output-VARIANCE PC: equal-or-worse on GSM8K (55.6/64.0 vs random 56.0/68.0,
|
||||
# single seed) and +31s calib init -- the optimal ablation dir is loss-defined,
|
||||
# not variance-defined, so a variance seed buys nothing on SFT. Reverted.
|
||||
# FIXME the contrastive dS seed (mean(h|pos)-mean(h|neg), cf. sspace.py) is the
|
||||
# one that should land on the behavior dir, but it needs pos/neg pairs this SFT
|
||||
# benchmark lacks -- only worth it for steering with labelled contrastive data.
|
||||
|
||||
@staticmethod
|
||||
def group_init(model: nn.Module, targets, cfg, calibration_data: CalibrationData | None) -> None:
|
||||
"""Warm-start each lora_c from calibration activations (and, if cov_orient,
|
||||
re-orient the frozen SVD by input covariance C=E[x xᵀ] first, CorDA-style).
|
||||
|
||||
lora_c is seeded to the top-k principal axes of the S-space OUTPUT coords
|
||||
h = diag(S) Vh x over the calibration set: the highest-energy output directions,
|
||||
where the loss-gradient on the ablation strength is largest, so lora_c starts in a
|
||||
high-gradient region instead of a near-orthogonal random one. NOTE this is the data
|
||||
VARIANCE direction, not a contrastive behavior direction -- this benchmark is SFT
|
||||
with no pos/neg split. For steering with contrastive pairs, seed lora_c from
|
||||
mean(h|pos) - mean(h|neg) instead (cf. steering-lite sspace extract).
|
||||
|
||||
Σ xxᵀ (d_in², heavy for down_proj) is only accumulated to orient; the warm-start
|
||||
alone (cov_orient=False) needs just the cheap r×r second moment Σ hhᵀ."""
|
||||
if calibration_data is None:
|
||||
"""If cov_orient, re-orient each target's SVD by input covariance C=E[x x^T]
|
||||
(CorDA) so the data-relevant output directions land in the top-r and the
|
||||
behavior direction is fully ablatable at low r. No-op otherwise (keeps the
|
||||
plain-SVD basis from init()). C is accumulated on CPU; for down_proj's large
|
||||
d_in this is heavy -- exclude it or use plain ablation there."""
|
||||
if not getattr(cfg, "cov_orient", False) or calibration_data is None:
|
||||
return
|
||||
orient = bool(getattr(cfg, "cov_orient", False))
|
||||
|
||||
layers = {name: layer for name, layer, _ in targets}
|
||||
gram: dict[str, T] = {} # Σ xxᵀ (d_in²), only when orienting
|
||||
mom: dict[str, T] = {} # Σ hhᵀ (r²), when not orienting (basis is fixed at init)
|
||||
cov: dict[str, T] = {}
|
||||
cnt: dict[str, int] = {n: 0 for n in layers}
|
||||
|
||||
def make_hook(name):
|
||||
layer = layers[name]
|
||||
def _h(module, args, kwargs):
|
||||
x = rearrange(args[0].detach(), "... d -> (...) d").to(torch.float32).cpu()
|
||||
if orient:
|
||||
g = x.T @ x
|
||||
gram[name] = g if name not in gram else gram[name] + g
|
||||
else:
|
||||
h = (x @ layer.lora_Vh.float().cpu().T) * layer.lora_S.float().cpu()
|
||||
m = h.T @ h
|
||||
mom[name] = m if name not in mom else mom[name] + m
|
||||
g = x.T @ x
|
||||
cov[name] = g if name not in cov else cov[name] + g
|
||||
cnt[name] += x.shape[0]
|
||||
return _h
|
||||
|
||||
@@ -143,41 +133,32 @@ class AntiPaSTOAblate:
|
||||
for h in handles:
|
||||
h.remove()
|
||||
|
||||
r, k, eps = cfg.r, cfg.k, float(cfg.cov_eps)
|
||||
r, eps = cfg.r, float(cfg.cov_eps)
|
||||
for name, layer in layers.items():
|
||||
if cnt[name] < r:
|
||||
raise RuntimeError(f"AntiPaSTOAblate at {name}: {cnt[name]} tokens, need >= r={r}")
|
||||
if orient:
|
||||
W_res = layer.weight.data.float().cpu()
|
||||
U_old, S_old, Vh_old = (layer.lora_U.float().cpu(),
|
||||
layer.lora_S.float().cpu(),
|
||||
layer.lora_Vh.float().cpu())
|
||||
W_orig = W_res + (U_old * S_old) @ Vh_old
|
||||
W_res = layer.weight.data.float().cpu()
|
||||
U_old, S_old, Vh_old = (layer.lora_U.float().cpu(),
|
||||
layer.lora_S.float().cpu(),
|
||||
layer.lora_Vh.float().cpu())
|
||||
W_orig = W_res + (U_old * S_old) @ Vh_old
|
||||
|
||||
C = gram[name] / cnt[name]
|
||||
lam, Q = torch.linalg.eigh(C)
|
||||
lam = lam.clamp_min(0) + eps
|
||||
Chalf = (Q * lam.sqrt()) @ Q.T
|
||||
Cinvhalf = (Q * lam.rsqrt()) @ Q.T
|
||||
Ut, St, Vht = torch.linalg.svd(W_orig @ Chalf, full_matrices=False)
|
||||
Ur = Ut[:, :r] # orthonormal output basis (ablation acts here)
|
||||
Sr = St[:r]
|
||||
Pr = Vht[:r] @ Cinvhalf # oblique input projector (input-side only)
|
||||
W_res_new = W_orig - (Ur * Sr) @ Pr
|
||||
with torch.no_grad():
|
||||
layer.lora_U.copy_(Ur.to(layer.lora_U))
|
||||
layer.lora_S.copy_(Sr.to(layer.lora_S))
|
||||
layer.lora_Vh.copy_(Pr.to(layer.lora_Vh)) # store P in the Vh slot
|
||||
layer.weight.data.copy_(W_res_new.to(layer.weight))
|
||||
# output S-space second moment in the (now oriented) basis: diag(S) P Σxxᵀ Pᵀ diag(S)
|
||||
SP = Sr[:, None] * Pr
|
||||
M = SP @ gram[name] @ SP.T
|
||||
else:
|
||||
M = mom[name] # (r, r) Σ hhᵀ in the init basis
|
||||
C = cov[name] / cnt[name]
|
||||
lam, Q = torch.linalg.eigh(C)
|
||||
lam = lam.clamp_min(0) + eps
|
||||
Chalf = (Q * lam.sqrt()) @ Q.T
|
||||
Cinvhalf = (Q * lam.rsqrt()) @ Q.T
|
||||
Ut, St, Vht = torch.linalg.svd(W_orig @ Chalf, full_matrices=False)
|
||||
Ur = Ut[:, :r] # orthonormal output basis (ablation acts here)
|
||||
Sr = St[:r]
|
||||
Pr = Vht[:r] @ Cinvhalf # oblique input projector (input-side only)
|
||||
W_res_new = W_orig - (Ur * Sr) @ Pr
|
||||
|
||||
c0 = torch.linalg.eigh(M).eigenvectors[:, -k:] # top-k principal dirs (orthonormal)
|
||||
with torch.no_grad():
|
||||
layer.lora_c.copy_(c0.to(layer.lora_c))
|
||||
layer.lora_U.copy_(Ur.to(layer.lora_U))
|
||||
layer.lora_S.copy_(Sr.to(layer.lora_S))
|
||||
layer.lora_Vh.copy_(Pr.to(layer.lora_Vh)) # store P in the Vh slot
|
||||
layer.weight.data.copy_(W_res_new.to(layer.weight))
|
||||
|
||||
@staticmethod
|
||||
def _orthonormal(c: T) -> T:
|
||||
|
||||
Reference in New Issue
Block a user