@@ -85,43 +85,33 @@ class AntiPaSTOAblate:
layer . lora_Vh . copy_ ( Vhr . to ( layer . lora_Vh . dtype ) )
W_res = ( W - ( Ur * Sr ) @ Vhr ) . to ( layer . weight . dtype )
layer . weight . data . copy_ ( W_res )
# lora_c start s random here; group_init warm-starts it from the S-space output
# variance when calibration_data is given (see group_init), else it trains from noise.
# lora_c i s random-init. Tried (job 94) seeding it from the top-k S-space
# output-VARIANCE PC: equal-or-worse on GSM8K (55.6/64.0 vs random 56.0/68.0,
# single seed) and +31s calib init -- the optimal ablation dir is loss-defined,
# not variance-defined, so a variance seed buys nothing on SFT. Reverted.
# FIXME the contrastive dS seed (mean(h|pos)-mean(h|neg), cf. sspace.py) is the
# one that should land on the behavior dir, but it needs pos/neg pairs this SFT
# benchmark lacks -- only worth it for steering with labelled contrastive data.
@staticmethod
def group_init ( model : nn . Module , targets , cfg , calibration_data : CalibrationData | None ) - > None :
""" Warm-start each lora_c from calibration activations (and, if cov_orient,
re-orient the frozen SVD by input covariance C=E[x xᵀ] first, CorDA-style).
lora_c is seeded to the top-k principal axes of the S-space OUTPUT coords
h = diag(S) Vh x over the calibration set: the highest-energy output directions,
where the loss-gradient on the ablation strength is largest, so lora_c starts in a
high-gradient region instead of a near-orthogonal random one. NOTE this is the data
VARIANCE direction, not a contrastive behavior direction -- this benchmark is SFT
with no pos/neg split. For steering with contrastive pairs, seed lora_c from
mean(h|pos) - mean(h|neg) instead (cf. steering-lite sspace extract).
Σ xxᵀ (d_in², heavy for down_proj) is only accumulated to orient; the warm-start
alone (cov_orient=False) needs just the cheap r× r second moment Σ hhᵀ. """
if calibration_data is None :
""" If cov_orient, re-orient each target ' s SVD by input covariance C=E[x x^T]
(CorDA) so the data-relevant output directions land in the top-r and the
behavior direction is fully ablatable at low r. No-op otherwise (keeps the
plain-SVD basis from init()). C is accumulated on CPU; for down_proj ' s large
d_in this is heavy -- exclude it or use plain ablation there. """
if not getattr ( cfg , " cov_orient " , False ) or calibration_data is None :
return
orient = bool ( getattr ( cfg , " cov_orient " , False ) )
layers = { name : layer for name , layer , _ in targets }
gram : dict [ str , T ] = { } # Σ xxᵀ (d_in²), only when orienting
mom : dict [ str , T ] = { } # Σ hhᵀ (r²), when not orienting (basis is fixed at init)
cov : dict [ str , T ] = { }
cnt : dict [ str , int ] = { n : 0 for n in layers }
def make_hook ( name ) :
layer = layers [ name ]
def _h ( module , args , kwargs ) :
x = rearrange ( args [ 0 ] . detach ( ) , " ... d -> (...) d " ) . to ( torch . float32 ) . cpu ( )
if orient :
g = x . T @ x
gram [ name ] = g if name not in gram else gram [ name ] + g
else :
h = ( x @ layer . lora_Vh . float ( ) . cpu ( ) . T ) * layer . lora_S . float ( ) . cpu ( )
m = h . T @ h
mom [ name ] = m if name not in mom else mom [ name ] + m
g = x . T @ x
cov [ name ] = g if name not in cov else cov [ name ] + g
cnt [ name ] + = x . shape [ 0 ]
return _h
@@ -143,41 +133,32 @@ class AntiPaSTOAblate:
for h in handles :
h . remove ( )
r , k , eps = cfg . r , cfg . k , float ( cfg . cov_eps )
r , eps = cfg . r , float ( cfg . cov_eps )
for name , layer in layers . items ( ) :
if cnt [ name ] < r :
raise RuntimeError ( f " AntiPaSTOAblate at { name } : { cnt [ name ] } tokens, need >= r= { r } " )
if orient :
W_res = layer . weight . data . float ( ) . cpu ( )
U_old , S_old , Vh_old = ( layer . lora_U . float ( ) . cpu ( ) ,
layer . lora_S . float ( ) . cpu ( ) ,
layer . lora_Vh . float ( ) . cpu ( ) )
W_orig = W_res + ( U_old * S_old ) @ Vh_old
W_res = layer . weight . data . float ( ) . cpu ( )
U_old , S_old , Vh_old = ( layer . lora_U . float ( ) . cpu ( ) ,
layer . lora_S . float ( ) . cpu ( ) ,
layer . lora_Vh . float ( ) . cpu ( ) )
W_orig = W_res + ( U_old * S_old ) @ Vh_old
C = gram [ name ] / cnt [ name ]
lam , Q = torch . linalg . eigh ( C )
lam = lam . clamp_min ( 0 ) + eps
Chalf = ( Q * lam . sqrt ( ) ) @ Q . T
Cinvhalf = ( Q * lam . rsqrt ( ) ) @ Q . T
Ut , St , Vht = torch . linalg . svd ( W_orig @ Chalf , full_matrices = False )
Ur = Ut [ : , : r ] # orthonormal output basis (ablation acts here)
Sr = St [ : r ]
Pr = Vht [ : r ] @ Cinvhalf # oblique input projector (input-side only)
W_res_new = W_orig - ( Ur * Sr ) @ Pr
with torch . no_grad ( ) :
layer . lora_U . copy_ ( Ur . to ( layer . lora_U ) )
layer . lora_S . copy_ ( Sr . to ( layer . lora_S ) )
layer . lora_Vh . copy_ ( Pr . to ( layer . lora_Vh ) ) # store P in the Vh slot
layer . weight . data . copy_ ( W_res_new . to ( layer . weight ) )
# output S-space second moment in the (now oriented) basis: diag(S) P Σxxᵀ Pᵀ diag(S)
SP = Sr [ : , None ] * Pr
M = SP @ gram [ name ] @ SP . T
else :
M = mom [ name ] # (r, r) Σ hhᵀ in the init basis
C = cov [ name ] / cnt [ name ]
lam , Q = torch . linalg . eigh ( C )
lam = lam . clamp_min ( 0 ) + eps
Chalf = ( Q * lam . sqrt ( ) ) @ Q . T
Cinvhalf = ( Q * lam . rsqrt ( ) ) @ Q . T
Ut , St , Vht = torch . linalg . svd ( W_orig @ Chalf , full_matrices = False )
Ur = Ut [ : , : r ] # orthonormal output basis (ablation acts here)
Sr = St [ : r ]
Pr = Vht [ : r ] @ Cinvhalf # oblique input projector (input-side only)
W_res_new = W_orig - ( Ur * Sr ) @ Pr
c0 = torch . linalg . eigh ( M ) . eigenvectors [ : , - k : ] # top-k principal dirs (orthonormal)
with torch . no_grad ( ) :
layer . lora_c . copy_ ( c0 . to ( layer . lora_c ) )
layer . lora_U . copy_ ( Ur . to ( layer . lora_U ) )
layer . lora_S . copy_ ( Sr . to ( layer . lora_S ) )
layer . lora_Vh . copy_ ( Pr . to ( layer . lora_Vh ) ) # store P in the Vh slot
layer . weight . data . copy_ ( W_res_new . to ( layer . weight ) )
@staticmethod
def _orthonormal ( c : T ) - > T :