mirror of
https://github.com/wassname/Brukino_AntiPaSTO_Appetizer.git
synced 2026-06-27 17:13:50 +08:00
Precompute SVD for S-space projection for efficiency
This commit is contained in:
+23
-25
@@ -2,7 +2,7 @@
|
||||
"cells": [
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"id": "76397472",
|
||||
"id": "b8a288c6",
|
||||
"metadata": {},
|
||||
"source": [
|
||||
"# Brukino's AntiPaSTO Appetizer: Guided CoT Eval & Frenet-Serret Curvature\n",
|
||||
@@ -20,7 +20,7 @@
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"id": "703c8e38",
|
||||
"id": "e8c081e6",
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
@@ -44,28 +44,35 @@
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"id": "e6f56583",
|
||||
"id": "7b5b34e4",
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"def get_collective_s_space_weight(model):\n",
|
||||
"def get_s_space_projector(model, top_k=256):\n",
|
||||
" \"\"\"\n",
|
||||
" Gathers all weight matrices that write to the residual stream\n",
|
||||
" (o_proj from attention and down_proj from MLP) across all layers,\n",
|
||||
" and concatenates them to form a collective \"write\" transformation.\n",
|
||||
" Then computes the SVD to extract the top_k modes.\n",
|
||||
" Returns: U (hidden_size, top_k), S (top_k,)\n",
|
||||
" \"\"\"\n",
|
||||
" Ws = []\n",
|
||||
" for layer in model.model.layers:\n",
|
||||
" # In Qwen2, o_proj and down_proj weights are shape [hidden_size, in_features]\n",
|
||||
" # We want a combined matrix of shape [hidden_size, sum(in_features)]\n",
|
||||
" # so that SVD gives U of shape [hidden_size, top_k].\n",
|
||||
" # o_proj.weight is [hidden_size, num_heads * head_dim]\n",
|
||||
" # down_proj.weight is [hidden_size, intermediate_size]\n",
|
||||
" Ws.append(layer.self_attn.o_proj.weight.detach().cpu())\n",
|
||||
" Ws.append(layer.mlp.down_proj.weight.detach().cpu())\n",
|
||||
" return torch.cat(Ws, dim=1).to(model.device)\n",
|
||||
" W = torch.cat(Ws, dim=1).to(model.device)\n",
|
||||
" \n",
|
||||
" # SVD on the collective weight matrix\n",
|
||||
" U, S, _ = torch.linalg.svd(W.float(), full_matrices=False)\n",
|
||||
" \n",
|
||||
" # Crop to top_k modes\n",
|
||||
" U = U[:, :top_k]\n",
|
||||
" S = S[:top_k]\n",
|
||||
" return U, S\n",
|
||||
"\n",
|
||||
"def project_to_s_space(hidden_states, W, top_k=256):\n",
|
||||
"def project_to_s_space(hidden_states, U, S):\n",
|
||||
" \"\"\"\n",
|
||||
" Projects the residual stream into the 'super' S-space of all residual writers.\n",
|
||||
" \n",
|
||||
@@ -75,16 +82,7 @@
|
||||
" hidden states from the residual stream, and the U from all residual writers, \n",
|
||||
" we can project the residual stream into S-space, which can be thought of as \n",
|
||||
" something like the coordinate space of learned modes of behaviors.\n",
|
||||
" \n",
|
||||
" W: The concatenated weight matrix of all residual writers, shape (sum(in_features), hidden_size)\n",
|
||||
" \"\"\"\n",
|
||||
" # SVD on the collective weight matrix\n",
|
||||
" U, S, _ = torch.linalg.svd(W.float(), full_matrices=False)\n",
|
||||
" \n",
|
||||
" # Crop to top_k modes (there will be a lot of overlap/redundancy)\n",
|
||||
" U = U[:, :top_k]\n",
|
||||
" S = S[:top_k]\n",
|
||||
" \n",
|
||||
" # Project: x_S = (x @ U)\n",
|
||||
" x_S = hidden_states.to(torch.float32) @ U\n",
|
||||
" \n",
|
||||
@@ -125,11 +123,11 @@
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"id": "7cbb2a21",
|
||||
"id": "ab5130a9",
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"def guided_eval(model, tokenizer, prompt_text, n_think=32, device=\"cuda\", s_space_weight=None):\n",
|
||||
"def guided_eval(model, tokenizer, prompt_text, n_think=32, device=\"cuda\", s_space_U=None, s_space_S=None):\n",
|
||||
" messages = [{\"role\": \"user\", \"content\": prompt_text}]\n",
|
||||
" \n",
|
||||
" inputs = tokenizer.apply_chat_template(\n",
|
||||
@@ -190,8 +188,8 @@
|
||||
" start_idx = prompt_ids.shape[1]\n",
|
||||
" cot_hiddens = final_layer_hiddens[start_idx : start_idx + generated_ids.shape[0]]\n",
|
||||
" \n",
|
||||
" if s_space_weight is not None:\n",
|
||||
" trajectory = project_to_s_space(cot_hiddens, s_space_weight)\n",
|
||||
" if s_space_U is not None and s_space_S is not None:\n",
|
||||
" trajectory = project_to_s_space(cot_hiddens, s_space_U, s_space_S)\n",
|
||||
" else:\n",
|
||||
" trajectory = cot_hiddens\n",
|
||||
" \n",
|
||||
@@ -207,7 +205,7 @@
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"id": "812de2b3",
|
||||
"id": "96c84bd3",
|
||||
"metadata": {
|
||||
"lines_to_next_cell": 2
|
||||
},
|
||||
@@ -241,11 +239,11 @@
|
||||
"results = {}\n",
|
||||
"\n",
|
||||
"# Project using the collective residual stream writers transformation\n",
|
||||
"s_space_W = get_collective_s_space_weight(model) \n",
|
||||
"s_space_U, s_space_S = get_s_space_projector(model)\n",
|
||||
"\n",
|
||||
"for p_key, p_prefix in PERSONAS.items():\n",
|
||||
" print(f\"\\n--- Running: {p_key} ---\")\n",
|
||||
" res = guided_eval(model, tokenizer, p_prefix + prompt_base, n_think=N_THINK_TOKENS, device=DEVICE, s_space_weight=s_space_W)\n",
|
||||
" res = guided_eval(model, tokenizer, p_prefix + prompt_base, n_think=N_THINK_TOKENS, device=DEVICE, s_space_U=s_space_U, s_space_S=s_space_S)\n",
|
||||
" results[p_key] = res\n",
|
||||
" print(f\"Logratio (Yes/No): {res['logratio']:.3f}\")\n",
|
||||
" print(f\"Prompt:\\n```md\\n{res['prompt']}```\")\n",
|
||||
|
||||
+18
-20
@@ -39,24 +39,31 @@ NUM_EXAMPLES = 3
|
||||
|
||||
|
||||
# %%
|
||||
def get_collective_s_space_weight(model):
|
||||
def get_s_space_projector(model, top_k=256):
|
||||
"""
|
||||
Gathers all weight matrices that write to the residual stream
|
||||
(o_proj from attention and down_proj from MLP) across all layers,
|
||||
and concatenates them to form a collective "write" transformation.
|
||||
Then computes the SVD to extract the top_k modes.
|
||||
Returns: U (hidden_size, top_k), S (top_k,)
|
||||
"""
|
||||
Ws = []
|
||||
for layer in model.model.layers:
|
||||
# In Qwen2, o_proj and down_proj weights are shape [hidden_size, in_features]
|
||||
# We want a combined matrix of shape [hidden_size, sum(in_features)]
|
||||
# so that SVD gives U of shape [hidden_size, top_k].
|
||||
# o_proj.weight is [hidden_size, num_heads * head_dim]
|
||||
# down_proj.weight is [hidden_size, intermediate_size]
|
||||
Ws.append(layer.self_attn.o_proj.weight.detach().cpu())
|
||||
Ws.append(layer.mlp.down_proj.weight.detach().cpu())
|
||||
return torch.cat(Ws, dim=1).to(model.device)
|
||||
W = torch.cat(Ws, dim=1).to(model.device)
|
||||
|
||||
# SVD on the collective weight matrix
|
||||
U, S, _ = torch.linalg.svd(W.float(), full_matrices=False)
|
||||
|
||||
# Crop to top_k modes
|
||||
U = U[:, :top_k]
|
||||
S = S[:top_k]
|
||||
return U, S
|
||||
|
||||
def project_to_s_space(hidden_states, W, top_k=256):
|
||||
def project_to_s_space(hidden_states, U, S):
|
||||
"""
|
||||
Projects the residual stream into the 'super' S-space of all residual writers.
|
||||
|
||||
@@ -66,16 +73,7 @@ def project_to_s_space(hidden_states, W, top_k=256):
|
||||
hidden states from the residual stream, and the U from all residual writers,
|
||||
we can project the residual stream into S-space, which can be thought of as
|
||||
something like the coordinate space of learned modes of behaviors.
|
||||
|
||||
W: The concatenated weight matrix of all residual writers, shape (sum(in_features), hidden_size)
|
||||
"""
|
||||
# SVD on the collective weight matrix
|
||||
U, S, _ = torch.linalg.svd(W.float(), full_matrices=False)
|
||||
|
||||
# Crop to top_k modes (there will be a lot of overlap/redundancy)
|
||||
U = U[:, :top_k]
|
||||
S = S[:top_k]
|
||||
|
||||
# Project: x_S = (x @ U)
|
||||
x_S = hidden_states.to(torch.float32) @ U
|
||||
|
||||
@@ -114,7 +112,7 @@ def compute_curvature(hidden_states):
|
||||
|
||||
|
||||
# %%
|
||||
def guided_eval(model, tokenizer, prompt_text, n_think=32, device="cuda", s_space_weight=None):
|
||||
def guided_eval(model, tokenizer, prompt_text, n_think=32, device="cuda", s_space_U=None, s_space_S=None):
|
||||
messages = [{"role": "user", "content": prompt_text}]
|
||||
|
||||
inputs = tokenizer.apply_chat_template(
|
||||
@@ -175,8 +173,8 @@ def guided_eval(model, tokenizer, prompt_text, n_think=32, device="cuda", s_spac
|
||||
start_idx = prompt_ids.shape[1]
|
||||
cot_hiddens = final_layer_hiddens[start_idx : start_idx + generated_ids.shape[0]]
|
||||
|
||||
if s_space_weight is not None:
|
||||
trajectory = project_to_s_space(cot_hiddens, s_space_weight)
|
||||
if s_space_U is not None and s_space_S is not None:
|
||||
trajectory = project_to_s_space(cot_hiddens, s_space_U, s_space_S)
|
||||
else:
|
||||
trajectory = cot_hiddens
|
||||
|
||||
@@ -218,11 +216,11 @@ plt.figure(figsize=(10, 6))
|
||||
results = {}
|
||||
|
||||
# Project using the collective residual stream writers transformation
|
||||
s_space_W = get_collective_s_space_weight(model)
|
||||
s_space_U, s_space_S = get_s_space_projector(model)
|
||||
|
||||
for p_key, p_prefix in PERSONAS.items():
|
||||
print(f"\n--- Running: {p_key} ---")
|
||||
res = guided_eval(model, tokenizer, p_prefix + prompt_base, n_think=N_THINK_TOKENS, device=DEVICE, s_space_weight=s_space_W)
|
||||
res = guided_eval(model, tokenizer, p_prefix + prompt_base, n_think=N_THINK_TOKENS, device=DEVICE, s_space_U=s_space_U, s_space_S=s_space_S)
|
||||
results[p_key] = res
|
||||
print(f"Logratio (Yes/No): {res['logratio']:.3f}")
|
||||
print(f"Prompt:\n```md\n{res['prompt']}```")
|
||||
|
||||
Reference in New Issue
Block a user