mirror of
https://github.com/wassname/Brukino_AntiPaSTO_Appetizer.git
synced 2026-06-27 16:58:47 +08:00
Implement super S-space projection across all residual writers
This commit is contained in:
+65
-4
@@ -39,6 +39,59 @@ NUM_EXAMPLES = 3
|
||||
|
||||
|
||||
# %%
|
||||
def get_collective_s_space_weight(model):
|
||||
"""
|
||||
Gathers all weight matrices that write to the residual stream
|
||||
(o_proj from attention and down_proj from MLP) across all layers,
|
||||
and concatenates them to form a collective "write" transformation.
|
||||
"""
|
||||
Ws = []
|
||||
for layer in model.model.layers:
|
||||
# In Qwen2, o_proj and down_proj weights are shape [hidden_size, in_features]
|
||||
# We want a combined matrix of shape [hidden_size, sum(in_features)]
|
||||
# so that SVD gives U of shape [hidden_size, top_k].
|
||||
# o_proj.weight is [hidden_size, num_heads * head_dim]
|
||||
# down_proj.weight is [hidden_size, intermediate_size]
|
||||
Ws.append(layer.self_attn.o_proj.weight.detach().cpu())
|
||||
Ws.append(layer.mlp.down_proj.weight.detach().cpu())
|
||||
return torch.cat(Ws, dim=1).to(model.device)
|
||||
|
||||
def project_to_s_space(hidden_states, W, top_k=256):
|
||||
"""
|
||||
Projects the residual stream into the 'super' S-space of all residual writers.
|
||||
|
||||
Explanation: The residual stream doesn't change much, but gets suppressed in the
|
||||
last 3-10% of layers. Since the residual stream interacts with all modules,
|
||||
we get the 'super' S-space of all residual stream writers. By getting the
|
||||
hidden states from the residual stream, and the U from all residual writers,
|
||||
we can project the residual stream into S-space, which can be thought of as
|
||||
something like the coordinate space of learned modes of behaviors.
|
||||
|
||||
W: The concatenated weight matrix of all residual writers, shape (sum(in_features), hidden_size)
|
||||
"""
|
||||
# SVD on the collective weight matrix
|
||||
U, S, _ = torch.linalg.svd(W.float(), full_matrices=False)
|
||||
|
||||
# Crop to top_k modes (there will be a lot of overlap/redundancy)
|
||||
U = U[:, :top_k]
|
||||
S = S[:top_k]
|
||||
|
||||
# Project: x_S = (x @ U)
|
||||
x_S = hidden_states.to(torch.float32) @ U
|
||||
|
||||
# Align signs: flip U (and x_S) so the maximum projection is positive
|
||||
# This standardizes the direction of the modes
|
||||
signs = torch.sign(x_S.max(dim=0).values + x_S.min(dim=0).values)
|
||||
# If the max absolute value was negative, signs will be -1, else 1
|
||||
signs[signs == 0] = 1.0 # prevent 0 multiplication
|
||||
|
||||
x_S = x_S * signs
|
||||
|
||||
# Scale by singular values
|
||||
x_S = x_S * S
|
||||
|
||||
return x_S
|
||||
|
||||
def compute_curvature(hidden_states):
|
||||
'''
|
||||
Computes Frenet-Serret extrinsic curvature (kappa).
|
||||
@@ -61,7 +114,7 @@ def compute_curvature(hidden_states):
|
||||
|
||||
|
||||
# %%
|
||||
def guided_eval(model, tokenizer, prompt_text, n_think=32, device="cuda"):
|
||||
def guided_eval(model, tokenizer, prompt_text, n_think=32, device="cuda", s_space_weight=None):
|
||||
messages = [{"role": "user", "content": prompt_text}]
|
||||
|
||||
inputs = tokenizer.apply_chat_template(
|
||||
@@ -122,9 +175,14 @@ def guided_eval(model, tokenizer, prompt_text, n_think=32, device="cuda"):
|
||||
start_idx = prompt_ids.shape[1]
|
||||
cot_hiddens = final_layer_hiddens[start_idx : start_idx + generated_ids.shape[0]]
|
||||
|
||||
if s_space_weight is not None:
|
||||
trajectory = project_to_s_space(cot_hiddens, s_space_weight)
|
||||
else:
|
||||
trajectory = cot_hiddens
|
||||
|
||||
return {
|
||||
"logratio": (p_yes - p_no).item(),
|
||||
"kappa_trajectory": compute_curvature(cot_hiddens).cpu().numpy(),
|
||||
"kappa_trajectory": compute_curvature(trajectory).cpu().numpy(),
|
||||
"prompt": tokenizer.decode(prompt_ids[0], skip_special_tokens=False),
|
||||
"generated_text": tokenizer.decode(generated_ids, skip_special_tokens=False)
|
||||
}
|
||||
@@ -159,9 +217,12 @@ print(f"\n--- Dilemma ---\n{item['dilemma_situation']}\nAction: {item['action']}
|
||||
plt.figure(figsize=(10, 6))
|
||||
results = {}
|
||||
|
||||
# Project using the collective residual stream writers transformation
|
||||
s_space_W = get_collective_s_space_weight(model)
|
||||
|
||||
for p_key, p_prefix in PERSONAS.items():
|
||||
print(f"\n--- Running: {p_key} ---")
|
||||
res = guided_eval(model, tokenizer, p_prefix + prompt_base, n_think=N_THINK_TOKENS, device=DEVICE)
|
||||
res = guided_eval(model, tokenizer, p_prefix + prompt_base, n_think=N_THINK_TOKENS, device=DEVICE, s_space_weight=s_space_W)
|
||||
results[p_key] = res
|
||||
print(f"Logratio (Yes/No): {res['logratio']:.3f}")
|
||||
print(f"Prompt:\n```md\n{res['prompt']}```")
|
||||
@@ -169,7 +230,7 @@ for p_key, p_prefix in PERSONAS.items():
|
||||
|
||||
plt.plot(res['kappa_trajectory'], label=f"{p_key} (logratio: {res['logratio']:.2f})")
|
||||
|
||||
plt.title(r"Extrinsic Curvature ($\kappa$) of Hidden States during CoT")
|
||||
plt.title(r"Extrinsic Curvature ($\kappa$) of S-Space Trajectories during CoT")
|
||||
plt.xlabel("Token Position in CoT")
|
||||
plt.ylabel(r"$\kappa(t)$")
|
||||
plt.legend()
|
||||
|
||||
Reference in New Issue
Block a user