Refactor to return U, S, Vh from SVD and perform cropping outside the function

This commit is contained in:
wassname
2026-04-10 10:00:28 +08:00
parent e44dc0e74e
commit c8a59851ed
2 changed files with 23 additions and 23 deletions
+9 -9
View File
@@ -39,13 +39,13 @@ NUM_EXAMPLES = 3
# %%
def get_s_space_projector(model, top_k=1024):
def get_s_space_svd(model):
"""
Gathers all weight matrices that write to the residual stream
(o_proj from attention and down_proj from MLP) across all layers,
and concatenates them to form a collective "write" transformation.
Then computes the SVD to extract the top_k modes.
Returns: U (hidden_size, top_k), S (top_k,)
Then computes and returns the full SVD.
Returns: U, S, Vh
"""
Ws = []
for layer in model.model.layers:
@@ -56,12 +56,9 @@ def get_s_space_projector(model, top_k=1024):
W = torch.cat(Ws, dim=1).to(model.device)
# SVD on the collective weight matrix
U, S, _ = torch.linalg.svd(W.float(), full_matrices=False)
U, S, Vh = torch.linalg.svd(W.float(), full_matrices=False)
# Crop to top_k modes
U = U[:, :top_k]
S = S[:top_k]
return U, S
return U, S, Vh
def project_to_s_space(hidden_states, U, S):
"""
@@ -216,7 +213,10 @@ plt.figure(figsize=(10, 6))
results = {}
# Project using the collective residual stream writers transformation
s_space_U, s_space_S = get_s_space_projector(model)
U, S, Vh = get_s_space_svd(model)
top_k = 1024
s_space_U = U[:, :top_k]
s_space_S = S[:top_k]
for p_key, p_prefix in PERSONAS.items():
print(f"\n--- Running: {p_key} ---")