mirror of
https://github.com/wassname/Brukino_AntiPaSTO_Appetizer.git
synced 2026-06-27 14:13:15 +08:00
Refactor to return U, S, Vh from SVD and perform cropping outside the function
This commit is contained in:
+14
-14
@@ -2,7 +2,7 @@
|
||||
"cells": [
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"id": "62eec772",
|
||||
"id": "e0ccb9f3",
|
||||
"metadata": {},
|
||||
"source": [
|
||||
"# Brukino's AntiPaSTO Appetizer: Guided CoT Eval & Frenet-Serret Curvature\n",
|
||||
@@ -20,7 +20,7 @@
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"id": "47c7efe2",
|
||||
"id": "f61786b3",
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
@@ -44,17 +44,17 @@
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"id": "8c590e6c",
|
||||
"id": "b130793b",
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"def get_s_space_projector(model, top_k=1024):\n",
|
||||
"def get_s_space_svd(model):\n",
|
||||
" \"\"\"\n",
|
||||
" Gathers all weight matrices that write to the residual stream\n",
|
||||
" (o_proj from attention and down_proj from MLP) across all layers,\n",
|
||||
" and concatenates them to form a collective \"write\" transformation.\n",
|
||||
" Then computes the SVD to extract the top_k modes.\n",
|
||||
" Returns: U (hidden_size, top_k), S (top_k,)\n",
|
||||
" Then computes and returns the full SVD.\n",
|
||||
" Returns: U, S, Vh\n",
|
||||
" \"\"\"\n",
|
||||
" Ws = []\n",
|
||||
" for layer in model.model.layers:\n",
|
||||
@@ -65,12 +65,9 @@
|
||||
" W = torch.cat(Ws, dim=1).to(model.device)\n",
|
||||
" \n",
|
||||
" # SVD on the collective weight matrix\n",
|
||||
" U, S, _ = torch.linalg.svd(W.float(), full_matrices=False)\n",
|
||||
" U, S, Vh = torch.linalg.svd(W.float(), full_matrices=False)\n",
|
||||
" \n",
|
||||
" # Crop to top_k modes\n",
|
||||
" U = U[:, :top_k]\n",
|
||||
" S = S[:top_k]\n",
|
||||
" return U, S\n",
|
||||
" return U, S, Vh\n",
|
||||
"\n",
|
||||
"def project_to_s_space(hidden_states, U, S):\n",
|
||||
" \"\"\"\n",
|
||||
@@ -123,7 +120,7 @@
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"id": "6ae905b2",
|
||||
"id": "812eed4d",
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
@@ -205,7 +202,7 @@
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"id": "30e7fb4e",
|
||||
"id": "57d967a9",
|
||||
"metadata": {
|
||||
"lines_to_next_cell": 2
|
||||
},
|
||||
@@ -239,7 +236,10 @@
|
||||
"results = {}\n",
|
||||
"\n",
|
||||
"# Project using the collective residual stream writers transformation\n",
|
||||
"s_space_U, s_space_S = get_s_space_projector(model)\n",
|
||||
"U, S, Vh = get_s_space_svd(model)\n",
|
||||
"top_k = 1024\n",
|
||||
"s_space_U = U[:, :top_k]\n",
|
||||
"s_space_S = S[:top_k]\n",
|
||||
"\n",
|
||||
"for p_key, p_prefix in PERSONAS.items():\n",
|
||||
" print(f\"\\n--- Running: {p_key} ---\")\n",
|
||||
|
||||
+9
-9
@@ -39,13 +39,13 @@ NUM_EXAMPLES = 3
|
||||
|
||||
|
||||
# %%
|
||||
def get_s_space_projector(model, top_k=1024):
|
||||
def get_s_space_svd(model):
|
||||
"""
|
||||
Gathers all weight matrices that write to the residual stream
|
||||
(o_proj from attention and down_proj from MLP) across all layers,
|
||||
and concatenates them to form a collective "write" transformation.
|
||||
Then computes the SVD to extract the top_k modes.
|
||||
Returns: U (hidden_size, top_k), S (top_k,)
|
||||
Then computes and returns the full SVD.
|
||||
Returns: U, S, Vh
|
||||
"""
|
||||
Ws = []
|
||||
for layer in model.model.layers:
|
||||
@@ -56,12 +56,9 @@ def get_s_space_projector(model, top_k=1024):
|
||||
W = torch.cat(Ws, dim=1).to(model.device)
|
||||
|
||||
# SVD on the collective weight matrix
|
||||
U, S, _ = torch.linalg.svd(W.float(), full_matrices=False)
|
||||
U, S, Vh = torch.linalg.svd(W.float(), full_matrices=False)
|
||||
|
||||
# Crop to top_k modes
|
||||
U = U[:, :top_k]
|
||||
S = S[:top_k]
|
||||
return U, S
|
||||
return U, S, Vh
|
||||
|
||||
def project_to_s_space(hidden_states, U, S):
|
||||
"""
|
||||
@@ -216,7 +213,10 @@ plt.figure(figsize=(10, 6))
|
||||
results = {}
|
||||
|
||||
# Project using the collective residual stream writers transformation
|
||||
s_space_U, s_space_S = get_s_space_projector(model)
|
||||
U, S, Vh = get_s_space_svd(model)
|
||||
top_k = 1024
|
||||
s_space_U = U[:, :top_k]
|
||||
s_space_S = S[:top_k]
|
||||
|
||||
for p_key, p_prefix in PERSONAS.items():
|
||||
print(f"\n--- Running: {p_key} ---")
|
||||
|
||||
Reference in New Issue
Block a user