Precompute SVD for S-space projection for efficiency

This commit is contained in:
wassname
2026-04-10 09:46:42 +08:00
parent a1a8648865
commit 382ffc4315
2 changed files with 41 additions and 45 deletions
+18 -20
View File
@@ -39,24 +39,31 @@ NUM_EXAMPLES = 3
# %%
def get_collective_s_space_weight(model):
def get_s_space_projector(model, top_k=256):
"""
Gathers all weight matrices that write to the residual stream
(o_proj from attention and down_proj from MLP) across all layers,
and concatenates them to form a collective "write" transformation.
Then computes the SVD to extract the top_k modes.
Returns: U (hidden_size, top_k), S (top_k,)
"""
Ws = []
for layer in model.model.layers:
# In Qwen2, o_proj and down_proj weights are shape [hidden_size, in_features]
# We want a combined matrix of shape [hidden_size, sum(in_features)]
# so that SVD gives U of shape [hidden_size, top_k].
# o_proj.weight is [hidden_size, num_heads * head_dim]
# down_proj.weight is [hidden_size, intermediate_size]
Ws.append(layer.self_attn.o_proj.weight.detach().cpu())
Ws.append(layer.mlp.down_proj.weight.detach().cpu())
return torch.cat(Ws, dim=1).to(model.device)
W = torch.cat(Ws, dim=1).to(model.device)
# SVD on the collective weight matrix
U, S, _ = torch.linalg.svd(W.float(), full_matrices=False)
# Crop to top_k modes
U = U[:, :top_k]
S = S[:top_k]
return U, S
def project_to_s_space(hidden_states, W, top_k=256):
def project_to_s_space(hidden_states, U, S):
"""
Projects the residual stream into the 'super' S-space of all residual writers.
@@ -66,16 +73,7 @@ def project_to_s_space(hidden_states, W, top_k=256):
hidden states from the residual stream, and the U from all residual writers,
we can project the residual stream into S-space, which can be thought of as
something like the coordinate space of learned modes of behaviors.
W: The concatenated weight matrix of all residual writers, shape (sum(in_features), hidden_size)
"""
# SVD on the collective weight matrix
U, S, _ = torch.linalg.svd(W.float(), full_matrices=False)
# Crop to top_k modes (there will be a lot of overlap/redundancy)
U = U[:, :top_k]
S = S[:top_k]
# Project: x_S = (x @ U)
x_S = hidden_states.to(torch.float32) @ U
@@ -114,7 +112,7 @@ def compute_curvature(hidden_states):
# %%
def guided_eval(model, tokenizer, prompt_text, n_think=32, device="cuda", s_space_weight=None):
def guided_eval(model, tokenizer, prompt_text, n_think=32, device="cuda", s_space_U=None, s_space_S=None):
messages = [{"role": "user", "content": prompt_text}]
inputs = tokenizer.apply_chat_template(
@@ -175,8 +173,8 @@ def guided_eval(model, tokenizer, prompt_text, n_think=32, device="cuda", s_spac
start_idx = prompt_ids.shape[1]
cot_hiddens = final_layer_hiddens[start_idx : start_idx + generated_ids.shape[0]]
if s_space_weight is not None:
trajectory = project_to_s_space(cot_hiddens, s_space_weight)
if s_space_U is not None and s_space_S is not None:
trajectory = project_to_s_space(cot_hiddens, s_space_U, s_space_S)
else:
trajectory = cot_hiddens
@@ -218,11 +216,11 @@ plt.figure(figsize=(10, 6))
results = {}
# Project using the collective residual stream writers transformation
s_space_W = get_collective_s_space_weight(model)
s_space_U, s_space_S = get_s_space_projector(model)
for p_key, p_prefix in PERSONAS.items():
print(f"\n--- Running: {p_key} ---")
res = guided_eval(model, tokenizer, p_prefix + prompt_base, n_think=N_THINK_TOKENS, device=DEVICE, s_space_weight=s_space_W)
res = guided_eval(model, tokenizer, p_prefix + prompt_base, n_think=N_THINK_TOKENS, device=DEVICE, s_space_U=s_space_U, s_space_S=s_space_S)
results[p_key] = res
print(f"Logratio (Yes/No): {res['logratio']:.3f}")
print(f"Prompt:\n```md\n{res['prompt']}```")