mirror of
https://github.com/wassname/Brukino_AntiPaSTO_Appetizer.git
synced 2026-06-27 16:58:47 +08:00
Refactor to return U, S, Vh from SVD and perform cropping outside the function
This commit is contained in:
+9
-9
@@ -39,13 +39,13 @@ NUM_EXAMPLES = 3
|
||||
|
||||
|
||||
# %%
|
||||
def get_s_space_projector(model, top_k=1024):
|
||||
def get_s_space_svd(model):
|
||||
"""
|
||||
Gathers all weight matrices that write to the residual stream
|
||||
(o_proj from attention and down_proj from MLP) across all layers,
|
||||
and concatenates them to form a collective "write" transformation.
|
||||
Then computes the SVD to extract the top_k modes.
|
||||
Returns: U (hidden_size, top_k), S (top_k,)
|
||||
Then computes and returns the full SVD.
|
||||
Returns: U, S, Vh
|
||||
"""
|
||||
Ws = []
|
||||
for layer in model.model.layers:
|
||||
@@ -56,12 +56,9 @@ def get_s_space_projector(model, top_k=1024):
|
||||
W = torch.cat(Ws, dim=1).to(model.device)
|
||||
|
||||
# SVD on the collective weight matrix
|
||||
U, S, _ = torch.linalg.svd(W.float(), full_matrices=False)
|
||||
U, S, Vh = torch.linalg.svd(W.float(), full_matrices=False)
|
||||
|
||||
# Crop to top_k modes
|
||||
U = U[:, :top_k]
|
||||
S = S[:top_k]
|
||||
return U, S
|
||||
return U, S, Vh
|
||||
|
||||
def project_to_s_space(hidden_states, U, S):
|
||||
"""
|
||||
@@ -216,7 +213,10 @@ plt.figure(figsize=(10, 6))
|
||||
results = {}
|
||||
|
||||
# Project using the collective residual stream writers transformation
|
||||
s_space_U, s_space_S = get_s_space_projector(model)
|
||||
U, S, Vh = get_s_space_svd(model)
|
||||
top_k = 1024
|
||||
s_space_U = U[:, :top_k]
|
||||
s_space_S = S[:top_k]
|
||||
|
||||
for p_key, p_prefix in PERSONAS.items():
|
||||
print(f"\n--- Running: {p_key} ---")
|
||||
|
||||
Reference in New Issue
Block a user